Source code for surkit.backend.pytorch_bkd

#!/usr/bin/env python
# -*- coding:UTF-8 -*-

import torch
from torch import autograd

if torch.cuda.is_available():
    torch.set_default_tensor_type(torch.cuda.FloatTensor)

Module = torch.nn.Module
Tensor = torch.Tensor


[docs]def is_tensor(obj): return torch.is_tensor(obj)
[docs]def zeros(shape, dtype=None): return torch.zeros(shape, dtype=dtype, requires_grad=True)
[docs]def zeros_like(tensor): return torch.zeros_like(tensor, requires_grad=True)
[docs]def cat(tensor_list, dim=0): return torch.cat(tensor_list, dim=dim)
[docs]def unsqueeze(tensor, dim): return torch.unsqueeze(tensor, dim)
[docs]def squeeze(tensor, dim): return torch.squeeze(tensor, dim)
[docs]def forward(model, x): return model(x)
[docs]def np_to_tensor(array, requires_grad=False): return torch.tensor(array, requires_grad=requires_grad, dtype=torch.float32)
[docs]def sin(tensor): return torch.sin(tensor)
[docs]def cos(tensor): return torch.cos(tensor)
[docs]def tan(tensor): return torch.tan(tensor)
[docs]def arcsin(tensor): return torch.arcsin(tensor)
[docs]def arccos(tensor): return torch.arccos(tensor)
[docs]def arctan(tensor): return torch.arctan(tensor)
[docs]def sinh(tensor): return torch.sinh(tensor)
[docs]def cosh(tensor): return torch.cosh(tensor)
[docs]def tanh(tensor): return torch.tanh(tensor)
[docs]def arcsinh(tensor): return torch.arcsinh(tensor)
[docs]def arccosh(tensor): return torch.arccosh(tensor)
[docs]def arctanh(tensor): return torch.arctanh(tensor)
[docs]def power(tensor, exponent): return torch.pow(tensor, exponent=exponent)
[docs]def exp(tensor): return torch.exp(tensor)
[docs]def log(tensor): return torch.log(tensor)
[docs]def log2(tensor): return torch.log2(tensor)
[docs]def log10(tensor): return torch.log10(tensor)
[docs]def sqrt(tensor): return torch.sqrt(tensor)
[docs]def grad(y, x): """ Calculate dy/dx. Args: y (tensor): x (tensor): Returns: dy/dx """ return autograd.grad(y, x, grad_outputs=torch.ones_like(x), create_graph=True)[0]
[docs]def save(model, path): torch.save(model, path)
[docs]def load(path): return torch.load(path)