import java.util.*;
class Solution {
public int cutOffTree(List<List<Integer>> forest) {
int m = forest.size();
int n = forest.get(0).size();
int[][][][] f = new int[m][n][m][n];
int max = m * n;
init(forest, m, n, f);
// floyd(forest, n, f, max);
List<int[]> list = new ArrayList<>();
list.add(new int[]{0,0,0});
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
int val = forest.get(i).get(j);
if( val > 1){
list.add(new int[]{i,j,val});
}
}
}
Collections.sort(list, new Comparator<int[]>() {
@Override
public int compare(int[] o1, int[] o2) {
return o1[2] - o2[2];
}
});
int ans = 0;
for(int i=0;i<list.size()-1;i++){
int[] val = list.get(i);
int x = val[0];
int y = val[1];
int[] nextVal = list.get(i+1);
int tmp = dijkstra(forest,x,y,f,nextVal[0],nextVal[1]);
if(tmp == -1){
return -1;
}
ans += tmp;
}
return ans;
}
public int dijkstra(List<List<Integer>> forest,int x,int y ,int[][][][] f,int nextx,int nexty){
PriorityQueue<int[]> queue = new PriorityQueue<>(new Comparator<int[]>() {
@Override
public int compare(int[] o1, int[] o2) {
return o1[2] - o2[2];
}
});
queue.offer(new int[]{x,y,0});
int m = forest.size();
int n = forest.get(0).size();
boolean[][] visited = new boolean[m][n];
visited[x][y] = true;
while( !queue.isEmpty()) {
int[] node = queue.poll();
if( node[2] == Integer.MAX_VALUE){
continue;
}
if(node[0] == nextx &&node[1] == nexty){
return f[x][y][nextx][nexty];
}
int[][] dirs = new int[][]{
{0, 1}, {0, -1},
{-1, 0}, {1,0}
};
for (int i = 0; i < 4; i++) {
int newx = node[0] + dirs[i][0];
int newy = node[1] + dirs[i][1];
if (newx >= 0 && newx < m && newy >= 0 && newy < n && !visited[newx][newy] && forest.get(newx).get(newy)!=0) {
visited[newx][newy] = true;
int xk = node[0];
int yk = node[1];
f[x][y][newx][newy] = f[x][y][xk][yk]+1;
queue.add(new int[]{newx,newy,f[x][y][newx][newy]});
}
}
}
return -1;
}
private void floyd(List<List<Integer>> forest, int n, int[][][][] f, int max) {
for (int k = 0; k < max; k++) {
int val = forest.get(k/ n).get(k% n);
if( val == 0){
continue;
}
for (int i = 0; i < max; i++) {
if (f[i / n][i % n][k / n][k % n] == Integer.MAX_VALUE) {
continue;
}
val = forest.get(i/ n).get(i% n);
if( val == 0){
continue;
}
for (int j = 0; j < max; j++) {
if (f[k / n][k % n][j / n][j % n] == Integer.MAX_VALUE) {
continue;
}
f[i / n][i % n][j / n][j % n] = Math.min(f[i / n][i % n][j / n][j % n], f[i / n][i % n][k / n][k % n] + f[k / n][k % n][j / n][j % n]);
}
}
}
}
private void init(List<List<Integer>> forest, int m, int n, int[][][][] f) {
int[][] dirs = new int[][]{
{0, 1}, {0, -1},
{-1, 0}, {1,0}
};
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
for (int ii = 0; ii < m; ii++) {
for (int jj = 0; jj < n; jj++) {
f[i][j][ii][jj] = Integer.MAX_VALUE;
if (i == ii && jj == j) {
f[i][j][i][j] = 0;
}
}
}
}
}
for (int x = 0; x < m; x++) {
for (int y = 0; y < n; y++) {
for (int d = 0; d < 4; d++) {
int newx = x + dirs[d][0];
int newy = y + dirs[d][1];
if (newx >= 0 && newx < m && newy >= 0 && newy < n) {
int v1 = forest.get(x).get(y);
int v2 = forest.get(newx).get(newy);
if (v1 == 0 || v2 == 0) {
f[x][y][newx][newy] = Integer.MAX_VALUE;
f[newx][newy][x][y] = Integer.MAX_VALUE;
} else {
f[x][y][newx][newy] = 1;
f[newx][newy][x][y] = 1;
}
}
}
}
}
}
}