Source code for rnnr.attachments

# Copyright 2019 Kemal Kurniawan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import timedelta
from typing import Any, Callable, Optional, Type
from warnings import warn
import abc
import logging
import time

from tqdm import tqdm

from .event import Event
from .runner import Runner


[docs]class Attachment(abc.ABC): """An abstract base class for an attachment."""
[docs] @abc.abstractmethod def attach_on(self, runner: Runner) -> None: """Attach to a runner. Args: runner: Runner to attach to. """ pass
[docs]class EpochTimer(Attachment): # pragma: no cover """An attachment to time epoch. Epochs are only timed when ``state['max_epoch']`` is greater than 1. At the start and end of every epoch, logging messages are written with log level of INFO. """ logger = logging.getLogger(f"{__name__}.epoch_timer") _epoch_start_time = "_epoch_start_time" def attach_on(self, runner: Runner) -> None: runner.on(Event._ETIMER_STARTED, self._start) runner.on(Event._ETIMER_FINISHED, self._finish) def _start(self, state): if state["max_epoch"] > 1: if self._epoch_start_time not in state: state[self._epoch_start_time] = time.time() msg = "Starting epoch %d/%d" else: msg = "Resuming epoch %d/%d" self.logger.info(msg, state["epoch"], state["max_epoch"]) def _finish(self, state): if state["max_epoch"] > 1: elapsed = timedelta(seconds=time.time() - state.pop(self._epoch_start_time)) self.logger.info( "Epoch %d/%d done in %s", state["epoch"], state["max_epoch"], elapsed )
[docs]class ProgressBar(Attachment): """An attachment to display a progress bar. The progress bar is implemented using `tqdm`_. Example: >>> from rnnr import Runner >>> from rnnr.attachments import ProgressBar >>> runner = Runner() >>> ProgressBar().attach_on(runner) >>> runner.run(range(10), max_epoch=10) Args: n_items: Get the number of items in a batch from ``state[n_items]`` to update the progress bar with. If not given, the default is to always set it to 1. stats: Get the batch statistics from ``state[stats]`` and display it along with the progress bar. The statistics dictionary has the names of the statistics as keys and the statistics as values. **kwargs: Keyword arguments to be passed to `tqdm`_ class. .. _tqdm: https://github.com/tqdm/tqdm """ _n_items_so_far = "_pbar_n_items_so_far" def __init__( self, *, n_items: str = "n_items", stats: Optional[str] = None, tqdm_cls: Optional[Type[tqdm]] = None, **kwargs, ) -> None: if tqdm_cls is None: # pragma: no cover tqdm_cls = tqdm self._tqdm_cls = tqdm_cls self._n_items = n_items self._stats = stats self._kwargs = kwargs self._pbar: tqdm def attach_on(self, runner: Runner) -> None: runner.on(Event._PBAR_CREATED, self._create) runner.on(Event._PBAR_UPDATED, self._update) runner.on(Event._PBAR_CLOSED, self._close) def _create(self, state: dict) -> None: n_items_so_far = state.get(self._n_items_so_far, 0) self._pbar = self._tqdm_cls(state["batches"], initial=n_items_so_far, **self._kwargs) state[self._n_items_so_far] = n_items_so_far def _update(self, state: dict) -> None: if self._stats is not None: self._pbar.set_postfix(**state[self._stats]) n_items = state.get(self._n_items, 1) self._pbar.update(n_items) state[self._n_items_so_far] += n_items def _close(self, state: dict) -> None: self._pbar.close() state.pop(self._n_items_so_far)
[docs]class LambdaReducer(Attachment): """An attachment to compute a reduction over batches. This attachment gets the value of each batch and compute a reduction over them at the end of each epoch. Example: >>> from rnnr import Event, Runner >>> from rnnr.attachments import LambdaReducer >>> runner = Runner() >>> LambdaReducer('product', lambda x, y: x * y).attach_on(runner) >>> @runner.on(Event.BATCH) ... def on_batch(state): ... state['output'] = state['batch'] ... >>> runner.run([10, 20, 30]) >>> runner.state['product'] 6000 Args: name: Name of this attachment to be used as the key in the runner's state dict to store the reduction result. reduce_fn: Reduction function. It should accept two batch values and return their reduction result. value: Get the value of a batch from ``state[value]``. """ def __init__( self, name: str, reduce_fn: Callable[[Any, Any], Any], *, value: str = "output", ) -> None: self.name = name self._reduce_fn = reduce_fn self._value = value def attach_on(self, runner: Runner) -> None: runner.on(Event._REDUCER_RESET, self._reset) runner.on(Event._REDUCER_UPDATED, self._update) runner.on(Event._REDUCER_COMPUTED, self._compute) @property def _result(self) -> str: return f"_{self.name}_reducer_result" def _reset(self, state: dict) -> None: if self._result in state: # pragma: no cover warn( f"You may have multiple reducers with name={self.name!r}, so one will " "overwrite the other." ) state[self._result] = None def _update(self, state: dict) -> None: if state[self._result] is None: state[self._result] = state[self._value] else: state[self._result] = self._reduce_fn(state[self._result], state[self._value]) def _compute(self, state: dict) -> None: state[self.name] = state.pop(self._result)
[docs]class MeanReducer(LambdaReducer): """An attachment to compute a mean over batch statistics. This attachment gets the value from each batch and compute their mean at the end of every epoch. Example: >>> from rnnr import Event, Runner >>> from rnnr.attachments import MeanReducer >>> runner = Runner() >>> MeanReducer('mean').attach_on(runner) >>> @runner.on(Event.BATCH) ... def on_batch(state): ... state['output'] = state['batch'] ... >>> runner.run([1, 2, 3]) >>> runner.state['mean'] 2.0 Args: name: Name of this attachment to be used as the key in the runner's state dict to store the mean value. value: Get the value of a batch from ``state[value]``. size: Get the size of a batch from ``state[size]``. If the state has no such key, the size defaults to 1. The sum of all these batch sizes is the divisor when computing the mean. """ def __init__(self, name: str, *, value: str = "output", size: str = "size",) -> None: super().__init__(name, lambda x, y: x + y, value=value) self._size = size @property def _total_size(self) -> str: return f"_{self.name}_reducer_total_size" def _reset(self, state: dict) -> None: super()._reset(state) state[self._total_size] = 0 def _update(self, state: dict) -> None: super()._update(state) state[self._total_size] += state.get(self._size, 1) def _compute(self, state: dict) -> None: super()._compute(state) state[self.name] /= state.pop(self._total_size)
[docs]class SumReducer(LambdaReducer): # pragma: no cover """An attachment to compute a sum over batch statistics. This attachment gets the value from each batch and compute their sum at the end of every epoch. Args: name: Name of this attachment to be used as the key in the runner's state dict to store the mean value. value: Get the value of a batch from ``state[value]``. """ def __init__(self, name: str, *, value: str = "output") -> None: super().__init__(name, lambda x, y: x + y, value=value)