Source code for rnnr.callbacks

# Copyright 2019 Kemal Kurniawan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import deque
from typing import Any, Callable, Optional
from pathlib import Path
import logging
import pickle


[docs]def maybe_stop_early(patience: int = 5, *, check: str = "better", counter: str = "_counter"): """A callback factory for early stopping. The returned calback keeps a counter in ``state[counter]`` for the number of times ``state[check]`` is ``False``. If this counter exceeds ``patience``, the callback stops the runner by setting ``state['running'] = False``. Example: >>> valid_losses = [0.1, 0.2, 0.3] # simulate validation batch losses >>> batches = range(10) >>> >>> from rnnr import Event, Runner >>> from rnnr.attachments import MeanReducer >>> from rnnr.callbacks import maybe_stop_early >>> >>> trainer = Runner() >>> @trainer.on(Event.EPOCH_STARTED) ... def print_epoch(state): ... print('Epoch', state['epoch'], 'started') ... >>> @trainer.on(Event.EPOCH_FINISHED) ... def eval_on_valid(state): ... def eval_fn(state): ... state['output'] = state['batch'] ... evaluator = Runner() ... evaluator.on(Event.BATCH, eval_fn) ... MeanReducer(name='mean').attach_on(evaluator) ... evaluator.run(valid_losses) ... if state.get('best_loss', float('inf')) > evaluator.state['mean']: ... state['better'] = True ... state['best_loss'] = evaluator.state['mean'] ... else: ... state['better'] = False ... >>> trainer.on(Event.EPOCH_FINISHED, maybe_stop_early(patience=2)) >>> trainer.run(batches, max_epoch=7) Epoch 1 started Epoch 2 started Epoch 3 started Epoch 4 started Args: patience: Stop the runner when the counter exceeds this number. check: Increment counter if ``state[check]`` is ``False``. counter: Store the counter in ``state[counter]``. Returns: Callback that does early stopping. """ logger = logging.getLogger(f"{__name__}.early_stopping") def callback(state): n = (state.get(counter, 0) + 1) if not state[check] else 0 state[counter] = n if state[counter] > patience: logger.info("Patience exceeded, stopping early") state["running"] = False return callback
[docs]def checkpoint( what: str, obj: Optional[Any] = None, *, under: Optional[Path] = None, at_most: int = 1, when: Optional[str] = None, using: Optional[Callable[[Any, Path], None]] = None, ext: str = "pkl", prefix_fmt: str = "{epoch}_", queue_fmt: str = "_saved_{what}", ): """A callback factory for checkpointing. Checkpointing means saving ``obj`` (or ``state[what]`` if ``obj`` is ``None``) during a run under ``under`` directory with ``{prefix_fmt}{what}.{ext}`` as the filename. Example: >>> from pathlib import Path >>> from pprint import pprint >>> from rnnr import Event, Runner >>> from rnnr.callbacks import checkpoint >>> >>> batches = range(3) >>> tmp_dir = Path('/tmp') >>> runner = Runner() >>> @runner.on(Event.EPOCH_FINISHED) ... def store_checkpoint(state): ... state['model'] = 'MODEL' ... state['optimizer'] = 'OPTIMIZER' ... >>> runner.on(Event.EPOCH_FINISHED, checkpoint('model', under=tmp_dir, at_most=3)) >>> runner.on(Event.EPOCH_FINISHED, checkpoint('optimizer', under=tmp_dir, at_most=3)) >>> runner.run(batches, max_epoch=7) >>> pprint(sorted(list(tmp_dir.glob('*.pkl')))) [PosixPath('/tmp/5_model.pkl'), PosixPath('/tmp/5_optimizer.pkl'), PosixPath('/tmp/6_model.pkl'), PosixPath('/tmp/6_optimizer.pkl'), PosixPath('/tmp/7_model.pkl'), PosixPath('/tmp/7_optimizer.pkl')] Args: what: Name of the object to save. obj: Object to save. If ``None``, will be obtained from ``state[what]``. under: Save the object under this directory. Defaults to the current working directory if not given. at_most: Maximum number of files saved. When the number of files exceeds this number, older files will be deleted. when: If given, only save the object when ``state[when]`` is ``True``. using: Function to invoke to save the object. If given, this must be a callable accepting two arguments: an object to save and a `Path` to save it to. The default is to save the object using `pickle`. ext: Extension for the filename. prefix_fmt: Format for the filename prefix. Any string keys in ``state`` can be used as replacement fields. queue_fmt: Keeps track of the saved files for the object with a queue stored in ``state[queue_fmt.format(what=what)]``. Returns: Callback that does checkpointing. """ if under is None: # pragma: no cover under = Path.cwd() if using is None: using = _save_with_pickle qkey = queue_fmt.format(what=what) logger = logging.getLogger(f"{__name__}.checkpointing") def callback(state): q = state.get(qkey, deque()) if when is None or state[when]: fmt = f"{prefix_fmt}{what}.{ext}" path = under / fmt.format(**state) logger.info("Saving to %s", path) using(state[what] if obj is None else obj, path) q.append(path) while len(q) > at_most: p = q.popleft() if p.exists(): # pragma: no cover p.unlink() state[qkey] = q return callback
[docs]def save(*args, **kwargs): # pragma: no cover """An alias for `checkpoint`.""" return checkpoint(*args, **kwargs)
def _save_with_pickle(obj: Any, path: Path) -> None: with open(path, "wb") as f: pickle.dump(obj, f)