[docs]classMetricsTracker:""" Tracks metrics for a controller. .. warning:: Passing ``lambda`` functions is discouraged, as each ``lambda`` function will overwrite the previous. Instead, create distinct functions for your metric. :Example: >>> from vanguard.base.metrics import loss >>> >>> tracker = MetricsTracker(loss) >>> for loss_value in range(5): ... tracker.run_metrics(float(loss_value), controller=None) >>> with tracker.print_metrics(): ... for loss_value in range(5): ... tracker.run_metrics(float(loss_value), controller=None) iteration: 6, loss: 0.0 iteration: 7, loss: 1.0 iteration: 8, loss: 2.0 iteration: 9, loss: 3.0 iteration: 10, loss: 4.0 >>> with tracker.print_metrics(every=2): ... for loss_value in range(5): ... tracker.run_metrics(float(loss_value), controller=None) iteration: 12, loss: 1.0 iteration: 14, loss: 3.0 >>> with tracker.print_metrics(every=2, format_string="loss: {loss}"): ... for loss_value in range(5): ... tracker.run_metrics(float(loss_value), controller=None) loss: 1.0 loss: 3.0 """
[docs]def__init__(self,*metrics:Callable,)->None:""" Initialise self. A metric takes the form of a function of (loss, controller) -> real number. The simplest and most obvious metric simply returns the loss value, e.g. see the function `vanguard.base.metrics.loss`. Other common examples might extract parameters from the controller, e.g. a kernel's lengthscale, and return that. """self._metric_outputs={}self._iteration=0self.add_metrics(*metrics)self._every=float("nan")self._print_format_string=""self._counter=itertools.count(1)
def__getitem__(self,item):return{metric.__name__:output_values[item]formetric,output_valuesinself._metric_outputs.items()}def__len__(self):returnlen(self._metric_outputs)@propertydef_default_format_string(self)->str:"""Get the default format string used for printing."""format_string_components=["iteration: {iteration}"]formetricinself._metric_outputs:component=f"{metric.__name__}: {{{metric.__name__}}}"format_string_components.append(component)format_string=", ".join(format_string_components)returnformat_string
[docs]defreset(self)->None:"""Remove the stored metrics outputs and reset the iteration count."""self._metric_outputs={metric:[]formetricinself._metric_outputs}self._iteration=0
[docs]defadd_metrics(self,*metrics:Callable,)->None:""" Add metrics to the tracker. See __init__ docstring for definition of metrics. """formetricinmetrics:self._metric_outputs[metric]=[float("nan")]*self._iteration
[docs]defrun_metrics(self,loss_value:float,controller:Optional["vanguard.base.basecontroller.BaseGPController"],**additional_info,)->None:""" Register the components of an iteration. Each metric in the tracker will be run on the arguments of this method, and then stored for future reference. Iterations do not need to be passed. Additional information passed as keyword arguments can be displayed to the user when combined with :meth:`print_metrics` and a customised format string. :param loss_value: The loss. :param controller: The controller instance. """self._iteration+=1formetric,outputsinself._metric_outputs.items():metric_value=metric(loss_value,controller)outputs.append(metric_value)additional_info[metric.__name__]=metric_valueadditional_info["iteration"]=self._iterationifnext(self._counter)%self._every==0:try:output_string=self._print_format_string.format(**additional_info)exceptKeyErroraserror:output_string=f"{loss_value} (Could not find values for {repr(error.args[0])})"print(output_string)
[docs]@contextmanagerdefprint_metrics(self,every:int=1,format_string:Optional[str]=None,)->Iterator[None]:""" Temporarily enabling printing the metrics within a context manager. :param every: How often to print the output. Does not start on the first iteration. Defaults to 1 (print always). :param format_string: Used to format the output. Keys passed here must match with information passed to the :meth:`run_metrics` method. If None, all metrics will be printed. """ifformat_stringisNone:format_string=self._default_format_stringself._every=everyself._print_format_string=format_stringself._counter=itertools.count(1)try:yieldfinally:self._print_format_string=""self._every=float("nan")
[docs]defloss(loss_value:float,controller:Optional["vanguard.base.basecontroller.BaseGPController"])->float:# pylint: disable=unused-argument"""Return the loss value."""returnloss_value