mlatom.interfaces.torchani_interface 源代码

'''
.. code-block::

  !---------------------------------------------------------------------------! 
  ! Interface_TorchANI: Interface between TorchANI and MLatom                 ! 
  ! Implementations by: Fuchun Ge and Max Pinheiro Jr                         !
  ! Defines the MS-ANI model                                                  !
  ! Implemented by: Mikolaj Martyka                                           ! 
  !---------------------------------------------------------------------------! 
'''

from __future__ import annotations
from typing import Any, Union, Dict, List, Callable
from types import FunctionType
import os, sys, uuid
import numpy as np
import tqdm
from collections import OrderedDict
import torch
from torch import nn, Tensor
import torchani
from torchani.data import TransformableIterable, IterableAdapter
from typing import Dict, List


from .. import data
from .. import models
from ..decorators import doc_inherit

def median(yp,yt):
        return torch.median(torch.abs(yp-yt))
def molDB2ANIdata_state(molDB, 
                  property_to_learn=None,
                  xyz_derivative_property_to_learn=None,
                 use_state=True):
    def molDBiter():
        for mol in molDB.molecules:
            ret = {'species': mol.get_element_symbols(), 'coordinates': mol.xyz_coordinates}
            if property_to_learn is not None:
                ret['energies'] = mol.__dict__[property_to_learn]
            if xyz_derivative_property_to_learn is not None:
                ret['forces'] = -1 * mol.get_xyz_vectorial_properties(xyz_derivative_property_to_learn)
            if use_state:
                ret['state']=mol.__dict__["current_state"]
            if False: #debug
                ret['id']=mol.__dict__["mol_id"]
            yield ret
    return TransformableIterable(IterableAdapter(lambda: molDBiter()))
[文档] class StateInputNet(torch.nn.Module): def __init__(self, AEV_comp, net): super(StateInputNet,self).__init__() self.AEV_computer = AEV_comp self.network = net def forward(self, species, coordinates, state): species_, aev = self.AEV_computer((species,coordinates)) state_tensor = torch.transpose(state.expand(species.size()[1],-1),0,1) aev_with_state = torch.cat((aev,state_tensor.unsqueeze(2)),2) out = self.network((species_, aev_with_state)) return out
def unpackData2State(mol_db): train_data = data.molecular_database() for mol_id, i in enumerate(mol_db): for idx, state in enumerate(i.electronic_states): new_mol = data.molecule() new_mol.read_from_xyz_string(i.get_xyz_string()) new_mol.current_state = idx new_mol.energy = state.energy new_mol.add_xyz_derivative_property(state.get_xyz_vectorial_properties("energy_gradients"), 'energy', 'energy_gradients' ) new_mol.mol_id = mol_id train_data.append(new_mol) return train_data #The two MOLDB2ANIdata should be merged, i just didnt have the time to do it. def molDB2ANIdata(molDB, property_to_learn=None, xyz_derivative_property_to_learn=None): def molDBiter(): for mol in molDB.molecules: ret = {'species': mol.get_element_symbols(), 'coordinates': mol.xyz_coordinates} if property_to_learn is not None: ret['energies'] = mol.__dict__[property_to_learn] if xyz_derivative_property_to_learn is not None: ret['forces'] = -1 * mol.get_xyz_vectorial_properties(xyz_derivative_property_to_learn) if mol.pbc is not None: ret['pbc'] = mol.pbc if mol.cell is not None: ret['cell'] = mol.cell yield ret return TransformableIterable(IterableAdapter(lambda: molDBiter())) PADDING = { 'species': -1, 'coordinates': 0.0, 'forces': 0.0, 'energies': 0.0, 'pbc': False, 'cell': 0.0, }
[文档] class ani(models.ml_model, models.torchani_model): ''' Create an `ANI <http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A>`_ (`ANAKIN <https://www.google.com/search?q=Anakin+Skywalker>`_-ME: Accurate NeurAl networK engINe for Molecular Energie) model object. Interfaces to `TorchANI <https://pubs.acs.org/doi/10.1021/acs.jcim.0c00451>`_. Arguments: model_file (str, optional): The filename that the model to be saved with or loaded from. device (str, optional): Indicate which device the calculation will be run on. i.e. 'cpu' for CPU, 'cuda' for Nvidia GPUs. When not speficied, it will try to use CUDA if there exists valid ``CUDA_VISIBLE_DEVICES`` in the environ of system. hyperparameters (Dict[str, Any] | :class:`mlatom.models.hyperparameters`, optional): Updates the hyperparameters of the model with provided. verbose (int, optional): 0 for silence, 1 for verbosity. ''' hyperparameters = models.hyperparameters({ #### Training #### 'batch_size': models.hyperparameter(value=8, minval=1, maxval=1024, optimization_space='linear', dtype=int), 'max_epochs': models.hyperparameter(value=1000000, minval=100, maxval=1000000, optimization_space='log', dtype=int), 'learning_rate': models.hyperparameter(value=0.001, minval=0.0001, maxval=0.01, optimization_space='log'), 'early_stopping_learning_rate': models.hyperparameter(value=1.0E-5, minval=1.0E-6, maxval=1.0E-4, optimization_space='log'), 'lr_reduce_patience': models.hyperparameter(value=64, minval=16, maxval=256, optimization_space='linear'), 'lr_reduce_factor': models.hyperparameter(value=0.5, minval=0.1, maxval=0.9, optimization_space='linear'), 'lr_reduce_threshold': models.hyperparameter(value=0.0, minval=-0.01, maxval=0.01, optimization_space='linear'), #### Loss #### 'force_coefficient': models.hyperparameter(value=0.1, minval=0.05, maxval=5, optimization_space='linear'), 'median_loss': models.hyperparameter(value=False), 'validation_loss_type': models.hyperparameter(value='MSE', choices=['MSE', 'mean_RMSE']), 'loss_type': models.hyperparameter(value='weighted', choices=['weighted', 'geometric']), #### Network #### "neurons": models.hyperparameter(value=[[160, 128, 96]]), "activation_function": models.hyperparameter(value=lambda: torch.nn.CELU(0.1), optimization_space='choice', choices=["CELU", "ReLU", "GELU"], dtype=(str, type, FunctionType)), "fixed_layers": models.hyperparameter(value=False), #### AEV #### 'Rcr': models.hyperparameter(value=5.2000e+00, minval=1.0, maxval=10.0, optimization_space='linear'), 'Rca': models.hyperparameter(value=3.5000e+00, minval=1.0, maxval=10.0, optimization_space='linear'), 'EtaR': models.hyperparameter(value=[1.6000000e+01]), 'ShfR': models.hyperparameter(value=[9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00]), 'Zeta': models.hyperparameter(value=[3.2000000e+01]), 'ShfZ': models.hyperparameter(value=[1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00]), 'EtaA': models.hyperparameter(value=[8.0000000e+00]), 'ShfA': models.hyperparameter(value=[9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00]), }) argsdict = {} model_file = None model = None property_name = 'y' species_order = [] program = 'TorchANI' meta_data = { "genre": "neural network" } verbose = 1 # 2 can give more training information def __init__(self, model_file: str = None, device: str = None, hyperparameters: Union[Dict[str,Any], models.hyperparameters]={}, verbose=1, **kwargs): if device == None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(device) self.hyperparameters = self.hyperparameters.copy() self.hyperparameters.update(hyperparameters) self.verbose = verbose self.energy_shifter = torchani.utils.EnergyShifter(None) if 'key' in kwargs: self.key = kwargs['key'] else: self.key = None if model_file: if os.path.isfile(model_file): self.load(model_file) else: if self.verbose: print(f'the trained ANI model will be saved in {model_file}') self.model_file = model_file def parse_args(self, args): super().parse_args(args) for hyperparam in self.hyperparameters: if hyperparam in args.hyperparameter_optimization['hyperparameters']: self.parse_hyperparameter_optimization(args, hyperparam) # elif hyperparam in args.data: # self.hyperparameters[hyperparam].value = args.data[hyperparam] elif 'ani' in args.data and hyperparam in args.ani.data: self.hyperparameters[hyperparam].value = args.ani.data[hyperparam]
[文档] def reset(self): super().reset() self.model = None
[文档] def save(self, model_file: str = '') -> None: ''' Save the model to file (.pt format). Arguments: model_file (str, optional): The filename that the model to be saved into. If not provided, a randomly generated string will be used. ''' if not model_file: model_file =f'ani_{str(uuid.uuid4())}.pt' self.model_file = model_file torch.save( { 'network': self.networkdict, 'args': self.argsdict, 'nn': self.nn.state_dict(), 'AEV_computer': self.aev_computer, 'energy_shifter': self.energy_shifter, } , model_file ) if self.verbose: print(f'model saved in {model_file}')
[文档] def load(self, model_file: str = '', species_order: Union[List[str], None] = None, AEV_parameters: Union[Dict, None] = None, self_energies: Union[List[float], None] = None, reset_parameters: bool = False, method: str = '') -> None: ''' Load a saved ANI model from file (.pt format). Arguments: model_file (str): The filename of the model to be loaded. species_order(List[str], optional): Manually provide species order if it's not present in the saved model. AEV_parameters(Dict, optional): Manually provide AEV parameters if it's not present in the saved model. self_energies(List[float], optional): Manually provide self energies if it's not present in the saved model. reset_parameters(bool): Reset network paramters in the loaded model. method(str): Load an ANI method, see also :meth:`ani.load_ani_model`. ''' if method: self.load_ani_model(method) return model_dict = torch.load(model_file, map_location=torch.device('cpu')) if 'property' in model_dict['args']: self.property_name = model_dict['args']['property'] if 'species_order' in model_dict['args']: self.species_order = model_dict['args']['species_order'] # if type(self.species_order[0]) == [int, np.int_]: # self.species_order = [data.atomic_number2element_symbol[z] for z in self.species_order] # if self.species_order[0].lower() == self.species_order[0]: # self.species_order = [data.atomic_number2element_symbol[str(z)] for z in self.species_order] else: print('species order not found, please provide explictly') self.species_order = species_order self.argsdict.update({'species_order': self.species_order}) if 'AEV_computer' in model_dict: self.aev_computer = model_dict['AEV_computer'] if 'use_cuda_extension' not in self.aev_computer.__dict__: self.aev_computer.use_cuda_extension = False self.argsdict.update({'Rcr': self.aev_computer.Rcr, 'Rca': self.aev_computer.Rca, 'EtaR': self.aev_computer.EtaR, 'ShfR': self.aev_computer.ShfR, 'Zeta': self.aev_computer.Zeta, 'ShfZ': self.aev_computer.ShfZ, 'EtaA': self.aev_computer.EtaA, 'ShfA': self.aev_computer.ShfA}) elif 'Rcr' in model_dict['args']: self.AEV_setup(**model_dict['args']) else: print('AEV parameters not found, please provide explictly') self.AEV_setup(**AEV_parameters) if 'energy_shifter' in model_dict: self.energy_shifter = model_dict['energy_shifter'] elif 'energy_shifter_train' in model_dict['args']: self.energy_shifter = model_dict['args']['energy_shifter_train'] elif 'self_energies_train' in model_dict['args']: self.energy_shifter = torchani.utils.EnergyShifter(model_dict['args']['self_energies_train']) elif 'self_energies' in model_dict['args']: self.energy_shifter = torchani.utils.EnergyShifter(model_dict['args']['self_energies']) else: print('self energy information not found, please provide explictly') self.energy_shifter = torchani.utils.EnergyShifter(self_energies) self.energy_shifter.to(self.device) if 'network' in model_dict and 'nn' in model_dict: self.networkdict = model_dict['network'] if isinstance(self.networkdict, OrderedDict) or type(self.networkdict) == dict: self.neurons = [[layer.out_features for layer in network if isinstance(layer, torch.nn.Linear)] for network in self.networkdict.values()] self.nn = torchani.ANIModel(self.networkdict if isinstance(self.networkdict, OrderedDict) else self.networkdict.values()) elif type(self.networkdict) == list: self.neurons = [[[layer.out_features for layer in network if isinstance(layer, torch.nn.Linear)] for network in subdict.values()] for subdict in self.networkdict] self.nn = torchani.nn.Ensemble([torchani.ANIModel(subdict if isinstance(subdict, OrderedDict) else subdict.values()) for subdict in self.networkdict]) if reset_parameters: self.NN_initialize() else: self.nn.load_state_dict(model_dict['nn']) self.optimizer_setup(**self.hyperparameters) else: print('network parameters not found') self.model = torchani.nn.Sequential(self.aev_computer, self.nn).to(self.device) self.model.eval() if self.verbose: print(f'model loaded from {model_file}')
[文档] def load_ani_model(self, method: str, **hyperparameters) -> None: ''' Load an ANI model. Arguments: method(str): Can be ``'ANI-1x'``, ``'ANI-1ccx'``, or ``'ANI-2x'``. ''' self.hyperparameters.update(hyperparameters) if 'ANI-1x'.casefold() in method.casefold(): model = torchani.models.ANI1x(periodic_table_index=True).to(self.device) elif 'ANI-1ccx'.casefold() in method.casefold(): model = torchani.models.ANI1ccx(periodic_table_index=True).to(self.device) elif 'ANI-2x'.casefold() in method.casefold(): model = torchani.models.ANI2x(periodic_table_index=True).to(self.device) else: print("method not found, please check ANI_methods().available_methods") return self.species_order = model.species self.argsdict.update({'species_order': self.species_order}) self.aev_computer = model.aev_computer self.networkdict = [OrderedDict(**{k: v for k, v in nn.items()}) for nn in model.neural_networks] self.neurons = [[[layer.out_features for layer in network if isinstance(layer, torch.nn.Linear)] for network in subdict.values()] for subdict in self.networkdict] self.nn = model.neural_networks self.optimizer_setup(**self.hyperparameters) self.energy_shifter = model.energy_shifter self.model = torchani.nn.Sequential(self.aev_computer, self.nn).to(self.device).float() if self.verbose: print(f'loaded {method} model')
[文档] @doc_inherit def train( self, molecular_database: data.molecular_database, property_to_learn: str = 'energy', xyz_derivative_property_to_learn: str = None, validation_molecular_database: Union[data.molecular_database, str, None] = 'sample_from_molecular_database', hyperparameters: Union[Dict[str,Any], models.hyperparameters] = {}, spliting_ratio: float = 0.8, save_model: bool = True, file_to_save_model: str = None, check_point: str = None, reset_optim_state: bool = False, use_last_model: bool = False, reset_parameters: bool = False, reset_network: bool = False, reset_optimizer: bool = False, reset_energy_shifter: Union[bool, torchani.utils.EnergyShifter, list] = False, save_every_epoch: bool = False, save_epoch_interval: int = None, energy_weighting_function: Callable = None, energy_weighting_function_kwargs: dict = {}, verbose: int = None ) -> None: r''' validation_molecular_database (:class:`mlatom.data.molecular_database` | str, optional): Explicitly defines the database for validation, or use ``'sample_from_molecular_database'`` to make it sampled from the training set. hyperparameters (Dict[str, Any] | :class:`mlatom.models.hyperparameters`, optional): Updates the hyperparameters of the model with provided. spliting_ratio (float, optional): The ratio sub-training dataset in the whole training dataset. save_model (bool, optional): Whether save the model to disk during training process. Note that the model might be saved many times during training. reset_optim_state (bool, optional): Whether to reset the state of optimizer. use_last_model (bool, optional): Whether to keep the ``self.model`` as it is in the last training epoch. If ``False``, the best model will be loaded to memory at the end of training. reset_parameters (bool, optional): Whether to reset the model's parameters before training. reset_network (bool, optional): Whether to re-construct the network before training. reset_optimizer (bool, optional): Whether to re-define the optimizer before training . save_every_epoch (bool, optional): Whether to save model in every epoch, valid when ``save_model`` is ``True``. save_epoch_interval (int, optional): The interval to save each epoch. energy_weighting_function (Callable, optional): A weighting function :math:`\mathit{W}(\mathbf{E_ref})` that assign weights to training points based on their reference energies. energy_weighting_function_kwargs (dict, optional): Extra weighting function arguments in a dictionary. ''' if hyperparameters: self.hyperparameters.update(hyperparameters) if verbose: self.verbose = verbose if file_to_save_model: self.model_file = file_to_save_model energy_weighting_function_kwargs = {k: (v.value if isinstance(v, models.hyperparameter) else v) for k, v in energy_weighting_function_kwargs.items()} if reset_energy_shifter: if self.energy_shifter: self.energy_shifter_ = self.energy_shifter if type(reset_energy_shifter) == list: self.energy_shifter = torchani.utils.EnergyShifter(reset_energy_shifter) elif isinstance(reset_energy_shifter, torchani.utils.EnergyShifter): self.energy_shifter = reset_energy_shifter else: self.energy_shifter = torchani.utils.EnergyShifter(None) if not molecular_database._is_uniform_cell(): print('non-uniform PBC cells detected, using batch_size=1') self.hyperparameters.batch_size = 1 self.data_setup(molecular_database, validation_molecular_database, spliting_ratio, property_to_learn, xyz_derivative_property_to_learn) # print energy shifter information if self.verbose and self.verbose == 2: print('\nDelta self atomic energies: ', ) for sp, esp in zip(self.species_order, self.energy_shifter.self_energies): print(f'{sp} : {esp}') print('\n') sys.stdout.flush() if not self.model: self.model_setup(**self.hyperparameters) else: if 'fixed_layers' in hyperparameters: if hyperparameters['fixed_layers'] and type(hyperparameters['fixed_layers'])==list: self.fix_layers(layers_to_fix=hyperparameters['fixed_layers']) if reset_network: self.NN_setup(**self.hyperparameters) if reset_parameters: self.NN_initialize() if reset_optimizer: self.optimizer_setup(**self.hyperparameters) self.model.train() if self.verbose: print(self.model) if check_point and os.path.isfile(check_point): if self.verbose and self.verbose == 2: print(f'\nCheckpoint file {check_point} found. Model training will start from checkpoint.\n') sys.stdout.flush() checkpoint = torch.load(check_point) self.nn.load_state_dict(checkpoint['nn']) if not reset_optim_state: self.AdamW.load_state_dict(checkpoint['AdamW']) self.SGD.load_state_dict(checkpoint['SGD']) self.AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler']) self.SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler']) def validate(): total_error = 0.0 mae_sum = torch.nn.L1Loss(reduction='sum'); energy_mae = 0.0 # mse_sum = torch.nn.MSELoss(reduction='sum'); energy_mse = 0.0 energy_loss = 0.0 count = 0 for properties in self.validation_set: true_energies = properties['energies'].to(self.device).float() species = properties['species'].to(self.device) pbc = properties['pbc'][0].to(self.device) if 'pbc' in properties else None cell = properties['cell'][0].float().to(self.device) if 'cell' in properties else None num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype) if callable(energy_weighting_function): weightings_e = energy_weighting_function(self.energy_shifter((species, true_energies)).energies, **energy_weighting_function_kwargs) else: weightings_e = 1 coordinates = properties['coordinates'].to(self.device).float() _, predicted_energies = self.model((species, coordinates), pbc=pbc, cell=cell) # add MAE, RMSE, loss report # energy_mse += mse_sum(predicted_energies*weightings_e, true_energies*weightings_e).item() energy_mae += mae_sum(predicted_energies*weightings_e, true_energies*weightings_e).item() energy_loss += (loss_function(predicted_energies, true_energies, weightings_e) / num_atoms.sqrt()).sum().item() if self.hyperparameters.validation_loss_type == 'mean_RMSE': total_error += loss_function(predicted_energies, true_energies, weightings_e, reduction='sum').nanmean().sqrt().item() else: total_error += loss_function(predicted_energies, true_energies, weightings_e, reduction='sum').nanmean().item() count += predicted_energies.shape[0] return total_error/count, (total_error/count)**0.5, energy_mae/count, energy_loss/count def loss_function(prediction, reference, weightings=1, reduction='none'): return torch.nn.functional.mse_loss(prediction*weightings, reference*weightings, reduction=reduction) if self.verbose: print("training starting from epoch", self.AdamW_scheduler.last_epoch + 1) for _ in range(self.AdamW_scheduler.last_epoch + 1, self.hyperparameters.max_epochs + 1): mse, rmse, mae, validation_loss = validate() if self.verbose: print('validation loss:', mse, 'at epoch', self.AdamW_scheduler.last_epoch + 1) if self.verbose == 2: print('validation MAE:', mae, 'at epoch', self.AdamW_scheduler.last_epoch + 1) print('validation RMSE:', rmse, 'at epoch', self.AdamW_scheduler.last_epoch + 1) print('validation energy loss:', validation_loss, 'at epoch', self.AdamW_scheduler.last_epoch + 1) print('best validation MSE:', self.AdamW_scheduler.best) sys.stdout.flush() learning_rate = self.AdamW.param_groups[0]['lr'] if self.verbose: print('learning_rate:',learning_rate) if learning_rate < self.hyperparameters.early_stopping_learning_rate: break if self.AdamW_scheduler.is_better(mse, self.AdamW_scheduler.best) or save_every_epoch: if save_model: self.save(self.model_file) if save_epoch_interval: if (self.AdamW_scheduler.last_epoch + 1)%save_epoch_interval==0 and save_model: self.save(f'{self.model_file}.epoch{self.AdamW_scheduler.last_epoch + 1}') self.AdamW_scheduler.step(mse) self.SGD_scheduler.step(mse) for properties in tqdm.tqdm( self.subtraining_set, total=len(self.subtraining_set), desc="epoch {}".format(self.AdamW_scheduler.last_epoch), disable=not self.verbose, ): true_energies = properties['energies'].to(self.device).float() species = properties['species'].to(self.device) num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype) pbc = properties['pbc'][0].to(self.device) if 'pbc' in properties else None cell = properties['cell'][0].float().to(self.device) if 'cell' in properties else None if callable(energy_weighting_function): weightings_e = energy_weighting_function(self.energy_shifter((species, true_energies)).energies, **energy_weighting_function_kwargs) weightings_f = energy_weighting_function(self.energy_shifter((species, true_energies)).energies, **energy_weighting_function_kwargs)[:, None, None] else: weightings_e = 1 weightings_f = 1 if xyz_derivative_property_to_learn: coordinates = properties['coordinates'].to(self.device).float().requires_grad_(True) true_forces = properties['forces'].to(self.device).float() _, predicted_energies = self.model((species, coordinates), pbc=pbc, cell=cell) forces = -torch.autograd.grad(predicted_energies.sum(), coordinates, create_graph=True, retain_graph=True)[0] # true_energies[true_energies.isnan()]=predicted_energies[true_energies.isnan()] if self.hyperparameters.median_loss: energy_loss= median(predicted_energies,true_energies) else: energy_loss = (loss_function(predicted_energies, true_energies, weightings_e) / num_atoms.sqrt()).nanmean() # true_forces[true_forces.isnan()]=forces[true_forces.isnan()] force_loss = (loss_function(true_forces, forces, weightings_f).sum(dim=(1, 2)) / num_atoms).nanmean() if self.hyperparameters.loss_type == 'weighted': loss = energy_loss + self.hyperparameters.force_coefficient * force_loss if self.hyperparameters.loss_type == 'geometric': loss = (energy_loss * force_loss)**0.5 else: coordinates = properties['coordinates'].to(self.device).float() _, predicted_energies = self.model((species, coordinates), pbc=pbc, cell=cell) loss = (loss_function(predicted_energies, true_energies, weightings_e) / num_atoms.sqrt()).nanmean() self.AdamW.zero_grad() self.SGD.zero_grad() loss.backward() self.AdamW.step() self.SGD.step() if self.verbose and self.verbose == 2: print('Training loss:', loss.item(), 'at epoch', self.AdamW_scheduler.last_epoch,'\n') sys.stdout.flush() if check_point: torch.save({ 'nn': self.nn.state_dict(), 'AdamW': self.AdamW.state_dict(), 'SGD': self.SGD.state_dict(), 'AdamW_scheduler': self.AdamW_scheduler.state_dict(), 'SGD_scheduler': self.SGD_scheduler.state_dict(), }, check_point) # print the performance of the best model if self.verbose and self.verbose == 2: print('\nPerformance of the best model on validation set') _, best_rmse, best_mae, best_val_loss = validate() print('best validation MAE:', best_mae,) print('best validation RMSE:', best_rmse,) print('best validation energy loss:', best_val_loss,) if save_model and not use_last_model: self.load(self.model_file)
[文档] @doc_inherit def predict( self, molecular_database: data.molecular_database = None, molecule: data.molecule = None, calculate_energy: bool = False, calculate_energy_gradients: bool = False, calculate_hessian: bool = False, property_to_predict: Union[str, None] = 'estimated_y', xyz_derivative_property_to_predict: Union[str, None] = None, hessian_to_predict: Union[str, None] = None, batch_size: int = 2**16, ) -> None: ''' batch_size (int, optional): The batch size for batch-predictions. ''' molDB, property_to_predict, xyz_derivative_property_to_predict, hessian_to_predict = \ super().predict(molecular_database=molecular_database, molecule=molecule, calculate_energy=calculate_energy, calculate_energy_gradients=calculate_energy_gradients, calculate_hessian=calculate_hessian, property_to_predict = property_to_predict, xyz_derivative_property_to_predict = xyz_derivative_property_to_predict, hessian_to_predict = hessian_to_predict) if not molDB._is_uniform_cell(): batch_size = 1 for batch in molDB.batches(batch_size): for properties in molDB2ANIdata(batch).species_to_indices(self.species_order).collate(batch_size, PADDING): species = properties['species'].to(self.device) xyz_coordinates = properties['coordinates'].float().to(self.device) break pbc = properties['pbc'][0].to(self.device) if 'pbc' in properties else None cell = properties['cell'][0].float().to(self.device) if 'cell' in properties else None if pbc is not None and cell is not None: xyz_coordinates = torchani.utils.map2central(cell, xyz_coordinates, pbc) xyz_coordinates = xyz_coordinates.requires_grad_(bool(xyz_derivative_property_to_predict or hessian_to_predict)) ANI_NN_energies = self.energy_shifter(self.model((species, xyz_coordinates), pbc=pbc, cell=cell)).energies if property_to_predict: batch.add_scalar_properties(ANI_NN_energies.detach().cpu().numpy(), property_to_predict) if xyz_derivative_property_to_predict or hessian_to_predict: ANI_NN_energy_gradients = torch.autograd.grad(ANI_NN_energies.sum(), xyz_coordinates, create_graph=True, retain_graph=True)[0] if xyz_derivative_property_to_predict: grads = ANI_NN_energy_gradients.detach().cpu().numpy() batch.add_xyz_vectorial_properties(grads, xyz_derivative_property_to_predict) if hessian_to_predict: ANI_NN_hessians = torchani.utils.hessian(xyz_coordinates, energies=ANI_NN_energies) batch.add_hessian_properties(ANI_NN_hessians.detach().cpu().numpy(), hessian_to_predict)
def AEV_setup(self, **kwargs): kwargs = models.hyperparameters(kwargs) Rcr = kwargs.Rcr Rca = kwargs.Rca EtaR = torch.tensor(kwargs.EtaR).to(self.device) ShfR = torch.tensor(kwargs.ShfR).to(self.device) Zeta = torch.tensor(kwargs.Zeta).to(self.device) ShfZ = torch.tensor(kwargs.ShfZ).to(self.device) EtaA = torch.tensor(kwargs.EtaA).to(self.device) ShfA = torch.tensor(kwargs.ShfA).to(self.device) self.aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, len(self.species_order)) self.argsdict.update({'Rcr': Rcr, 'Rca': Rca, 'EtaR': EtaR, 'ShfR': ShfR, 'Zeta': Zeta, 'ShfZ': ShfZ, 'EtaA': EtaA, 'ShfA': ShfA, 'species_order': self.species_order}) def NN_setup(self, **kwargs): kwargs = models.hyperparameters(kwargs) if len(kwargs.neurons) == 1: self.neurons = [kwargs.neurons[0].copy() for _ in range(len(self.species_order))] else: self.neurons = kwargs.neurons self.networkdict = OrderedDict() for i, specie in enumerate(self.species_order): self.neurons[i] += [1] layers = [torch.nn.Linear(self.aev_computer.aev_length, self.neurons[i][0])] for j in range(len(self.neurons[i]) - 1): if type(kwargs.activation_function) == str: layers += [torch.nn.__dict__[kwargs.activation_function]()] elif callable(kwargs.activation_function): layers += [kwargs.activation_function()] layers += [torch.nn.Linear(self.neurons[i][j], self.neurons[i][j + 1])] self.networkdict[specie] = torch.nn.Sequential(*layers) self.nn = torchani.ANIModel(self.networkdict) self.NN_initialize() self.optimizer_setup(**kwargs)
[文档] def NN_initialize(self, a: float = 1.0) -> None: ''' Reset the network parameters using :meth:`torch.nn.init.kaiming_normal_`. Arguments: a(float): Check `torch.nn.init.kaiming_normal_() <https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_>`_. ''' def init_params(m): if isinstance(m, torch.nn.Linear): torch.nn.init.kaiming_normal_(m.weight, a) torch.nn.init.zeros_(m.bias) self.nn.apply(init_params)
def optimizer_setup(self, **kwargs): kwargs = models.hyperparameters(kwargs) if isinstance(self.networkdict, OrderedDict) or type(self.networkdict) == dict: wlist2d = [ [ {'params': [self.networkdict[specie][j * 2].weight]} if j == 0 or j == len(self.neurons[i]) - 1 else {'params': [self.networkdict[specie][j*2].weight], 'weight_decay': 0.0001 / 10**j} for j in range(len(self.neurons[i])) ] for i, specie in enumerate(self.species_order) ] blist2d = [ [ {'params': [self.networkdict[specie][j * 2].bias]} for j in range(len(self.neurons[i])) ] for i, specie in enumerate(self.species_order) ] self.AdamW = torch.optim.AdamW([i for j in wlist2d for i in j],lr=kwargs.learning_rate) self.SGD = torch.optim.SGD([i for j in blist2d for i in j], lr=kwargs.learning_rate) elif type(self.networkdict) == list: wlist3d =[[ [ {'params': [subdict[specie][j * 2].weight]} if j == 0 or j == len(self.neurons[k][i]) - 1 else {'params': [subdict[specie][j*2].weight], 'weight_decay': 0.0001 / 10**j} for j in range(len(self.neurons[k][i])) ] for i, specie in enumerate(self.species_order) ] for k, subdict in enumerate(self.networkdict)] blist3d = [[ [ {'params': [subdict[specie][j * 2].bias]} for j in range(len(self.neurons[k][i])) ] for i, specie in enumerate(self.species_order) ] for k, subdict in enumerate(self.networkdict)] self.AdamW = torch.optim.AdamW([i for k in wlist3d for j in k for i in j],lr=kwargs.learning_rate) self.SGD = torch.optim.SGD([i for k in blist3d for j in k for i in j], lr=kwargs.learning_rate) self.AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.AdamW, factor=kwargs.lr_reduce_factor, patience=kwargs.lr_reduce_patience, threshold=kwargs.lr_reduce_threshold) self.SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.SGD, factor=kwargs.lr_reduce_factor, patience=kwargs.lr_reduce_patience, threshold=kwargs.lr_reduce_threshold) def model_setup(self, **kwargs): self.AEV_setup(**kwargs) self.NN_setup(**kwargs) self.model = torchani.nn.Sequential(self.aev_computer, self.nn).float().to(self.device)
[文档] def fix_layers(self, layers_to_fix: Union[List[List[int]],List[int]]): ''' Fix specific layers to be non-trainable for each element. Arguments: layers_to_fix (List): Should be: - A list of integers. Layers indicated by the integers will be fixed. - A list of lists of integers. Each sub-list defines the layers to be fixed for each species, in the order of `self.species_order`. Examples: .. code-block:: python import mlatom as ml >>> ani = ml.models.ani(model_file='ANI.pt') model loade from ANI.pt >>> ani.model # show model summary Sequential( (0): AEVComputer() (1): ANIModel( (C): Sequential( (0): Linear(in_features=240, out_features=160, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=160, out_features=128, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=128, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (H): Sequential( (0): Linear(in_features=240, out_features=160, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=160, out_features=128, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=128, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) (O): Sequential( (0): Linear(in_features=240, out_features=160, bias=True) (1): CELU(alpha=0.1) (2): Linear(in_features=160, out_features=128, bias=True) (3): CELU(alpha=0.1) (4): Linear(in_features=128, out_features=96, bias=True) (5): CELU(alpha=0.1) (6): Linear(in_features=96, out_features=1, bias=True) ) ) ) >>> parameters = dict(a.model.named_parameters()) >>> {key: parameters[key].requires_grad for key in parameters} # show the trainability of parameters {'1.C.0.weight': True, '1.C.0.bias': True, '1.C.2.weight': True, '1.C.2.bias': True, '1.C.4.weight': True, '1.C.4.bias': True, '1.C.6.weight': True, '1.C.6.bias': True, '1.H.0.weight': True, '1.H.0.bias': True, '1.H.2.weight': True, '1.H.2.bias': True, '1.H.4.weight': True, '1.H.4.bias': True, '1.H.6.weight': True, '1.H.6.bias': True, '1.O.0.weight': True, '1.O.0.bias': True, '1.O.2.weight': True, '1.O.2.bias': True, '1.O.4.weight': True, '1.O.4.bias': True, '1.O.6.weight': True, '1.O.6.bias': True} >>> ani.fix_layers([[0, 2], [0], [4]]) # fix layer 0 and 2 for C, layer 0 for H, and layer 4 for O >>> {key: parameters[key].requires_grad for key in parameters} {'1.C.0.weight': False, '1.C.0.bias': False, '1.C.2.weight': False, '1.C.2.bias': False, '1.C.4.weight': True, '1.C.4.bias': True, '1.C.6.weight': True, '1.C.6.bias': True, '1.H.0.weight': False, '1.H.0.bias': False, '1.H.2.weight': True, '1.H.2.bias': True, '1.H.4.weight': True, '1.H.4.bias': True, '1.H.6.weight': True, '1.H.6.bias': True, '1.O.0.weight': True, '1.O.0.bias': True, '1.O.2.weight': True, '1.O.2.bias': True, '1.O.4.weight': False, '1.O.4.bias': False, '1.O.6.weight': True, '1.O.6.bias': True} >>> ani.fix_layers([[0, 2, 4]]) # fix layer 0, 2, and 4 for all networks {'1.C.0.weight': False, '1.C.0.bias': False, '1.C.2.weight': False, '1.C.2.bias': False, '1.C.4.weight': False, '1.C.4.bias': False, '1.C.6.weight': True, '1.C.6.bias': True, '1.H.0.weight': False, '1.H.0.bias': False, '1.H.2.weight': False, '1.H.2.bias': False, '1.H.4.weight': False, '1.H.4.bias': False, '1.H.6.weight': True, '1.H.6.bias': True, '1.O.0.weight': False, '1.O.0.bias': False, '1.O.2.weight': False, '1.O.2.bias': False, '1.O.4.weight': False, '1.O.4.bias': False, '1.O.6.weight': True, '1.O.6.bias': True} ''' if layers_to_fix: if len(layers_to_fix) == 1: layers_to_fix = layers_to_fix * len(self.species_order) for name, parameter in self.model.named_parameters(): indices = name.split('.') if int(indices[-2]) in layers_to_fix[self.species_order.index(indices[-3] if indices[-3] in data.element_symbol2atomic_number else data.atomic_number2element_symbol[int(indices[-3])])]: parameter.requires_grad=False
def data_setup(self, molecular_database, validation_molecular_database, spliting_ratio, property_to_learn, xyz_derivative_property_to_learn, ): assert molecular_database, 'provide molecular database' self.property_name = property_to_learn data_element_symbols = list(np.sort(np.unique(np.concatenate(molecular_database.element_symbols)))) if not self.species_order: self.species_order = data_element_symbols else: for element in data_element_symbols: if element not in self.species_order: print('element(s) outside supported species detected, please check the database') return if validation_molecular_database == 'sample_from_molecular_database': idx = np.arange(len(molecular_database)) np.random.shuffle(idx) molecular_database, validation_molecular_database = [molecular_database[i_split] for i_split in np.split(idx, [int(len(idx) * spliting_ratio)])] elif not validation_molecular_database: raise NotImplementedError("please specify validation_molecular_database or set it to 'sample_from_molecular_database'") if self.energy_shifter.self_energies is None: if np.isnan(molecular_database.get_properties(property_to_learn)).sum(): molDB2ANIdata(molecular_database.filter_by_property(property_to_learn), property_to_learn, xyz_derivative_property_to_learn).subtract_self_energies(self.energy_shifter, self.species_order) self.subtraining_set = molDB2ANIdata(molecular_database, property_to_learn, xyz_derivative_property_to_learn).species_to_indices(self.species_order).subtract_self_energies(self.energy_shifter.self_energies).shuffle() else: self.subtraining_set = molDB2ANIdata(molecular_database, property_to_learn, xyz_derivative_property_to_learn).subtract_self_energies(self.energy_shifter, self.species_order).species_to_indices(self.species_order).shuffle() else: self.subtraining_set = molDB2ANIdata(molecular_database, property_to_learn, xyz_derivative_property_to_learn).species_to_indices(self.species_order).subtract_self_energies(self.energy_shifter.self_energies).shuffle() if len(self.species_order) != len(self.energy_shifter.self_energies): true_species_order = sorted(data_element_symbols, key=lambda x: self.species_order.index(x)) expanded_self_energies = np.zeros((len(self.species_order))) for ii, sp in enumerate(self.species_order): if sp in true_species_order: expanded_self_energies[ii] = self.energy_shifter.self_energies[true_species_order.index(sp)] elif 'energy_shifter_' in self.__dict__: expanded_self_energies[ii] = self.energy_shifter_.self_energies[ii] self.energy_shifter.self_energies = torch.tensor(expanded_self_energies) self.validation_set = molDB2ANIdata(validation_molecular_database, property_to_learn, xyz_derivative_property_to_learn).species_to_indices(self.species_order).subtract_self_energies(self.energy_shifter.self_energies).shuffle() self.energy_shifter = self.energy_shifter.to(self.device) self.subtraining_set = self.subtraining_set.collate(self.hyperparameters.batch_size, PADDING) self.validation_set = self.validation_set.collate(self.hyperparameters.batch_size, PADDING) self.argsdict.update({'self_energies': self.energy_shifter.self_energies, 'property': self.property_name})
[文档] class msani(models.ml_model, models.torchani_model): ''' Create an `MS-ANI', an extension of the ANI NN model for multi-state learning. First described in <10.26434/chemrxiv-2024-dtc1w>. Interfaces to `TorchANI <https://pubs.acs.org/doi/10.1021/acs.jcim.0c00451>`_. Arguments: model_file (str, optional): The filename that the model to be saved with or loaded from. device (str, optional): Indicate which device the calculation will be run on. i.e. 'cpu' for CPU, 'cuda' for Nvidia GPUs. When not speficied, it will try to use CUDA if there exists valid ``CUDA_VISIBLE_DEVICES`` in the environ of system. hyperparameters (Dict[str, Any] | :class:`mlatom.models.hyperparameters`, optional): Updates the hyperparameters of the model with provided. verbose (int, optional): 0 for silence, 1 for verbosity. ''' hyperparameters = models.hyperparameters({ #### Training #### 'batch_size': models.hyperparameter(value=16, minval=1, maxval=1024, optimization_space='linear', dtype=int), 'max_epochs': models.hyperparameter(value=1000000, minval=100, maxval=1000000, optimization_space='log', dtype=int), 'learning_rate': models.hyperparameter(value=0.001, minval=0.0001, maxval=0.01, optimization_space='log'), 'early_stopping_learning_rate': models.hyperparameter(value=1.0E-5, minval=1.0E-6, maxval=1.0E-4, optimization_space='log'), 'lr_reduce_patience': models.hyperparameter(value=64, minval=16, maxval=256, optimization_space='linear'), 'lr_reduce_factor': models.hyperparameter(value=0.5, minval=0.1, maxval=0.9, optimization_space='linear'), 'lr_reduce_threshold': models.hyperparameter(value=0.0, minval=-0.01, maxval=0.01, optimization_space='linear'), #### Loss #### 'force_coefficient': models.hyperparameter(value=0.1, minval=0.05, maxval=5, optimization_space='linear'), 'median_loss': models.hyperparameter(value=False), 'gap_coefficient': models.hyperparameter(value=1.0, minval=0.05, maxval=5, optimization_space='linear'), #### Network #### "neurons": models.hyperparameter(value=[[160, 128, 96]]), "activation_function": models.hyperparameter(value=lambda: torch.nn.CELU(0.1), optimization_space='choice', choices=["CELU", "ReLU", "GELU"], dtype=(str, type, FunctionType)), "fixed_layers": models.hyperparameter(value=False), #### AEV #### 'Rcr': models.hyperparameter(value=5.2000e+00, minval=1.0, maxval=10.0, optimization_space='linear'), 'Rca': models.hyperparameter(value=3.5000e+00, minval=1.0, maxval=10.0, optimization_space='linear'), 'EtaR': models.hyperparameter(value=[1.6000000e+01]), 'ShfR': models.hyperparameter(value=[9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00]), 'Zeta': models.hyperparameter(value=[3.2000000e+01]), 'ShfZ': models.hyperparameter(value=[1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00]), 'EtaA': models.hyperparameter(value=[8.0000000e+00]), 'ShfA': models.hyperparameter(value=[9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00]), }) argsdict = {} model_file = None model = None property_name = 'y' species_order = [] program = 'TorchANI' meta_data = { "genre": "neural network" } verbose = 1 def __init__(self, model_file: str = None, device: str = None, hyperparameters: Union[Dict[str,Any], models.hyperparameters]={}, verbose=1, nstates=1,validate_train=True): if device == None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(device) self.hyperparameters = self.hyperparameters.copy() self.hyperparameters.update(hyperparameters) self.verbose = verbose self.nstates = nstates self.validate_train = validate_train self.energy_shifter = torchani.utils.EnergyShifter(None) if model_file: if os.path.isfile(model_file): self.load(model_file) else: if self.verbose: print(f'the trained MS-ANI model will be saved in {model_file}') self.model_file = model_file self.hyperparameters.batch_size = self.hyperparameters.batch_size*self.nstates def parse_args(self, args): super().parse_args(args) for hyperparam in self.hyperparameters: if hyperparam in args.hyperparameter_optimization['hyperparameters']: self.parse_hyperparameter_optimization(args, hyperparam) # elif hyperparam in args.data: # self.hyperparameters[hyperparam].value = args.data[hyperparam] elif 'ani' in args.data and hyperparam in args.ani.data: self.hyperparameters[hyperparam].value = args.ani.data[hyperparam]
[文档] def reset(self): super().reset() self.model = None
[文档] def save(self, model_file: str = '') -> None: ''' Save the model to file (.pt format). Arguments: model_file (str, optional): The filename that the model to be saved into. If not provided, a randomly generated string will be used. ''' if not model_file: model_file =f'ani_{str(uuid.uuid4())}.pt' self.model_file = model_file torch.save( { 'network': self.networkdict, 'args': self.argsdict, 'nn': self.nn.state_dict(), 'AEV_computer': self.aev_computer, 'energy_shifter': self.energy_shifter, } , model_file ) if self.verbose: print(f'model saved in {model_file}')
[文档] def load(self, model_file: str = '', species_order: Union[List[str], None] = None, AEV_parameters: Union[Dict, None] = None, self_energies: Union[List[float], None] = None, reset_parameters: bool = False, method: str = '') -> None: ''' Load a saved ANI model from file (.pt format). Arguments: model_file (str): The filename of the model to be loaded. species_order(List[str], optional): Manually provide species order if it's not present in the saved model. AEV_parameters(Dict, optional): Manually provide AEV parameters if it's not present in the saved model. self_energies(List[float], optional): Manually provide self energies if it's not present in the saved model. reset_parameters(bool): Reset network paramters in the loaded model. method(str): Load an ANI method, see also :meth:`ani.load_ani_model`. ''' if method: self.load_ani_model(method) return model_dict = torch.load(model_file, map_location=torch.device('cpu')) if 'property' in model_dict['args']: self.property_name = model_dict['args']['property'] if 'species_order' in model_dict['args']: self.species_order = model_dict['args']['species_order'] # if type(self.species_order[0]) == [int, np.int_]: # self.species_order = [data.atomic_number2element_symbol[z] for z in self.species_order] # if self.species_order[0].lower() == self.species_order[0]: # self.species_order = [data.atomic_number2element_symbol[str(z)] for z in self.species_order] else: print('species order not found, please provide explictly') self.species_order = species_order self.argsdict.update({'species_order': self.species_order}) if 'AEV_computer' in model_dict: self.aev_computer = model_dict['AEV_computer'] if 'use_cuda_extension' not in self.aev_computer.__dict__: self.aev_computer.use_cuda_extension = False self.argsdict.update({'Rcr': self.aev_computer.Rcr, 'Rca': self.aev_computer.Rca, 'EtaR': self.aev_computer.EtaR, 'ShfR': self.aev_computer.ShfR, 'Zeta': self.aev_computer.Zeta, 'ShfZ': self.aev_computer.ShfZ, 'EtaA': self.aev_computer.EtaA, 'ShfA': self.aev_computer.ShfA}) elif 'Rcr' in model_dict['args']: self.AEV_setup(**model_dict['args']) else: print('AEV parameters not found, please provide explictly') self.AEV_setup(**AEV_parameters) if 'energy_shifter' in model_dict: self.energy_shifter = model_dict['energy_shifter'] elif 'energy_shifter_train' in model_dict['args']: self.energy_shifter = model_dict['args']['energy_shifter_train'] elif 'self_energies_train' in model_dict['args']: self.energy_shifter = torchani.utils.EnergyShifter(model_dict['args']['self_energies_train']) elif 'self_energies' in model_dict['args']: self.energy_shifter = torchani.utils.EnergyShifter(model_dict['args']['self_energies']) else: print('self energy information not found, please provide explictly') self.energy_shifter = torchani.utils.EnergyShifter(self_energies) self.energy_shifter.to(self.device) if 'network' in model_dict and 'nn' in model_dict: self.networkdict = model_dict['network'] if isinstance(self.networkdict, OrderedDict) or type(self.networkdict) == dict: self.neurons = [[layer.out_features for layer in network if isinstance(layer, torch.nn.Linear)] for network in self.networkdict.values()] self.nn = torchani.ANIModel(self.networkdict if isinstance(self.networkdict, OrderedDict) else self.networkdict.values()) elif type(self.networkdict) == list: self.neurons = [[[layer.out_features for layer in network if isinstance(layer, torch.nn.Linear)] for network in subdict.values()] for subdict in self.networkdict] self.nn = torchani.nn.Ensemble([torchani.ANIModel(subdict if isinstance(subdict, OrderedDict) else subdict.values()) for subdict in self.networkdict]) if reset_parameters: self.NN_initialize() else: self.nn.load_state_dict(model_dict['nn']) self.optimizer_setup(**self.hyperparameters) else: print('network parameters not found') self.model = torch.nn.DataParallel(StateInputNet(self.aev_computer, self.nn)).to(self.device) self.model.eval() if self.verbose: print(f'model loaded from {model_file}')
[文档] def load_ani_model(self, method: str, **hyperparameters) -> None: ''' Load an ANI model. Arguments: method(str): Can be ``'ANI-1x'``, ``'ANI-1ccx'``, or ``'ANI-2x'``. ''' self.hyperparameters.update(hyperparameters) if 'ANI-1x'.casefold() in method.casefold(): model = torchani.models.ANI1x(periodic_table_index=True).to(self.device) elif 'ANI-1ccx'.casefold() in method.casefold(): model = torchani.models.ANI1ccx(periodic_table_index=True).to(self.device) elif 'ANI-2x'.casefold() in method.casefold(): model = torchani.models.ANI2x(periodic_table_index=True).to(self.device) else: print("method not found, please check ANI_methods().available_methods") return self.species_order = model.species self.argsdict.update({'species_order': self.species_order}) self.aev_computer = model.aev_computer self.networkdict = [OrderedDict(**{k: v for k, v in nn.items()}) for nn in model.neural_networks] self.neurons = [[[layer.out_features for layer in network if isinstance(layer, torch.nn.Linear)] for network in subdict.values()] for subdict in self.networkdict] self.nn = model.neural_networks self.optimizer_setup(**self.hyperparameters) self.energy_shifter = model.energy_shifter self.model = torch.nn.DataParallel(StateInputNet(self.aev_computer, self.nn)).to(self.device).float() if self.verbose: print(f'loaded {method} model')
[文档] @doc_inherit def train( self, molecular_database: data.molecular_database, property_to_learn: str = 'energy', xyz_derivative_property_to_learn: str = None, validation_molecular_database: Union[data.molecular_database, str, None] = 'sample_from_molecular_database', hyperparameters: Union[Dict[str,Any], models.hyperparameters] = {}, spliting_ratio: float = 0.8, save_model: bool = True, check_point: str = None, reset_optim_state: bool = False, use_last_model: bool = False, reset_parameters: bool = False, reset_network: bool = False, reset_optimizer: bool = False, save_every_epoch: bool = False, energy_weighting_function: Callable = None, energy_weighting_function_kwargs: dict = {}, ) -> None: r''' validation_molecular_database (:class:`mlatom.data.molecular_database` | str, optional): Explicitly defines the database for validation, or use ``'sample_from_molecular_database'`` to make it sampled from the training set. hyperparameters (Dict[str, Any] | :class:`mlatom.models.hyperparameters`, optional): Updates the hyperparameters of the model with provided. spliting_ratio (float, optional): The ratio sub-training dataset in the whole training dataset. save_model (bool, optional): Whether save the model to disk during training process. Note that the model might be saved many times during training. reset_optim_state (bool, optional): Whether to reset the state of optimizer. use_last_model (bool, optional): Whether to keep the ``self.model`` as it is in the last training epoch. If ``False``, the best model will be loaded to memory at the end of training. reset_parameters (bool, optional): Whether to reset the model's parameters before training. reset_network (bool, optional): Whether to re-construct the network before training. reset_optimizer (bool, optional): Whether to re-define the optimizer before training . save_every_epoch (bool, optional): Whether to save model in every epoch, valid when ``save_model`` is ``True``. energy_weighting_function (Callable[Array-like], optional): A weighting function :math:`\mathit{W}(\mathbf{E_ref})` that assign weights to training points based on their reference energies. ''' if hyperparameters: self.hyperparameters.update(hyperparameters) energy_weighting_function_kwargs = {k: (v.value if isinstance(v, models.hyperparameter) else v) for k, v in energy_weighting_function_kwargs.items()} self.data_setup(molecular_database, validation_molecular_database, spliting_ratio, property_to_learn, xyz_derivative_property_to_learn) if not self.model: self.model_setup(**self.hyperparameters) if reset_network: self.NN_setup(**self.hyperparameters) if reset_parameters: self.NN_initialize() if reset_optimizer: self.optimizer_setup(**self.hyperparameters) self.model.train() if self.verbose: print(self.model) if check_point and os.path.isfile(check_point): checkpoint = torch.load(check_point) self.nn.load_state_dict(checkpoint['nn']) if not reset_optim_state: self.AdamW.load_state_dict(checkpoint['AdamW']) self.SGD.load_state_dict(checkpoint['SGD']) self.AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler']) self.SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler']) def validate(): total_mse = 0.0 count = 0 for properties in self.validation_set: true_energies = properties['energies'].to(self.device).float() species = properties['species'].to(self.device) num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype) if callable(energy_weighting_function): weightings_e = energy_weighting_function(self.energy_shifter((species, true_energies)).energies, **energy_weighting_function_kwargs) else: weightings_e = 1 coordinates = properties['coordinates'].to(self.device).float() state = properties['state'].to(self.device).float() _, predicted_energies = self.model(species, coordinates,state) total_mse += loss_function(predicted_energies, true_energies, weightings_e, reduction='sum').nanmean().item() count += predicted_energies.shape[0] return total_mse/count def validate_training(): total_mse = 0.0 count = 0 for properties in self.validation_set: true_energies = properties['energies'].to(self.device).float() species = properties['species'].to(self.device) num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype) state = properties['state'].to(self.device).float() true_gap_list = [] for i in range(int(len(true_energies)/self.nstates)): for j in range(1,self.nstates): true_gap_list.append(abs(true_energies[i+j]-true_energies[i+j-1])) true_gaps = torch.FloatTensor(true_gap_list).to(self.device) if callable(energy_weighting_function): weightings_e = energy_weighting_function(self.energy_shifter((species, true_energies)).energies, **energy_weighting_function_kwargs) weightings_f = energy_weighting_function(self.energy_shifter((species, true_energies)).energies, **energy_weighting_function_kwargs)[:, None, None] else: weightings_e = 1 weightings_f = 1 if xyz_derivative_property_to_learn: coordinates = properties['coordinates'].to(self.device).float().requires_grad_(True) true_forces = properties['forces'].to(self.device).float() _, predicted_energies = self.model(species, coordinates,state) predicted_gap_list = [] for i in range(int(len(predicted_energies)/self.nstates)): for j in range(1,self.nstates): predicted_gap_list.append(abs(predicted_energies[i+j]-predicted_energies[i+j-1])) predicted_gaps = torch.FloatTensor(predicted_gap_list).to(self.device) forces = -torch.autograd.grad(predicted_energies.sum(), coordinates, create_graph=True, retain_graph=True)[0] # true_energies[true_energies.isnan()]=predicted_energies[true_energies.isnan()] if self.hyperparameters.median_loss: energy_loss= median(predicted_energies,true_energies) else: energy_loss = (loss_function(predicted_energies, true_energies, weightings_e) / num_atoms.sqrt()).nanmean() if self.hyperparameters.gap_coefficient != 0.0: gap_loss = (loss_function(predicted_gaps,true_gaps, 1) / num_atoms[0].sqrt()).nanmean() else: gap_loss = 0 # true_forces[true_ # true_forces[true_forces.isnan()]=forces[true_forces.isnan()] force_loss = (loss_function(true_forces, forces, weightings_f).sum(dim=(1, 2)) / num_atoms).nanmean() loss = energy_loss + self.hyperparameters.force_coefficient * force_loss + self.hyperparameters.gap_coefficient * gap_loss total_mse += loss.item() else: coordinates = properties['coordinates'].to(self.device).float() _, predicted_energies = self.model(species, coordinates, state) predicted_gap_list = [] for i in range(int(len(predicted_energies)/self.nstates)): for j in range(1,self.nstates): predicted_gap_list.append(abs(predicted_energies[i+j]-predicted_energies[i+j-1])) predicted_gaps = torch.FloatTensor(predicted_gap_list).to(self.device) energy_loss = (loss_function(predicted_energies, true_energies, weightings_e) / num_atoms.sqrt()).nanmean() if self.hyperparameters.gap_coefficient != 0.0: gap_loss = (loss_function(predicted_gaps,true_gaps, 1) / num_atoms[0].sqrt()).nanmean() else: gap_loss = 0 # true_forces[true_ loss = energy_loss + self.hyperparameters.gap_coefficient * gap_loss total_mse += loss.item() return total_mse def loss_function(prediction, reference, weightings=1, reduction='none'): return torch.nn.functional.mse_loss(prediction*weightings, reference*weightings, reduction=reduction) if self.verbose: print("training starting from epoch", self.AdamW_scheduler.last_epoch + 1) for _ in range(self.AdamW_scheduler.last_epoch + 1, self.hyperparameters.max_epochs + 1): if self.validate_train == True: rmse = validate_training() else: rmse = validate() if self.verbose: print('validation MSE:', rmse, 'at epoch', self.AdamW_scheduler.last_epoch + 1) sys.stdout.flush() learning_rate = self.AdamW.param_groups[0]['lr'] if self.verbose: print('learning_rate:',learning_rate) if learning_rate < self.hyperparameters.early_stopping_learning_rate: break if self.AdamW_scheduler.is_better(rmse, self.AdamW_scheduler.best) or save_every_epoch: if save_model: self.save(self.model_file) self.AdamW_scheduler.step(rmse) self.SGD_scheduler.step(rmse) for properties in tqdm.tqdm( self.subtraining_set, total=len(self.subtraining_set), desc="epoch {}".format(self.AdamW_scheduler.last_epoch), disable=not self.verbose, ): true_energies = properties['energies'].to(self.device).float() true_gap_list = [] for i in range(int(len(true_energies)/self.nstates)): for j in range(1,self.nstates): true_gap_list.append(abs(true_energies[i+j]-true_energies[i+j-1])) true_gaps = torch.FloatTensor(true_gap_list).to(self.device) species = properties['species'].to(self.device) state = properties['state'].to(self.device).float() num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype) if callable(energy_weighting_function): weightings_e = energy_weighting_function(self.energy_shifter((species, true_energies)).energies, **energy_weighting_function_kwargs) weightings_f = energy_weighting_function(self.energy_shifter((species, true_energies)).energies, **energy_weighting_function_kwargs)[:, None, None] else: weightings_e = 1 weightings_f = 1 if xyz_derivative_property_to_learn: coordinates = properties['coordinates'].to(self.device).float().requires_grad_(True) true_forces = properties['forces'].to(self.device).float() _, predicted_energies = self.model(species, coordinates,state) predicted_gap_list = [] for i in range(int(len(predicted_energies)/self.nstates)): for j in range(1,self.nstates): predicted_gap_list.append(abs(predicted_energies[i+j]-predicted_energies[i+j-1])) predicted_gaps = torch.FloatTensor(predicted_gap_list).to(self.device) #print(predicted_gaps) forces = -torch.autograd.grad(predicted_energies.sum(), coordinates, create_graph=True, retain_graph=True)[0] # true_energies[true_energies.isnan()]=predicted_energies[true_energies.isnan()] if self.hyperparameters.median_loss: energy_loss= median(predicted_energies,true_energies) else: energy_loss = (loss_function(predicted_energies, true_energies, weightings_e) / num_atoms.sqrt()).nanmean() # print(predicted_gaps.shape, true_gaps.shape) #print(num_atoms.sqrt()) if self.hyperparameters.gap_coefficient != 0.0: gap_loss = (loss_function(predicted_gaps,true_gaps, 1) / num_atoms[0].sqrt()).nanmean() else: gap_loss = 0 # true_forces[true_ # true_forces[true_forces.isnan()]=forces[true_forces.isnan()] force_loss = (loss_function(true_forces, forces, weightings_f).sum(dim=(1, 2)) / num_atoms).nanmean() loss = energy_loss + self.hyperparameters.force_coefficient * force_loss + self.hyperparameters.gap_coefficient * gap_loss else: coordinates = properties['coordinates'].to(self.device).float() _, predicted_energies = self.model(species, coordinates, state) predicted_gap_list = [] for i in range(int(len(predicted_energies)/self.nstates)): for j in range(1,self.nstates): predicted_gap_list.append(abs(predicted_energies[i+j]-predicted_energies[i+j-1])) predicted_gaps = torch.FloatTensor(predicted_gap_list).to(self.device) #print(predicted_gaps) energy_loss = (loss_function(predicted_energies, true_energies, weightings_e) / num_atoms.sqrt()).nanmean() if self.hyperparameters.gap_coefficient != 0.0: gap_loss = (loss_function(predicted_gaps,true_gaps, 1) / num_atoms[0].sqrt()).nanmean() else: gap_loss = 0 # true_forces[true_ loss = energy_loss + self.hyperparameters.gap_coefficient * gap_loss self.AdamW.zero_grad() self.SGD.zero_grad() loss.backward() self.AdamW.step() self.SGD.step() if check_point: torch.save({ 'nn': self.nn.state_dict(), 'AdamW': self.AdamW.state_dict(), 'SGD': self.SGD.state_dict(), 'AdamW_scheduler': self.AdamW_scheduler.state_dict(), 'SGD_scheduler': self.SGD_scheduler.state_dict(), }, check_point) if save_model and not use_last_model: self.load(self.model_file)
[文档] @doc_inherit def predict( self, molecular_database: data.molecular_database = None, molecule: data.molecule = None, nstates = 1, current_state=0, calculate_energy: bool = False, calculate_energy_gradients: bool = False, calculate_hessian: bool = False, property_to_predict: Union[str, None] = 'estimated_y', xyz_derivative_property_to_predict: Union[str, None] = None, hessian_to_predict: Union[str, None] = None, batch_size: int = 2**16, ) -> None: ''' batch_size (int, optional): The batch size for batch-predictions. ''' molDB, property_to_predict, xyz_derivative_property_to_predict, hessian_to_predict = \ super().predict(molecular_database=molecular_database, molecule=molecule, calculate_energy=calculate_energy, calculate_energy_gradients=calculate_energy_gradients, calculate_hessian=calculate_hessian, property_to_predict = property_to_predict, xyz_derivative_property_to_predict = xyz_derivative_property_to_predict, hessian_to_predict = hessian_to_predict) for batch in molDB.batches(batch_size): state_energies =[] state_gradients =[] state_hessians = [] for i in range(nstates): for properties in molDB2ANIdata_state(batch, use_state=False).species_to_indices(self.species_order).collate(batch_size, PADDING): species = properties['species'].to(self.device) state = torch.full((len(batch),),i).to(self.device) xyz_coordinates = properties['coordinates'].float().to(self.device).requires_grad_(bool(xyz_derivative_property_to_predict or hessian_to_predict)) break ANI_NN_energies = self.energy_shifter(self.model(species, xyz_coordinates, state)).energies #print("raw energy", self.model(species,xyz_coordinates,state).energies) #print("raw energy - shifted energy = SAE", torch.sub(self.model(species,xyz_coordinates,state).energies, ANI_NN_energies)) #print("total energy", ANI_NN_energies) if property_to_predict: if nstates>1: batch.add_scalar_properties(ANI_NN_energies.detach().cpu().numpy(), property_to_predict+"_state"+str(i)) state_energies.append(ANI_NN_energies.detach().cpu().numpy()) else: batch.add_scalar_properties(ANI_NN_energies.detach().cpu().numpy(), property_to_predict) state_energies.append(ANI_NN_energies.detach().cpu().numpy()) if xyz_derivative_property_to_predict or hessian_to_predict: ANI_NN_energy_gradients = torch.autograd.grad(ANI_NN_energies.sum(), xyz_coordinates, create_graph=True, retain_graph=True)[0] if xyz_derivative_property_to_predict: if nstates>1: grads = ANI_NN_energy_gradients.detach().cpu().numpy() state_gradients.append(grads) batch.add_xyz_vectorial_properties(grads, xyz_derivative_property_to_predict+"_state"+str(i)) else: grads = ANI_NN_energy_gradients.detach().cpu().numpy() state_gradients.append(grads) batch.add_xyz_vectorial_properties(grads, xyz_derivative_property_to_predict) if hessian_to_predict: ANI_NN_hessians = torchani.utils.hessian(xyz_coordinates, energies=ANI_NN_energies) state_hessians.append(ANI_NN_hessians.detach().cpu().numpy()) batch.add_scalar_properties(ANI_NN_hessians.detach().cpu().numpy(), hessian_to_predict) for idx, mol in enumerate(batch): mol_copy = mol.copy() mol_copy.electronic_states = [] for _ in range(nstates - len(mol.electronic_states)): mol.electronic_states.append(mol_copy.copy()) if calculate_energy: for i in range(nstates): mol.electronic_states[i].energy = state_energies[i][idx] mol.__dict__["energy"] = mol.electronic_states[current_state].energy if xyz_derivative_property_to_predict: for i in range(nstates): mol.electronic_states[i].add_xyz_derivative_property(np.array(state_gradients[i][idx]).astype(float), 'energy', 'energy_gradients') mol.add_xyz_derivative_property(np.array(state_gradients[current_state][idx]).astype(float), 'energy', 'energy_gradients')
def AEV_setup(self, **kwargs): kwargs = models.hyperparameters(kwargs) Rcr = kwargs.Rcr Rca = kwargs.Rca EtaR = torch.tensor(kwargs.EtaR).to(self.device) ShfR = torch.tensor(kwargs.ShfR).to(self.device) Zeta = torch.tensor(kwargs.Zeta).to(self.device) ShfZ = torch.tensor(kwargs.ShfZ).to(self.device) EtaA = torch.tensor(kwargs.EtaA).to(self.device) ShfA = torch.tensor(kwargs.ShfA).to(self.device) self.aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, len(self.species_order)) self.argsdict.update({'Rcr': Rcr, 'Rca': Rca, 'EtaR': EtaR, 'ShfR': ShfR, 'Zeta': Zeta, 'ShfZ': ShfZ, 'EtaA': EtaA, 'ShfA': ShfA, 'species_order': self.species_order}) def NN_setup(self, **kwargs): kwargs = models.hyperparameters(kwargs) if len(kwargs.neurons) == 1: self.neurons = [kwargs.neurons[0].copy() for _ in range(len(self.species_order))] self.networkdict = OrderedDict() for i, specie in enumerate(self.species_order): self.neurons[i] += [1] layers = [torch.nn.Linear(self.aev_computer.aev_length+1, self.neurons[i][0])] for j in range(len(self.neurons[i]) - 1): if type(kwargs.activation_function) == str: layers += [torch.nn.__dict__[kwargs.activation_function]()] elif callable(kwargs.activation_function): layers += [kwargs.activation_function()] layers += [torch.nn.Linear(self.neurons[i][j], self.neurons[i][j + 1])] self.networkdict[specie] = torch.nn.Sequential(*layers) self.nn = torchani.ANIModel(self.networkdict) self.NN_initialize() self.optimizer_setup(**kwargs)
[文档] def NN_initialize(self, a: float = 1.0) -> None: ''' Reset the network parameters using :meth:`torch.nn.init.kaiming_normal_`. Arguments: a(float): Check `torch.nn.init.kaiming_normal_() <https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_>`_. ''' def init_params(m): if isinstance(m, torch.nn.Linear): torch.nn.init.kaiming_normal_(m.weight, a) torch.nn.init.zeros_(m.bias) self.nn.apply(init_params)
def optimizer_setup(self, **kwargs): kwargs = models.hyperparameters(kwargs) if isinstance(self.networkdict, OrderedDict) or type(self.networkdict) == dict: wlist2d = [ [ {'params': [self.networkdict[specie][j * 2].weight]} if j == 0 or j == len(self.neurons[i]) - 1 else {'params': [self.networkdict[specie][j*2].weight], 'weight_decay': 0.0001 / 10**j} for j in range(len(self.neurons[i])) ] for i, specie in enumerate(self.species_order) ] blist2d = [ [ {'params': [self.networkdict[specie][j * 2].bias]} for j in range(len(self.neurons[i])) ] for i, specie in enumerate(self.species_order) ] self.AdamW = torch.optim.AdamW([i for j in wlist2d for i in j],lr=kwargs.learning_rate) self.SGD = torch.optim.SGD([i for j in blist2d for i in j], lr=kwargs.learning_rate) elif type(self.networkdict) == list: wlist3d =[[ [ {'params': [subdict[specie][j * 2].weight]} if j == 0 or j == len(self.neurons[k][i]) - 1 else {'params': [subdict[specie][j*2].weight], 'weight_decay': 0.0001 / 10**j} for j in range(len(self.neurons[k][i])) ] for i, specie in enumerate(self.species_order) ] for k, subdict in enumerate(self.networkdict)] blist3d = [[ [ {'params': [subdict[specie][j * 2].bias]} for j in range(len(self.neurons[k][i])) ] for i, specie in enumerate(self.species_order) ] for k, subdict in enumerate(self.networkdict)] self.AdamW = torch.optim.AdamW([i for k in wlist3d for j in k for i in j],lr=kwargs.learning_rate) self.SGD = torch.optim.SGD([i for k in blist3d for j in k for i in j], lr=kwargs.learning_rate) self.AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.AdamW, factor=kwargs.lr_reduce_factor, patience=kwargs.lr_reduce_patience, threshold=kwargs.lr_reduce_threshold) self.SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.SGD, factor=kwargs.lr_reduce_factor, patience=kwargs.lr_reduce_patience, threshold=kwargs.lr_reduce_threshold) def model_setup(self, **kwargs): self.AEV_setup(**kwargs) self.NN_setup(**kwargs) self.model = StateInputNet(self.aev_computer, self.nn).float().to(self.device)
[文档] def fix_layers(self, layers_to_fix: Union[List[List[int]],List[int]]): ''' Fix specific layers to be non-trainable for each element. Arguments: layers_to_fix (List): Should be: - A list of integers. Layers indicate by the integers will be fixed - A list of lists of integers. Each sub-list defines the layers to be fixed for each species, in the order of `self.species_order`. ''' if layers_to_fix: if len(layers_to_fix) == 1: layers_to_fix = layers_to_fix * len(self.species_order) for name, parameter in self.model.named_parameters(): indices = name.split('.') if int(indices[-2]) in layers_to_fix[self.species_order.index(indices[-3] if indices[-3] in data.element_symbol2atomic_number else data.atomic_number2element_symbol[int(indices[-3])])]: parameter.requires_grad=False
def data_setup(self, molecular_database, validation_molecular_database, spliting_ratio, property_to_learn, xyz_derivative_property_to_learn, ): assert molecular_database, 'provide molecular database' self.property_name = property_to_learn data_element_symbols = list(np.sort(np.unique(np.concatenate(molecular_database.element_symbols)))) if not self.species_order: self.species_order = data_element_symbols else: for element in data_element_symbols: if element not in self.species_order: print('element(s) outside supported species detected, please check the database') return if validation_molecular_database == 'sample_from_molecular_database': idx = np.arange(len(molecular_database)) np.random.shuffle(idx) molecular_database, validation_molecular_database = [molecular_database[i_split] for i_split in np.split(idx, [int(len(idx) * spliting_ratio)])] elif not validation_molecular_database: raise NotImplementedError("please specify validation_molecular_database or set it to 'sample_from_molecular_database'") molecular_database = unpackData2State(molecular_database) validation_molecular_database = unpackData2State(validation_molecular_database) if self.energy_shifter.self_energies is None: if np.isnan(molecular_database.get_properties(property_to_learn)).sum(): print(property_to_learn) print(xyz_derivative_property_to_learn) molDB2ANIdata_state(molecular_database.filter_by_property(property_to_learn), property_to_learn, xyz_derivative_property_to_learn).subtract_self_energies(self.energy_shifter, self.species_order) self.subtraining_set = molDB2ANIdata_state(molecular_database, property_to_learn, xyz_derivative_property_to_learn).species_to_indices(self.species_order).subtract_self_energies(self.energy_shifter.self_energies).cache() else: self.subtraining_set = molDB2ANIdata_state(molecular_database, property_to_learn, xyz_derivative_property_to_learn).subtract_self_energies(self.energy_shifter, self.species_order).species_to_indices(self.species_order).cache() else: self.subtraining_set = molDB2ANIdata_state(molecular_database, property_to_learn, xyz_derivative_property_to_learn).species_to_indices(self.species_order).subtract_self_energies(self.energy_shifter.self_energies).cache() self.validation_set = molDB2ANIdata_state(validation_molecular_database, property_to_learn, xyz_derivative_property_to_learn).species_to_indices(self.species_order).subtract_self_energies(self.energy_shifter.self_energies).cache() self.energy_shifter = self.energy_shifter.to(self.device) self.subtraining_set = self.subtraining_set.collate(self.hyperparameters.batch_size, PADDING) self.validation_set = self.validation_set.collate(self.hyperparameters.batch_size, PADDING) self.argsdict.update({'self_energies': self.energy_shifter.self_energies, 'property': self.property_name})
[文档] class ani_child(models.torchani_model): def __init__(self, parent, index, name='ani_child', device='cuda' if torch.cuda.is_available() else 'cpu'): self.device = torch.device(device) self.name = name self.model = parent.__getitem__(index)
[文档] def predict( self, molecular_database: data.molecular_database = None, molecule: data.molecule = None, calculate_energy: bool = False, calculate_energy_gradients: bool = False, calculate_hessian: bool = False, batch_size: int = 2**16, ) -> None: molDB = super().predict(molecular_database=molecular_database, molecule=molecule) if not molDB._is_uniform_cell(): batch_size = 1 for batch in molDB.batches(batch_size): for properties in molDB2ANIdata(batch).species_to_indices('periodic_table').collate(batch_size, PADDING): species = properties['species'].to(self.device) if torchani.utils.PERIODIC_TABLE[0] == 'H': species += 1 xyz_coordinates = properties['coordinates'].float().to(self.device) break pbc = properties['pbc'][0].to(self.device) if 'pbc' in properties else None cell = properties['cell'][0].float().to(self.device) if 'cell' in properties else None if pbc is not None and cell is not None: xyz_coordinates = torchani.utils.map2central(cell, xyz_coordinates, pbc) xyz_coordinates = xyz_coordinates.requires_grad_(calculate_energy_gradients or calculate_hessian) ANI_NN_energies = self.model((species, xyz_coordinates), pbc=pbc, cell=cell).energies if calculate_energy: batch.add_scalar_properties(ANI_NN_energies.detach().cpu().numpy(), 'energy') if calculate_energy_gradients or calculate_hessian: ANI_NN_energy_gradients = torch.autograd.grad(ANI_NN_energies.sum(), xyz_coordinates, create_graph=True, retain_graph=True)[0] if calculate_energy_gradients: grads = ANI_NN_energy_gradients.detach().cpu().numpy() batch.add_xyz_vectorial_properties(grads, 'energy_gradients') if calculate_hessian: ANI_NN_hessians = torchani.utils.hessian(xyz_coordinates, energies=ANI_NN_energies) batch.add_hessian_properties(ANI_NN_hessians.detach().cpu().numpy(), 'hessian')
def node(self,): return models.model_tree_node(name=self.name, operator='predict', model=self)
[文档] class torchani_methods(models.torchani_model, metaclass=models.meta_method): ''' Create a model object with one of the ANI methods Arguments: method (str): A string that specifies the method. Available choices: ``'ANI-1x'``, ``'ANI-1ccx'``, ``'ANI-2x'``, ``'ANI-1ccx-gelu'`` or ``'ANI-1x-gelu'``. device (str, optional): Indicate which device the calculation will be run on, i.e. 'cpu' for CPU, 'cuda' for Nvidia GPUs. When not speficied, it will try to use CUDA if there exists valid ``CUDA_VISIBLE_DEVICES`` in the environ of system. ''' # no atomic energies for ANI-1ccx-gelu thus calculating energy on single atom will get error atomic_energies = {'ANI-1ccx': {1:-0.50088088, 6:-37.79199048, 7:-54.53379230, 8:-75.00968205}} def __init__(self, method: str = 'ANI-1ccx', device: str = 'cuda' if torch.cuda.is_available() else 'cpu', **kwargs): self.device = torch.device(device) self.method = method self.model_setup(method) def model_setup(self, method): self.method = method if 'ANI-1xnr'.casefold() in method.casefold(): self.model = load_ani1xnr_model().to(self.device) elif 'ANI-1ccx-gelu'.casefold() in method.casefold(): self.model = self.load_ani_gelu_model('ani1ccx_gelu') elif 'ANI-1x-gelu'.casefold() in method.casefold(): self.model = self.load_ani_gelu_model('ani1x_gelu') elif 'ANI-1x'.casefold() in method.casefold(): self.model = torchani.models.ANI1x(periodic_table_index=True).to(self.device).double() elif 'ANI-1ccx'.casefold() in method.casefold(): self.model = torchani.models.ANI1ccx(periodic_table_index=True).to(self.device).double() elif 'ANI-2x'.casefold() in method.casefold(): self.model = torchani.models.ANI2x(periodic_table_index=True).to(self.device).double() else: print("method not found, please check ANI_methods().available_methods") modelname = method.lower().replace('-','') if 'gelu'.casefold() in method.casefold(): self.element_symbols_available = self.model[0].species_order self.children = [models.model_tree_node(name=f'{modelname}_nn{index}', model=self.model[index], operator='predict') for index in range(len(self.model))] else: self.element_symbols_available = self.model.species self.children = [ani_child(self.model, index, name=f'{modelname}_nn{index}', device=self.device).node() for index in range(len(self.model))] if 'D4'.casefold() in self.method.casefold(): d4 = models.model_tree_node(name='d4_wb97x', operator='predict', model=models.methods(method='D4', functional='wb97x')) ani_nns = models.model_tree_node(name=f'{modelname}_nn', children=self.children, operator='average') self.model = models.model_tree_node(name=modelname, children=[ani_nns, d4], operator='sum') else: self.model = models.model_tree_node(name=modelname, children=self.children, operator='average') def load_ani_gelu_model(self, method): currentdir=os.path.dirname(__file__) dirname=os.path.join(currentdir, f'../ani_gelu_model') if not os.path.exists(f'{dirname}/{method}_cv0.pt'): self.download_models(dirname, method) return [ani(model_file=f'{dirname}/{method}_cv{ii}.pt', verbose=0) for ii in range(8)] def download_models(self, dirname, method): import requests urls = [f"https://github.com/dralgroup/mlatom/raw/refs/heads/main/mlatom/ani_gelu_model/{method}_cv{ii}.pt" for ii in range(8)] if not os.path.exists(dirname): os.makedirs(dirname, exist_ok=True) print(f'Downloading ani gelu model parameters ...') for ii in range(8): resource_res = requests.get(urls[ii]) with open(f'{dirname}/{method}_cv{ii}.pt','wb') as f: f.write(resource_res.content)
[文档] @doc_inherit def predict( self, molecular_database: data.molecular_database = None, molecule: data.molecule = None, calculate_energy: bool = False, calculate_energy_gradients: bool = False, calculate_hessian: bool = False, batch_size: int = 2**16, ) -> None: ''' batch_size (int, optional): The batch size for batch-predictions. ''' molDB = super().predict(molecular_database=molecular_database, molecule=molecule) for element_symbol in np.unique(np.concatenate(molDB.element_symbols)): if element_symbol not in self.element_symbols_available: print(f' * Warning * Molecule contains elements \'{element_symbol}\', which is not supported by method \'{self.method}\' that only supports {self.element_symbols_available}, no calculations performed') return monoatomic_idx = molDB.number_of_atoms == 1 for mol in molDB[monoatomic_idx]: self.predict_monoatomic(molecule=mol, calculate_energy=calculate_energy, calculate_energy_gradients=calculate_energy_gradients, calculate_hessian=calculate_hessian) self.model.predict(molecular_database=molDB[~monoatomic_idx], calculate_energy=calculate_energy, calculate_energy_gradients=calculate_energy_gradients, calculate_hessian=calculate_hessian, batch_size=batch_size) for mol in molDB[~monoatomic_idx]: self.get_SD(molecule=mol, calculate_energy=calculate_energy, calculate_energy_gradients=calculate_energy_gradients, calculate_hessian=calculate_hessian)
def predict_monoatomic(self, molecule, calculate_energy=True, calculate_energy_gradients=False, calculate_hessian=False): if calculate_energy: molecule.energy = self.atomic_energies[self.method][molecule.atoms[0].atomic_number] if calculate_energy_gradients: molecule.add_xyz_vectorial_property(np.zeros_like(molecule.xyz_coordinates), 'energy_gradients') if calculate_hessian: molecule.hessian = np.zeros((len(molecule),)*2) def get_SD(self, molecule=None, calculate_energy=True, calculate_energy_gradients=False, calculate_hessian=False): properties = [] ; atomic_properties = [] if calculate_energy: properties.append('energy') if calculate_energy_gradients: atomic_properties.append('energy_gradients') if calculate_hessian: properties.append('hessian') if 'D4'.casefold() in self.method.casefold(): modelname = self.method.lower().replace('-','') modelname = f'{modelname}_nn' else: modelname = self.method.lower().replace('-','') molecule.__dict__[f'{modelname}'].standard_deviation(properties=properties+atomic_properties) def train(self, **kwargs): # default settings kwargs['save_model'] = True # force saving model if 'file_to_save_model' in kwargs: file_to_save_model = kwargs['file_to_save_model'] else: file_to_save_model = None if 'reset_energy_shifter' not in kwargs: kwargs['reset_energy_shifter'] = True if 'hyperparameters' in kwargs: _hyperparameters = kwargs['hyperparameters'] else: _hyperparameters = {} # check hyperparameters if 'fix_layers' not in _hyperparameters: _hyperparameters['fixed_layers'] = [[0,4]] if 'loss_type' not in _hyperparameters: _hyperparameters['loss_type'] = 'geometric' if 'max_epochs' not in _hyperparameters: _hyperparameters['max_epochs'] = 100 kwargs['hyperparameters'] = _hyperparameters pretrained_models = [] if 'gelu'.casefold() in self.method.casefold(): if 'ANI-1ccx-gelu'.casefold() in self.method.casefold(): pretrained_models = self.load_ani_gelu_model(method='ani1ccx_gelu') elif 'ANI-1x-gelu'.casefold() in self.method.casefold(): pretrained_models = self.load_ani_gelu_model(method='ani1x_gelu') elif 'ANI-1xnr'.casefold() not in self.method.casefold(): animodel = ani() if 'ANI-1ccx'.casefold() in self.method.casefold(): animodel.load_ani_model('ANI-1ccx') elif 'ANI-1x'.casefold() in self.method.casefold(): animodel.load_ani_model('ANI-1x') elif 'ANI-2x'.casefold() in self.method.casefold(): animodel.load_ani_model('ANI-2x') for ii in range(8): pmodel = ani() pmodel.species_order = animodel.species_order pmodel.aev_computer = animodel.aev_computer pmodel.networkdict = animodel.networkdict[ii] pmodel.neurons = animodel.neurons[ii] pmodel.nn = animodel.nn[ii] pmodel.energy_shifter = animodel.energy_shifter pmodel.optimizer_setup(**pmodel.hyperparameters) pmodel.model = torchani.nn.Sequential(animodel.aev_computer, animodel.nn[ii]).to(pmodel.device).float() pretrained_models.append(pmodel) else: raise ValueError('Currently ANI-1xnr can not be retrained on') retrained_models = [] modelname = self.method.lower().replace('-','') for ii, pmodel in enumerate(pretrained_models): print(f'\nStart retraining on model {ii}...') sys.stdout.flush() if file_to_save_model: kwargs['file_to_save_model'] = file_to_save_model + f'.cv{ii}' else: kwargs['file_to_save_model'] = f'{modelname}_retrained.pt.cv{ii}' pmodel.train(**kwargs) retrained_models.append(pmodel) self.children = [ models.model_tree_node( name=f'{modelname}_nn{ii}', model=rmodel, operator='predict') for ii, rmodel in enumerate(retrained_models)] self.model = models.model_tree_node( name=modelname, children=self.children, operator='average')
def load_ani1xnr_model(): # ANI-1xnr https://github.com/atomistic-ml/ani-1xnr/ # Universal reactive ML potential ANI-1xnr: https://doi.org/10.1038/s41557-023-01427-3 species = ['H', 'C', 'N', 'O'] def parse_ani1xnr_resources(): import requests import zipfile import io local_dir = os.path.expanduser('~/.local/ANI1xnr/') url = "https://github.com/atomistic-ml/ani-1xnr/archive/refs/heads/main.zip" if not os.path.exists(local_dir+'ani-1xnr-main'): os.makedirs(local_dir, exist_ok=True) print(f'Downloading ANI-1xnr model parameters ...') resource_res = requests.get(url) resource_zip = zipfile.ZipFile(io.BytesIO(resource_res.content)) resource_zip.extractall(local_dir) return local_dir model_prefix = parse_ani1xnr_resources() + 'ani-1xnr-main/model/ani-1xnr/' # const_file, sae_file, ensemble_prefix, ensemble_size = parse_ani1xnr_resources() const_file = model_prefix + 'rHCNO-5.2R_32-3.5A_a8-4.params' sae_file = model_prefix + 'sae_linfit.dat' consts = torchani.neurochem.Constants(const_file) aev_computer = torchani.AEVComputer(**consts) energy_shifter = torchani.neurochem.load_sae(sae_file) species_converter = torchani.nn.SpeciesConverter(species) neural_networks = [] for ii in range(8): neural_network = torchani.neurochem.load_model(consts.species, model_prefix + f'train{ii}/networks') neural_networks.append(torchani.nn.Sequential(species_converter, aev_computer, neural_network, energy_shifter)) model_ensemble = torchani.nn.Ensemble(modules=neural_networks) model_ensemble.species = species return model_ensemble
[文档] class aimnet2_methods(models.torchani_model, metaclass=models.meta_method): ''' Universal ML methods with AIMNet2: https://doi.org/10.26434/chemrxiv-2023-296ch Arguments: method (str): A string that specifies the method. Available choices: ``'AIMNet2@b973c'`` and ``'AIMNet2@wb97m-d3'``. device (str, optional): Indicate which device the calculation will be run on, i.e. 'cpu' for CPU, 'cuda' for Nvidia GPUs. When not speficied, it will try to use CUDA if there exists valid ``CUDA_VISIBLE_DEVICES`` in the environ of system. ''' element_symbols_available = ['H', 'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'I'] def __init__(self, method: str = 'AIMNet2@b973c', device: str = 'cuda' if torch.cuda.is_available() else 'cpu', **kwargs): self.device = torch.device(device) self.model_path = self.parse_aimnet2_resources(method) self.model = torch.jit.load(self.model_path)
[文档] @doc_inherit def predict( self, molecular_database: data.molecular_database = None, molecule: data.molecule = None, calculate_energy: bool = False, calculate_energy_gradients: bool = False, calculate_hessian: bool = False, batch_size: int = 2**16, ) -> None: molDB = super().predict(molecular_database=molecular_database, molecule=molecule) for element_symbol in np.unique(np.concatenate(molDB.element_symbols)): if element_symbol not in self.element_symbols_available: print(f' * Warning * Molecule contains elements \'{element_symbol}\', which is not supported by method \'{self.method}\' that only supports {self.element_symbols_available}, no calculations performed') return for mol in molDB: self.predict_for_molecule( molecule=mol, calculate_energy=calculate_energy, calculate_energy_gradients=calculate_energy_gradients, calculate_hessian=calculate_hessian)
def predict_for_molecule( self, molecule, calculate_energy: bool = False, calculate_energy_gradients: bool = False, calculate_hessian: bool = False): coord = torch.as_tensor(molecule.xyz_coordinates).to(torch.float).to(self.device).unsqueeze(0) numbers = torch.as_tensor(molecule.atomic_numbers).to(torch.long).to(self.device).unsqueeze(0) charge = torch.tensor([molecule.charge], dtype=torch.float, device=self.device) nninput = dict(coord=coord, numbers=numbers, charge=charge) prev = torch.is_grad_enabled() torch._C._set_grad_enabled(calculate_energy_gradients) if calculate_energy_gradients: nninput['coord'].requires_grad_(True) nnoutput = self.model(nninput) if calculate_energy: molecule.energy = nnoutput['energy'].item() if calculate_energy_gradients: if 'forces' in nnoutput: f = nnoutput['forces'][0] else: f = - torch.autograd.grad(nnoutput['energy'], nninput['coord'])[0][0] forces = f.detach().cpu().numpy() molecule.add_xyz_vectorial_property(forces, 'energy_gradients') if calculate_hessian: print('Hessian not available yet') molecule.hessian = np.zeros((len(molecule),)*2) torch._C._set_grad_enabled(prev) @staticmethod def parse_aimnet2_resources(method): import requests local_dir = os.path.expanduser('~/.local/AIMNet2/') repo_name = "AIMNet2" tag_name = method.lower().replace('@', '_') url = "https://github.com/isayevlab/{}/raw/old/models/{}_ens.jpt".format(repo_name, tag_name) if not os.path.exists(local_dir+f'{tag_name}_ens.jpt'): os.makedirs(local_dir, exist_ok=True) print(f'Downloading {method} model parameters ...') resource_res = requests.get(url) with open(local_dir+f'{tag_name}_ens.jpt', 'wb') as f: f.write(resource_res.content) return local_dir + f'{tag_name}_ens.jpt'
def printHelp(): helpText = __doc__.replace('.. code-block::\n\n', '') + ''' To use Interface_ANI, please install TorchANI and its dependencies Arguments with their default values: MLprog=TorchANI enables this interface MLmodelType=ANI requests ANI model ani.batch_size=8 batch size ani.max_epochs=10000000 max epochs ani.learning_rate=0.001 learning rate used in the Adam and SGD optimizers ani.early_stopping_learning_rate=0.00001 learning rate that triggers early-stopping ani.force_coefficient=0.1 weight for force ani.Rcr=5.2 radial cutoff radius ani.Rca=3.5 angular cutoff radius ani.EtaR=1.6 radial smoothness in radial part ani.ShfR=0.9,1.16875, radial shifts in radial part 1.4375,1.70625,1.975, 2.24375,2.5125,2.78125, 3.05,3.31875,3.5875, 3.85625,4.125,4.9375, 4.6625,4.93125 ani.Zeta=32 angular smoothness ani.ShfZ=0.19634954, angular shifts 0.58904862,0.9817477, 1.3744468,1.7671459, 2.1598449,2.552544, 2.9452431 ani.EtaA=8 radial smoothness in angular part ani.ShfA=0.9,1.55,2.2,2.85 radial shifts in angular part Cite TorchANI: X. Gao, F. Ramezanghorbani, O. Isayev, J. S. Smith, A. E. Roitberg, J. Chem. Inf. Model. 2020, 60, 3408 Cite ANI model: J. S. Smith, O. Isayev, A. E. Roitberg, Chem. Sci. 2017, 8, 3192 ''' print(helpText)