API reference¶
Event¶
Runner¶
-
class
rnnr.Runner[source]¶ A neural network runner.
A runner provides a thin abstraction of iterating over batches for several epochs, which is typically done in neural network training. To customize the behavior during a run, a runner provides a way to listen to events emitted during such run. To listen to an event, call
Runner.onand provide a callback which will be called when the event is emitted. An event callback is a callable that accepts adictand returns nothing. Thedictis the state of the run. By default, the state contains:batches- Iterable of batches which constitutes an epoch.max_epoch- Maximum number of epochs to run.n_iters- Current number of batch iterations.running- A boolean which equalsTrueif the runner is still running. Can be set toFalseto stop the runner earlier.epoch- Current number of epoch.batch- Current batch retrieved fromstate['batches'].
Note
Callbacks for an event are called in the order they are passed to
on.Caution
All the state contents above are required for a runner to function properly. You are free to change their values to suit your use cases better, but be careful.
-
on(event, callbacks=None)[source]¶ Add single/multiple callback(s) to listen to an event.
If
callbacksisNone, this method returns a decorator which accepts a single callback for the event. Ifcallbacksis a sequence of callbacks, they will all be added as listeners to the event in order.Parameters: - event (
Event) – Event to listen. - callbacks – Callback(s) for the event.
Returns: A decorator which accepts a callback, if
callbacksisNone.- event (
Callbacks¶
-
rnnr.callbacks.maybe_stop_early(patience=5, *, check='better', counter='_counter')[source]¶ A callback factory for early stopping.
The returned calback keeps a counter in
state[counter]for the number of timesstate[check]isFalse. If this counter exceedspatience, the callback stops the runner by settingstate['running'] = False.Example
>>> valid_losses = [0.1, 0.2, 0.3] # simulate validation batch losses >>> batches = range(10) >>> >>> from rnnr import Event, Runner >>> from rnnr.attachments import MeanReducer >>> from rnnr.callbacks import maybe_stop_early >>> >>> trainer = Runner() >>> @trainer.on(Event.EPOCH_STARTED) ... def print_epoch(state): ... print('Epoch', state['epoch'], 'started') ... >>> @trainer.on(Event.EPOCH_FINISHED) ... def eval_on_valid(state): ... def eval_fn(state): ... state['output'] = state['batch'] ... evaluator = Runner() ... evaluator.on(Event.BATCH, eval_fn) ... MeanReducer(name='mean').attach_on(evaluator) ... evaluator.run(valid_losses) ... if state.get('best_loss', float('inf')) > evaluator.state['mean']: ... state['better'] = True ... state['best_loss'] = evaluator.state['mean'] ... else: ... state['better'] = False ... >>> trainer.on(Event.EPOCH_FINISHED, maybe_stop_early(patience=2)) >>> trainer.run(batches, max_epoch=7) Epoch 1 started Epoch 2 started Epoch 3 started Epoch 4 started
Parameters: Returns: Callback that does early stopping.
-
rnnr.callbacks.checkpoint(what, obj=None, *, under=None, at_most=1, when=None, using=None, ext='pkl', prefix_fmt='{epoch}_', queue_fmt='_saved_{what}')[source]¶ A callback factory for checkpointing.
Checkpointing means saving
obj(orstate[what]ifobjisNone) during a run underunderdirectory with{prefix_fmt}{what}.{ext}as the filename.Example
>>> from pathlib import Path >>> from pprint import pprint >>> from rnnr import Event, Runner >>> from rnnr.callbacks import checkpoint >>> >>> batches = range(3) >>> tmp_dir = Path('/tmp') >>> runner = Runner() >>> @runner.on(Event.EPOCH_FINISHED) ... def store_checkpoint(state): ... state['model'] = 'MODEL' ... state['optimizer'] = 'OPTIMIZER' ... >>> runner.on(Event.EPOCH_FINISHED, checkpoint('model', under=tmp_dir, at_most=3)) >>> runner.on(Event.EPOCH_FINISHED, checkpoint('optimizer', under=tmp_dir, at_most=3)) >>> runner.run(batches, max_epoch=7) >>> pprint(sorted(list(tmp_dir.glob('*.pkl')))) [PosixPath('/tmp/5_model.pkl'), PosixPath('/tmp/5_optimizer.pkl'), PosixPath('/tmp/6_model.pkl'), PosixPath('/tmp/6_optimizer.pkl'), PosixPath('/tmp/7_model.pkl'), PosixPath('/tmp/7_optimizer.pkl')]
Parameters: - what (
str) – Name of the object to save. - obj (
Optional[Any]) – Object to save. IfNone, will be obtained fromstate[what]. - under (
Optional[Path]) – Save the object under this directory. Defaults to the current working directory if not given. - at_most (
int) – Maximum number of files saved. When the number of files exceeds this number, older files will be deleted. - when (
Optional[str]) – If given, only save the object whenstate[when]isTrue. - using (
Optional[Callable[[Any,Path],None]]) – Function to invoke to save the object. If given, this must be a callable accepting two arguments: an object to save and aPathto save it to. The default is to save the object usingpickle. - ext (
str) – Extension for the filename. - prefix_fmt (
str) – Format for the filename prefix. Any string keys instatecan be used as replacement fields. - queue_fmt (
str) – Keeps track of the saved files for the object with a queue stored instate[queue_fmt.format(what=what)].
Returns: Callback that does checkpointing.
- what (
-
rnnr.callbacks.save(*args, **kwargs)[source]¶ An alias for
checkpoint.
Attachments¶
-
class
rnnr.attachments.EpochTimer[source]¶ An attachment to time epoch.
Epochs are only timed when
state['max_epoch']is greater than 1. At the start and end of every epoch, logging messages are written with log level of INFO.
-
class
rnnr.attachments.ProgressBar(*, n_items='n_items', stats=None, tqdm_cls=None, **kwargs)[source]¶ An attachment to display a progress bar.
The progress bar is implemented using tqdm.
Example
>>> from rnnr import Runner >>> from rnnr.attachments import ProgressBar >>> runner = Runner() >>> ProgressBar().attach_on(runner) >>> runner.run(range(10), max_epoch=10)
Parameters: - n_items (
str) – Get the number of items in a batch fromstate[n_items]to update the progress bar with. If not given, the default is to always set it to 1. - stats (
Optional[str]) – Get the batch statistics fromstate[stats]and display it along with the progress bar. The statistics dictionary has the names of the statistics as keys and the statistics as values. - **kwargs – Keyword arguments to be passed to tqdm class.
- n_items (
-
class
rnnr.attachments.LambdaReducer(name, reduce_fn, *, value='output')[source]¶ An attachment to compute a reduction over batches.
This attachment gets the value of each batch and compute a reduction over them at the end of each epoch.
Example
>>> from rnnr import Event, Runner >>> from rnnr.attachments import LambdaReducer >>> runner = Runner() >>> LambdaReducer('product', lambda x, y: x * y).attach_on(runner) >>> @runner.on(Event.BATCH) ... def on_batch(state): ... state['output'] = state['batch'] ... >>> runner.run([10, 20, 30]) >>> runner.state['product'] 6000
Parameters: - name (
str) – Name of this attachment to be used as the key in the runner’s state dict to store the reduction result. - reduce_fn (
Callable[[Any,Any],Any]) – Reduction function. It should accept two batch values and return their reduction result. - value (
str) – Get the value of a batch fromstate[value].
- name (
-
class
rnnr.attachments.SumReducer(name, *, value='output')[source]¶ An attachment to compute a sum over batch statistics.
This attachment gets the value from each batch and compute their sum at the end of every epoch.
Parameters:
-
class
rnnr.attachments.MeanReducer(name, *, value='output', size='size')[source]¶ An attachment to compute a mean over batch statistics.
This attachment gets the value from each batch and compute their mean at the end of every epoch.
Example
>>> from rnnr import Event, Runner >>> from rnnr.attachments import MeanReducer >>> runner = Runner() >>> MeanReducer('mean').attach_on(runner) >>> @runner.on(Event.BATCH) ... def on_batch(state): ... state['output'] = state['batch'] ... >>> runner.run([1, 2, 3]) >>> runner.state['mean'] 2.0
Parameters: - name (
str) – Name of this attachment to be used as the key in the runner’s state dict to store the mean value. - value (
str) – Get the value of a batch fromstate[value]. - size (
str) – Get the size of a batch fromstate[size]. If the state has no such key, the size defaults to 1. The sum of all these batch sizes is the divisor when computing the mean.
- name (