""" My Monte Carlo Tree Search Demo """
import argparse
import math
import random
from copy import deepcopy
from typing_extensions import Self
def parse_args() -> argparse.Namespace:
"""Parse arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, help="Fix random seed", default=0)
parser.add_argument("--tape_length", type=int, help="Tape length", default=50)
parser.add_argument(
"--sample_times_limit", type=int, help="Sample times limit", default=100
)
parser.add_argument(
"--exploration_constant", type=float, help="Exploration constant", default=1.0
)
return parser.parse_args()
def set_seed(seed: int) -> None:
"""Set seed for reproducibility."""
random.seed(seed)
class Action:
"""Action class."""
def __init__(self, write_position):
self.write_position = write_position
class State:
"""State class."""
def __init__(self, tape_length: int) -> None:
self.tape = [0] * tape_length
self.tape_length = tape_length
self.possible_actions = []
for i in range(self.tape_length):
self.possible_actions.append(Action(write_position=i))
self.written_times = 0
def get_possible_actions(self) -> list:
"""Get possible actions."""
if self.is_terminal():
return []
return self.possible_actions
def take_action(self, action: Action) -> Self:
"""Take action."""
if action is None:
return self
new_state = deepcopy(self)
new_state.tape[action.write_position] = 1
new_state.written_times = self.written_times + 1
return new_state
def is_terminal(self) -> bool:
"""Check if the state is terminal."""
if self.written_times == self.tape_length:
return True
return False
def get_reward(self) -> int:
"""Get reward."""
return sum(self.tape)
def show_tape(self) -> None:
"""Show tape."""
print(self.tape)
class TreeNode:
"""Tree node class."""
def __init__(self, state: State, parent: Self) -> None:
self.state = state
self.is_terminal = state.is_terminal()
self.is_fully_expanded = self.is_terminal
self.parent = parent
self.num_visits = 0
self.total_reward = 0
self.children = {}
class MCTS:
"""Monte Carlo Tree Search class."""
def __init__(self, iteration_limit: int, exploration_constant: float) -> None:
self.search_limit = iteration_limit
self.exploration_constant = exploration_constant
def search(self, initial_state: State) -> Action:
"""Search for the best action."""
if initial_state.is_terminal():
return None
root = TreeNode(initial_state, None)
for _ in range(self.search_limit):
node = self.select_node(root)
reward = self.rollout(node.state)
self.back_propogate(node, reward)
return self.get_best_action_child(root, 0.0)[0]
def select_node(self, node: TreeNode) -> TreeNode:
"""Select node."""
while not node.is_terminal:
if node.is_fully_expanded:
_, node = self.get_best_action_child(node, self.exploration_constant)
else:
return self.expand(node)
return node
def get_best_action_child(self, node: TreeNode, exploration_value: float) -> tuple:
"""Get best child."""
best_value = float("-inf")
best_actions_children = []
actions = node.state.get_possible_actions()
for action in actions:
child = node.children[action]
if child.num_visits == 0:
return action, child
child_value = (
child.total_reward / child.num_visits
+ exploration_value
* math.sqrt(2 * math.log(node.num_visits) / child.num_visits)
)
if child_value > best_value:
best_value = child_value
best_actions_children = [[action, child]]
elif child_value == best_value:
best_actions_children.append([action, child])
return random.choice(best_actions_children)
def rollout(self, state: State) -> int:
"""Rollout."""
while not state.is_terminal():
action = random.choice(state.get_possible_actions())
state = state.take_action(action)
return state.get_reward()
def back_propogate(self, node: TreeNode, reward: int) -> None:
"""Back propogate."""
while node is not None:
node.num_visits += 1
node.total_reward += reward
node = node.parent
def expand(self, node: TreeNode) -> TreeNode:
"""Expand."""
actions = node.state.get_possible_actions()
for action in actions:
if action not in node.children:
new_node = TreeNode(node.state.take_action(action), node)
node.children[action] = new_node
if len(actions) == len(node.children):
node.is_fully_expanded = True
return new_node
return None
if __name__ == "__main__":
args = parse_args()
set_seed(args.seed)
game_state = State(args.tape_length)
searcher = MCTS(
iteration_limit=args.sample_times_limit,
exploration_constant=args.exploration_constant,
)
for _ in range(args.tape_length):
best_action = searcher.search(initial_state=game_state)
game_state = game_state.take_action(best_action)
game_state.show_tape()
print("Final reward:", game_state.get_reward())