Source code for pantea.descriptors.acsf.radial
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from pantea.descriptors.acsf.cutoff import CutoffFunction
from pantea.descriptors.acsf.symmetry import BaseSymmetryFunction
from pantea.pytree import register_jax_pytree_node
from pantea.types import Array
[docs]class RadialSymmetryFunction(BaseSymmetryFunction, metaclass=ABCMeta):
"""A base class for `two body` (radial) symmetry functions."""
def __hash__(self) -> int:
"""Enforce to use the parent class's hash method (JIT)."""
return super().__hash__()
@abstractmethod
def __call__(self, rij: Array) -> Array: ...
[docs]@dataclass
class G1(RadialSymmetryFunction):
"""Plain cutoff function as symmetry function."""
cfn: CutoffFunction
def __post_init__(self) -> None:
self._assert_jit_dynamic_attributes()
self._assert_jit_static_attributes(expected=("cfn",))
def __hash__(self) -> int:
"""Enforce to use the parent class's hash method (JIT)."""
return super().__hash__()
@jax.jit
def __call__(self, rij: Array) -> Array:
return self.cfn(rij)
[docs]@dataclass
class G2(RadialSymmetryFunction):
"""Radial exponential symmetry function."""
cfn: CutoffFunction
r_shift: float
eta: float
def __post_init__(self) -> None:
self._assert_jit_dynamic_attributes()
self._assert_jit_static_attributes(expected=("cfn", "r_shift", "eta"))
def __hash__(self) -> int:
"""Enforce to use the parent class's hash method (JIT)."""
return super().__hash__()
@jax.jit
def __call__(self, rij: Array) -> Array:
return jnp.exp(-self.eta * (rij - self.r_shift) ** 2) * self.cfn(rij)
register_jax_pytree_node(G1)
register_jax_pytree_node(G2)