API reference¶
Event¶
Runner¶
-
class
rnnr.
Runner
[source]¶ A neural network runner.
A runner provides a thin abstraction of iterating over batches for several epochs, which is typically done in neural network training. To customize the behavior during a run, a runner provides a way to listen to events emitted during such run. To listen to an event, call
Runner.on
and provide a callback which will be called when the event is emitted. An event callback is a callable that accepts adict
and returns nothing. Thedict
is the state of the run. By default, the state contains:batches
- Iterable of batches which constitutes an epoch.max_epoch
- Maximum number of epochs to run.n_iters
- Current number of batch iterations.running
- A boolean which equalsTrue
if the runner is still running. Can be set toFalse
to stop the runner earlier.epoch
- Current number of epoch.batch
- Current batch retrieved fromstate['batches']
.
Note
Callbacks for an event are called in the order they are passed to
on
.Caution
All the state contents above are required for a runner to function properly. You are free to change their values to suit your use cases better, but be careful.
-
on
(event, callbacks=None)[source]¶ Add single/multiple callback(s) to listen to an event.
If
callbacks
isNone
, this method returns a decorator which accepts a single callback for the event. Ifcallbacks
is a sequence of callbacks, they will all be added as listeners to the event in order.Parameters: - event (
Event
) – Event to listen. - callbacks – Callback(s) for the event.
Returns: A decorator which accepts a callback, if
callbacks
isNone
.- event (
Callbacks¶
-
rnnr.callbacks.
maybe_stop_early
(patience=5, *, check='better', counter='_counter')[source]¶ A callback factory for early stopping.
The returned calback keeps a counter in
state[counter]
for the number of timesstate[check]
isFalse
. If this counter exceedspatience
, the callback stops the runner by settingstate['running'] = False
.Example
>>> valid_losses = [0.1, 0.2, 0.3] # simulate validation batch losses >>> batches = range(10) >>> >>> from rnnr import Event, Runner >>> from rnnr.attachments import MeanReducer >>> from rnnr.callbacks import maybe_stop_early >>> >>> trainer = Runner() >>> @trainer.on(Event.EPOCH_STARTED) ... def print_epoch(state): ... print('Epoch', state['epoch'], 'started') ... >>> @trainer.on(Event.EPOCH_FINISHED) ... def eval_on_valid(state): ... def eval_fn(state): ... state['output'] = state['batch'] ... evaluator = Runner() ... evaluator.on(Event.BATCH, eval_fn) ... MeanReducer(name='mean').attach_on(evaluator) ... evaluator.run(valid_losses) ... if state.get('best_loss', float('inf')) > evaluator.state['mean']: ... state['better'] = True ... state['best_loss'] = evaluator.state['mean'] ... else: ... state['better'] = False ... >>> trainer.on(Event.EPOCH_FINISHED, maybe_stop_early(patience=2)) >>> trainer.run(batches, max_epoch=7) Epoch 1 started Epoch 2 started Epoch 3 started Epoch 4 started
Parameters: Returns: Callback that does early stopping.
-
rnnr.callbacks.
checkpoint
(what, obj=None, *, under=None, at_most=1, when=None, using=None, ext='pkl', prefix_fmt='{epoch}_', queue_fmt='_saved_{what}')[source]¶ A callback factory for checkpointing.
Checkpointing means saving
obj
(orstate[what]
ifobj
isNone
) during a run underunder
directory with{prefix_fmt}{what}.{ext}
as the filename.Example
>>> from pathlib import Path >>> from pprint import pprint >>> from rnnr import Event, Runner >>> from rnnr.callbacks import checkpoint >>> >>> batches = range(3) >>> tmp_dir = Path('/tmp') >>> runner = Runner() >>> @runner.on(Event.EPOCH_FINISHED) ... def store_checkpoint(state): ... state['model'] = 'MODEL' ... state['optimizer'] = 'OPTIMIZER' ... >>> runner.on(Event.EPOCH_FINISHED, checkpoint('model', under=tmp_dir, at_most=3)) >>> runner.on(Event.EPOCH_FINISHED, checkpoint('optimizer', under=tmp_dir, at_most=3)) >>> runner.run(batches, max_epoch=7) >>> pprint(sorted(list(tmp_dir.glob('*.pkl')))) [PosixPath('/tmp/5_model.pkl'), PosixPath('/tmp/5_optimizer.pkl'), PosixPath('/tmp/6_model.pkl'), PosixPath('/tmp/6_optimizer.pkl'), PosixPath('/tmp/7_model.pkl'), PosixPath('/tmp/7_optimizer.pkl')]
Parameters: - what (
str
) – Name of the object to save. - obj (
Optional
[Any
]) – Object to save. IfNone
, will be obtained fromstate[what]
. - under (
Optional
[Path
]) – Save the object under this directory. Defaults to the current working directory if not given. - at_most (
int
) – Maximum number of files saved. When the number of files exceeds this number, older files will be deleted. - when (
Optional
[str
]) – If given, only save the object whenstate[when]
isTrue
. - using (
Optional
[Callable
[[Any
,Path
],None
]]) – Function to invoke to save the object. If given, this must be a callable accepting two arguments: an object to save and aPath
to save it to. The default is to save the object usingpickle
. - ext (
str
) – Extension for the filename. - prefix_fmt (
str
) – Format for the filename prefix. Any string keys instate
can be used as replacement fields. - queue_fmt (
str
) – Keeps track of the saved files for the object with a queue stored instate[queue_fmt.format(what=what)]
.
Returns: Callback that does checkpointing.
- what (
-
rnnr.callbacks.
save
(*args, **kwargs)[source]¶ An alias for
checkpoint
.
Attachments¶
-
class
rnnr.attachments.
EpochTimer
[source]¶ An attachment to time epoch.
Epochs are only timed when
state['max_epoch']
is greater than 1. At the start and end of every epoch, logging messages are written with log level of INFO.
-
class
rnnr.attachments.
ProgressBar
(*, n_items='n_items', stats=None, tqdm_cls=None, **kwargs)[source]¶ An attachment to display a progress bar.
The progress bar is implemented using tqdm.
Example
>>> from rnnr import Runner >>> from rnnr.attachments import ProgressBar >>> runner = Runner() >>> ProgressBar().attach_on(runner) >>> runner.run(range(10), max_epoch=10)
Parameters: - n_items (
str
) – Get the number of items in a batch fromstate[n_items]
to update the progress bar with. If not given, the default is to always set it to 1. - stats (
Optional
[str
]) – Get the batch statistics fromstate[stats]
and display it along with the progress bar. The statistics dictionary has the names of the statistics as keys and the statistics as values. - **kwargs – Keyword arguments to be passed to tqdm class.
- n_items (
-
class
rnnr.attachments.
LambdaReducer
(name, reduce_fn, *, value='output')[source]¶ An attachment to compute a reduction over batches.
This attachment gets the value of each batch and compute a reduction over them at the end of each epoch.
Example
>>> from rnnr import Event, Runner >>> from rnnr.attachments import LambdaReducer >>> runner = Runner() >>> LambdaReducer('product', lambda x, y: x * y).attach_on(runner) >>> @runner.on(Event.BATCH) ... def on_batch(state): ... state['output'] = state['batch'] ... >>> runner.run([10, 20, 30]) >>> runner.state['product'] 6000
Parameters: - name (
str
) – Name of this attachment to be used as the key in the runner’s state dict to store the reduction result. - reduce_fn (
Callable
[[Any
,Any
],Any
]) – Reduction function. It should accept two batch values and return their reduction result. - value (
str
) – Get the value of a batch fromstate[value]
.
- name (
-
class
rnnr.attachments.
SumReducer
(name, *, value='output')[source]¶ An attachment to compute a sum over batch statistics.
This attachment gets the value from each batch and compute their sum at the end of every epoch.
Parameters:
-
class
rnnr.attachments.
MeanReducer
(name, *, value='output', size='size')[source]¶ An attachment to compute a mean over batch statistics.
This attachment gets the value from each batch and compute their mean at the end of every epoch.
Example
>>> from rnnr import Event, Runner >>> from rnnr.attachments import MeanReducer >>> runner = Runner() >>> MeanReducer('mean').attach_on(runner) >>> @runner.on(Event.BATCH) ... def on_batch(state): ... state['output'] = state['batch'] ... >>> runner.run([1, 2, 3]) >>> runner.state['mean'] 2.0
Parameters: - name (
str
) – Name of this attachment to be used as the key in the runner’s state dict to store the mean value. - value (
str
) – Get the value of a batch fromstate[value]
. - size (
str
) – Get the size of a batch fromstate[size]
. If the state has no such key, the size defaults to 1. The sum of all these batch sizes is the divisor when computing the mean.
- name (