Creating Warp Functions¶
Vanguard provides a limited number of pre-existing warp functions, and so users may need to create their own.
Base Warp Function¶
All warp functions should subclass this WarpFunction class.
- class vanguard.warps.basefunction.MultitaskWarpFunction(*warps)[source]¶
Module for multitask warp functions.
It is expected that the warps will be applied element-wise to vectors, with possibly a different warp for each dimension.
1d warps (subclasses of py:class:~vanguard.warps.warpfunctions.WarpFunction) for each dimension are just passed to this module’s constructor.
- Example:
>>> # New warp functions should inherit from the WarpFunction class: >>> class AddTwo(WarpFunction): ... def forward(self, y): ... return y + 2 ... ... def inverse(self, x): ... return x - 2 ... ... def deriv(self, y): ... return 1 >>> >>> # Warp functions can be composed together: >>> add_two = AddTwo() >>> add_two.forward(torch.Tensor([0])) tensor([2.]) >>> add_four = AddTwo() @ AddTwo() >>> add_four.forward(torch.Tensor([0])) tensor([4.]) >>> multitask_warp = MultitaskWarpFunction(add_two, add_four) >>> multitask_warp.forward(torch.Tensor([[0, 1], [-2, -3]])) tensor([[2., 5.], [0., 1.]])
- Parameters:
warps (
WarpFunction)
- __init__(*warps)[source]¶
Initialise self.
- Parameters:
warps (
WarpFunction) – The warp functions to be applied.
- compose(other)[source]¶
Compose with another warp function.
- Parameters:
other (
WarpFunction) – The other warp function.- Return type:
- Returns:
A new MultitaskWarpFunction instance with composed functions task-wise.
Note
For convenience, it is often easier to use the
@operator in place ofcompose().- Example:
>>> warp_1, warp_2 = MultitaskWarpFunction(), MultitaskWarpFunction() >>> >>> # This will be the equivalent of warp_1(warp_2(...)) task-wise. >>> composed_warp = warp_1 @ warp_2
- compose_with_self(n)[source]¶
Repeatedly compose a warp function with itself.
- Parameters:
n (
int) – The number of times to compose.- Return type:
- Returns:
A new WarpFunction instance with composed functions.
- Raises:
ValueError – If
nis negative.
Note
When
n == 0, this method will return the identity warp.Warning
The warp functions are not copied before composition, meaning that each component of the returned warp function will be the same object. When applied to a controller class with the
SetWarpdecorator, the warp function (and its components) will be copied and this will no longer be an issue.
- class vanguard.warps.basefunction.WarpFunction[source]¶
Base module for warp functions.
Subclasses must implement the
forward()andinverse()methods. Optionally, thederiv()method can be implemented. If not, then it defaults to autograd, which is significantly slower.- Example:
>>> # New warp functions should inherit from the WarpFunction class: >>> class AddTwo(WarpFunction): ... def forward(self, y): ... return y + 2 ... ... def inverse(self, x): ... return x - 2 ... ... def deriv(self, y): ... return 1 >>> >>> # Warp functions can be composed together: >>> add_four = AddTwo() @ AddTwo() >>> add_four.forward(torch.Tensor([0])) tensor([4.]) >>> >>> # You can also compose copies of the same function: >>> add_ten = AddTwo() @ 5 >>> add_ten.forward(torch.Tensor([0])) tensor([10.])
- property components: list[WarpFunction]¶
Get the components of the composition.
- compose(other)[source]¶
Compose with another warp function.
- Parameters:
other (
WarpFunction) – The other warp function.- Return type:
- Returns:
A new WarpFunction instance with composed functions.
Note
For convenience, it is often easier to use the
@operator in place ofcompose().- Example:
>>> warp_1, warp_2 = WarpFunction(), WarpFunction() >>> >>> # This will be the equivalent of warp_1(warp_2(...)) >>> composed_warp = warp_1 @ warp_2
- compose_with_self(n)[source]¶
Repeatedly compose a warp function with itself.
- Parameters:
n (
int) – The number of times to compose.- Return type:
- Returns:
A new WarpFunction instance with composed functions.
- Raises:
ValueError – If
nis negative.
Note
When
n == 0, this method will return the identity warp.Warning
The warp functions are not copied before composition, meaning that each component of the returned warp function will be the same object. When applied to a controller class with the
SetWarpdecorator, the warp function (and its components) will be copied and this will no longer be an issue.
- freeze()[source]¶
Return a copy of the warp with frozen parameters.
Note
We override
torch.nn.Module.parametersto return an empty generator, so no amount ofreturn_grad=Truewill make the parameters trainable again. This is the most reliable way of freezing parameters and keeping them frozen in downstream usage.- Return type:
- Returns:
A copy of self with parameters frozen.
Intermediate Warp Functions¶
Enable lazy initialisation in controllers.
Some warp functions require the input data passed to the controller class in
order to initialise properly. In order to avoid needing to set this ahead of time,
the require_controller_input() decorator will allow a warp function to be
initialised lazily, only becoming a full warp function upon activation.
- vanguard.warps.intermediate.is_intermediate_warp_function(func)[source]¶
Establish if a warp function is intermediate.
- Parameters:
func (
WarpFunction) – A warp function instance which may be intermediate.- Return type:
- Returns:
True, if the warp function is intermediate.
- vanguard.warps.intermediate.require_controller_input(cache_name)[source]¶
Force a warp function to wrap lazily, so that it may take controller class input.
- Parameters:
cache_name (
str) – The name of the class attribute which will hold the input parameters.- Example:
>>> import torch >>> from vanguard.warps.warpfunctions import AffineWarpFunction >>> >>> @require_controller_input("controller_inputs") ... class GaussianScaledAffineWarpFunction(AffineWarpFunction): ... ''' ... Scale inputs by the mean and standard deviation ... of the training data. ... ''' ... def __init__(self): ... train_y = self.controller_inputs["train_y"] ... mu, sigma = train_y.mean().item(), train_y.std().item() ... super().__init__(1/sigma, -mu/sigma) >>> >>> warp_function = GaussianScaledAffineWarpFunction() >>> warp_function(torch.Tensor([1])) Traceback (most recent call last): ... AttributeError: ... >>> warp_function.activate(train_y=torch.as_tensor([0.0, 1.0, 2.0, 3.0, 4.0])) >>> warp_function(torch.as_tensor([1.0])).detach().cpu() tensor([[-0.6325]])
- Return type:
Callable[[type[WarpFunction]],type[WarpFunction]]
Note
The
SetWarpdecorator will callactivate()on the user’s behalf, so in the majority of cases one should not worry about this step. Only key word arguments can be passed toactivate().Warning
Despite best efforts, failing to activate an intermediate warp function before usage can return opaque error messages, and so checking for failed activation should be a priority when debugging any sort of error surrounding usage.