Learning molecular dynamics: MDtrajNet
Back in March 2022, we introduced a novel concept of directly learning dynamics via 4D-spacetime atomistic AI models (4D models for short). The idea is to predict the nuclear coordinates as a continuous function of time. The model GICnet was published in JPCL in 2023:
Fuchun Ge, Lina Zhang, Yi-Fan Hou, Yuxinxin Chen, Arif Ullah, Pavlo O. Dral*. Four-dimensional-spacetime atomistic artificial intelligence models. J. Phys. Chem. Lett. 2023, 14, 7732–7743. DOI: 10.1021/acs.jpclett.3c01592.
However, this proof-of-concept work is not perfect and has many limitations. One major problem is that the GICnet model is not generalizable, i.e., it can only be trained and used for a specific molecule. Now we get a better choice, MDtrajNet, which overcomes these limitations. We also present MDtrajNet-1, a foundational model that directly generates MD trajectories across the chemical space.
MDtrajNet
MDtrajNet combines equivariant neural networks with a Transformer-based architecture to achieve strong accuracy and transferability in predicting long-time trajectories for both known and unseen systems. The errors of the trajectories generated by the foundational model MDtrajNet-1 for various molecular systems are close to those of the conventional ab initio MD. The model’s flexible design supports diverse application scenarios, including different statistical ensemble, boundary conditons, and interaction types.
See our preprint for more details:
Fuchun Ge and Pavlo O. Dral*. Artificial intelligence for direct prediction of molecular dynamics across chemical space. ChemRxiv. 2025. DOI: 10.26434/chemrxiv-2025-kc7sn.
备注
In this tutorial, we only talk about the Python API. Currently, the usage of MDtrajNet via input file/command line is not supported.
Now, let’s see how to use MDtrajNet in MLatom!
Prerequisites
MLatom 3.17.4
or latere3nn 0.4.4
(no guarantee for other versions)
备注
MLatom will download MDtrajNet-1 models for you. If the download fails, you can download it by yourself by following the error message.
Tutorial
Get started with examples on how to use it (notebook file
and model file
)
import mlatom as ml
import torch
import py3Dmol
import numpy as np
Using MDtrajNet-1¶
Here we use the MDtrajNet-1 model, which can be used out-of-the-box without additional training. Let's propagate a trajectory of the urea molecule.
# Load MDtrajNet-1 model
mdtrajnet1 = ml.MDtrajNet(model_file="MDtrajNet1")
ani1ccx = ml.methods('ANI-1ccx') # reference MLIP
# Set up initial conditions
urea = ml.molecule.from_numpy(
np.array(
[[-8.6269673e-04, 1.4393133e-01, 3.0587106e-03],
[ 1.1928320e+00, -1.5854324e+00, 1.8119251e-02],
[ 2.0238326e+00, -5.6408115e-02, 4.5813598e-02],
[-2.0032527e+00, -1.0434096e-01, 8.0641443e-03],
[-1.1603842e+00, -1.6217848e+00, 1.5372290e-02],
[ 1.1671916e+00, -5.7923901e-01, -1.2174180e-02],
[-1.1425726e+00, -6.1858261e-01, -2.3802600e-03],
[-2.7523117e-02, 1.3611412e+00, -4.8005872e-04]]
),
np.array(['C', 'H', 'H', 'H', 'H', 'N', 'N', 'O'])
)
urea.add_xyz_vectorial_property(
np.array(
[[ 0.00193602, 0.0003093 , -0.00290614],
[ 0.00918389, 0.0049542 , -0.03214212],
[-0.00590873, -0.02061948, 0.01312468],
[ 0.01689092, 0.0109872 , 0.01195774],
[-0.00867305, 0.01350845, -0.02433634],
[-0.00258934, 0.0021774 , 0.00338955],
[ 0.00418191, 0.00347437, 0.00301366],
[-0.00357088, -0.00573639, -0.00144728]]
),
'xyz_velocities'
)
# Propagate with MDtrajNet-1
traj4D = mdtrajnet1.propagate(
molecule=urea,
time=120, # simulate 120 fs
time_step=4, # temporal resolution
time_segment=8, # length of the time segment
rescale_velocities=True, # rescale the velocities to enhance stability with ANI-1ccx
potential_model=ani1ccx
).to_database() # first-time evaluation of models using GPU takes some time
# Propagate with traditional MD using ANI-1ccx
trajMD = ml.md(
model=ani1ccx,
molecule=urea,
time_step=0.05,
maximum_propagation_time=120
).molecular_trajectory.to_database()[::80]
# Dump trajectories
traj4D.write_file_with_xyz_coordinates('urea_4d.xyz')
trajMD.write_file_with_xyz_coordinates('urea_ani-1ccx.xyz')
model loaded from /mlatom/software/models/mdtrajnet1_model/MDtrajNet-1.0.pt model loaded from /mlatom/software/models/mdtrajnet1_model/MDtrajNet-1.1.pt model loaded from /mlatom/software/models/mdtrajnet1_model/MDtrajNet-1.2.pt model loaded from /mlatom/software/models/mdtrajnet1_model/MDtrajNet-1.3.pt /mlatom/software/miniconda3/lib/python3.11/site-packages/torchani/resources/ 0 fs remaining
If you use the Jupyter notebook, you can use the following codes to visualize the trajectories and check the similarities between the trajectories:
traj4D.view()
trajMD.view()
for i in range(0,30,5):
viewer=py3Dmol.view(width=200,height=150)
viewer.addModel(traj4D[i].get_xyz_string(),'xyz')
viewer.addModel(trajMD[i].get_xyz_string(),'xyz')
viewer.setStyle({"stick":{'radius':0.2}, "sphere": {"scale": 0.25}})
viewer.setStyle({'model':1 },{"stick":{'radius':0.2}, "sphere": {"scale": 0.25, 'color': 'cyan'}})
viewer.zoom(2.5)
viewer.rotate(-45, 'x')
viewer.show()
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
Train a MDtrajNet model¶
If you do not like MDtrajNet-1 or you want to have your own model, here is the tutorial of how to train a new MDtrajNet model. We take the diamond as an example, which is a periodic system.
# create a 2x2x2 supercell of diamond
a = 3.567 # the lattice constant of diamond
xyz = np.array(
[
[0, 0, 0],
[0, 2, 2],
[2, 0, 2],
[2, 2, 0],
[3, 3, 3],
[3, 1, 1],
[1, 3, 1],
[1, 1, 3]
]
) * a / 4 - a
z = np.array([6] * 8)
dia = ml.molecule.from_numpy(xyz, z)
dia.pbc = True
dia.cell = a
dia = dia.proliferate(XYZshifts=range(2))
dia.pbc = True
dia.cell = a * 2
# assign some random velocities
np.random.seed(0)
v = np.random.normal(0, 0.0032, (64, 3))
v -= np.mean(v, axis=0)
dia.add_xyz_vectorial_property(v, 'xyz_velocities')
# Create dataset (it will take some time)
ani1xnr = ml.models.methods("ANI-1xnr")
MDdia = ml.md(
model=ani1xnr,
molecule=dia,
time_step=0.05,
maximum_propagation_time=120
).molecular_trajectory
# Train the model
pbc = torch.tensor([True, True, True])
cell = torch.eye(3)*3.567*2
tc = 10.0
n_train = 1024
n_valid = 128
batch_size = 64
max_epochs = 4
trajs = [MDdia]
model = ml.MDtrajNet(model_file=f'model.pt', species=["C",])
model.train(
trajectories=trajs,
batch_size=batch_size,
num_train=n_train,
num_valid=n_valid,
max_epochs=max_epochs,
hyperparameters={
'time_cutoff': tc,
't_embed_dim': 4,
"radial_cutoff": 3.567,
"irreps_key": "8x0e + 8x1o",
"irreps_query": "8x0e + 8x1o",
"irreps_value": "8x0e + 8x1o",
},
pbc=pbc,
cell=cell
)
the trained 4D model will be saved in model.pt -------------------------------------------------------------------------------- epoch 0 :: lr: 0.001 time: 0.00 s validation RMSEs:: geometry : 0.011723 Å velocity : 0.002334 Å/fs validation loss :: 0.011723 model saved in model.pt
Propagate trajectories with MDtrajNet model¶
After you get the model, you can propagate trajectories with it!
model = ml.MDtrajNet(model_file='diamond.pt')
traj4Ddia = model.propagate(
molecule=dia,
time=120,
time_step=4,
time_segment=8,
rescale_velocities=True,
potential_model=ani1xnr
).to_database()
model loaded from diamond.pt 0 fs remaining
You can visualize the trajectories with:
trajMDdia = MDdia.to_database()[::80]
lastframe = traj4Ddia[0]
for frame in traj4Ddia[1:]:
frame.xyz_coordinates = lastframe.xyz_coordinates + (frame.xyz_coordinates - lastframe.xyz_coordinates + a ) % (a * 2) - a
lastframe = frame
lastframe = trajMDdia[0]
for frame in trajMDdia[1:]:
frame.xyz_coordinates = lastframe.xyz_coordinates + (frame.xyz_coordinates - lastframe.xyz_coordinates + a ) % (a * 2) - a
lastframe = frame
traj4Ddia.view()
trajMDdia.view()
for i in range(0, 30, 5):
viewer=py3Dmol.view(width=200,height=150)
viewer.addModel(traj4Ddia[i].get_xyz_string(),'xyz')
viewer.addModel(trajMDdia[i].get_xyz_string(),'xyz')
viewer.setStyle({"model": 0}, {"stick":{'radius':0.2}, "sphere": {"scale": 0.25}})
viewer.setStyle({"model": 1}, {"stick":{'radius':0.2}, "sphere": {"scale": 0.25, 'color': 'cyan'}})
viewer.show()
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
3Dmol.js failed to load for some reason. Please check your browser console for error messages.