Source code for pantea.atoms.structure

from __future__ import annotations

import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, DefaultDict, Dict, Iterator, List, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
from ase import Atoms as AseAtoms

from pantea.atoms.box import Box, _wrap_into_box
from pantea.atoms.element import ElementMap
from pantea.logger import logger
from pantea.pytree import BaseJaxPytreeDataClass, register_jax_pytree_node
from pantea.types import Array, Dtype, Element, default_dtype
from pantea.units import units


[docs]@dataclass class Structure(BaseJaxPytreeDataClass): """ A Structure contains arrays that store atomic attributes for a collection of atoms in a simulation box and at any time. The currently existing attributes of atoms within a Structure are as follows: * `positions`: positions of atoms, an array of (natoms, 3) * `forces`: force components, an array of (natoms, 3) * `energies`: associated atom potential energies, an array of (natoms,) * `charges`: charge of atoms, an array of (natoms,) * `total_energy`: total potential energy, scalar value * `total_charge`: total charge, scalar value Structure serves as a fundamental data container for atoms. Multiple structures can be used to train a potential, or alternatively, the total energy and force components can be computed for a specific structure. Each structure has additional attributes as: * `Box`: applying periodic boundary condition (PBC) along x, y, and z directions * `ElementMap`: determines how to extract assigned atom types from the element and vice versa .. note:: The structure can be viewed as a separate domain for implementing MPI in large-scale molecular dynamics (MD) simulations, as demonstrated in the `miniMD`_ code. .. _miniMD: https://github.com/Mantevo/miniMD """ positions: Array forces: Array energies: Array charges: Array total_energy: Array total_charge: Array atom_types: Array element_map: ElementMap box: Optional[Box] = None def __post_init__(self) -> None: self._assert_jit_dynamic_attributes( expected=( "positions", "forces", "energies", "charges", "total_energy", "total_charge", "atom_types", ) ) self._assert_jit_static_attributes( expected=( "element_map", "box", ) ) if self.box is not None: self.positions = _wrap_into_box(self.positions, self.lattice)
[docs] @classmethod def from_ase( cls, atoms: AseAtoms, dtype: Optional[Dtype] = None, ) -> Structure: """ Create an instance of the structure from `ASE`_ atoms input. :param atoms: input `ASE`_ atoms instance :param dtype: data type for arrays, defaults to None :return: initialized structure .. _ASE: https://wiki.fysik.dtu.dk/ase/index.html """ logging.debug(f"Initializing {cls.__name__} from ASE atoms") if dtype is None: dtype = default_dtype.FLOATX kwargs = dict() data = { "elements": [ ElementMap.get_element_from_atomic_number(n) for n in atoms.get_atomic_numbers() ], "lattice": np.array(atoms.get_cell() * units.FROM_ANGSTROM, dtype=dtype), "positions": atoms.get_positions() * units.FROM_ANGSTROM, } for attr, ase_attr in zip( ("energies", "charges"), ("potential_energies", "charges"), ): try: data[attr] = getattr(atoms, f"get_{ase_attr}")() except RuntimeError: continue for attr in ("energies", "charges"): if attr in data: data[f"total_{attr}"] = sum(data[attr]) input_data = defaultdict(list, data) try: element_map: ElementMap = ElementMap.from_list(input_data["elements"]) kwargs.update( cls._init_arrays(input_data, element_map=element_map, dtype=dtype), ) kwargs["element_map"] = element_map kwargs["box"] = cls._init_box(input_data["lattice"], dtype=dtype) except KeyError: logger.error( "Can not find at least one of the expected keyword in the input data.", exception=KeyError, ) return cls(**kwargs)
[docs] @classmethod def from_dict( cls, data: Dict[str, Any], dtype: Optional[Dtype] = None, ) -> Structure: """ Create a structure using an input data dictionary that contains distinct lists of positions, forces, elements, lattice, etc. :param data: input data :param dtype: data type for arrays, defaults to None :return: the initialized Structure """ logging.debug(f"Initializing {cls.__name__} from an input dictionary") if dtype is None: dtype = default_dtype.FLOATX input_data: DefaultDict[str, List] = defaultdict(list, data) kwargs: Dict[str, Any] = dict() try: element_map = ElementMap.from_list(input_data["elements"]) kwargs.update( cls._init_arrays(input_data, element_map=element_map, dtype=dtype), ) kwargs["element_map"] = element_map kwargs["box"] = cls._init_box(input_data["lattice"], dtype=dtype) except KeyError: logger.error( "Cannot find at least one of the expected keyword in the input data dictionary.", exception=KeyError, ) return cls(**kwargs)
@classmethod def _init_arrays( cls, data: Dict[str, Any], element_map: ElementMap, dtype: Dtype, ) -> Dict[str, Array]: logger.debug(f"{cls.__name__}: allocating arrays as follows:") arrays: Dict[str, Array] = dict() for atom_attr in Structure._get_atom_attributes(): try: array: Array if atom_attr == "atom_types": array = jnp.array( [ element_map.get_atom_type_from_element(name) for name in data["elements"] ], dtype=default_dtype.INDEX, ) else: array = jnp.array(data[atom_attr], dtype=dtype) arrays[atom_attr] = jnp.squeeze(array) logger.debug( f"{atom_attr:12} -> Array(shape={array.shape}, dtype='{array.dtype}')" ) except KeyError: logger.error( f"Cannot find atom attribute {atom_attr} in the input data", exception=KeyError, ) return arrays @classmethod def _init_box( cls, lattice: List[float], dtype: Dtype, ) -> Optional[Box]: if len(lattice) > 0: return Box.from_list(lattice, dtype=dtype) else: logger.debug("No lattice info were found") return None @classmethod def _get_atom_attributes(cls) -> Tuple[str, ...]: return cls._get_jit_dynamic_attributes() def __hash__(self) -> int: """Use parent class's hash method because of JIT.""" return super().__hash__() @property def natoms(self) -> int: """Return number of atoms in the structure""" return self.positions.shape[0] @property def dtype(self) -> Dtype: """Return data type of the arrays in the structure (default is `float64`).""" return self.positions.dtype @property def lattice(self) -> Optional[Array]: """Return the cell matrix (3x3).""" if self.box is not None: return self.box.lattice
[docs] def get_unique_elements(self) -> Tuple[Element, ...]: return self.element_map.unique_elements
[docs] def get_elements(self) -> Tuple[Element, ...]: """Get array of elements.""" to_element = self.element_map.atom_type_to_element atom_types_host = jax.device_get(self.atom_types) return tuple(str(to_element[at]) for at in atom_types_host)
[docs] def select(self, element: Element) -> Array: """ Retrieve the indices of all atoms that correspond to the given element. :param element: element name (e.g. `H` for hydrogen) :return: atom indices """ return jnp.nonzero( self.atom_types == self.element_map.element_to_atom_type[element] )[0]
[docs] def to_dict(self) -> Dict[str, np.ndarray]: """ The atomic attributes are represented as a dictionary of NumPy arrays. This format can be employed, for instance, when saving the structure data into a file. :return: dictionary of atom attributes. :rtype: Dict[str, np.ndarray] """ data = dict() for atom_attr in self._get_atom_attributes(): array: Array = getattr(self, atom_attr) data[atom_attr] = np.asarray(array) data["lattice"] = self.box.lattice if self.box else [] data["elements"] = [ self.element_map.get_element_from_atom_type(n) for n in data["atom_types"] ] return data
[docs] def to_ase(self) -> AseAtoms: """ Represent the structure as ASE atoms. The returned object can be utilized with the `ASE`_ package for visualization or modification of the structure. :return: `ASE`_ representation of the structure .. _ASE: https://wiki.fysik.dtu.dk/ase/index.html """ logger.debug(f"Converting {self.__class__.__name__} to ASE atoms") to_element = self.element_map.atom_type_to_element cell = ( units.TO_ANGSTROM * np.asarray(self.box.lattice) if self.box is not None else None ) return AseAtoms( symbols=[to_element[int(at)] for at in self.atom_types], positions=[units.TO_ANGSTROM * np.asarray(pos) for pos in self.positions], cell=cell, pbc=True if self.box else False, charges=[np.asarray(ch) for ch in self.charges], )
def _get_energy_offset(self, atom_energy: Dict[Element, float]) -> Array: energy_offset: Array = jnp.empty_like(self.energies) for element in self.get_unique_elements(): energy_offset = energy_offset.at[self.select(element)].set( atom_energy[element] ) return energy_offset
[docs] def remove_energy_offset(self, atom_energy: Dict[Element, float]) -> None: """ Remove the input reference energies from individual atoms and the total energy. :param atom_energy: atom reference energy :type atom_energy: Dict[Element, float]] """ energy_offset = self._get_energy_offset(atom_energy) self.energies -= energy_offset self.total_energy -= energy_offset.sum()
[docs] def add_energy_offset(self, atom_energy: Dict[Element, float]) -> None: """ Add the input reference energies to individual atoms and the total energy. :param atom_energy: atom reference energy :type atom_energy: Dict[Element, float]] """ energy_offset = self._get_energy_offset(atom_energy) self.energies += energy_offset self.total_energy += energy_offset.sum()
def __repr__(self) -> str: return ( f"{self.__class__.__name__}" f"(natoms={self.natoms}, " f"elements={self.get_unique_elements()}, " f"dtype={self.dtype})" ) # ---
[docs] def as_kernel_args(self) -> StructureAsKernelArgs: return StructureAsKernelArgs( self.positions, self.atom_types, self.lattice, self.total_energy, self.element_map.element_to_atom_type, )
def _get_positions_per_element(self) -> Iterator[Tuple[Element, Array]]: for element in self.get_unique_elements(): atom_index = self.select(element) yield element, self.positions[atom_index]
[docs] def get_positions_per_element(self) -> Dict[Element, Array]: """Get position of atoms per element.""" return { element: position for element, position in self._get_positions_per_element() }
def _get_forces_per_element(self) -> Iterator[Tuple[Element, Array]]: for element in self.get_unique_elements(): atom_index = self.select(element) yield element, self.forces[atom_index]
[docs] def get_forces_per_element(self) -> Dict[Element, Array]: """Get force components per element.""" return {element: force for element, force in self._get_forces_per_element()}
[docs]class StructureAsKernelArgs(NamedTuple): """ This is a jit-complied compatible representation of structure to be used for for energy and force computing kernels. """ positions: Array atom_types: Array lattice: Array total_energy: Array element_map: Dict[Element, Array]
register_jax_pytree_node(Structure)