Hooks

Hook is a utility tool for training. All hooks are derived from nnabla_rl.hook.Hook

Hook

class nnabla_rl.hook.Hook(timing: int = 1000)[source]

Base class of hooks for Algorithm classes.

Hook is called at every ‘timing’ iterations during the training. ‘timing’ is specified at the beginning of the class instantiation.

abstract on_hook_called(algorithm: Algorithm)[source]

Called every “timing” iteration which is set on Hook’s instance creation. Will run additional periodical operation (see each class’ documentation) during the training.

Parameters:

algorithm (Algorithm) – Algorithm instance to perform additional operation.

setup(algorithm: Algorithm, total_iterations: int)[source]

Called before the training starts.

Parameters:
  • algorithm (Algorithm) – Algorithm instance to perform additional operation.

  • total_iteration (int) – the number of total iterations.

teardown(algorithm: Algorithm, total_iterations: int)[source]

Called after the training ends.

Parameters:
  • algorithm (Algorithm) – Algorithm instance to perform additional operation.

  • total_iteration (int) – the number of total iterations.

List of Hooks

class nnabla_rl.hooks.EvaluationHook(env, evaluator=<nnabla_rl.utils.evaluator.EpisodicEvaluator object>, timing=1000, writer=None)[source]

Bases: Hook

Hook to run evaluation during training.

Parameters:
  • env (gym.Env) – Environment to run the evaluation

  • evaluator (Callable[[nnabla_rl.algorithm.Algorithm, gym.Env], List[float]]) – Evaluator which runs the actual evaluation. Defaults to EpisodicEvaluator.

  • timing (int) – Evaluation interval. Defaults to 1000 iteration.

  • writer (nnabla_rl.writer.Writer, optional) – Writer instance to save/print the evaluation results. Defaults to None.

on_hook_called(algorithm)[source]

Called every “timing” iteration which is set on Hook’s instance creation. Will run additional periodical operation (see each class’ documentation) during the training.

Parameters:

algorithm (Algorithm) – Algorithm instance to perform additional operation.

class nnabla_rl.hooks.IterationNumHook(timing=1)[source]

Bases: Hook

Hook to print the iteration number periodically. This hook just prints the iteration number of training.

Parameters:

timing (int) – Printing interval. Defaults to 1 iteration.

on_hook_called(algorithm)[source]

Called every “timing” iteration which is set on Hook’s instance creation. Will run additional periodical operation (see each class’ documentation) during the training.

Parameters:

algorithm (Algorithm) – Algorithm instance to perform additional operation.

class nnabla_rl.hooks.IterationStateHook(writer=None, timing=1000)[source]

Bases: Hook

Hook which retrieves the iteration state to print/save the training status through writer.

Parameters:
  • timing (int) – Retriving interval. Defaults to 1000 iteration.

  • writer (nnabla_rl.writer.Writer, optional) – Writer instance to save/print the iteration states. Defaults to None.

on_hook_called(algorithm)[source]

Called every “timing” iteration which is set on Hook’s instance creation. Will run additional periodical operation (see each class’ documentation) during the training.

Parameters:

algorithm (Algorithm) – Algorithm instance to perform additional operation.

class nnabla_rl.hooks.SaveSnapshotHook(outdir, timing=1000)[source]

Bases: Hook

Hook to save the training snapshot of current algorithm.

Parameters:

timing (int) – Saving interval. Defaults to 1000 iteration.

on_hook_called(algorithm)[source]

Called every “timing” iteration which is set on Hook’s instance creation. Will run additional periodical operation (see each class’ documentation) during the training.

Parameters:

algorithm (Algorithm) – Algorithm instance to perform additional operation.

class nnabla_rl.hooks.ProgressBarHook(timing: int = 1)[source]

Bases: Hook

Hook to show progress bar.

Parameters:

timing (int) – Updating interval. Defaults to 1 iteration.

on_hook_called(algorithm: Algorithm)[source]

Called every “timing” iteration which is set on Hook’s instance creation. Will run additional periodical operation (see each class’ documentation) during the training.

Parameters:

algorithm (Algorithm) – Algorithm instance to perform additional operation.

setup(algorithm: Algorithm, total_iterations: int)[source]

Called before the training starts.

Parameters:
  • algorithm (Algorithm) – Algorithm instance to perform additional operation.

  • total_iteration (int) – the number of total iterations.

teardown(algorithm: Algorithm, total_iterations: int)[source]

Called after the training ends.

Parameters:
  • algorithm (Algorithm) – Algorithm instance to perform additional operation.

  • total_iteration (int) – the number of total iterations.

class nnabla_rl.hooks.TimeMeasuringHook(timing=1)[source]

Bases: Hook

Hook to measure and print the actual time spent to run the iteration(s).

Parameters:

timing (int) – Measuring interval. Defaults to 1 iteration.

on_hook_called(algorithm)[source]

Called every “timing” iteration which is set on Hook’s instance creation. Will run additional periodical operation (see each class’ documentation) during the training.

Parameters:

algorithm (Algorithm) – Algorithm instance to perform additional operation.