API reference

Event

class rnnr.Event[source]

An enumeration of events.

STARTED

Emitted once at the start of a run.

EPOCH_STARTED

Emitted at the start of every epoch.

BATCH

Emitted on every batch.

EPOCH_FINISHED

Emitted at the end of every epoch.

FINISHED

Emitted once at the end of a run.

Runner

class rnnr.Runner[source]

A neural network runner.

A runner provides a thin abstraction of iterating over batches for several epochs, which is typically done in neural network training. To customize the behavior during a run, a runner provides a way to listen to events emitted during such run. To listen to an event, call Runner.on and provide a callback which will be called when the event is emitted. An event callback is a callable that accepts a dict and returns nothing. The dict is the state of the run. By default, the state contains:

  • batches - Iterable of batches which constitutes an epoch.
  • max_epoch - Maximum number of epochs to run.
  • n_iters - Current number of batch iterations.
  • running - A boolean which equals True if the runner is still running. Can be set to False to stop the runner earlier.
  • epoch - Current number of epoch.
  • batch - Current batch retrieved from state['batches'].
state

Runner’s state that is passed to event callbacks.

Type:dict

Note

Callbacks for an event are called in the order they are passed to on.

Caution

All the state contents above are required for a runner to function properly. You are free to change their values to suit your use cases better, but be careful.

on(event, callbacks=None)[source]

Add single/multiple callback(s) to listen to an event.

If callbacks is None, this method returns a decorator which accepts a single callback for the event. If callbacks is a sequence of callbacks, they will all be added as listeners to the event in order.

Parameters:
  • event (Event) – Event to listen.
  • callbacks – Callback(s) for the event.
Returns:

A decorator which accepts a callback, if callbacks is None.

resume(repeat_last_batch=False)[source]

Resume runner starting from the current state.

Parameters:repeat_last_batch (bool) – Whether to repeat processing the last batch. Ignored if the last epoch is finished (i.e. the batches have been exhausted).
Return type:None
run(batches, max_epoch=1)[source]

Run on batches for a number of epochs.

Parameters:
  • batches (Iterable[Any]) – Batches to iterate over in an epoch.
  • max_epoch (int) – Maximum number of epochs to run.
Return type:

None

Callbacks

rnnr.callbacks.maybe_stop_early(patience=5, *, check='better', counter='_counter')[source]

A callback factory for early stopping.

The returned calback keeps a counter in state[counter] for the number of times state[check] is False. If this counter exceeds patience, the callback stops the runner by setting state['running'] = False.

Example

>>> valid_losses = [0.1, 0.2, 0.3]  # simulate validation batch losses
>>> batches = range(10)
>>>
>>> from rnnr import Event, Runner
>>> from rnnr.attachments import MeanReducer
>>> from rnnr.callbacks import maybe_stop_early
>>>
>>> trainer = Runner()
>>> @trainer.on(Event.EPOCH_STARTED)
... def print_epoch(state):
...     print('Epoch', state['epoch'], 'started')
...
>>> @trainer.on(Event.EPOCH_FINISHED)
... def eval_on_valid(state):
...     def eval_fn(state):
...         state['output'] = state['batch']
...     evaluator = Runner()
...     evaluator.on(Event.BATCH, eval_fn)
...     MeanReducer(name='mean').attach_on(evaluator)
...     evaluator.run(valid_losses)
...     if state.get('best_loss', float('inf')) > evaluator.state['mean']:
...         state['better'] = True
...         state['best_loss'] = evaluator.state['mean']
...     else:
...         state['better'] = False
...
>>> trainer.on(Event.EPOCH_FINISHED, maybe_stop_early(patience=2))
>>> trainer.run(batches, max_epoch=7)
Epoch 1 started
Epoch 2 started
Epoch 3 started
Epoch 4 started
Parameters:
  • patience (int) – Stop the runner when the counter exceeds this number.
  • check (str) – Increment counter if state[check] is False.
  • counter (str) – Store the counter in state[counter].
Returns:

Callback that does early stopping.

rnnr.callbacks.checkpoint(what, obj=None, *, under=None, at_most=1, when=None, using=None, ext='pkl', prefix_fmt='{epoch}_', queue_fmt='_saved_{what}')[source]

A callback factory for checkpointing.

Checkpointing means saving obj (or state[what] if obj is None) during a run under under directory with {prefix_fmt}{what}.{ext} as the filename.

Example

>>> from pathlib import Path
>>> from pprint import pprint
>>> from rnnr import Event, Runner
>>> from rnnr.callbacks import checkpoint
>>>
>>> batches = range(3)
>>> tmp_dir = Path('/tmp')
>>> runner = Runner()
>>> @runner.on(Event.EPOCH_FINISHED)
... def store_checkpoint(state):
...     state['model'] = 'MODEL'
...     state['optimizer'] = 'OPTIMIZER'
...
>>> runner.on(Event.EPOCH_FINISHED, checkpoint('model', under=tmp_dir, at_most=3))
>>> runner.on(Event.EPOCH_FINISHED, checkpoint('optimizer', under=tmp_dir, at_most=3))
>>> runner.run(batches, max_epoch=7)
>>> pprint(sorted(list(tmp_dir.glob('*.pkl'))))
[PosixPath('/tmp/5_model.pkl'),
 PosixPath('/tmp/5_optimizer.pkl'),
 PosixPath('/tmp/6_model.pkl'),
 PosixPath('/tmp/6_optimizer.pkl'),
 PosixPath('/tmp/7_model.pkl'),
 PosixPath('/tmp/7_optimizer.pkl')]
Parameters:
  • what (str) – Name of the object to save.
  • obj (Optional[Any]) – Object to save. If None, will be obtained from state[what].
  • under (Optional[Path]) – Save the object under this directory. Defaults to the current working directory if not given.
  • at_most (int) – Maximum number of files saved. When the number of files exceeds this number, older files will be deleted.
  • when (Optional[str]) – If given, only save the object when state[when] is True.
  • using (Optional[Callable[[Any, Path], None]]) – Function to invoke to save the object. If given, this must be a callable accepting two arguments: an object to save and a Path to save it to. The default is to save the object using pickle.
  • ext (str) – Extension for the filename.
  • prefix_fmt (str) – Format for the filename prefix. Any string keys in state can be used as replacement fields.
  • queue_fmt (str) – Keeps track of the saved files for the object with a queue stored in state[queue_fmt.format(what=what)].
Returns:

Callback that does checkpointing.

rnnr.callbacks.save(*args, **kwargs)[source]

An alias for checkpoint.

Attachments

class rnnr.attachments.Attachment[source]

Bases: abc.ABC

An abstract base class for an attachment.

attach_on(runner)[source]

Attach to a runner.

Parameters:runner (Runner) – Runner to attach to.
Return type:None
class rnnr.attachments.EpochTimer[source]

An attachment to time epoch.

Epochs are only timed when state['max_epoch'] is greater than 1. At the start and end of every epoch, logging messages are written with log level of INFO.

class rnnr.attachments.ProgressBar(*, n_items='n_items', stats=None, tqdm_cls=None, **kwargs)[source]

An attachment to display a progress bar.

The progress bar is implemented using tqdm.

Example

>>> from rnnr import Runner
>>> from rnnr.attachments import ProgressBar
>>> runner = Runner()
>>> ProgressBar().attach_on(runner)
>>> runner.run(range(10), max_epoch=10)
Parameters:
  • n_items (str) – Get the number of items in a batch from state[n_items] to update the progress bar with. If not given, the default is to always set it to 1.
  • stats (Optional[str]) – Get the batch statistics from state[stats] and display it along with the progress bar. The statistics dictionary has the names of the statistics as keys and the statistics as values.
  • **kwargs – Keyword arguments to be passed to tqdm class.
class rnnr.attachments.LambdaReducer(name, reduce_fn, *, value='output')[source]

An attachment to compute a reduction over batches.

This attachment gets the value of each batch and compute a reduction over them at the end of each epoch.

Example

>>> from rnnr import Event, Runner
>>> from rnnr.attachments import LambdaReducer
>>> runner = Runner()
>>> LambdaReducer('product', lambda x, y: x * y).attach_on(runner)
>>> @runner.on(Event.BATCH)
... def on_batch(state):
...     state['output'] = state['batch']
...
>>> runner.run([10, 20, 30])
>>> runner.state['product']
6000
Parameters:
  • name (str) – Name of this attachment to be used as the key in the runner’s state dict to store the reduction result.
  • reduce_fn (Callable[[Any, Any], Any]) – Reduction function. It should accept two batch values and return their reduction result.
  • value (str) – Get the value of a batch from state[value].
class rnnr.attachments.SumReducer(name, *, value='output')[source]

An attachment to compute a sum over batch statistics.

This attachment gets the value from each batch and compute their sum at the end of every epoch.

Parameters:
  • name (str) – Name of this attachment to be used as the key in the runner’s state dict to store the mean value.
  • value (str) – Get the value of a batch from state[value].
class rnnr.attachments.MeanReducer(name, *, value='output', size='size')[source]

An attachment to compute a mean over batch statistics.

This attachment gets the value from each batch and compute their mean at the end of every epoch.

Example

>>> from rnnr import Event, Runner
>>> from rnnr.attachments import MeanReducer
>>> runner = Runner()
>>> MeanReducer('mean').attach_on(runner)
>>> @runner.on(Event.BATCH)
... def on_batch(state):
...     state['output'] = state['batch']
...
>>> runner.run([1, 2, 3])
>>> runner.state['mean']
2.0
Parameters:
  • name (str) – Name of this attachment to be used as the key in the runner’s state dict to store the mean value.
  • value (str) – Get the value of a batch from state[value].
  • size (str) – Get the size of a batch from state[size]. If the state has no such key, the size defaults to 1. The sum of all these batch sizes is the divisor when computing the mean.