井字棋小游戏AI(蒙特卡洛搜索)
刚把《强化学习》的第一部分写完,突发奇想想写一个井字棋小游戏AI,采用MCTS算法,中间采用了UCT算法作为树中策略,等概率随机作为树外策略。
代码:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 20010;
double UCT_C = 2.0;
struct node {
double x, y;
double to_double(void) {
return x / y;
}
void init() {
x = 0;
y = 0;
}
};
node V[maxn];
double eps = 1e-10;
vector<int> Next[maxn];
vector<int> Tree[maxn];
bool ed[maxn];
char table[5][5];
mt19937 random(time(0));
int dep(int x) {
int ret = 0;
for (int i = 0; i < 9; i++) {
if(x % 3 != 0) ret++;
x /= 3;
}
return ret;
}
int rbuild(void) {
int res = 0, p = 1;
for (int i = 0; i < 9; i++, p *= 3) {
int x = i / 3, y = i % 3, tmp = 0;
if(table[x][y] == 0) tmp = 0;
else if(table[x][y] == 'b') tmp = 1;
else tmp = 2;
res = res + p * tmp;
}
return res;
}
void build(int st) {
for (int i = 0; i < 9; i++, st /= 3) {
int now = st % 3;
int x = i / 3, y = i % 3;
if(now == 0) table[x][y] = 0;
else if(now == 1) table[i / 3][i % 3] = 'b';
else table[i / 3][i % 3] = 'w';
}
}
vector<int> find_next(int x) {
build(x);
int now = x, p = 1, d = dep(x);
vector<int> ret;
for (int i = 0; i < 9; i++, p *= 3) {
int x = i / 3, y = i % 3;
if(table[x][y] == 0) {
ret.push_back(now + p * ((d % 2) + 1));
}
}
return ret;
}
bool lose(int st) {
build(st);
for (int i = 0; i < 3; i++) {
if(table[i][0] == table[i][1] && table[i][1] == table[i][2] && table[i][0] == 'w') return true;
if(table[0][i] == table[1][i] && table[1][i] == table[2][i] && table[0][i] == 'w') return true;
}
if(table[0][0] == table[1][1] && table[2][2] == table[1][1] && table[0][0] == 'w') return true;
if(table[2][0] == table[1][1] && table[0][2] == table[1][1] && table[2][0] == 'w') return true;
return false;
}
bool vectory(int st) {
build(st);
for (int i = 0; i < 3; i++) {
if(table[i][0] == table[i][1] && table[i][1] == table[i][2] && table[i][0] == 'b') return true;
if(table[0][i] == table[1][i] && table[1][i] == table[2][i] && table[0][i] == 'b') return true;
}
if(table[0][0] == table[1][1] && table[2][2] == table[1][1] && table[0][0] == 'b') return true;
if(table[2][0] == table[1][1] && table[0][2] == table[1][1] && table[2][0] == 'b') return true;
return false;
}
int dfs(int x) {
if(ed[x]) {
if(vectory(x)) {
return 1;
}
else if(lose(x)) {
return 2;
}
else return 3;
}
int t = random() % Next[x].size();
return dfs(Next[x][t]);
}
double UCT(int x, double tot) {
return V[x].to_double() + UCT_C * sqrt(log(tot) / V[x].y);
}
void MCTS(int root, int flag) {
int now = root;
stack<int> path;
path.push(now);
while(!ed[now] && Tree[now].size() == Next[now].size()) {
double mx = 0;
int mx_pos = 0;
for (auto t : Tree[now]) {
if(UCT(t, V[now].y) > mx) {
mx = UCT(t, V[now].y);
mx_pos = t;
}
}
now = mx_pos;
flag ^= 1;
path.push(now);
}
if(!ed[now]) {
int x = Next[now][Tree[now].size()];
Tree[now].push_back(x);
flag ^= 1;
V[x].init();
path.push(x);
now = x;
}
int res = dfs(now);
while(path.size()) {
now = path.top();
path.pop();
if(res == 3) {
V[now].x += 2;
V[now].y += 2;
}
else if((res == 1 && flag) || (res == 2 && flag == 0)) {
V[now].x += 2;
V[now].y += 2;
} else {
V[now].y += 2;
}
flag ^= 1;
}
}
int solve(int root, bool flag) {
for (int i = 1; i <= 500; i++) {
MCTS(root, flag);
}
int res = -1;
double mx = -1;
for (auto x : Next[root]) {
if(V[x].to_double() > mx) {
mx = V[x].to_double();
res = x;
}
}
return res;
}
void init() {
int tmp;
for (int i = 0; i < 19683; i++) {
bool x1 = vectory(i), x2 = lose(i);
tmp = 9 - dep(i);
if(x1 || x2 || (tmp == 0))
ed[i] = true;
else {
Next[i] = find_next(i);
}
}
}
void print_table() {
printf("请落子(比如 0 0):\n");
printf("----------\n");
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
if(table[i][j] == 0) printf(" ");
else printf("%c", table[i][j]);
if(j < 2) printf("-");
}
if(i < 2) {
printf("\n");
for (int j = 0; j < 5; j++) {
if(j % 2 == 0) printf("|");
else printf(" ");
}
}
printf("\n");
}
printf("----------\n");
}
void play() {
int s = 0;
int e = 0, l = 0, a = 0;
bool flag;
int round = 0;
int T = 10;
while(T--) {
round++;
printf("第%d回合:\n", round);
// printf("----------\n\n\n\n\nround %d\n\n\n\n--------\n", round);
memset(table, 0, sizeof(table));
int p = 0;
printf("请决定执黑还是执白:\n0: 黑棋; 1: 白棋\n");
scanf("%d", &p);
// print_table();
s = 0;
flag = 0;
print_table();
while(!ed[s]) {
int x, y;
if(p == 0) {
scanf("%d %d", &x, &y);
table[x][y] = 'b';
s = rbuild();
} else {
s = solve(s, flag);
build(s);
}
print_table();
if(ed[s]) {
if(vectory(s)) {
printf("黑方胜利!\n");
a++;
}
else {
printf("平局!\n");
e++;
}
break;
}
flag ^= 1;
if(p) {
scanf("%d %d", &x, &y);
table[x][y] = 'w';
s = rbuild();
} else {
s = solve(s, flag);
build(s);
}
print_table();
if(ed[s]) {
if(lose(s)) {
printf("白方胜利!\n");
l++;
}
else {
printf("平局!\n");
e++;
}
break;
}
flag ^= 1;
}
}
printf("AI wins: %d\nplayer wins: %d\nequals: %d\n", a, l, e);
}
int main() {
srand(time(0));
init();
play();
}
每步的计算量在500的时候已经基本能跑出最优解了,可见MCTS比暴力搜索好很多

浙公网安备 33010602011771号