#include<bits/stdc++.h>
using namespace std;
int n,m;
int ans=0;
int a[1000][1000];
int b[1000][1000];
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
scanf("%d",&a[i][j]);
b[i][j]=b[i][j-1]+a[i][j]+b[i-1][j]-b[i-1][j-1];
}
}
for(int i=2;i<=min(n,m);i++){
for(int x=1;x<=n-i+1;x++){
for(int y=1;y<=m-i+1;y++){
int cnt=0;
int xx=x+i-1;
int yy=y+i-1;
cnt=b[xx][yy]-b[x-1][yy]-b[xx][y-1]+b[x-1][y-1];
if(cnt-(i*i-cnt)<=1&&cnt-(i*i-cnt)>=-1){
ans++;
}
}
}
}
printf("%d",ans);
return 0;
}