[docs]def__init__(self,warp:WarpFunction,*args:Any,**kwargs:Any)->None:""" :param warp`: The warp to be used to define the distribution. """super().__init__(*args,**kwargs)self.warp=warp
[docs]deflog_prob(self,value:torch.Tensor)->torch.Tensor:""" Calculate the log-probability of the values under the warped Gaussian distribution. :param value: Shape should be compatible with the distributions shape. :returns: The log probability of the values. """gaussian=super().log_prob(self.warp(value))jacobian=torch.log(self.warp.deriv(value).abs())returngaussian+jacobian
[docs]defsample(self,*args:Any,**kwargs:Any):""" Sample from the distribution. """gaussian_samples=super().sample(*args,**kwargs)returnself.warp.inverse(gaussian_samples)
[docs]@classmethoddeffrom_data(cls,warp:WarpFunction,samples:Union[torch.Tensor,numpy.typing.NDArray[np.floating]],optimiser:type[torch.optim.Optimizer]=torch.optim.Adam,n_iterations:int=100,lr:float=0.001,)->Self:""" Fit a warped Gaussian distribution to the given data using the supplied warp. The mean and variance will be optimised along with the free parameters of the warp. :param warp: The warp to use. :param samples: (n_samples, ...) The data to fit. :param optimiser: A subclass of :class:`torch.optim.Optimizer` used to tune the parameters. :param n_iterations: The number of optimisation iterations. :param lr: The learning rate for optimisation. :returns: A fit distribution. """t_samples=torch.as_tensor(samples,dtype=BaseGPController.get_default_tensor_dtype())optim=optimiser(params=[{"params":warp.parameters(),"lr":lr}])# pyright: ignore [reportCallIssue]foriinrange(n_iterations):loss=-cls._mle_log_prob_parametrised_with_warp_parameters(warp,t_samples)loss.backward(retain_graph=i<n_iterations-1)optim.step()w_samples=warp(t_samples)loc=w_samples.mean(dim=0).detach()# pyright: ignore [reportCallIssue]scale=w_samples.std(dim=0).detach()+1e-4distribution=cls(warp,loc=loc,scale=scale)returndistribution
@staticmethoddef_mle_log_prob_parametrised_with_warp_parameters(warp:WarpFunction,data:torch.Tensor)->torch.Tensor:""" Compute the log probability of the data under the warped Gaussian. This is done using the optimal MLEs for the Gaussian mean and variance parameters, leaving only a function of the warp parameters. """w_data=warp(data)loc=w_data.mean(dim=0).detach()scale=w_data.std(dim=0).detach()+1e-4gaussian_log_prob=(-((w_data-loc)**2)/(2*scale**2)-torch.log(scale)# pyright: ignore [reportOperatorIssue]).sum()log_jacobian=torch.log(warp.deriv(data).abs()).sum()returngaussian_log_prob+log_jacobian