#!/usr/bin/env python
# -*- coding:UTF-8 -*-
import pickle
import flax
# import flax.training.checkpoints
import flax.linen as nn
import flax.training.train_state
import jax
import jax.numpy as jnp
Module = nn.Module
Tensor = jnp.ndarray
[docs]def is_tensor(obj):
return isinstance(obj, jnp.ndarray)
[docs]def zeros(shape, dtype=None):
return jnp.zeros(shape, dtype=dtype)
[docs]def zeros_like(tensor):
return jnp.zeros_like(tensor)
[docs]def cat(tensor_list, dim=0):
return jnp.concatenate(tensor_list, axis=dim)
[docs]def unsqueeze(tensor, dim):
return jnp.expand_dims(tensor, dim)
[docs]def squeeze(tensor, dim):
return jnp.squeeze(tensor, dim)
[docs]@jax.jit
def forward(model, x, params=None):
if type(model) == flax.training.train_state.TrainState:
return model.apply_fn(model.params, x)
return model.apply(params, x)
[docs]def np_to_tensor(array):
return jnp.array(array)
# return jax.device_put(array)
[docs]@jax.jit
def sin(tensor):
return jnp.sin(tensor)
[docs]@jax.jit
def cos(tensor):
return jnp.cos(tensor)
[docs]@jax.jit
def tan(tensor):
return jnp.tan(tensor)
[docs]@jax.jit
def arcsin(tensor):
return jnp.arcsin(tensor)
[docs]@jax.jit
def arccos(tensor):
return jnp.arccos(tensor)
[docs]@jax.jit
def arctan(tensor):
return jnp.arctan(tensor)
[docs]@jax.jit
def sinh(tensor):
return jnp.sinh(tensor)
[docs]@jax.jit
def cosh(tensor):
return jnp.cosh(tensor)
[docs]@jax.jit
def tanh(tensor):
return jnp.tanh(tensor)
[docs]@jax.jit
def arcsinh(tensor):
return jnp.arcsinh(tensor)
[docs]@jax.jit
def arccosh(tensor):
return jnp.arccosh(tensor)
[docs]@jax.jit
def arctanh(tensor):
return jnp.arctanh(tensor)
[docs]@jax.jit
def power(tensor, exponent):
return jnp.power(tensor, exponent)
[docs]@jax.jit
def exp(tensor):
return jnp.exp(tensor)
[docs]@jax.jit
def log(tensor):
return jnp.log(tensor)
[docs]@jax.jit
def log2(tensor):
return jnp.log2(tensor)
[docs]@jax.jit
def log10(tensor):
return jnp.log10(tensor)
[docs]@jax.jit
def sqrt(tensor):
return jnp.sqrt(tensor)
[docs]@jax.jit
def grad(state, x, ind_x, ind_y):
"""
Calculate dy/dx.
Args:
y (tensor):
x (tensor):
Returns:
dy/dx
"""
return jax.grad(lambda x : jnp.sum(state.apply_fn(state.params, x)[:, ind_y]), 1)(x)[:, ind_x]
[docs]def save(state, path):
state_dict=flax.serialization.to_state_dict(state)
pickle.dump(state_dict,open(path,"wb"))
[docs]def load(state, path):
pkl_file = pickle.load(open(path, "rb"))
state = flax.serialization.from_state_dict(target=state, state=pkl_file)
return state