Source code for pantea.simulation.system

from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from typing import Optional, Protocol, Tuple

import jax
import jax.numpy as jnp

from pantea.atoms import Structure
from pantea.atoms.box import Box
from pantea.atoms.element import ElementMap
from pantea.logger import logger
from pantea.types import Array, Element
from pantea.units import units

KB: float = units.BOLTZMANN_CONSTANT


@jax.jit
def _get_kinetic_energy(velocities: Array, masses: Array) -> Array:
    return 0.5 * jnp.sum(masses * velocities * velocities)


@jax.jit
def _get_temperature(velocities: Array, masses: Array) -> Array:
    kinetic_energy = _get_kinetic_energy(velocities, masses)
    natoms = velocities.shape[0]
    return 2 * kinetic_energy / (3 * natoms * KB)


@jax.jit
def _get_virial(
    velocities: Array,
    masses: Array,
    positions: Array,
    forces: Array,
) -> Array:
    return 2 * _get_kinetic_energy(velocities, masses) + jnp.sum(positions * forces)


@jax.jit
def _calculate_center_of_mass(array: Array, masses: Array) -> Array:
    return jnp.sum(masses * array, axis=0) / jnp.sum(masses)


[docs]class PotentialInterface(Protocol): def __call__(self, structure: Structure) -> Array: ...
[docs] def compute_forces(self, structure: Structure) -> Array: ...
[docs]@dataclass class System: """An extended atom Structure for molecular simulations (e.g. MD).""" potential: PotentialInterface structure: Structure velocities: Array masses: Array
[docs] @classmethod def from_structure( cls, structure: Structure, potential: PotentialInterface, temperature: float = 300.0, seed: int = 2024, ) -> System: logger.debug(f"Creating {cls.__name__} from Structure") masses = ElementMap.get_masses_from_structure(structure).reshape(-1, 1) velocities = cls.generate_random_velocities( jnp.array(temperature), masses, seed ) return cls(potential, deepcopy(structure), velocities, masses)
def __post_init__(self) -> None: self.update_forces_from_positions() self.update_total_potential_energy_from_positions()
[docs] @classmethod def generate_random_velocities( cls, temperature: Array, masses: Array, seed: int, ) -> Array: """Generate Maxwell-Boltzmann distributed random velocities.""" key = jax.random.PRNGKey(seed) natoms = masses.shape[0] velocities = jax.random.normal(key, shape=(natoms, 3)) velocities *= jnp.sqrt(temperature / _get_temperature(velocities, masses)) velocities -= _calculate_center_of_mass(velocities, masses) return velocities
[docs] def update_forces_from_positions(self) -> None: self.structure.forces = self.potential.compute_forces(self.structure)
[docs] def update_total_potential_energy_from_positions(self) -> None: self.structure.total_energy = self.potential(self.structure)
[docs] @classmethod def compute_forces( cls, potential: PotentialInterface, structure: Structure, ) -> Array: return potential.compute_forces(structure)
[docs] def get_elements(self) -> Tuple[Element, ...]: return self.structure.get_elements()
[docs] def get_pressure(self) -> Array: box = self.structure.box assert ( box is not None ), "Calculating pressure... input structure must have PBC box" virial = _get_virial(self.velocities, self.masses, self.positions, self.forces) return virial / (3.0 * box.volume) # type: ignore
[docs] def get_temperature(self) -> Array: return _get_temperature(self.velocities, self.masses)
[docs] def get_center_of_mass_velocity(self) -> Array: return _calculate_center_of_mass(self.velocities, self.masses)
[docs] def get_center_of_mass_position(self) -> Array: return _calculate_center_of_mass(self.positions, self.masses)
[docs] def get_potential_energy(self) -> Array: """Return total potential energy.""" return self.potential(self.structure)
[docs] def get_kinetic_energy(self) -> Array: return _get_kinetic_energy(self.velocities, self.masses)
[docs] def get_total_energy(self) -> Array: return self.get_potential_energy() + self.get_kinetic_energy()
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"potential={self.potential.__class__.__name__}, " f"structure={self.structure}, " f"temperature={self.get_temperature():.2f})" ) @property def positions(self) -> Array: return self.structure.positions @property def forces(self) -> Array: return self.structure.forces @property def box(self) -> Optional[Box]: return self.structure.box @property def natoms(self) -> int: return self.structure.natoms