#!/usr/bin/env python
# -*- coding:UTF-8 -*-
import torch
from torch import autograd
if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
Module = torch.nn.Module
Tensor = torch.Tensor
[docs]def is_tensor(obj):
return torch.is_tensor(obj)
[docs]def zeros(shape, dtype=None):
return torch.zeros(shape, dtype=dtype, requires_grad=True)
[docs]def zeros_like(tensor):
return torch.zeros_like(tensor, requires_grad=True)
[docs]def cat(tensor_list, dim=0):
return torch.cat(tensor_list, dim=dim)
[docs]def unsqueeze(tensor, dim):
return torch.unsqueeze(tensor, dim)
[docs]def squeeze(tensor, dim):
return torch.squeeze(tensor, dim)
[docs]def forward(model, x):
return model(x)
[docs]def np_to_tensor(array, requires_grad=False):
return torch.tensor(array, requires_grad=requires_grad, dtype=torch.float32)
[docs]def sin(tensor):
return torch.sin(tensor)
[docs]def cos(tensor):
return torch.cos(tensor)
[docs]def tan(tensor):
return torch.tan(tensor)
[docs]def arcsin(tensor):
return torch.arcsin(tensor)
[docs]def arccos(tensor):
return torch.arccos(tensor)
[docs]def arctan(tensor):
return torch.arctan(tensor)
[docs]def sinh(tensor):
return torch.sinh(tensor)
[docs]def cosh(tensor):
return torch.cosh(tensor)
[docs]def tanh(tensor):
return torch.tanh(tensor)
[docs]def arcsinh(tensor):
return torch.arcsinh(tensor)
[docs]def arccosh(tensor):
return torch.arccosh(tensor)
[docs]def arctanh(tensor):
return torch.arctanh(tensor)
[docs]def power(tensor, exponent):
return torch.pow(tensor, exponent=exponent)
[docs]def exp(tensor):
return torch.exp(tensor)
[docs]def log(tensor):
return torch.log(tensor)
[docs]def log2(tensor):
return torch.log2(tensor)
[docs]def log10(tensor):
return torch.log10(tensor)
[docs]def sqrt(tensor):
return torch.sqrt(tensor)
[docs]def grad(y, x):
"""
Calculate dy/dx.
Args:
y (tensor):
x (tensor):
Returns:
dy/dx
"""
return autograd.grad(y, x, grad_outputs=torch.ones_like(x), create_graph=True)[0]
[docs]def save(model, path):
torch.save(model, path)
[docs]def load(path):
return torch.load(path)