#pragma warning(disable:4786)
#include<iostream>
#include<set>
#include<queue>
#include<cmath>
using namespace std;
int gOrient[4][2] = {-1,0, 1,0, 0,-1, 0,1};
//-----------------------------------------------------------------------------
class State
{
public:
int findPos(int n) const;
bool operator < (const State& s) const;
void print() const;
void distance(int ans[]);
void calHash();
void calc(int ans[]);
public:
int a[9];
int step;
int dis;
int hash;
};
//-----------------------------------------------------------------------------
class EightFigure
{
public:
EightFigure(int* init, int* ans = NULL);
bool hasAnswer() const;
void setInitState(int init[]);
void setAnswer(int ans[]);
void astar();
private:
int answer[9];
set<int> calculated;
priority_queue<State> states;
};
//-----------------------------------------------------------------------------
int State::findPos(int n) const
{
for (int i=0; i<9; i++)
if (a[i] == n)
return i;
return -1;
}
//优先队列默认大的优先,这里实际是求大于
bool State::operator < (const State& s) const
{
return s.step + s.dis < step + dis;
}
void State::print() const
{
for (int i=0; i<9; i++)
{
if (i % 3 == 0)
cout<<endl;
cout<<a[i]<<' ';
}
cout<<endl;
}
void State::distance(int ans[])
{
dis = 0;
for (int i=0; i<9; i++)
dis += abs(ans[i]-a[i]);
}
void State::calHash()
{
int fac[] = {1,1,2,6,24,120,720,4050,40320};
hash = 0;
for (int i=0; i<9; i++)
{
int reverse = 0;
for (int j=i+1; j<9; j++)
if (a[j] < a[i])
reverse++;
hash += reverse * fac[a[i]];
}
}
void State::calc(int ans[])
{
step++;
distance(ans);
calHash();
}
//-----------------------------------------------------------------------------
EightFigure::EightFigure(int* init, int* ans)
{
if (ans == NULL)
for (int i=0; i<9; i++)
answer[i] = i;
else
memcpy(answer, ans, 9*sizeof(int));
State s;
memcpy(s.a, init, 9*sizeof(int));
s.step = 0;
s.distance(ans);
s.calHash();
states.push(s);
calculated.insert(s.hash);
}
//判断是否无解
bool EightFigure::hasAnswer() const
{
State now = states.top();
int reverseInit = 0, reverseAns = 0;
for (int i=0; i<9; i++)
{
int countInit = 0, countAns = 0;
for (int j=i+1; j<9; j++)
{
if (now.a[j] < now.a[i])
countInit++;
if (answer[j] < answer[i])
countAns++;
}
reverseInit += countInit;
reverseAns += countAns;
}
return reverseInit%2 == reverseAns%2;
}
void EightFigure::setAnswer(int ans[])
{
memcpy(answer, ans, 9*sizeof(int));
}
void EightFigure::astar()
{
if (!hasAnswer())
{
cout<<"无解\n";
return;
}
while (!states.empty())
{
State now = states.top();
states.pop();
now.print();
if (now.dis == 0)
break;
int zero = now.findPos(0);
int x = zero/3;
int y = zero%3;
for (int i=0; i<4; i++)
{
int xx = x + gOrient[i][0];
int yy = y + gOrient[i][1];
if (xx < 0 || xx >= 3 || yy < 0 || yy >= 3)
continue;
State s = now;
s.a[zero] = s.a[xx*3+yy];
s.a[xx*3+yy] = 0;
s.calc(answer);
if (calculated.find(s.hash) == calculated.end())
{
calculated.insert(s.hash);
states.push(s);
}
}
}
}
//-----------------------------------------------------------------------------
void main()
{
int init[9] = {3,5,1,2,8,6,7,4,0};
int ans[] = {1,2,3,4,5,6,7,8,0};
EightFigure eightFigure(init, ans);
eightFigure.astar();
}