Getting started#

Below example codes illustrate how to use different modules in Pantea.

Initialization#

[1]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"  # disable GPU-accelerated computing

Imports#

[2]:
import pantea
import jax.numpy as jnp
import jax
import numpy as np
import matplotlib.pylab as plt
import seaborn as sns
from tqdm import tqdm
from pathlib import Path
[3]:
# from pantea.logger import set_logging_level
# import logging
# set_logging_level(logging.INFO)

Dataset#

RuNNer#

Read dataset in RuNNer format.

[4]:
from pantea.datasets import Dataset
base_dir = Path('./home/H2O_2')
structures = Dataset.from_runner("input.data", persist=False)
print("Total number of structures:", len(structures))
# structures.preload()
structures
Total number of structures: 20
[4]:
Dataset(datasource=RunnerDataSource(filename='input.data', dtype=float64), persist=False)

Data loader#

[5]:
# from torch.utils.data import DataLoader

Split train and validation structures#

[6]:
# import torch
# validation_split = 0.032
# nsamples = len(structures)
# split = int(np.floor(validation_split * nsamples))
# train_structures, valid_structures = torch.utils.data.random_split(structures, lengths=[nsamples-split, split])
# structures = valid_structures

Structure#

[7]:
s = structures[0]
s
[7]:
Structure(natoms=12, elements=('H', 'O'), dtype=float64)
[8]:
from ase.visualize import view
atoms = s.to_ase()
# view(atoms, viewer='ngl') # ase, ngl
[9]:
from ase.io.vasp import write_vasp
write_vasp('POSCAR', atoms)

Compare between structures#

[10]:
from pantea.utils.compare import compare
compare(structures[0], structures[1])
Comparing two structures, error metrics: RMSEpa
[10]:
{'force_RMSEpa': Array(0.85609414, dtype=float64),
 'energy_RMSEpa': Array(0.01299882, dtype=float64)}

Calculate distance between atoms#

[11]:
from pantea.atoms import calculate_distances
dis = calculate_distances(s, atom_indices=0)
dis[0, :5]
[11]:
Array([0.        , 2.24720275, 3.85385356, 4.84207409, 7.55265933],      dtype=float64)
[12]:
# sns.displot(dis.flatten(), bins=20)
# plt.axvline(dis.mean(), color='r');

Find neighboring atom#

[13]:
from pantea.atoms import Neighbor

neighbor = Neighbor.from_structure(s, r_cutoff=10.0)
print(neighbor)
print("Number of neighbors for atom index 0:", jnp.sum(neighbor.masks[0]))
Neighbor(r_cutoff=10.0)
Number of neighbors for atom index 0: 11

Per-atom energy offset#

[14]:
# structure = structures[0]
# atom_energy = {'O': 2.4, 'H': 1.2}

# structure.add_energy_offset(atom_energy)
# structure.total_energy

Descriptor#

Atomic environment descriptor.

[15]:
from pantea.descriptors import ACSF
from pantea.descriptors.acsf import G2, G3, G9, CutoffFunction
[16]:
acsf = ACSF('O')

cfn = CutoffFunction.from_cutoff_type(12.0, cutoff_type="tanhu")
g2_1 = G2(cfn, 0.0, 0.001)
g2_2 = G2(cfn, 0.0, 0.01)
g3_1 = G3(cfn, 0.2, 1.0, 1.0, 0.0)
g9_1 = G3(cfn, 0.2, 1.0, 1.0, 0.0)

acsf.add(g2_1, 'H')
acsf.add(g2_2, 'H')
acsf.add(g3_1, 'H', 'H')
acsf.add(g3_1, 'H', 'O')
acsf.add(g9_1, 'H', 'O')
acsf
[16]:
ACSF(central_element='O', symmetry_functions=5)

Computing descriptor values#

[17]:
val = acsf(s)
val[1, :]
[17]:
Array([1.08272759e+00, 9.35614402e-01, 2.52484053e-04, 7.64628911e-06,
       7.64628911e-06], dtype=float64)
[18]:
# sns.displot(val[:, 0], bins=20);

Gradient#

[19]:
acsf.grad(structures[0], atom_index=2)
[19]:
Array([[ 1.79978042e-02,  5.45284166e-02,  2.27747548e-02],
       [ 2.44782245e-02,  5.73790821e-02,  2.26291032e-02],
       [-4.11964454e-09,  4.37485864e-08,  2.58386814e-08],
       [-4.53174665e-06,  4.41009856e-04,  6.82317265e-05],
       [-4.53174665e-06,  4.41009856e-04,  6.82317265e-05]],      dtype=float64)

Scaler#

Descriptor scaler.

[21]:
from pantea.descriptors import DescriptorScaler

Fitting scaling parameters#

[23]:
scaler = DescriptorScaler(scale_type='scale_center')
# acsf = nnp.descriptor["H"]

for structure in tqdm(structures):
    x = acsf(structure)
    scaler.fit(x)

scaler
100%|██████████| 20/20 [00:00<00:00, 72.13it/s]
[23]:
DescriptorScaler(scale_type='scale_center', scale_min=0.0, scale_max=1.0)
[24]:
scaled_x = []
for structure in tqdm(structures):
    x = acsf(structure)
    scaled_x.append(scaler(x))

scaled_x = jnp.concatenate(scaled_x, axis=0)
scaled_x.shape
100%|██████████| 20/20 [00:00<00:00, 133.38it/s]
[24]:
(80, 5)
[25]:
# sx = scaled_x[:, 5]
# sns.displot(sx, bins=30)
# plt.axvline(sx.mean(), color='r', lw=3);
# plt.axvline(0, color='k');

Model#

[28]:
from pantea.models import NeuralNetworkModel
from pantea.models.nn import UniformInitializer
from flax import linen as nn
[29]:
model = NeuralNetworkModel(hidden_layers=((8, 'tanh'), (8, 'tanh')))
model
[29]:
NeuralNetworkModel(hidden_layers=((8, 'tanh'), (8, 'tanh')), param_dtype=float64)
[30]:
rng = jax.random.PRNGKey(2022)                       # PRNG Key
x = jnp.ones(shape=(8, acsf.num_symmetry_functions)) # Dummy Input
params = model.init(rng, x)                          # Initialize the parameters
jax.tree_map(lambda x: x.shape, params)              # Check the parameters
[30]:
{'params': {'layers_0': {'bias': (8,), 'kernel': (5, 8)},
  'layers_2': {'bias': (8,), 'kernel': (8, 8)},
  'layers_4': {'bias': (1,), 'kernel': (8, 1)}}}

Computing output energy#

[31]:
energies = model.apply(params, scaled_x[:, :])
[32]:
# sns.displot(energies, bins=30);

Atomic Potential#

An atomic potential calculates the energy of a specific element in structures. It forms the basic building block of the final potential, which typically contains multiple elements. Atomic potential bundles up all the necessary components such as descriptors, scalers, and models in order to output the per-atomic energy.

[33]:
from pantea.potentials import AtomicPotential
/home/hossein/miniconda3/envs/pantea/lib/python3.10/site-packages/pydantic/_internal/_fields.py:149: UserWarning: Field "model_save_naming_format" has conflict with protected namespace "model_".

You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
  warnings.warn(
[34]:
atomic_potential = AtomicPotential(
    descriptor=acsf,
    scaler=scaler,
    model=model,
)

atomic_potential
[34]:
AtomicPotential(
  descriptor=ACSF(central_element='O', symmetry_functions=5),
  scaler=DescriptorScaler(scale_type='scale_center', scale_min=0.0, scale_max=1.0),
  model=NeuralNetworkModel(hidden_layers=((8, 'tanh'), (8, 'tanh')), param_dtype=float64),
)
[35]:
out =  atomic_potential.apply(params["params"], s)
out.shape
[35]:
(4, 1)
[36]:
energies = []
for structure in tqdm(structures):
    out = atomic_potential.apply(params['params'], structure)
    energies.append(out)

energies = jnp.concatenate(energies, axis=0)
energies.shape
100%|██████████| 20/20 [00:00<00:00, 114.17it/s]
[36]:
(80, 1)
[37]:
# sns.displot(energies, bins=30);

Please note that the above graph is exactly the same graph as we obtained before by using the model.

Neural Network Potential#

An instance of neural network potential (NNP) including descirptor, scaler, and model for multiple elements can be initialzied directly from the input potential files.

[38]:
from pantea.datasets import Dataset
from pantea.potentials import NeuralNetworkPotential
from ase.visualize import view

Read dataset#

[39]:
base_dir = Path("GRN")

# Atomic data
structures = Dataset.from_runner(Path(base_dir, "input.data"))

# structures = [structures[i] for i in range(10)]
structure = structures[0]
structure
# view(structure.to_ase() * (3, 3, 2), viewer='ngl')
[39]:
Structure(natoms=24, elements=('C',), dtype=float64)

Load potential parameters#

[40]:
# Potential
nnp = NeuralNetworkPotential.from_file(Path(base_dir, "input.nn"))

# nnp.save()
nnp.load()

Predictions#

Warm-up period is bacause of the lazy class loading and just-in-time (JIT) compilation.

[41]:
total_energy = nnp(structure)
total_energy
[41]:
Array(-7.87381352, dtype=float64)
[42]:
forces = nnp.compute_forces(structure)
forces[:5]
[42]:
Array([[-0.06574249,  0.09717347,  0.12231123],
       [ 0.00342561, -0.1228041 ,  0.00807254],
       [-0.03078268, -0.02746679, -0.05159936],
       [-0.05093628, -0.096709  ,  0.0169795 ],
       [ 0.01947876, -0.09550387, -0.01810463]], dtype=float64)