Potential training#

An example notebook which shows how to train a high-dimensional neural network potential (HDNNP).

[1]:
# !gpustat
[2]:
from utils import set_env
# set_env('.env')
[3]:
import os
os.environ["JAX_ENABLE_X64"] = "1"
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

Imports#

[5]:
import logging
import numpy as np
from pathlib import Path
import matplotlib.pylab as plt
import seaborn as sns
import random
from tqdm import tqdm
from collections import defaultdict
import jax.numpy as jnp
import jax

import pantea
from pantea.types import default_dtype
from pantea.datasets import RunnerDataset
from pantea.potentials import NeuralNetworkPotential
from pantea.logger import LoggingContextManager
[7]:
# pantea.logger.set_logging_level(logging.DEBUG)
default_dtype.FLOATX = jnp.float64
print(f"default dtype: {default_dtype.FLOATX.dtype}")
print(f"default device: {jax.devices()[0]}")
default dtype: float64
default device: TFRT_CPU_0

Dataset#

[8]:
base_dir = Path('GRN')
[9]:
structures = RunnerDataset(Path(base_dir, "input.data"), persist=True)
# structures = RunnerDataset(Path(base_dir, "input.data"), transform=ToStructure(r_cutoff=3.0), persist=True)
print("Total number of structures:", len(structures))
structures
Total number of structures: 801
[9]:
RunnerDataset(filename='GRN/input.data', persist=True, dtype=float64)
[10]:
# indices = random.choices(range(len(structures)), k=100)
# structures = [structures[i] for i in range(len(structures))] # len(structures)
[11]:
# import torch
# validation_split = 0.10
# nsamples = len(structures)
# split = int(np.floor(validation_split * nsamples))
# train_structures, valid_structures = torch.utils.data.random_split(structures, lengths=[nsamples-split, split])
# structures = valid_structures
[12]:
s = structures[0]
s
[12]:
Structure(natoms=24, elements=('C',), dtype=float64)
[13]:
# energies = jnp.asarray([x.total_energy for x in structures]).reshape(-1)
# print("Energy difference:", max(energies) - min(energies))
# sns.histplot(energies);
[14]:
# with LoggingContextManager(level=logging.DEBUG):
# structures[0].to_dict()
[15]:
# from ase.visualize import view
# atoms = s.to_ase()
# view(atoms, viewer="x3d", repeat=1)
[16]:
# from ase.io import read, write
# write("atoms.png", atoms * (2, 2, 1), rotation='30z,-80x')
# write("atoms.xyz", atoms * (2, 2, 1))
# ![atoms](atoms.png)
[17]:
# from pantea.atoms import Structure
# sp = Structure.from_ase(atoms)
# view(sp.to_ase(), viewer="x3d", repeat=3)

Potential#

[18]:
nnp = NeuralNetworkPotential.from_file(Path(base_dir, "input.nn"))
nnp
[18]:
NeuralNetworkPotential(atomic_potential={'C': AtomicPotential(
  descriptor=ACSF(central_element='C', symmetry_functions=30),
  scaler=Scaler(scale_type='center', scale_min=0.0, scale_max=1.0),
  model=NeuralNetworkModel(hidden_layers=((15, 'tanh'), (15, 'tanh')), param_dtype=float64),
)})
[19]:
# nnp.load()
# nnp.load_scaler()

Extrapolation warnings#

[20]:
# nnp.set_extrapolation_warnings(100)

Fit scaler#

[21]:
structures = [structures[index] for index in range(5)]
[22]:
# with LoggingContextManager(level=logging.DEBUG):
nnp.fit_scaler(structures)
Fitting descriptor scaler...
100%|██████████| 5/5 [00:05<00:00,  1.06s/it]
Done.

[24]:
time nnp(s)
CPU times: user 34.3 ms, sys: 25 µs, total: 34.3 ms
Wall time: 33.4 ms
[24]:
Array(2.4637737, dtype=float64)
[26]:
time nnp.compute_forces(s)
CPU times: user 62.9 ms, sys: 0 ns, total: 62.9 ms
Wall time: 61.6 ms
[26]:
Array([[-0.73818644, -1.11250721,  2.19987674],
       [-0.60619114,  1.33811373,  0.23699997],
       [ 0.39212829,  1.86935193, -1.38083261],
       [-1.24370145,  1.38746971, -0.16639672],
       [ 1.17766406, -0.06752429, -0.27530415],
       [-0.90934218, -1.97845414,  1.2512508 ],
       [-0.08749216, -1.34951345,  1.21127088],
       [ 0.64631809, -0.81786649,  1.80367966],
       [-0.95089644, -1.12067772, -1.071214  ],
       [-1.98676865,  2.32836132,  1.22984263],
       [-0.67455221,  0.96647395, -0.77369183],
       [ 0.625575  , -1.00776043, -1.20962872],
       [ 1.50403387,  0.24804563, -0.56286058],
       [ 0.83629278,  2.39211837, -0.65869605],
       [ 1.84471198,  1.33191484, -1.44894841],
       [-0.52273753,  1.91611271,  0.86284524],
       [ 1.44972508,  0.02283892, -1.30880114],
       [-1.74267357,  1.41118767, -0.40835024],
       [ 0.93299509,  0.35660988, -2.28476991],
       [-1.45420113, -2.42908854, -0.36560575],
       [ 0.59659043, -2.43054421,  0.84919574],
       [ 0.956988  , -2.31821003,  0.65325916],
       [ 0.20396445,  0.49034126,  0.39256995],
       [ 0.2309334 , -2.1927167 ,  1.32401521]], dtype=float64)

Training#

[ ]:
h = nnp.fit_model(structures)

for sub in h:
    if 'loss' in sub:
        plt.plot(h['epoch'], h[sub], label=sub)
plt.legend();
[ ]:
# nnp.save()

Validation#

Energy#

[ ]:
print(f"{len(structures)=}")
true_energy = [s.total_energy for s in structures]
pred_energy = [nnp(s) for s in structures]
ii = range(len(structures))
plt.scatter(true_energy, pred_energy, label='NNP')
plt.plot(true_energy, true_energy, 'r', label="REF")
plt.xlabel("true energy")
plt.ylabel("pred energy")
plt.legend()
plt.show()

Force#

[ ]:
import jax.numpy as jnp

true_forces = []
pred_forces = []
print(f"{len(structures)=}")
for structure in structures:
    true_forces_per_structure = structure.force
    pred_forces_per_structure = nnp.compute_force(structure)
    true_forces.append(true_forces_per_structure)
    pred_forces.append(pred_forces_per_structure)

dim = 0
to_axis = {d: c for d, c in enumerate('xyz')}
true_forces = jnp.concatenate(true_forces, axis=0)
pred_forces = jnp.concatenate(pred_forces, axis=0)

plt.scatter(true_forces[:, dim], pred_forces[:, dim], label='NNP')
plt.plot(true_forces[:, dim], true_forces[:, dim], 'r', label='REF')

label= f"force [{to_axis[dim]}]"
plt.ylabel("pred " + label)
plt.xlabel("true " + label)
plt.legend()
plt.show()