diff --git a/examples/energy_force.py b/examples/energy_force.py index 89a28d692..84d36be2a 100644 --- a/examples/energy_force.py +++ b/examples/energy_force.py @@ -26,7 +26,7 @@ # The ``periodic_table_index`` arguments tells TorchANI to use element index # in periodic table to index species. If not specified, you need to use # 0, 1, 2, 3, ... to index species -model = torchani.models.ANI1ccx(periodic_table_index=True).to(device) +model = torchani.models.ANI2x(periodic_table_index=True).to(device) ############################################################################### # Now let's define the coordinate and species. If you just want to compute the diff --git a/torchani/models.py b/torchani/models.py index 598ec382a..70096810f 100644 --- a/torchani/models.py +++ b/torchani/models.py @@ -26,7 +26,9 @@ import os import io import requests +import glob import zipfile +import shutil import torch from torch import Tensor from typing import Tuple, Optional @@ -73,15 +75,19 @@ def _from_neurochem_resources(cls, info_file_path, periodic_table_index=False, m @staticmethod def _parse_neurochem_resources(info_file_path): def get_resource(resource_path, file_path): - return os.path.join(resource_path, 'resources/' + file_path) + return os.path.join(resource_path, file_path) - resource_path = os.path.dirname(__file__) - local_dir = os.path.expanduser('~/.local/torchani') + resource_path = os.path.join(os.path.dirname(__file__), 'resources/') + local_dir = os.path.expanduser('~/.local/torchani/') + repo_name = "ani-model-zoo" + tag_name = "ani-2x" + extracted_name = '{}-{}'.format(repo_name, tag_name) + url = "https://github.com/aiqm/{}/archive/{}.zip".format(repo_name, tag_name) if not os.path.isfile(get_resource(resource_path, info_file_path)): if not os.path.isfile(get_resource(local_dir, info_file_path)): print('Downloading ANI model parameters ...') - resource_res = requests.get("https://www.dropbox.com/sh/otrzul6yuye8uzs/AABuaihE22vtaB_rdrI0r6TUa?dl=1") + resource_res = requests.get(url) resource_zip = zipfile.ZipFile(io.BytesIO(resource_res.content)) try: resource_zip.extractall(resource_path) @@ -91,6 +97,14 @@ def get_resource(resource_path, file_path): else: resource_path = local_dir + files = glob.glob(os.path.join(resource_path, extracted_name, "resources", "*")) + for f in files: + try: + shutil.move(f, resource_path) + except shutil.Error: + pass + shutil.rmtree(os.path.join(resource_path, extracted_name)) + info_file = get_resource(resource_path, info_file_path) with open(info_file) as f: