Source code for pantea.atoms.box
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Sequence
import jax
import jax.numpy as jnp
from pantea.logger import logger
from pantea.pytree import BaseJaxPytreeDataClass, register_jax_pytree_node
from pantea.types import Array, Dtype, default_dtype
[docs]@dataclass
class Box(BaseJaxPytreeDataClass):
"""
Simulation box is used to apply periodic boundary
conditions (PBC) in the presence of a lattice info.
.. warning::
The current implementation only works for orthogonal cells
and does not support triclinic cells.
"""
lattice: Array
def __post_init__(self) -> None:
self._assert_jit_static_attributes()
self._assert_jit_dynamic_attributes(expected=("lattice",))
[docs] @classmethod
def from_list(
cls,
data: Sequence[float],
dtype: Optional[Dtype] = None,
) -> Box:
logger.debug(f"Initializing {cls.__name__} from list")
if dtype is None:
dtype = default_dtype.FLOATX
lattice = jnp.array(data, dtype=dtype).reshape(3, 3)
return Box(lattice)
[docs] @jax.jit
def apply_pbc(self, dx: Array) -> Array:
"""
Apply periodic boundary condition (PBC) on the atom positions.
For this method to function correctly, it is essential that all atoms are
initially positioned within the boundaries of the box.
Otherwise, the results may not be as anticipated. This could happen for
when example time step is too large.
:param dx: positional differences
:type dx: Array
:return: PBC applied position differences
:rtype: Array
"""
return _apply_pbc(dx, self.lattice)
[docs] @jax.jit
def wrap_into_box(self, positions: Array) -> Array:
"""
Adjust the coordinates of the atoms to ensure they fall
within the boundaries of the simulation box that has
periodic boundary conditions (PBC).
:param positions: atom positions
:type positions: Array
:return: wrapped atom positions using the PBC.
:rtype: Array
"""
logger.debug("Shift all atoms within the simulation box")
return _wrap_into_box(positions, self.lattice)
@property
def lx(self) -> Array:
"""Return length of cell in x-direction."""
return self.lattice[0, 0]
@property
def ly(self) -> Array:
"""Return length of cell in y-direction."""
return self.lattice[1, 1]
@property
def lz(self) -> Array:
"""Return length of cell in z-direction."""
return self.lattice[2, 2]
@property
def length(self) -> Array:
"""Return length of cell in x, y, and z-directions."""
return self.lattice.diagonal()
@property
def volume(self) -> Array:
"""Return calculated volume of the box."""
return jnp.prod(self.length)
@property
def dtype(self) -> Dtype:
return self.lattice.dtype
def __hash__(self) -> int:
"""Enforce to use the parent class's hash method (JIT)."""
return super().__hash__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}(lattice={self.lattice}, dtype={self.dtype})"
def _apply_pbc(dx: Array, lattice: Array) -> Array:
"""Apply periodic boundary condition (PBC) along x,y, and z directions."""
box = lattice.diagonal()
dx = jnp.where(dx > 0.5 * box, dx - box, dx)
dx = jnp.where(dx < -0.5 * box, dx + box, dx)
return dx
_jitted_apply_pbc = jax.jit(_apply_pbc)
def _wrap_into_box(positions: Array, lattice: Array) -> Array:
"""Wrap atoms back into the simulation box using periodic boundary condition."""
box = lattice.diagonal()
return jnp.remainder(positions, box)
_jitted_wrap_into_box = jax.jit(_wrap_into_box)
register_jax_pytree_node(Box)