Source code for pantea.descriptors.acsf.angular
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from pantea.descriptors.acsf.cutoff import CutoffFunction
from pantea.descriptors.acsf.symmetry import BaseSymmetryFunction
from pantea.pytree import register_jax_pytree_node
from pantea.types import Array
[docs]class AngularSymmetryFunction(BaseSymmetryFunction, metaclass=ABCMeta):
"""A base class for `three body` (angular) symmetry functions."""
def __hash__(self) -> int:
"""Enforce to use the parent class's hash method (JIT)."""
return super().__hash__()
@abstractmethod
def __call__(
self,
rij: Array,
rik: Array,
rjk: Array,
cost: Array,
) -> Array: ...
[docs]@dataclass
class G3(AngularSymmetryFunction):
"""Angular symmetry function."""
cfn: CutoffFunction
eta: float
zeta: float
lambda0: float
r_shift: float
def __post_init__(self) -> None:
self._assert_jit_dynamic_attributes()
self._assert_jit_static_attributes(
expected=("cfn", "eta", "zeta", "lambda0", "r_shift")
)
def __hash__(self) -> int:
"""Enforce to use the parent class's hash method (JIT)."""
return super().__hash__()
@jax.jit
def __call__(
self,
rij: Array,
rik: Array,
rjk: Array,
cost: Array,
) -> Array:
return (
2.0 ** (1.0 - self.zeta)
* jnp.power(1 + self.lambda0 * cost, self.zeta)
* jnp.exp(-self.eta * (rij**2 + rik**2 + rjk**2))
* self.cfn(rij)
* self.cfn(rik)
* self.cfn(rjk)
)
[docs]@dataclass
class G9(AngularSymmetryFunction):
"""
Modified angular symmetry function.
J. Behler, J. Chem. Phys. 134, 074106 (2011).
"""
cfn: CutoffFunction
eta: float
zeta: float
lambda0: float
r_shift: float
def __post_init__(self) -> None:
self._assert_jit_dynamic_attributes()
self._assert_jit_static_attributes(
expected=("cfn", "eta", "zeta", "lambda0", "r_shift")
)
def __hash__(self) -> int:
"""Enforce to use the parent class's hash method (JIT)."""
return super().__hash__()
@jax.jit
def __call__(
self,
rij: Array,
rik: Array,
rjk: Array,
cost: Array,
) -> Array:
# TODO: r_shift, define params argument instead
return (
2.0 ** (1.0 - self.zeta)
* jnp.power(1 + self.lambda0 * cost, self.zeta)
* jnp.exp(-self.eta * (rij**2 + rik**2))
* self.cfn(rij)
* self.cfn(rik)
)
register_jax_pytree_node(G3)
register_jax_pytree_node(G9)