Source code for hypernets.searchers.mcts_core

# -*- coding:utf-8 -*-
"""

"""
import numpy as np
import math
from collections import OrderedDict
from ..utils.common import generate_id
from ..core import get_random_state


[docs]class MCNode(object): def __init__(self, id, name, param_sample, parent=None, tree=None, is_terminal=False, random_state=None): self.visits = 0 self.reward = 0.0 self.rewards = [] self.simulation_rewards = [] assert name is not None if name != 'ROOT': assert param_sample.assigned self.name = name self.param_sample = param_sample self.parent = parent if parent is not None: assert parent.tree is not None self.tree = parent.tree else: assert tree is not None self.tree = tree self._is_terminal = is_terminal self._children = [] self.id = id self.random_state = random_state if random_state is not None else get_random_state() @property def is_terminal(self): return self._is_terminal
[docs] def set_terminal(self): self._is_terminal = True return self._is_terminal
@property def expanded(self): return all([n.visits > 0 for n in self.children]) @property def is_leaf(self): return len(self._children) <= 0 @property def children(self): return self._children @property def depth(self): d = 0 node = self.parent while node is not None: d += 1 node = node.parent return d
[docs] def add_child(self, param_sample): self.children.append(MCNode(generate_id(), param_sample.label, param_sample, parent=self))
[docs] def set_parent(self, parent): assert self.parent is not None self.parent = parent
[docs] def expansion(self, param_space, max_space): # print(f'Node expanssion: param_space:{param_space.label}') assert not param_space.assigned samples = param_space.expansion(sample_num=max_space) for param_sample in samples: self.add_child(param_sample)
[docs] def random_sample(self): child = self.random_state.choice(self._children) return child
[docs] def info(self): return f'Reward:{self.reward}, Visits:{self.visits}, Name:{self.name}, Alias:{self.param_sample.alias if self.param_sample is not None else None}, value:{self.param_sample.value if self.param_sample is not None else None}, depth:{self.depth}, children:{len(self._children)}, is_terminal:{self.is_terminal}'
def __str__(self, level=0): indent = "\t" * level str = f'{indent}{self.info()}\n' for child in self.children: str += child.__str__(level + 1) return str
[docs]class MCTree(object): def __init__(self, space_fn, policy, max_node_space): self.space_fn = space_fn self.policy = policy self.max_node_space = max_node_space self.root = MCNode(generate_id(), 'ROOT', param_sample=None, tree=self) self._current_node = self.root @property def current_node(self): return self._current_node
[docs] def selection_and_expansion(self): node = self.root while not node.is_terminal: if node.is_leaf: space_sample, child = self.expansion(node) if child != node: return space_sample, child else: # print(f'Tree selection:{node.info()}') node = self.policy.selection(node) if node.visits <= 0: break space_sample = self.node_to_space(node) return space_sample, node
[docs] def path_to_node(self, node): nodes = [] while node.parent is not None: nodes.insert(0, node) node = node.parent return nodes
[docs] def node_to_space(self, node): space_sample = self.space_fn() nodes = self.path_to_node(node) i = 0 for hp in space_sample.params_iterator: if i >= len(nodes): break hp.assign(nodes[i].param_sample.value) i += 1 return space_sample
[docs] def expansion(self, node): # print(f'Tree expansion:{node.info()}') space_sample = self.space_fn() nodes = self.path_to_node(node) i = 0 for hp in space_sample.params_iterator: if i < len(nodes): hp.assign(nodes[i].param_sample.value) else: node.expansion(hp, self.max_node_space) child = node.random_sample() hp.assign(child.param_sample.value) return space_sample, child i += 1 node.set_terminal() return space_sample, node
[docs] def simulation(self, node): raise NotImplementedError
[docs] def back_propagation(self, node, reward, is_simulation=False): while node is not None: node.reward, node.visits = self.policy.back_propagation(node, reward, is_simulation) node = node.parent
[docs] def roll_out(self, space_sample, node): # print(f'Tree roll out:{node.info()}') terminal = True for hp in space_sample.params_iterator: terminal = False hp.random_sample() if terminal: node.set_terminal() return space_sample
[docs]class BasePolicy(object):
[docs] def selection(self, node): raise NotImplementedError
[docs] def back_propagation(self, node, reward): raise NotImplementedError
[docs]class UCT(BasePolicy): def __init__(self, exploration_bonus=0.6): self.exploration_bonus = exploration_bonus
[docs] def selection(self, node): node_log_nt = 0 if node.visits <= 0 else math.log10(node.visits) scores = [(child.reward + self.exploration_bonus * math.sqrt( node_log_nt / child.visits)) if child.visits > 0 else np.inf for child in node.children] details = [(child.param_sample.value, np.sum(child.rewards), child.reward, len(child.rewards)) for child in node.children] index = np.argmax(scores) selected = node.children[index] # print(f'UCT selection: scores:{scores}, index:{index}, selected:{selected.info()}') # print(f'Detials: {details}') # print('*******************************************************************************') return selected
[docs] def back_propagation(self, node, reward, is_simulation=False): if is_simulation: node.simulation_rewards.append(reward) # return np.average(node.rewards + [reward]), len(node.rewards) + 1 else: node.rewards.append(reward) # return np.average(node.rewards), len(node.rewards) avg_reward = np.average(node.rewards) if len(node.rewards) > 0 else 0.0 if len(node.simulation_rewards) > 0: avg_sim_reward = np.average(node.simulation_rewards) new_reward = (avg_reward + avg_sim_reward) / 2 visits = len(node.simulation_rewards) else: new_reward = avg_reward visits = len(node.rewards) return new_reward, visits