surkit.losses package

Submodules

surkit.losses.losses_jax module

surkit.losses.losses_oneflow module

surkit.losses.losses_pytorch module

surkit.losses.losses_pytorch.elbo(x, y, samples, net)[source]

Calculate negative evidence lower bound for bayes nn variational inference

Parameters:
  • x (tensor) – input

  • y (tensor) – target

  • samples (int) – number of nn samples

  • net (module) – bayes nn

Returns:

loss, negative elbo

Return type:

tensor

surkit.losses.losses_pytorch.get(loss, reduction='mean')[source]

Return a loss function based on the given string, or return the given callable loss function

Parameters:
  • loss (str | function) – name of a loss function or a callable pytorch loss function

  • reduction (str) – reduction applied to the output

Returns:

a callable loss function

Return type:

function

Module contents