Source code for surkit.losses.losses_pytorch

#!/usr/bin/env python
# -*- coding:UTF-8 -*-

import torch
import torch.nn as nn

__all__ = ['get', 'elbo']


[docs]def get(loss, reduction='mean'): """ Return a loss function based on the given string, or return the given callable loss function Args: loss (str | function): name of a loss function or a callable pytorch loss function reduction (str): reduction applied to the output Returns: function: a callable loss function """ loss_dict = { "mean absolute error": nn.L1Loss(reduction=reduction), "mae": nn.L1Loss(reduction=reduction), "mean squared error": nn.MSELoss(reduction=reduction), "mse": nn.MSELoss(reduction=reduction), "gaussiannll": nn.GaussianNLLLoss(reduction=reduction), } if isinstance(loss, str): return loss_dict[loss.lower()] if callable(loss): return loss
[docs]def elbo(x, y, samples, net): """ Calculate negative evidence lower bound for bayes nn variational inference Args: x (tensor): input y (tensor): target samples (int): number of nn samples net (module): bayes nn Returns: tensor: loss, negative elbo """ out = [None for _ in range(samples)] log_priors = torch.zeros(samples) log_posts = torch.zeros(samples) log_likes = torch.zeros(samples) for i in range(samples): out[i] = net(x).reshape(-1) # make predictions log_priors[i] = net.log_prior() # get log prior log_posts[i] = net.log_post() # get log variational posterior log_likes[i] = torch.distributions.Normal(out[i], net.noise_tol).log_prob( y.reshape(-1)).sum() # calculate the log likelihood # calculate monte carlo estimate of prior posterior and likelihood log_prior = log_priors.mean() log_post = log_posts.mean() log_like = log_likes.mean() # calculate the negative elbo (which is our loss function) loss = log_post - log_prior - log_like return loss