Source code for pantea.descriptors.acsf.acsf

import itertools
from dataclasses import dataclass
from functools import partial
from typing import Callable, Optional, Tuple

import jax.numpy as jnp
from jax import jacfwd, jit, lax, vmap

from pantea.atoms.distance import (
    _calculate_distances_per_atom,
    _calculate_distances_with_aux_per_atom,
)
from pantea.atoms.neighbor import _calculate_cutoff_masks_per_atom
from pantea.atoms.structure import Structure, StructureAsKernelArgs
from pantea.descriptors.acsf.angular import AngularSymmetryFunction
from pantea.descriptors.acsf.radial import RadialSymmetryFunction
from pantea.descriptors.acsf.symmetry import NeighborElements
from pantea.logger import logger
from pantea.pytree import BaseJaxPytreeDataClass, register_jax_pytree_node
from pantea.types import Array

AssignedRadialSymmetryFunction = Tuple[RadialSymmetryFunction, NeighborElements]
AssignedAngularSymmetryFunction = Tuple[AngularSymmetryFunction, NeighborElements]


[docs]@dataclass class AtomCenteredSymmetryFunction(BaseJaxPytreeDataClass): """ Atom-centered Symmetry Function (`ACSF`_) descriptor captures information about the distribution of neighboring atoms around a central atom by considering both radial (two-body) and angular (three-body) symmetry functions within a cutoff distance. Radial symmetry functions describe the distances, while angular symmetry functions hold information about the angles formed by the central atom with pairs of neighboring atoms. The ACSF represents a fingerprint of the local atomic environment and can be used in various machine learning potentials. .. _ACSF: https://compphysvienna.github.io/n2p2/topics/descriptors.html?highlight=symmetry%20function# """ central_element: str radial_symmetry_functions: Tuple[AssignedRadialSymmetryFunction, ...] angular_symmetry_functions: Tuple[AssignedAngularSymmetryFunction, ...] def __call__( self, structure: Structure, atom_index: Optional[Array] = None, ) -> Array: """ Calculate descriptor values for the input given structure and atom index. :param structure: input structure instance :type structure: Structure :param atom_index: atom indices, defaults select all atom indices of type the central element of the descriptor. :type atom_index: Optional[Array], optional :return: descriptor values :rtype: Array """ if self.num_symmetry_functions == 0: logger.warning("No symmetry function was found") if atom_index is None: index = structure.select(self.central_element) else: index = jnp.atleast_1d(atom_index) # Check all atom types match the central element if not jnp.all( structure.element_map.element_to_atom_type[self.central_element] == structure.atom_types[index] ): logger.error( f"Inconsistent central element '{self.central_element}': " f" input atom index={atom_index}" f" (atom_types='{int(structure.atom_types[index])}')", exception=ValueError, ) return _jitted_calculate_acsf_descriptor( self, structure.positions[index], structure.as_kernel_args(), ) # type: ignore
[docs] def grad( self, structure: Structure, atom_index: Optional[Array] = None, ) -> Array: """ Compute gradient of the ACSF descriptor respect to the atom position. :param structure: input Structure instance :param atom_index: atom index in Structure [0, natoms) :return: gradient of the descriptor value respect to the atom position Please note that `grad_per_element` method is way much faster than the current implementation of this method method. """ if atom_index is None: positions = structure.positions else: index = jnp.atleast_1d(atom_index) # check atom index if not jnp.all((0 <= index) & (index < structure.natoms)): logger.error( f"unexpected {atom_index=}." f"Input index must be between [0, {structure.natoms})", ValueError, ) positions = structure.positions[index] return _jitted_calculate_grad_acsf_descriptor( self, positions, structure.as_kernel_args(), ) # type: ignore
@property def num_radial_symmetry_functions(self) -> int: """Return number of `two-body` (radial) symmetry functions.""" return len(self.radial_symmetry_functions) @property def num_angular_symmetry_functions(self) -> int: """Return number of `three-body` (angular) symmetry functions.""" return len(self.angular_symmetry_functions) @property def num_symmetry_functions(self) -> int: """Return the total (`two-body` and `tree-body`) number of symmetry functions.""" return self.num_radial_symmetry_functions + self.num_angular_symmetry_functions @property def r_cutoff(self) -> float: # type: ignore """Return the maximum cutoff radius for list of the symmetry functions.""" return max( [ symmetry_function.r_cutoff for (symmetry_function, _) in itertools.chain( *[ self.radial_symmetry_functions, self.angular_symmetry_functions, ] ) ] ) def __hash__(self) -> int: """Enforce to use the parent class's hash method (JIT).""" return super().__hash__() def __repr__(self) -> str: return ( f"{self.__class__.__name__}(central_element='{self.central_element}'" f", num_symmetry_functions={self.num_symmetry_functions})" )
def _calculate_acsf_descriptor_per_atom( acsf: AtomCenteredSymmetryFunction, position: Array, structure: StructureAsKernelArgs, ) -> Array: """Compute the ACSF descriptor values per atom.""" result: Array = jnp.empty(acsf.num_symmetry_functions, dtype=position.dtype) # calculate distances respect to the reference atom (_i) distances_i, position_differences_i = _calculate_distances_with_aux_per_atom( position, structure.positions, structure.lattice ) # Loop over the radial terms for index, (symmetry_function, assigned_elements) in enumerate( acsf.radial_symmetry_functions ): result = result.at[index].set( _calculate_radial_acsf_per_atom( symmetry_function, distances_i, structure.atom_types, structure.element_map[assigned_elements.neighbor_j], ) ) # Loop over the angular terms for index, (symmetry_function, assigned_elements) in enumerate( acsf.angular_symmetry_functions, start=acsf.num_radial_symmetry_functions, ): result = result.at[index].set( _calculate_angular_acsf_per_atom( symmetry_function, structure.atom_types, position_differences_i, distances_i, structure.lattice, structure.element_map[assigned_elements.neighbor_j], structure.element_map[assigned_elements.neighbor_k], # type: ignore ) ) return result _calculate_acsf_descriptor = vmap( _calculate_acsf_descriptor_per_atom, in_axes=(None, 0, None), ) _jitted_calculate_acsf_descriptor = jit( _calculate_acsf_descriptor, static_argnums=(0,), ) _calculate_grad_acsf_descriptor_per_atom = jacfwd( _calculate_acsf_descriptor_per_atom, argnums=1, ) _calculate_grad_acsf_descriptor = vmap( _calculate_grad_acsf_descriptor_per_atom, in_axes=(None, 0, None), ) _jitted_calculate_grad_acsf_descriptor = jit( _calculate_grad_acsf_descriptor, static_argnums=(0,), ) def _calculate_radial_acsf_per_atom( radial_symmetry_function: RadialSymmetryFunction, distances_i: Array, atom_types: Array, neighbor_atom_type_j: Array, ) -> Array: r_cutoff = jnp.array(radial_symmetry_function.r_cutoff) cutoff_mask_i = _calculate_cutoff_masks_per_atom(distances_i, r_cutoff) cutoff_masks_and_atom_types_ij = cutoff_mask_i & ( atom_types == neighbor_atom_type_j ) return jnp.sum( radial_symmetry_function(distances_i), where=cutoff_masks_and_atom_types_ij, axis=0, ) def _calculate_angular_acsf_per_atom( angular_symmetry_function: AngularSymmetryFunction, atom_types: Array, position_differences_i: Array, distances_i: Array, lattice: Array, neighbor_atom_type_j: Array, neighbor_atom_type_k: Array, ) -> Array: # cutoff-radius masks r_cutoff = jnp.array(angular_symmetry_function.r_cutoff) cutoff_masks_i = _calculate_cutoff_masks_per_atom(distances_i, r_cutoff) # masks for neighboring element j cutoff_masks_and_atom_types_ij = cutoff_masks_i & ( atom_types == neighbor_atom_type_j ) # masks for neighboring element k cutoff_masks_and_atom_types_ik = cutoff_masks_i & ( atom_types == neighbor_atom_type_k ) # angular terms total, _ = lax.scan( partial( _inner_loop_over_angular_acsf_terms, mask_ik=cutoff_masks_and_atom_types_ik, position_differences_i=position_differences_i, distances_i=distances_i, lattice=lattice, kernel=angular_symmetry_function, ), jnp.array(0.0), (position_differences_i, distances_i, cutoff_masks_and_atom_types_ij), ) # correct the double-counting # avoided jit recompilation due to lambda's variable hash return lax.cond( neighbor_atom_type_j == neighbor_atom_type_k, _divide_by_two, _return_input, total, ) # type: ignore # Called by lax.scan (no need for @jax.jit) def _inner_loop_over_angular_acsf_terms( total: Array, inputs: Tuple[Array, ...], position_differences_i: Array, distances_i: Array, mask_ik: Array, lattice: Array, kernel: Callable[[Array, Array, Array, Array], Array], ) -> Tuple[Array, Array]: # Scan occurs along the leading axis Rij, rij, mask_ij = inputs # fix nan issue in gradient # see https://github.com/google/jax/issues/1052#issuecomment-514083352 operand = rij * distances_i is_zero = operand == 0.0 true_op = jnp.where(is_zero, 1.0, operand) cost = jnp.where( mask_ik, jnp.inner(Rij, position_differences_i) / true_op, 1.0, ) cost = jnp.where(is_zero, 0.0, cost) # type: ignore rjk = jnp.where( # diff_jk = diff_ji - position_differences_ik mask_ik, _calculate_distances_per_atom(Rij, position_differences_i, lattice), 0.0, ) # second tuple output for Rjk value = jnp.where( mask_ij, jnp.sum( kernel(rij, distances_i, rjk, cost), where=mask_ik & (rjk > 0.0), # exclude k=j # type:ignore axis=0, ), 0.0, ) return total + value, value def _return_input(array: Array) -> Array: return array def _divide_by_two(array: Array) -> Array: return array * 0.5 ACSF = AtomCenteredSymmetryFunction register_jax_pytree_node(AtomCenteredSymmetryFunction)