Source code for pantea.simulation.thermostat

from __future__ import annotations

from typing import NamedTuple, Protocol

import jax
import jax.numpy as jnp

from pantea.simulation.system import _get_temperature
from pantea.types import Array


@jax.jit
def _get_rescaled_velocities(
    params: BrendsenThermostatParams,
    velocities: Array,
):
    scaling_factor = 1.0 / jnp.sqrt(
        1.0
        + (params.time_step / params.time_constant)
        * (params.current_temperature / params.target_temperature - 1.0)
    )
    return velocities * scaling_factor


[docs]class MDSimulatorInterface(Protocol): time_step: float
[docs]class SystemInterface(Protocol): velocities: Array
[docs] def get_temperature(self) -> Array: ...
[docs]class BrendsenThermostatParams(NamedTuple): time_step: Array time_constant: Array current_temperature: Array target_temperature: Array
[docs]class BrendsenThermostat: """Control simulation temperature using Brendsen thermostat.""" def __init__( self, target_temperature: float, time_constant: float, ) -> None: self.target_temperature: Array = jnp.array(target_temperature) self.time_constant: Array = jnp.array(time_constant)
[docs] def get_rescaled_velocities( self, simulator: MDSimulatorInterface, system: SystemInterface, ) -> Array: current_temperature = system.get_temperature() params = BrendsenThermostatParams( simulator.time_step, self.time_constant, current_temperature, self.target_temperature, ) return _get_rescaled_velocities(params, system.velocities)