Source code for pantea.utils.compare
from typing import Any, Dict, List, Union
import jax.numpy as jnp
from pantea.atoms.structure import Structure
from pantea.types import Array
[docs]def compare(
structure1: Structure,
structure2: Structure,
errors: Union[str, List] = "RMSEpa",
return_difference: bool = False,
) -> Dict[str, Array]:
"""
Compare the `force` and `total energy` values
between two input structures and returning the desired errors metrics.
:param structure1: first structure
:param structure2: second structure
:param error: a list of error metrics including `RMSE`, `RMSEpa`, `MSE`, and `MSEpa`. Defaults to [`RMSEpa`]
:param return_difference: whether return energy and force array differences or not, defaults to False
:return: a dictionary of error metrics.
"""
assert all(
structure1.atom_types == structure2.atom_types
), "Expected similar structures with the same atom types."
result: Dict[str, Any] = dict()
frc_diff: Array = structure1.forces - structure2.forces
eng_diff: Array = structure1.total_energy - structure2.total_energy
errors = [errors] if isinstance(errors, str) else errors
print(f"Comparing two structures, error metrics: {', '.join(errors)}")
errors = [x.lower() for x in errors]
if "rmse" in errors:
result["force_RMSE"] = jnp.sqrt(jnp.mean(frc_diff**2))
result["energy_RMSE"] = jnp.sqrt(jnp.mean(eng_diff**2))
if "rmsepa" in errors:
result["force_RMSEpa"] = jnp.sqrt(jnp.mean(frc_diff**2))
result["energy_RMSEpa"] = (
jnp.sqrt(jnp.mean(eng_diff**2)) / structure1.natoms
)
if "mse" in errors:
result["force_MSE"] = jnp.mean(frc_diff**2)
result["energy_MSE"] = jnp.mean(eng_diff**2)
if "msepa" in errors:
result["force_MSEpa"] = jnp.mean(frc_diff**2)
result["energy_MSEpa"] = jnp.mean(eng_diff**2) / structure1.natoms
if return_difference:
result["frc_diff"] = frc_diff
result["eng_diff"] = eng_diff
return result