#include<bits/stdc++.h>
#define inf 1e18
#define ll long long
#define ull unsigned long long
#define int long long
#define PI acos(-1.0)
#define PII pair<int,int>
using namespace std;
const int N =2e3+7 , M = 1e6+7;
const int mod = 1e9+7;
char s[N][N];
int r[N][N],d[N][N],re[N],n,m;
void solve(){
scanf("%lld%lld",&n,&m);
for(int i=1;i<=n;++i){
scanf("%s",s[i]+1);
}/*
for(int i=1;i<=n;++i){
printf("%s\n",s[i]+1);
}*/
map<char,int>las;
for(int i=1;i<=n;++i){
las.clear();
for(int j=m;j>=1;--j){
if(j==m){
r[i][j]=m;
}
else{
if(las[s[i][j]])r[i][j]=min(r[i][j+1],las[s[i][j]]-1);
else r[i][j]=r[i][j+1];
}
las[s[i][j]]=j;
}
}
for(int j=1;j<=m;++j){
las.clear();
for(int i=n;i>=1;--i){
if(i==n){
d[i][j]=n;
}
else{
if(las[s[i][j]])d[i][j]=min(d[i+1][j],las[s[i][j]]-1);
else d[i][j]=d[i+1][j];
}
las[s[i][j]]=i;
}
}
int ans=0;
for(int i=1;i<=n;++i){
for(int j=1;j<=m;++j){
int minn = d[i][j];
for(int k = j; k<= r[i][j];++k){
minn=min(minn,d[i][k]);
re[k] = minn;
}
minn = r[i][j];
for(int k = i;k <= d[i][j];++k){
minn = min(minn, r[k][j]);
while(re[minn] +1 <= k) --minn;
ans += minn-j+1;
}
}
}
printf("%lld\n",ans);
}
signed main(){
int t=1;
// scanf("%lld",&t);
while(t--){
solve();
}
return 0;
}