Creating Warp Functions

Vanguard provides a limited number of pre-existing warp functions, and so users may need to create their own.

Base Warp Function

All warp functions should subclass this WarpFunction class.

class vanguard.warps.basefunction.MultitaskWarpFunction(*warps)[source]

Module for multitask warp functions.

It is expected that the warps will be applied element-wise to vectors, with possibly a different warp for each dimension.

1d warps (subclasses of py:class:~vanguard.warps.warpfunctions.WarpFunction) for each dimension are just passed to this module’s constructor.

Example:
>>> # New warp functions should inherit from the WarpFunction class:
>>> class AddTwo(WarpFunction):
...     def forward(self, y):
...         return y + 2
...
...     def inverse(self, x):
...          return x - 2
...
...     def deriv(self, y):
...         return 1
>>>
>>> # Warp functions can be composed together:
>>> add_two = AddTwo()
>>> add_two.forward(torch.Tensor([0]))
tensor([2.])
>>> add_four = AddTwo() @ AddTwo()
>>> add_four.forward(torch.Tensor([0]))
tensor([4.])
>>> multitask_warp = MultitaskWarpFunction(add_two, add_four)
>>> multitask_warp.forward(torch.Tensor([[0, 1], [-2, -3]]))
tensor([[2., 5.],
        [0., 1.]])
Parameters:

warps (WarpFunction)

__init__(*warps)[source]

Initialise self.

Parameters:

warps (WarpFunction) – The warp functions to be applied.

compose(other)[source]

Compose with another warp function.

Parameters:

other (WarpFunction) – The other warp function.

Return type:

MultitaskWarpFunction

Returns:

A new MultitaskWarpFunction instance with composed functions task-wise.

Note

For convenience, it is often easier to use the @ operator in place of compose().

Example:
>>> warp_1, warp_2 = MultitaskWarpFunction(), MultitaskWarpFunction()
>>>
>>> # This will be the equivalent of warp_1(warp_2(...)) task-wise.
>>> composed_warp = warp_1 @ warp_2
compose_with_self(n)[source]

Repeatedly compose a warp function with itself.

Parameters:

n (int) – The number of times to compose.

Return type:

MultitaskWarpFunction

Returns:

A new WarpFunction instance with composed functions.

Raises:

ValueError – If n is negative.

Note

When n == 0, this method will return the identity warp.

Warning

The warp functions are not copied before composition, meaning that each component of the returned warp function will be the same object. When applied to a controller class with the SetWarp decorator, the warp function (and its components) will be copied and this will no longer be an issue.

deriv(y)[source]

Return the derivative of the warp function at a point, y.

Parameters:

y (Tensor) – An input tensor.

Return type:

Tensor

Returns:

A tensor of same shape as y, the warp function’s gradient at y.

forward(y)[source]

Pass an input tensor through the warp function.

Parameters:

y (Tensor) – A stack of input tensors.

Return type:

Tensor

Returns:

A stack of tensors in the same shape as stack_of_y.

inverse(x)[source]

Return the inverse of the warp function at a point, x.

Parameters:

x (Tensor) – An input tensor.

Return type:

Tensor

Returns:

A tensor of same shape as x, the warp function’s inverse at x.

property num_tasks: int

Return the number of tasks this warp function operates on.

class vanguard.warps.basefunction.WarpFunction[source]

Base module for warp functions.

Subclasses must implement the forward() and inverse() methods. Optionally, the deriv() method can be implemented. If not, then it defaults to autograd, which is significantly slower.

Example:
>>> # New warp functions should inherit from the WarpFunction class:
>>> class AddTwo(WarpFunction):
...     def forward(self, y):
...         return y + 2
...
...     def inverse(self, x):
...          return x - 2
...
...     def deriv(self, y):
...         return 1
>>>
>>> # Warp functions can be composed together:
>>> add_four = AddTwo() @ AddTwo()
>>> add_four.forward(torch.Tensor([0]))
tensor([4.])
>>>
>>> # You can also compose copies of the same function:
>>> add_ten = AddTwo() @ 5
>>> add_ten.forward(torch.Tensor([0]))
tensor([10.])
property components: list[WarpFunction]

Get the components of the composition.

compose(other)[source]

Compose with another warp function.

Parameters:

other (WarpFunction) – The other warp function.

Return type:

WarpFunction

Returns:

A new WarpFunction instance with composed functions.

Note

For convenience, it is often easier to use the @ operator in place of compose().

Example:
>>> warp_1, warp_2 = WarpFunction(), WarpFunction()
>>>
>>> # This will be the equivalent of warp_1(warp_2(...))
>>> composed_warp = warp_1 @ warp_2
compose_with_self(n)[source]

Repeatedly compose a warp function with itself.

Parameters:

n (int) – The number of times to compose.

Return type:

WarpFunction

Returns:

A new WarpFunction instance with composed functions.

Raises:

ValueError – If n is negative.

Note

When n == 0, this method will return the identity warp.

Warning

The warp functions are not copied before composition, meaning that each component of the returned warp function will be the same object. When applied to a controller class with the SetWarp decorator, the warp function (and its components) will be copied and this will no longer be an issue.

copy()[source]

Return a copy guaranteed to have distinct parameters.

Return type:

Self

deriv(y)[source]

Return the derivative of the warp function at a point, y.

Parameters:

y (Tensor) – An input tensor.

Return type:

Tensor

Returns:

A tensor of same shape as y, the warp function’s gradient at y.

forward(y)[source]

Pass an input tensor through the warp function.

Parameters:

y (Tensor) – An input tensor.

Return type:

Tensor

Returns:

A tensor of same shape as y.

freeze()[source]

Return a copy of the warp with frozen parameters.

Note

We override torch.nn.Module.parameters to return an empty generator, so no amount of return_grad=True will make the parameters trainable again. This is the most reliable way of freezing parameters and keeping them frozen in downstream usage.

Return type:

Self

Returns:

A copy of self with parameters frozen.

inverse(x)[source]

Return the inverse of the warp function at a point, x.

Parameters:

x (Tensor) – An input tensor.

Return type:

Tensor

Returns:

A tensor of same shape as x, the warp function’s inverse at x.

Intermediate Warp Functions

Enable lazy initialisation in controllers.

Some warp functions require the input data passed to the controller class in order to initialise properly. In order to avoid needing to set this ahead of time, the require_controller_input() decorator will allow a warp function to be initialised lazily, only becoming a full warp function upon activation.

vanguard.warps.intermediate.is_intermediate_warp_function(func)[source]

Establish if a warp function is intermediate.

Parameters:

func (WarpFunction) – A warp function instance which may be intermediate.

Return type:

bool

Returns:

True, if the warp function is intermediate.

vanguard.warps.intermediate.require_controller_input(cache_name)[source]

Force a warp function to wrap lazily, so that it may take controller class input.

Parameters:

cache_name (str) – The name of the class attribute which will hold the input parameters.

Example:
>>> import torch
>>> from vanguard.warps.warpfunctions import AffineWarpFunction
>>>
>>> @require_controller_input("controller_inputs")
... class GaussianScaledAffineWarpFunction(AffineWarpFunction):
...     '''
...     Scale inputs by the mean and standard deviation
...     of the training data.
...     '''
...     def __init__(self):
...         train_y = self.controller_inputs["train_y"]
...         mu, sigma = train_y.mean().item(), train_y.std().item()
...         super().__init__(1/sigma, -mu/sigma)
>>>
>>> warp_function = GaussianScaledAffineWarpFunction()
>>> warp_function(torch.Tensor([1]))  
Traceback (most recent call last):
    ...
AttributeError: ...
>>> warp_function.activate(train_y=torch.as_tensor([0.0, 1.0, 2.0, 3.0, 4.0]))
>>> warp_function(torch.as_tensor([1.0])).detach().cpu()
tensor([[-0.6325]])
Return type:

Callable[[type[WarpFunction]], type[WarpFunction]]

Note

The SetWarp decorator will call activate() on the user’s behalf, so in the majority of cases one should not worry about this step. Only key word arguments can be passed to activate().

Warning

Despite best efforts, failing to activate an intermediate warp function before usage can return opaque error messages, and so checking for failed activation should be a priority when debugging any sort of error surrounding usage.