Skip to content

Commit

Permalink
Update model URL (#484)
Browse files Browse the repository at this point in the history
* Update model URL

* Update models.py

* Update models.py

* fix

* cleanup

* save
  • Loading branch information
zasdfgbnm authored Jun 11, 2020
1 parent 9e7baa6 commit ac27ed3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/energy_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# The ``periodic_table_index`` arguments tells TorchANI to use element index
# in periodic table to index species. If not specified, you need to use
# 0, 1, 2, 3, ... to index species
model = torchani.models.ANI1ccx(periodic_table_index=True).to(device)
model = torchani.models.ANI2x(periodic_table_index=True).to(device)

###############################################################################
# Now let's define the coordinate and species. If you just want to compute the
Expand Down
22 changes: 18 additions & 4 deletions torchani/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
import os
import io
import requests
import glob
import zipfile
import shutil
import torch
from torch import Tensor
from typing import Tuple, Optional
Expand Down Expand Up @@ -73,15 +75,19 @@ def _from_neurochem_resources(cls, info_file_path, periodic_table_index=False, m
@staticmethod
def _parse_neurochem_resources(info_file_path):
def get_resource(resource_path, file_path):
return os.path.join(resource_path, 'resources/' + file_path)
return os.path.join(resource_path, file_path)

resource_path = os.path.dirname(__file__)
local_dir = os.path.expanduser('~/.local/torchani')
resource_path = os.path.join(os.path.dirname(__file__), 'resources/')
local_dir = os.path.expanduser('~/.local/torchani/')
repo_name = "ani-model-zoo"
tag_name = "ani-2x"
extracted_name = '{}-{}'.format(repo_name, tag_name)
url = "https://github.com/aiqm/{}/archive/{}.zip".format(repo_name, tag_name)

if not os.path.isfile(get_resource(resource_path, info_file_path)):
if not os.path.isfile(get_resource(local_dir, info_file_path)):
print('Downloading ANI model parameters ...')
resource_res = requests.get("https://www.dropbox.com/sh/otrzul6yuye8uzs/AABuaihE22vtaB_rdrI0r6TUa?dl=1")
resource_res = requests.get(url)
resource_zip = zipfile.ZipFile(io.BytesIO(resource_res.content))
try:
resource_zip.extractall(resource_path)
Expand All @@ -91,6 +97,14 @@ def get_resource(resource_path, file_path):
else:
resource_path = local_dir

files = glob.glob(os.path.join(resource_path, extracted_name, "resources", "*"))
for f in files:
try:
shutil.move(f, resource_path)
except shutil.Error:
pass
shutil.rmtree(os.path.join(resource_path, extracted_name))

info_file = get_resource(resource_path, info_file_path)

with open(info_file) as f:
Expand Down

0 comments on commit ac27ed3

Please sign in to comment.