Source code for pantea.descriptors.acsf.cutoff
from __future__ import annotations
import math
from dataclasses import dataclass
from functools import partial, update_wrapper
from typing import Any, Callable, Mapping
import jax
import jax.numpy as jnp
from pantea.pytree import BaseJaxPytreeDataClass, register_jax_pytree_node
from pantea.types import Array
[docs]@dataclass(frozen=True)
class CutoffFunction(BaseJaxPytreeDataClass):
"""Cutoff function for ACSF descriptor.
Cutoff functions are utilized in the calculation of Atom-centered Symmetry
Function (ACSF) descriptors. These functions serve to limit the influence
of atoms located beyond a specified distance from the central atom.
The ACSF descriptors employ cutoff functions to determine the range within which
neighboring atoms contribute to the descriptor calculation. In fact,
cutoff function assigns a weight to each neighbor atom based on its
distance from the central atom. Typically, a smooth cutoff function is
employed to smoothly taper off the contribution of
atoms as they move away from the central atom.
The choice of cutoff function can vary depending on the specific application.
Examples of commonly used cutoff functions include the hyperbolic tangent (tanh)
cutoff, exponential, or exponential.
See `cutoff function`_ and `cutoff type`_ for more details.
.. _`cutoff function`: https://compphysvienna.github.io/n2p2/api/cutoff_functions.html?highlight=cutoff#
.. _`cutoff type`: https://compphysvienna.github.io/n2p2/topics/keywords.html?highlight=cutoff_type
"""
r_cutoff: float
cutoff_function: Callable
[docs] @classmethod
def from_type(
cls,
cutoff_type: str,
r_cutoff: float,
) -> CutoffFunction:
"""Create a cutoff function from the input cutoff type."""
_cutoff_function_map: Mapping[str, Callable[[Array], Array]] = {
"hard": _hard,
"tanhu": _wrapped_partial(_tanhu, r_cutoff=r_cutoff),
"tanh": _wrapped_partial(_tanh, r_cutoff=r_cutoff),
"cos": _wrapped_partial(_cos, r_cutoff=r_cutoff),
"exp": _wrapped_partial(_exp, r_cutoff=r_cutoff),
"poly1": _poly1,
"poly2": _poly2,
}
return cls(r_cutoff, _cutoff_function_map[cutoff_type])
def __post_init__(self) -> None:
self._assert_jit_dynamic_attributes()
self._assert_jit_static_attributes(expected=("r_cutoff", "cutoff_function"))
@jax.jit
def __call__(self, r: Array) -> Array:
return jnp.where(
r < self.r_cutoff,
self.cutoff_function(r),
jnp.zeros_like(r),
)
def __hash__(self) -> int:
"""Override the hash function from the base jax pytree data class."""
return super().__hash__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}(r_cutoff={self.r_cutoff})"
_TANH_PRE: float = ((math.e + 1 / math.e) / (math.e - 1 / math.e)) ** 3
def _hard(r: Array) -> Array:
return jnp.ones_like(r)
def _tanhu(r: Array, r_cutoff: float) -> Array:
return jnp.tanh(1.0 - r / r_cutoff) ** 3
def _tanh(r: Array, r_cutoff: float) -> Array:
return _TANH_PRE * jnp.tanh(1.0 - r / r_cutoff) ** 3
def _cos(r: Array, r_cutoff: float) -> Array:
return 0.5 * (jnp.cos(jnp.pi * r / r_cutoff) + 1.0)
def _exp(r: Array, r_cutoff: float) -> Array:
return jnp.exp(1.0 - 1.0 / (1.0 - (r / r_cutoff) ** 2))
def _poly1(r: Array) -> Array:
return (2.0 * r - 3.0) * r**2 + 1.0
def _poly2(r: Array) -> Array:
return ((15.0 - 6.0 * r) * r - 10) * r**3 + 1.0
def _wrapped_partial(function: Callable, **kwargs: Any) -> Callable[[Array], Array]:
partial_function = partial(function, **kwargs)
update_wrapper(partial_function, function)
return partial_function
register_jax_pytree_node(CutoffFunction)