Source code for surkit.backend.jax_bkd

#!/usr/bin/env python
# -*- coding:UTF-8 -*-
import pickle

import flax
# import flax.training.checkpoints
import flax.linen as nn
import flax.training.train_state
import jax
import jax.numpy as jnp

Module = nn.Module
Tensor = jnp.ndarray

[docs]def is_tensor(obj): return isinstance(obj, jnp.ndarray)
[docs]def zeros(shape, dtype=None): return jnp.zeros(shape, dtype=dtype)
[docs]def zeros_like(tensor): return jnp.zeros_like(tensor)
[docs]def cat(tensor_list, dim=0): return jnp.concatenate(tensor_list, axis=dim)
[docs]def unsqueeze(tensor, dim): return jnp.expand_dims(tensor, dim)
[docs]def squeeze(tensor, dim): return jnp.squeeze(tensor, dim)
[docs]@jax.jit def forward(model, x, params=None): if type(model) == flax.training.train_state.TrainState: return model.apply_fn(model.params, x) return model.apply(params, x)
[docs]def np_to_tensor(array): return jnp.array(array)
# return jax.device_put(array)
[docs]@jax.jit def sin(tensor): return jnp.sin(tensor)
[docs]@jax.jit def cos(tensor): return jnp.cos(tensor)
[docs]@jax.jit def tan(tensor): return jnp.tan(tensor)
[docs]@jax.jit def arcsin(tensor): return jnp.arcsin(tensor)
[docs]@jax.jit def arccos(tensor): return jnp.arccos(tensor)
[docs]@jax.jit def arctan(tensor): return jnp.arctan(tensor)
[docs]@jax.jit def sinh(tensor): return jnp.sinh(tensor)
[docs]@jax.jit def cosh(tensor): return jnp.cosh(tensor)
[docs]@jax.jit def tanh(tensor): return jnp.tanh(tensor)
[docs]@jax.jit def arcsinh(tensor): return jnp.arcsinh(tensor)
[docs]@jax.jit def arccosh(tensor): return jnp.arccosh(tensor)
[docs]@jax.jit def arctanh(tensor): return jnp.arctanh(tensor)
[docs]@jax.jit def power(tensor, exponent): return jnp.power(tensor, exponent)
[docs]@jax.jit def exp(tensor): return jnp.exp(tensor)
[docs]@jax.jit def log(tensor): return jnp.log(tensor)
[docs]@jax.jit def log2(tensor): return jnp.log2(tensor)
[docs]@jax.jit def log10(tensor): return jnp.log10(tensor)
[docs]@jax.jit def sqrt(tensor): return jnp.sqrt(tensor)
[docs]@jax.jit def grad(state, x, ind_x, ind_y): """ Calculate dy/dx. Args: y (tensor): x (tensor): Returns: dy/dx """ return jax.grad(lambda x : jnp.sum(state.apply_fn(state.params, x)[:, ind_y]), 1)(x)[:, ind_x]
[docs]def save(state, path): state_dict=flax.serialization.to_state_dict(state) pickle.dump(state_dict,open(path,"wb"))
[docs]def load(state, path): pkl_file = pickle.load(open(path, "rb")) state = flax.serialization.from_state_dict(target=state, state=pkl_file) return state