{ "cells": [ { "cell_type": "raw", "id": "black-nitrogen", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Creating a Decorator\n", "====================" ] }, { "cell_type": "code", "execution_count": null, "id": "a3dd9d20df386d16", "metadata": {}, "outputs": [], "source": [ "# © Crown Copyright GCHQ\n", "#\n", "# Licensed under the GNU General Public License, version 3 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.gnu.org/licenses/gpl-3.0.en.html\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "raw", "id": "15be1965b0e230bd", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Although Vanguard has a number of out-of-the-box decorators to allow for advanced Gaussian processes techniques, one\n", "might need something more specialist. Luckily, decorators in Vanguard are designed to be as extensible as possible.\n", "This walkthrough will explain how to create a new decorator to shuffle the input data passed to a controller." ] }, { "cell_type": "code", "execution_count": null, "id": "decreased-monaco", "metadata": {}, "outputs": [], "source": [ "from collections.abc import Iterable\n", "from typing import Any, Callable, TypeVar, Union\n", "\n", "import numpy as np\n", "import torch\n", "from gpytorch.kernels import RBFKernel\n", "from gpytorch.likelihoods import FixedNoiseGaussianLikelihood\n", "from gpytorch.means import ConstantMean\n", "from gpytorch.mlls import ExactMarginalLogLikelihood\n", "from numpy.typing import ArrayLike, NDArray\n", "\n", "from vanguard.base import GPController\n", "from vanguard.decoratorutils import Decorator, process_args, wraps_class\n", "from vanguard.optimise import SmartOptimiser\n", "from vanguard.uncertainty import GaussianUncertaintyGPController" ] }, { "cell_type": "code", "execution_count": null, "id": "9534690a-e184-4e19-b2eb-6b379a022e0b", "metadata": {}, "outputs": [], "source": [ "T = TypeVar(\"T\")\n", "SeedT = Union[ArrayLike, np.random.BitGenerator, None]" ] }, { "cell_type": "raw", "id": "accomplished-sandwich", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Recapping Python Decorators\n", "---------------------------\n", "\n", "In Python, a decorator is a function which returns another function. Consider the following function:" ] }, { "cell_type": "code", "execution_count": null, "id": "infinite-distribution", "metadata": {}, "outputs": [], "source": [ "def is_py_file(file_path: str) -> bool:\n", " \"\"\"\n", " Determine if a path points to a Python file.\n", "\n", " :param file_path: Path to query\n", " :return: :data:`True` if ``file_path`` has a Python extension\n", " \"\"\"\n", " return str(file_path).endswith(\".py\")\n", "\n", "\n", "is_py_file(\"foo.py\"), is_py_file(\"bar.js\")" ] }, { "cell_type": "raw", "id": "therapeutic-scope", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "This function will fail if it passed anything other than a string, but it will raise an :class:`AttributeError`:" ] }, { "cell_type": "raw", "id": "pretty-healthcare", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "It would be preferable for the function to raise a :class:`TypeError`, which could be achieved with a simple\n", "``try``/``except`` block, but it's possible that ``is_py_file`` cannot be edited, or that this change needs to be made\n", "to multiple functions which could each take a different number of inputs. Instead, a decorator can be used to check that\n", "``file_path`` is a string, without mutating ``is_py_file``:" ] }, { "cell_type": "code", "execution_count": null, "id": "speaking-employment", "metadata": {}, "outputs": [], "source": [ "CallableStringT = TypeVar(\"CallableStringT\", bound=Callable[[str, ...], Any])\n", "\n", "\n", "def check_string(func: CallableStringT) -> CallableStringT:\n", " \"\"\"Check that the input is a string.\"\"\"\n", "\n", " def inner_function(*args: str) -> Any:\n", " for arg in args:\n", " if not isinstance(arg, str):\n", " raise TypeError(\"All inputs must be strings.\")\n", " return func(*args)\n", "\n", " return inner_function" ] }, { "cell_type": "raw", "id": "tribal-miniature", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "The decorator can then be applied in the following fashion:" ] }, { "cell_type": "code", "execution_count": null, "id": "brilliant-arctic", "metadata": {}, "outputs": [], "source": [ "@check_string # equivalent to: is_py_file = check_string(is_py_file)\n", "def is_py_file(file_path: str) -> bool:\n", " \"\"\"\n", " Determine if a path points to a Python file.\n", "\n", " :param file_path: Path to query\n", " :return: :data:`True` if ``file_path`` has a Python extension\n", " \"\"\"\n", " return str(file_path).endswith(\".py\")" ] }, { "cell_type": "raw", "id": "alone-boring", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Sometimes it is helpful for a decorator to accept some arguments to adjust its behaviour. In this case, the function in\n", "question just needs to return a *decorator*:" ] }, { "cell_type": "code", "execution_count": null, "id": "duplicate-mailing", "metadata": {}, "outputs": [], "source": [ "CallableTT = TypeVar(\"CallableTT\", bound=Callable[[T, ...], Any])\n", "\n", "\n", "def check_type(t: type[T]) -> Callable[[CallableTT], CallableTT]:\n", " \"\"\"Check that the input is of a certain type.\"\"\"\n", "\n", " def decorator(func: CallableTT) -> CallableTT:\n", " def inner_function(*args: T) -> Any:\n", " for arg in args:\n", " if not isinstance(arg, t):\n", " raise TypeError(f\"All inputs must be of type {t}.\")\n", " return func(*args)\n", "\n", " return inner_function\n", "\n", " return decorator" ] }, { "cell_type": "code", "execution_count": null, "id": "identified-genre", "metadata": {}, "outputs": [], "source": [ "@check_type(str) # equivalent to: is_py_file = check_type(str)(is_py_file)\n", "def is_py_file(file_path: str) -> bool:\n", " \"\"\"\n", " Determine if a path points to a Python file.\n", "\n", " :param file_path: Path to query\n", " :return: :data:`True` if ``file_path`` has a Python extension\n", " \"\"\"\n", " return str(file_path).endswith(\".py\")" ] }, { "cell_type": "raw", "id": "progressive-belgium", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Decorators in Vanguard\n", "----------------------\n", "\n", "All decorators should inherit from :class:`~vanguard.decoratorutils.basedecorator.Decorator` in order to ensure\n", "consistency, and to make use of the in-built features. The :class:`~vanguard.decoratorutils.basedecorator.Decorator`\n", "requires a ``framework_class`` argument, which should be an (uninstantiated) subclass of\n", ":class:`~vanguard.base.gpcontroller.GPController`. Any new features added by a decorator should be relative to its\n", "framework class. If the decorator is applied to a different :class:`~vanguard.base.gpcontroller.GPController` subclass,\n", "then checks will be run to ensure that this class does not define any new methods, nor overwrite any existing ones. The\n", "reason for this is to avoid any potential issues with any extended features, forcing the user to explicitly ignore such\n", "problems if they are certain it will not affect the validity of the decorator.\n", "\n", "Vanguard decorators also take a ``required_decorators`` parameter (usually a set but can be any iterable), which\n", "references a number of uninstantiated decorator classes which must be applied before a particular decorator can be\n", "applied. This allows for maximum separation between functionality, and the majority of decorators do not have any\n", "requirements." ] }, { "cell_type": "raw", "id": "trying-childhood", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Creating a Decorator: Shuffling Inputs\n", "--------------------------------------\n", "\n", "Consider the following function:" ] }, { "cell_type": "code", "execution_count": null, "id": "sharp-parliament", "metadata": {}, "outputs": [], "source": [ "def consistent_shuffle(*arrays: NDArray[float], seed: SeedT = None) -> list[NDArray[float]]:\n", " \"\"\"Shuffle all arrays into the same order, to maintain consistency.\"\"\"\n", " rng = np.random.RandomState(seed=seed)\n", " indices = np.arange(len(arrays[0]))\n", " rng.shuffle(indices)\n", "\n", " shuffled_arrays = [array[indices] for array in arrays]\n", " return shuffled_arrays" ] }, { "cell_type": "code", "execution_count": null, "id": "crude-sellers", "metadata": {}, "outputs": [], "source": [ "x = np.array([1, 2, 3, 4, 5])\n", "y = np.array([1, 4, 9, 16, 25])" ] }, { "cell_type": "code", "execution_count": null, "id": "signal-increase", "metadata": {}, "outputs": [], "source": [ "consistent_shuffle(x, y, seed=1)" ] }, { "cell_type": "raw", "id": "legitimate-allen", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "This function will be applied to the ``train_x``, ``train_y`` and ``y_std`` inputs to the newly decorated class. In\n", "order to work with these parameters, the :func:`~vanguard.decoratorutils.wrapping.process_args` function comes in handy:" ] }, { "cell_type": "code", "execution_count": null, "id": "spectacular-words", "metadata": {}, "outputs": [], "source": [ "process_args(\n", " GPController.__init__,\n", " None,\n", " x,\n", " y,\n", " RBFKernel,\n", " mean_class=ConstantMean,\n", " y_std=0.1,\n", " likelihood_class=FixedNoiseGaussianLikelihood,\n", " marginal_log_likelihood_class=ExactMarginalLogLikelihood,\n", " optimiser_class=torch.optim.Adam,\n", " smart_optimiser_class=SmartOptimiser,\n", ")" ] }, { "cell_type": "raw", "id": "significant-olympus", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "This returns a parameter mapping, essentially ensuring that all parameters are treated as keyword arguments, even the\n", "ones which were passed as positional arguments. This function can be used in the\n", ":meth:`~vanguard.decoratorutils.basedecorator.Decorator._decorate_class` method of the decorator to intercept arguments\n", "passed to the decorated :class:" ] }, { "cell_type": "code", "execution_count": null, "id": "dense-idaho", "metadata": {}, "outputs": [], "source": [ "class ShuffleDecorator(Decorator):\n", " \"\"\"Shuffles input data.\"\"\"\n", "\n", " def __init__(self, **kwargs: Any) -> None:\n", " super().__init__(framework_class=GPController, required_decorators={}, **kwargs)\n", "\n", " def _decorate_class(self, cls: type[T]) -> type[T]:\n", " class InnerClass(cls):\n", " \"\"\"An inner class.\"\"\"\n", "\n", " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", " all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs)\n", "\n", " old_train_x = all_parameters_as_kwargs.pop(\"train_x\")\n", " old_train_y = all_parameters_as_kwargs.pop(\"train_y\")\n", " old_y_std = all_parameters_as_kwargs.pop(\"y_std\") # pop to avoid duplication\n", "\n", " if isinstance(old_y_std, (float, int)):\n", " old_y_std = np.ones_like(old_train_x) * old_y_std\n", "\n", " new_train_x, new_train_y, new_y_std = consistent_shuffle(old_train_x, old_train_y, old_y_std)\n", "\n", " super().__init__(train_x=new_train_x, train_y=new_train_y, y_std=new_y_std, **all_parameters_as_kwargs)\n", "\n", " return InnerClass" ] }, { "cell_type": "raw", "id": "excessive-enforcement", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "There are a few things to note here:\n", "\n", "* A call to :meth:`~vanguard.decoratorutils.basedecorator.Decorator.verify_decorated_class` will be made in the\n", " ``__call__`` method to run checks for any new or overwritten methods in the decorated class. In special circumstances\n", " this can be ignored, although it is not recommended.\n", "* Since this code is using ``super()``, a value for ``self`` doesn't need to be passed to\n", " :func:`~vanguard.decoratorutils.wrapping.process_args`.\n", "* Parameters are \"popped\" rather than simply referenced in order to avoid forgetting to set them before passing them\n", " forward, and to avoid any duplication." ] }, { "cell_type": "raw", "id": "progressive-underground", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "The decorator can now be applied to a controller class in one of two ways. The latter is recommended for readability and\n", "extension." ] }, { "cell_type": "code", "execution_count": null, "id": "atlantic-conviction", "metadata": {}, "outputs": [], "source": [ "ShuffledGPController = ShuffleDecorator()(GPController)\n", "\n", "\n", "@ShuffleDecorator()\n", "class ShuffledGPController(GPController): # noqa: F811\n", " \"\"\"Shuffles inputs to the controller.\"\"\"\n", "\n", " pass" ] }, { "cell_type": "raw", "id": "uniform-sauce", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Class Wrapping\n", "--------------\n", "\n", "Although the new ``ShuffledGPController`` will now work as expected, there are some inconsistencies in the docstrings\n", "and the names. This is best observed using :func:`help`, but can be seen by inspecting the ``__name__`` and ``__doc__``\n", "attributes:" ] }, { "cell_type": "code", "execution_count": null, "id": "moving-johns", "metadata": {}, "outputs": [], "source": [ "print(ShuffledGPController.__name__)\n", "print(ShuffledGPController.__doc__)" ] }, { "cell_type": "raw", "id": "continuing-abortion", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "This can fixed by using the :func:`~vanguard.decoratorutils.wrapping.wraps_class` decorator, which behaves a lot like\n", "the :func:`functools.wraps`, only for classes:" ] }, { "cell_type": "code", "execution_count": null, "id": "accepting-search", "metadata": {}, "outputs": [], "source": [ "class ShuffleDecorator(Decorator):\n", " \"\"\"Shuffles input data.\"\"\"\n", "\n", " def __init__(self, **kwargs: Any) -> None:\n", " super().__init__(framework_class=GPController, required_decorators={}, **kwargs)\n", "\n", " def _decorate_class(self, cls: type[T]) -> type[T]:\n", " @wraps_class(cls)\n", " class InnerClass(cls):\n", " \"\"\"An inner class.\"\"\"\n", "\n", " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", " all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs)\n", "\n", " old_train_x = all_parameters_as_kwargs.pop(\"train_x\")\n", " old_train_y = all_parameters_as_kwargs.pop(\"train_y\")\n", " old_y_std = all_parameters_as_kwargs.pop(\"y_std\") # pop to avoid duplication\n", "\n", " if isinstance(old_y_std, (float, int)):\n", " old_y_std = np.ones_like(old_train_x) * old_y_std\n", "\n", " new_train_x, new_train_y, new_y_std = consistent_shuffle(old_train_x, old_train_y, old_y_std)\n", "\n", " super().__init__(train_x=new_train_x, train_y=new_train_y, y_std=new_y_std, **all_parameters_as_kwargs)\n", "\n", " return InnerClass" ] }, { "cell_type": "code", "execution_count": null, "id": "gothic-flush", "metadata": {}, "outputs": [], "source": [ "@ShuffleDecorator()\n", "class ShuffledGPController(GPController):\n", " \"\"\"Shuffles inputs to the controller.\"\"\"\n", "\n", " pass\n", "\n", "\n", "print(ShuffledGPController.__name__)\n", "print(ShuffledGPController.__doc__)" ] }, { "cell_type": "raw", "id": "crucial-lover", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ ".. warning::\n", " When a decorator is applied in-line with ``NewClass = Decorator()(OldClass)``, then the values ``NewClass.__name__``\n", " and ``NewClass.__doc__`` will correspond to ``OldClass.__name__`` and ``OldClass.__doc__`` respectively. This is\n", " often not expected behaviour, so should be done with care." ] }, { "cell_type": "raw", "id": "joined-secretary", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ ".. note::\n", "\n", " :meth:`~vanguard.decoratorutils.wrapping.wraps_class` will also take care of the names and docstrings of methods\n", " within the wrapped class." ] }, { "cell_type": "raw", "id": "cubic-hanging", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Decorator Parameters\n", "--------------------\n", "\n", "Sometimes it is necessary to implement additional arguments to allow a user to adjust the behaviour of the decorator.\n", "Since ``consistent_shuffle`` takes a ``seed`` parameter, it would be good to allow the decorator to make use of it:" ] }, { "cell_type": "code", "execution_count": null, "id": "random-diameter", "metadata": {}, "outputs": [], "source": [ "class ShuffleDecorator(Decorator):\n", " \"\"\"Shuffles input data.\"\"\"\n", "\n", " def __init__(self, seed: SeedT = None, **kwargs: Any) -> None:\n", " super().__init__(framework_class=GPController, required_decorators={}, **kwargs)\n", " self.seed = seed\n", "\n", " def _decorate_class(self, cls: type[T]) -> type[T]:\n", " seed = self.seed\n", "\n", " @wraps_class(cls)\n", " class InnerClass(cls):\n", " \"\"\"An inner class.\"\"\"\n", "\n", " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", " all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs)\n", "\n", " old_train_x = all_parameters_as_kwargs.pop(\"train_x\")\n", " old_train_y = all_parameters_as_kwargs.pop(\"train_y\")\n", " old_y_std = all_parameters_as_kwargs.pop(\"y_std\") # pop to avoid duplication\n", "\n", " if isinstance(old_y_std, (float, int)):\n", " old_y_std = np.ones_like(old_train_x) * old_y_std\n", "\n", " new_train_x, new_train_y, new_y_std = consistent_shuffle(old_train_x, old_train_y, old_y_std, seed=seed)\n", "\n", " super().__init__(train_x=new_train_x, train_y=new_train_y, y_std=new_y_std, **all_parameters_as_kwargs)\n", "\n", " return InnerClass" ] }, { "cell_type": "raw", "id": "acquired-poultry", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Note the defining of the intermediate value ``seed``, before entering ``InnerClass``. This is necessary because within\n", "the scope of ``InnerClass``, ``self`` no longer refers to the decorator instance." ] }, { "cell_type": "raw", "id": "adult-symposium", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Handling Different Controllers\n", "------------------------------\n", "\n", "A good decorator would ideally be re-usable for many different components. However, note what happens when\n", "``ShuffleDecorator`` is applied to the :class:`~vanguard.uncertainty.GaussianUncertaintyGPController` class." ] }, { "cell_type": "code", "execution_count": null, "id": "directed-luxembourg", "metadata": {}, "outputs": [], "source": [ "@ShuffleDecorator()\n", "class ShuffledGaussianUncertaintyGPController(GaussianUncertaintyGPController):\n", " \"\"\"Shuffles inputs to the controller.\"\"\"\n", "\n", " pass" ] }, { "cell_type": "raw", "id": "excessive-liechtenstein", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "To acknowledge that these methods are not expected to affect the behaviour of the decorator, they must be explicitly\n", "ignored:" ] }, { "cell_type": "code", "execution_count": null, "id": "hybrid-organic", "metadata": {}, "outputs": [], "source": [ "@ShuffleDecorator(\n", " ignore_methods={\n", " \"predict_at_point\",\n", " \"_get_additive_grad_noise\",\n", " \"_noise_transform\",\n", " \"_append_constant_to_infinite_generator\",\n", " }\n", ")\n", "class ShuffledGaussianUncertaintyGPController(GaussianUncertaintyGPController): # noqa: F811\n", " \"\"\"Shuffles inputs to the controller.\"\"\"\n", "\n", " pass" ] }, { "cell_type": "raw", "id": "97ef08145874c1e1", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ ".. note::\n", " It is possible to ignore all of these warnings by passing ``ignore_all=True`` to the decorator, although this is\n", " only recommended if one is certain that changing the decorated controller will not cause any new errors. Also, passing\n", " ``raise_instead=True`` will raise an error instead of emitting a warning, which will cause the program to stop\n", " completely." ] }, { "cell_type": "raw", "id": "controlled-brisbane", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "These methods are expected, but have been overwritten. Most of these methods are not expected to affect the decorator\n", "either, with the exception of ``__init__``. Although ``__init__`` could be ignored and the code would run,\n", ":class:`~vanguard.uncertainty.GaussianUncertaintyGPController` takes a ``train_x_std`` parameter which would need to be\n", "shuffled also. This would be a problem for a user of the decorator, and can be avoided by adding the ability to pass\n", "additional parameters to be shuffled:" ] }, { "cell_type": "code", "execution_count": null, "id": "functioning-ballot", "metadata": {}, "outputs": [], "source": [ "class ShuffleDecorator(Decorator):\n", " \"\"\"Shuffles input data.\"\"\"\n", "\n", " def __init__(self, seed: SeedT = None, additional_params_to_shuffle: Iterable[str] = (), **kwargs: Any) -> None:\n", " if additional_params_to_shuffle:\n", " kwargs[\"ignore_methods\"] = set(kwargs[\"ignore_methods\"]) | {\"__init__\"}\n", "\n", " super().__init__(framework_class=GPController, required_decorators={}, **kwargs)\n", "\n", " self.seed = seed\n", " self.params_to_shuffle = set.union({\"train_x\", \"train_y\", \"y_std\"}, set(additional_params_to_shuffle))\n", "\n", " def _decorate_class(self, cls: type[T]) -> type[T]:\n", " seed = self.seed\n", " params_to_shuffle = self.params_to_shuffle\n", "\n", " @wraps_class(cls)\n", " class InnerClass(cls):\n", " \"\"\"An inner class.\"\"\"\n", "\n", " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", " all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs)\n", "\n", " array_for_reference = all_parameters_as_kwargs[\"train_x\"]\n", "\n", " pre_shuffled_args = [all_parameters_as_kwargs.pop(param) for param in params_to_shuffle]\n", " pre_shuffled_args_as_arrays = [\n", " np.ones_like(array_for_reference) * arg if isinstance(arg, (float, int)) else arg\n", " for arg in pre_shuffled_args\n", " ]\n", " shuffled_args = consistent_shuffle(*pre_shuffled_args_as_arrays, seed=seed)\n", "\n", " shuffled_params_as_kwargs = dict(zip(params_to_shuffle, shuffled_args))\n", "\n", " super().__init__(**shuffled_params_as_kwargs, **all_parameters_as_kwargs)\n", "\n", " return InnerClass" ] }, { "cell_type": "raw", "id": "complimentary-juvenile", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "There are a few changes to unpack here; take note of the following:\n", "\n", "* If a user passes ``additional_params_to_shuffle``, then it can be assumed that they have properly checked\n", " ``__init__``, and it can be automatically ignored by the decorator.\n", "* The popping and array-converting of parameters now needs to be less constrained, and done more programmatically." ] }, { "cell_type": "code", "execution_count": null, "id": "characteristic-exemption", "metadata": {}, "outputs": [], "source": [ "ignore_methods = {\n", " \"_get_posterior_over_fuzzy_point_in_eval_mode\",\n", " \"__init__\",\n", " \"_sgd_round\",\n", " \"_process_x_std\",\n", " \"_set_requires_grad\",\n", " \"predict_at_point\",\n", " \"_get_additive_grad_noise\",\n", " \"_noise_transform\",\n", " \"_append_constant_to_infinite_generator\",\n", "}\n", "\n", "\n", "@ShuffleDecorator(seed=1, additional_params_to_shuffle={\"train_x_std\"}, ignore_methods=ignore_methods)\n", "class ShuffledGaussianUncertaintyGPController(GaussianUncertaintyGPController): # noqa: F811\n", " \"\"\"Shuffles inputs to the controller.\"\"\"\n", "\n", " pass" ] }, { "cell_type": "raw", "id": "distinct-mason", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "There are plenty of other ways in which ``ShuffleDecorator`` can be improved or made more extendable, but the concepts\n", "are more or less the same." ] }, { "cell_type": "code", "execution_count": null, "id": "convenient-insight", "metadata": {}, "outputs": [], "source": [ "train_x = np.array([1, 2, 3, 4, 5])\n", "train_x_std = np.array([0.01, 0.02, 0.03, 0.04, 0.05])\n", "train_y = np.array([1, 4, 9, 16, 25])\n", "y_std = np.array([0.02, 0.04, 0.06, 0.08, 0.1])" ] }, { "cell_type": "code", "execution_count": null, "id": "binding-journalist", "metadata": {}, "outputs": [], "source": [ "controller = ShuffledGaussianUncertaintyGPController(\n", " train_x,\n", " train_x_std,\n", " train_y,\n", " y_std,\n", " kernel_class=RBFKernel,\n", " mean_class=ConstantMean,\n", " likelihood_class=FixedNoiseGaussianLikelihood,\n", " marginal_log_likelihood_class=ExactMarginalLogLikelihood,\n", " optimiser_class=torch.optim.Adam,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "south-treasury", "metadata": {}, "outputs": [], "source": [ "print(controller.train_x.T)\n", "print(controller.train_x_std.T)\n", "print(controller.train_y.T)\n", "print(controller._y_variance.T)" ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }