surkit.losses package
Submodules
surkit.losses.losses_jax module
surkit.losses.losses_oneflow module
surkit.losses.losses_pytorch module
- surkit.losses.losses_pytorch.elbo(x, y, samples, net)[source]
Calculate negative evidence lower bound for bayes nn variational inference
- Parameters:
x (tensor) – input
y (tensor) – target
samples (int) – number of nn samples
net (module) – bayes nn
- Returns:
loss, negative elbo
- Return type:
tensor
- surkit.losses.losses_pytorch.get(loss, reduction='mean')[source]
Return a loss function based on the given string, or return the given callable loss function
- Parameters:
loss (str | function) – name of a loss function or a callable pytorch loss function
reduction (str) – reduction applied to the output
- Returns:
a callable loss function
- Return type:
function