Source code for hypernets.core.search_space

# -*- coding:utf-8 -*-


import numpy as np
import hashlib
import threading
import contextlib
import queue
import copy
import time
from collections import OrderedDict
from .mutables import Mutable, MutableScope
from ..utils.common import generate_id, combinations
from ..utils import logging
from .random_state import get_random_state

logger = logging.get_logger(__name__)

[docs]class HyperNode(Mutable): def __init__(self, space=None, name=None): self._space = space if space is not None else get_default_space() Mutable.__init__(self, self._space.scope, name) self._space.add_node(self)
[docs] def attach_to_space(self, space=None, name=None): self._space = space if space is not None else get_default_space() self.attach_to_scope(self._space.scope, name) self._space.add_node(self)
@property def space(self): return self._space
[docs]class HyperSpace(Mutable): def __init__(self, scope=None, name=None): if scope is None: scope = MutableScope() Mutable.__init__(self, scope, name) self.edges = set() self.modules = set() self.hyper_params = set() self._inputs = set() self._outputs = set() self._assigned_params_stack = [] self._is_compiled = False self.space_id = generate_id() @property def type(self): return 'DAG'
[docs] def as_default(self): return _default_space_stack.get_controller(self)
@property def all_assigned(self): all_assigned = self.traverse(lambda m: m.all_assigned, direction='forward', discard_isolated_node=False) return all_assigned @property def assigned_params_stack(self): return self._assigned_params_stack
[docs] def push_assigned_param(self, param): self._assigned_params_stack.append(param)
@property def params_iterator(self): visited = {} while not self.all_assigned: for p in self.get_unassigned_params(): if p.is_mutable == False: p.random_sample() continue if not p.assigned: if visited.get(p): visited[p] += 1 if visited[p] > 10: return # raise RuntimeError('Too many attempts to get assignable params') else: visited[p] = 1 yield p
[docs] def add_node(self, node): if isinstance(node, ModuleSpace): self.modules.add(node) elif isinstance(node, ParameterSpace): self.hyper_params.add(node) else: raise ValueError(f"Not supported node:{node}") self.__dict__[] = node
[docs] def compile(self, deepcopy=True): if deepcopy: space = copy.deepcopy(self) else: space = self space._compile_space() return space
def _compile_space(self): assert not self._is_compiled, 'HyperSpace does not allow to compile repeatedly.' space_out = [] counter = 0 start_ts = time.time() def compile_module(module): module.compile() nonlocal counter counter += 1 if len(self.get_outputs(module)) <= 0: space_out.append(module) return True self.traverse(compile_module, direction='forward') end_ts = time.time() if logger.is_info_enabled(): logger.debug(f'Compile Space: compiled {counter} modules in {end_ts - start_ts} seconds.') self._is_compiled = True self._outputs = set(space_out)
[docs] def forward(self, inputs=None): counter = [] start_ts = time.time() def forward_module(module): input_modules = self.get_inputs(module) if len(input_modules) <= 0: # The input module of space module.forward(inputs) elif len(input_modules) == 1: module.forward(inputs=input_modules[0].output) else: module.forward(inputs=[m.output for m in input_modules]) counter.append(0) return True self.traverse(forward_module, direction='forward') end_ts = time.time() if logger.is_info_enabled(): logger.debug(f'Forward Space: forwarded {len(counter)} modules in {end_ts - start_ts} seconds.') outputs = [output.output for output in self.get_outputs()] return outputs
[docs] def compile_and_forward(self, inputs=None, deepcopy=True): space = self.compile(deepcopy) outputs = space.forward(inputs) return space, outputs
[docs] def traverse(self, fn, direction='forward', start_modules=[], discard_isolated_node=True): if direction == 'forward': fn_inputs = self.get_inputs fn_outputs = self.get_outputs elif direction == 'backward': fn_inputs = self.get_outputs fn_outputs = self.get_inputs else: raise ValueError(f'Not supported direction:{direction}') standby = queue.Queue() visited = set() finished = set() if start_modules is None or len(start_modules) <= 0: start_modules = fn_inputs(discard_isolated_node=discard_isolated_node) for m in start_modules: standby.put(m) visited.add(m) while not standby.empty(): m_todo = standby.get() inputs = fn_inputs(m_todo) ready = True for mi in inputs: if mi not in finished: ready = False break if not ready: visited.remove(m_todo) continue is_continues = fn(m_todo) if not is_continues: return False finished.add(m_todo) for m in fn_outputs(m_todo): if not m in visited: standby.put(m) visited.add(m) return True
[docs] def connect(self, from_module, to_module): self.edges.add((from_module, to_module))
[docs] def disconnect(self, from_module, to_module): found = False for f, t in self.edges: if f == from_module and t == to_module: found = True break if len(self.edges & {(from_module, to_module)}) == 1: self.edges.remove((from_module, to_module))
[docs] def disconnect_all(self, module): found = set() for f, t in self.edges: if f == module or t == module: found.add((f, t)) for f, t in found: self.edges.remove((f, t))
[docs] def reroute_to(self, old_module, new_module): assert isinstance(new_module, (list, ModuleSpace)) found = set() for f, t in self.edges: if t == old_module: found.add((f, t)) for f, t in found: self.edges.remove((f, t)) if isinstance(new_module, ModuleSpace): self.edges.add((f, new_module)) else: for m in new_module: self.edges.add((f, m))
[docs] def reroute_from(self, old_module, new_module): found = set() assert isinstance(new_module, (list, ModuleSpace)) for f, t in self.edges: if f == old_module: found.add((f, t)) for f, t in found: self.edges.remove((f, t)) if isinstance(new_module, ModuleSpace): self.edges.add((new_module, t)) else: for m in new_module: self.edges.add((m, t))
[docs] def replace_route(self, old_module, new_module): self.reroute_to(old_module, new_module) self.reroute_from(old_module, new_module)
[docs] def is_isolated_module(self, module): assert module is not None for from_module, to_module in self.edges: if module == from_module or module == to_module: return False return True
[docs] def get_sub_graph_outputs(self, module_in_subgraph): return self.get_sub_graph_end_modules(module_in_subgraph, direction='forward')
[docs] def get_sub_graph_inputs(self, module_in_subgraph): return self.get_sub_graph_end_modules(module_in_subgraph, direction='backward')
[docs] def get_sub_graph_end_modules(self, module_in_subgraph, direction='forward'): assert isinstance(module_in_subgraph, (ModuleSpace, list)) if isinstance(module_in_subgraph, list): module_in_subgraph = module_in_subgraph[0] if direction == 'forward': # get outputs get_upstream = self.get_outputs get_downstream = self.get_inputs elif direction == 'backward': # get inputs get_upstream = self.get_inputs get_downstream = self.get_outputs else: raise ValueError(f'Not supported direction:{direction}') standby = queue.Queue() visited = {module_in_subgraph} finished = {module_in_subgraph} end_modules = set() if len(get_upstream(module_in_subgraph)) <= 0: end_modules.add(module_in_subgraph) for m in get_upstream(module_in_subgraph): standby.put(m) visited.add(m) for m in get_downstream(module_in_subgraph): standby.put(m) visited.add(m) while not standby.empty(): m_todo = standby.get() m_downstream = get_downstream(m_todo) ready = True for mi in m_downstream: if mi not in finished: standby.put(mi) visited.add(mi) ready = False break if not ready: if len(visited & {m_todo}) > 0: visited.remove(m_todo) continue finished.add(m_todo) m_upstream = get_upstream(m_todo) if len(m_upstream) <= 0: end_modules.add(m_todo) continue for m in m_upstream: if not m in visited: standby.put(m) visited.add(m) return sorted(end_modules, key=lambda m:
[docs] def set_inputs(self, modules): assert modules is not None assert isinstance(modules, (ModuleSpace, list)) if isinstance(modules, ModuleSpace): self._inputs = {modules} else: assert len(modules) > 0 assert all([isinstance(m, ModuleSpace) for m in modules]) self._inputs = set(modules)
[docs] def set_outputs(self, modules): assert modules is not None assert isinstance(modules, (ModuleSpace, list)) if isinstance(modules, ModuleSpace): self._outputs = {modules} else: assert len(modules) > 0 assert all([isinstance(m, ModuleSpace) for m in modules]) self._outputs = set(modules)
[docs] def get_inputs(self, module=None, discard_isolated_node=True): inputs = set() if module is None: if len(self._inputs) > 0: return sorted(self._inputs, key=lambda m: for m in self.modules: has_in = False has_out = False for from_module, to_module in self.edges: if m == from_module: has_out = True if m == to_module: has_in = True break if not has_in: if has_out or (not discard_isolated_node and m.path == ''): inputs.add(m) if len(inputs) == 0: if len(self.modules) == 1: inputs = self.modules.copy() else: raise ValueError('Graph is not connected.') else: for from_module, to_module in self.edges: if module == to_module: inputs.add(from_module) return sorted(inputs, key=lambda m:
[docs] def get_outputs(self, module=None, discard_isolated_node=True): outputs = set() if module is None: if len(self._outputs) > 0: return sorted(self._outputs, key=lambda m: for m in self.modules: has_in = False has_out = False for from_module, to_module in self.edges: if m == from_module: has_out = True break if m == to_module: has_in = True if not has_out: if has_in or (not discard_isolated_node and m.path == ''): outputs.add(m) if len(outputs) == 0: if len(self.modules) == 1: outputs = self.modules.copy() else: raise ValueError('Graph is not connected.') else: for from_module, to_module in self.edges: if module == from_module: outputs.add(to_module) return sorted(outputs, key=lambda m:
[docs] def random_sample(self): for hp in self.params_iterator: hp.random_sample()
[docs] def get_unassigned_params(self, traverse_direction='forward'): assignables = [] def append_params(m): ps = m.get_assignable_params() for p in ps: if p not in assignables: assignables.append(p) return True self.traverse(append_params, direction=traverse_direction, discard_isolated_node=False) return assignables
[docs] def get_assigned_params(self): assert self.all_assigned return self._assigned_params_stack
[docs] def get_assigned_param_values(self, traverse_direction='forward'): ps = self.get_assigned_params() return { p.value for p in ps}
[docs] def get_all_params(self): all = list(self.hyper_params) return all
[docs] def params_summary(self, only_assignable=True, line_width=60, LR='\n'): outputs = [] outputs.append(f'\n{(line_width + 2) * "-"}') if only_assignable: params = self.get_assigned_params() else: params = self.get_all_params() for i, hp in enumerate(params): outputs.append( f'({i}) {hp.alias}:{(line_width - len(str(i) + "() " + hp.alias + str(hp.value))) * " "}{hp.value}') outputs.append(f'{(line_width + 2) * "-"}') return LR.join(outputs)
@property def signature(self): assert self.all_assigned labels = [p.label for p in self._assigned_params_stack] key = ';'.join(labels) md5 = hashlib.md5(key.encode('utf-8')).hexdigest() return md5 @property def vectors(self): assert self.all_assigned vectors = [p.value2numeric(p.value) for p in self._assigned_params_stack] return vectors
[docs] def assign_by_vectors(self, vectors): i = 0 for p in self.params_iterator: if not p.is_mutable: p.random_sample() continue if i >= len(vectors): raise ValueError('`vector` and `space` does not match.') p.assign(p.numeric2value(vectors[i])) i += 1 if len(vectors) != i: raise ValueError('`vector` and `space` does not match.')
@property def combinations(self): count = 1 for hp in self.hyper_params: count *= hp.choice_num return count def _repr_html_(self): html = '''<table border="1" class="dataframe"> <thead> <tr style="text-align: right;"> <th>key</th> <th>value</th> </tr> </thead> <tbody>''' html += f'''<tr> <td>signature</td> <td>{self.signature}</td> </tr> <tr> <td>vectors</td> <td>{self.vectors}</td> </tr>''' params = self.get_assigned_params() for i, hp in enumerate(params): html += f'''<tr> <td>{i}-{hp.alias}</td> <td>{hp.value}</td> </tr> <tr>''' html += ''' </tbody> </table> </div>''' return html
[docs]class DefaultStack(threading.local): def __init__(self): super(DefaultStack, self).__init__() self.enforce_nesting = True self.stack = []
[docs] def get_default(self): return self.stack[-1] if len(self.stack) >= 1 else None
[docs] def reset(self): self.stack = []
[docs] def is_cleared(self): return not self.stack
[docs] @contextlib.contextmanager def get_controller(self, default): self.stack.append(default) try: yield default finally: if self.stack: if self.enforce_nesting: if self.stack[-1] is not default: raise AssertionError( "Nesting violated for default stack of %s objects" % type(default)) self.stack.pop() else: self.stack.remove(default)
[docs]class DefaultSpaceStack(DefaultStack): def __init__(self): super(DefaultSpaceStack, self).__init__() self._global_default_space = None
[docs] def get_default(self): """Override that returns a global default if the stack is empty.""" default = super(DefaultSpaceStack, self).get_default() if default is None: default = self._global_default() return default
def _global_default(self): if self._global_default_space is None: self._global_default_space = HyperSpace() return self._global_default_space
[docs] def reset(self): super(DefaultSpaceStack, self).reset() self._global_default_space = None
_default_space_stack = DefaultSpaceStack()
[docs]def get_default_space(): return _default_space_stack.get_default()
[docs]class ParameterSpace(HyperNode): def __init__(self, space=None, name=None, random_state=None): HyperNode.__init__(self, space, name) self._assigned = False self._value = None self.random_state = random_state if random_state is not None else get_random_state() self.references = set() @property def is_mutable(self): return True @property def type(self): return 'Param' @property def assigned(self): return self._assigned @property def value(self): return self._value @property def config_keys(self): raise NotImplementedError @property def label(self): vs = [str(self.__dict__[key]) for key in self.config_keys] return f"{}-{'-'.join(vs)}"
[docs] def value2numeric(self, value): raise NotImplementedError
[docs] def numeric2value(self, numeric): raise NotImplementedError
[docs] def random_sample(self, assign=True): value = self._random_sample() if assign: self.assign(value) return value
def _random_sample(self): raise NotImplementedError
[docs] def assign(self, value): assert not self._assigned self._check(value) self._assigned = True self._value = value if self.is_mutable: for m in self.references: m.update()
[docs] def attach(self, mutable, alias=None): self.references.add(mutable) if alias is not None: all = [] if self.alias is None else [self.alias] all.append( + '.' + alias) self.alias = ','.join(all)
[docs] def detach(self, mutable): self.references.remove(mutable)
def _check(self, value): pass
[docs] def same_config(self, other): if self.__class__ == other.__class__ and self.alias == other.alias: for key in self.config_keys: if self.__dict__[key] != other.__dict__[key]: return False return True else: return False
[docs] def expansion(self, sample_num): raise NotImplementedError
@property def choice_num(self): if self.is_mutable: return self._get_choice_num() else: return 1 def _get_choice_num(self): raise NotImplementedError
[docs]class Int(ParameterSpace): def __init__(self, low, high, step=1, random_state=None, space=None, name=None): ParameterSpace.__init__(self, space, name, random_state) assert isinstance(low, int) and isinstance(high, int), '`low` and `high` must be a int.' assert low < high, '`low` must less than `high`.' self.low = low self.high = high self.step = step def _random_sample(self): value = self.random_state.randint(self.low, self.high) if self.step is not None: all = np.arange(self.low, self.high + self.step, step=self.step) value = all[np.abs(all - value).argmin()] return value def _check(self, value): assert value >= self.low and value <= self.high
[docs] def value2numeric(self, value): return value
[docs] def numeric2value(self, numeric): return numeric
@property def config_keys(self): return ['low', 'high', 'step']
[docs] def expansion(self, sample_num): p = self._get_choice_num() if sample_num > p or sample_num <= 0: sample_num = p samples = [] values = [] while len(samples) < sample_num: v = self._random_sample() if v in values: continue sample = copy.deepcopy(self) sample.assign(v) samples.append(sample) values.append(v) return sorted(samples, key=lambda s: s.value)
def _get_choice_num(self): p = self.high - self.low if self.step is not None: p = len(np.arange(self.low, self.high + self.step, step=self.step)) - 1 return p
[docs]class Real(ParameterSpace): def __init__(self, low, high, q=None, prior="uniform", step=0.01, max_expansion=100, random_state=None, space=None, name=None): ParameterSpace.__init__(self, space, name, random_state) low = float(low) high = float(high) assert low < high, '`low` must less than `high`.' self.low = low self.high = high self.q = q self.prior = prior self.step = float(step) self.max_expansion = max_expansion def _random_sample(self): if self.prior == "uniform": assert self.high >= self.low, 'Upper bound must be larger than lower bound' value = self.random_state.uniform(self.low, self.high) self._check(value) elif self.prior == "log_uniform": assert self.low > 0, 'Lower bound must be positive' value = np.exp(self.random_state.uniform(self.low, self.high)) elif self.prior == "q_uniform": assert self.q is not None, 'q cannot be None' value = np.clip( np.round(self.random_state.uniform(self.low, self.high) / self.q) * self.q, self.low, self.high) self._check(value) else: raise ValueError(f'Not supported prior:{self.prior}') if self.step is not None: if self.prior == 'log_uniform': all = np.round(np.arange(np.exp(self.low), np.exp(self.high) + self.step, step=self.step), 8) value = all[np.abs(all - value).argmin()] if value > np.exp(self.high): value = np.exp(self.high) else: all = np.round(np.arange(self.low, self.high + self.step, step=self.step), 8) value = all[np.abs(all - value).argmin()] if value > self.high: value = self.high return value def _check(self, value): if self.prior == "log_uniform": assert value >= np.exp(self.low) and value <= np.exp(self.high) else: assert value >= self.low and value <= self.high @property def config_keys(self): return ['low', 'high', 'q', 'prior', 'step']
[docs] def value2numeric(self, value): return value
[docs] def numeric2value(self, numeric): return numeric
[docs] def expansion(self, sample_num): if sample_num <= 0: sample_num = self.max_expansion sample_num = min(sample_num, self._get_choice_num()) values = [] samples = [] while len(samples) < sample_num: v = self._random_sample() if v in values: continue sample = copy.deepcopy(self) sample.assign(v) samples.append(sample) values.append(v) return sorted(samples, key=lambda s: s.value)
def _get_choice_num(self): p = self.max_expansion if self.step is not None: if self.prior == 'log_uniform': p = len(np.arange(np.exp(self.low), np.exp(self.high) + self.step, step=self.step)) else: p = len(np.arange(self.low, self.high + self.step, step=self.step)) return p
[docs]class Choice(ParameterSpace): def __init__(self, options, random_state=None, space=None, name=None): ParameterSpace.__init__(self, space, name, random_state) assert isinstance(options, list), '`options` must be a List.' assert len(options) > 0, '`options` contains at least one item.' self.options = options @property def is_mutable(self): return len(self.options) > 1 def _random_sample(self): index = self.random_state.choice(range(len(self.options))) return self.options[index] def _check(self, value): assert value in self.options @property def config_keys(self): return ['options']
[docs] def value2numeric(self, value): return self.options.index(value)
[docs] def numeric2value(self, numeric): return self.options[numeric]
[docs] def expansion(self, sample_num=0): samples = [] for option in self.options: sample = copy.deepcopy(self) sample.assign(option) samples.append(sample) return samples
def _get_choice_num(self): return len(self.options)
[docs]class MultipleChoice(ParameterSpace): def __init__(self, options, num_chosen_most=0, num_chosen_least=1, random_state=None, space=None, name=None): ParameterSpace.__init__(self, space, name, random_state) assert isinstance(options, list), '`options` must be a List.' assert len(options) >= num_chosen_least, f'`options` contains at least {num_chosen_least} item.' self.options = options self.num_chosen_most = num_chosen_most self.num_chosen_least = num_chosen_least def _random_sample(self): high = self.num_chosen_most if high <= 0: high = len(self.options) indices = self.random_state.choice(range(0, len(self.options)), self.random_state.randint(self.num_chosen_least, high + 1), False) values = [self.options[index] for index in sorted(indices)] return values def _check(self, value): assert isinstance(value, list) assert len(value) >= self.num_chosen_least, f'value contains at least {self.num_chosen_least} item.' assert (self.num_chosen_most == 0 or self.num_chosen_most <= len(self.options)) assert all([v in self.options for v in value]) @property def config_keys(self): return ['options', 'num_chosen_most', 'num_chosen_least']
[docs] def value2numeric(self, value): numeric = int(''.join(['1' if v in value else '0' for v in self.options]), 2) return numeric
[docs] def numeric2value(self, numeric): bin = np.binary_repr(numeric, len(self.options)) values = [] for i in range(len(bin)): if bin[i] == '1': values.append(self.options[i]) return values
[docs] def expansion(self, sample_num): c = self._get_choice_num() if sample_num > c or sample_num <= 0: sample_num = c values = [] samples = [] while len(values) < sample_num: v = self._random_sample() if v in values: continue sample = copy.deepcopy(self) sample.assign(v) samples.append(sample) values.append(v) return samples
def _get_choice_num(self): return int(combinations(len(self.options), self.num_chosen_most, 1))
[docs]class Bool(Choice): def __init__(self, random_state=None, space=None, name=None): Choice.__init__(self, [False, True], random_state, space, name)
[docs]class Constant(ParameterSpace): def __init__(self, value, space=None, name=None): ParameterSpace.__init__(self, space, name) self.assign(value) @property def is_mutable(self): return False @property def config_keys(self): return ['_value']
[docs]class Dynamic(ParameterSpace): def __init__(self, lambda_fn, space=None, name=None, **param_dict): ParameterSpace.__init__(self, space, name) self._lambda_fn = lambda_fn self._param_dict = {} for n, p in param_dict.items(): if isinstance(p, Dynamic): raise ValueError('Dynamic cannot be nested.') self._param_dict[n] = p p.attach(self, n) self.update() @property def is_mutable(self): return False
[docs] def update(self): if all(p.assigned for p in self._param_dict.values()): args = {name: p.value for name, p in self._param_dict.items()} value = self._lambda_fn(**args) self.assign(value)
@property def param_dict(self): return self._param_dict
[docs]class Cascade(ParameterSpace): def __init__(self, lambda_fn, space=None, name=None, **param_dict): ParameterSpace.__init__(self, space, name) self._lambda_fn = lambda_fn self._param_dict = {} for n, p in param_dict.items(): self._param_dict[n] = p p.attach(self, n) self.update() @property def is_mutable(self): return False @property def assigned(self): if self.value is not None: if isinstance(self.value, ParameterSpace): return self.value.assigned else: return True
[docs] def update(self): if all(p.assigned for p in self._param_dict.values()): args = {name: p.value for name, p in self._param_dict.items()} name, value = self._lambda_fn(args, assert isinstance(value, ParameterSpace), 'The value of `Cascade` must be a ParameterSpace.' for m in self.references: if isinstance(m, ModuleSpace): m.add_parameters(**{name: value}) elif isinstance(m, ParameterSpace): value.attach(m) self.assign(value)
@property def param_dict(self): return self._param_dict
[docs]class ModuleSpace(HyperNode): def __init__(self, space=None, name=None, **hyperparams): HyperNode.__init__(self, space, name) self._hyper_params = OrderedDict() self._is_params_ready = False self._is_compiled = False self._output = None self.add_parameters(**hyperparams) self.update() def __call__(self, *args, **kwargs): assert len(args) > 0 m = args[0] assert m is not None, 'The module to be connected cannot be None.' assert isinstance(m, (ModuleSpace, list)) if isinstance(m, ModuleSpace):, self) else: for mi in m: assert mi is not None, 'The module to be connected cannot be None.', self) return self
[docs] def connect(self, module_or_list): if isinstance(module_or_list, ModuleSpace):, module_or_list) elif isinstance(module_or_list, list): assert len(module_or_list) > 0, f'module_or_list contains at least 1 Module.' assert all([isinstance(m, ModuleSpace) for m in module_or_list]), 'module_or_list can only contain Module.' for m in module_or_list:, m) else: raise ValueError(f'module_or_list is neither Module nor List.') return self
@property def type(self): return 'Module' @property def param_values(self): return {name: p.value for name, p in self._hyper_params.items()} @property def hyper_params(self): return self._hyper_params
[docs] def get_assignable_params(self): assignables = [] for name, p in self._hyper_params.items(): if isinstance(p, Constant): continue elif isinstance(p, (Dynamic, Cascade)): for dp in p.param_dict.values(): if not isinstance(dp, (Dynamic, Cascade)) and not dp in assignables: assignables.append(dp) else: if not p in assignables: assignables.append(p) return assignables
[docs] def get_all_params(self): all = [] for name, p in self._hyper_params.items(): if not p in all: all.append(p) if isinstance(p, (Dynamic, Cascade)): for dp in p.param_dict.values(): if not dp in all: all.append(dp) return all
@property def output(self): return self._output @property def is_compiled(self): return self._is_compiled @property def is_params_ready(self): return self._is_params_ready @property def all_assigned(self): for hp in self._hyper_params.values(): if not hp.assigned: return False return True
[docs] def compile(self): if not self.is_compiled: self._compile() self._is_compiled = True
def _compile(self): raise NotImplementedError
[docs] def forward(self, inputs=None): self._output = self._forward(inputs) return self._output
def _forward(self, inputs): raise NotImplementedError
[docs] def compile_and_forward(self, inputs=None): self.compile() return self.forward(inputs)
def _on_params_ready(self): pass
[docs] def add_parameters(self, **hyperparameters): for name, param in hyperparameters.items(): if not isinstance(param, ParameterSpace): param = Constant(param) if self._hyper_params.get(name) is not None: raise ValueError(f'Parameter `{name}` has existed.') self._hyper_params[name] = param param.attach(self, name)
[docs] def update(self): if all(p.assigned for p in self._hyper_params.values()): self._is_params_ready = True self._on_params_ready()