Source code for pantea.models.nn.activation
import jax.numpy as jnp
from flax import linen as nn
from frozendict import frozendict
from pantea.types import Array
[docs]def identity(x: Array) -> Array:
return x
[docs]def tanh(x: Array) -> Array:
return nn.tanh(x)
[docs]def logistic(x: Array) -> Array:
return 1.0 / (1.0 + jnp.exp(-x))
[docs]def softplus(x: Array) -> Array:
return nn.softplus(x)
[docs]def relu(x: Array) -> Array:
return nn.relu(x)
[docs]def gaussian(x: Array) -> Array:
return jnp.exp(-0.5 * x**2)
[docs]def cos(x: Array) -> Array:
return jnp.cos(x)
[docs]def revlogistic(x: Array) -> Array:
return 1.0 - 1.0 / (1.0 + jnp.exp(-x))
[docs]def exp(x: Array) -> Array:
return jnp.exp(-x)
[docs]def harmonic(x: Array) -> Array:
return x * x
# see here https://compphysvienna.github.io/n2p2/api/neural_network.html?highlight=activation%20function
_activation_function_map: frozendict = frozendict(
{
"identity": identity,
"tanh": nn.tanh,
"logistic": logistic,
"softplus": nn.softplus,
"relu": nn.relu,
"gaussian": gaussian,
"cos": cos,
"exp": exp,
"harmonic": harmonic,
}
)