Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ZakiaYahya committed Jul 10, 2023
1 parent 75b5503 commit 2436539
Showing 1 changed file with 38 additions and 22 deletions.
60 changes: 38 additions & 22 deletions ersilia/hub/fetch/actions/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class ModelRepositoryGetter(BaseAction):
def __init__(self, model_id, config_json, force_from_github, force_from_s3):
def __init__(self, model_id, config_json, force_from_github, force_from_s3, repo_path):
BaseAction.__init__(
self, model_id=model_id, config_json=config_json, credentials_json=None
)
Expand All @@ -28,6 +28,7 @@ def __init__(self, model_id, config_json, force_from_github, force_from_s3):
self.org = self.cfg.HUB.ORG
self.force_from_github = force_from_github
self.force_from_s3 = force_from_s3
self.repo_path = repo_path

def _dev_model_path(self):
pt = Paths()
Expand Down Expand Up @@ -87,6 +88,23 @@ def _change_py_version_in_dockerfile_if_necessary(self):
for s in S:
f.write(s + os.linesep)

@staticmethod
def _is_sudo():
return os.geteuid() == 0

def _remove_sudo_if_root(self):
path = self._model_path(model_id=self.model_id)
dockerfile_path = os.path.join(path, "Dockerfile")
if self.is_sudo():
self.logger.debug("User is root! Removing sudo commands")
with open(dockerfile_path, "r") as f:
content = f.read()
content = content.replace("RUN sudo ", "RUN ")
with open(dockerfile_path, "w") as f:
f.write(content)
else:
self.logger.debug("User is not root")

@throw_ersilia_exception
def get(self):
"""Copy model repository from local or download from S3 or GitHub"""
Expand All @@ -98,21 +116,25 @@ def get(self):
)
self._copy_from_local(dev_model_path, folder)
else:
if self.force_from_github:
self._copy_from_github(folder)
if self.repo_path is not None:
self._copy_from_local(self.repo_path, folder)
else:
try:
self.logger.debug("Trying to download from S3")
self._copy_zip_from_s3(folder)
except:
self.logger.debug(
"Could not download in zip format in S3. Downloading from GitHub repository."
)
if self.force_from_s3:
raise S3DownloaderError(model_id=self.model_id)
else:
self._copy_from_github(folder)
if self.force_from_github:
self._copy_from_github(folder)
else:
try:
self.logger.debug("Trying to download from S3")
self._copy_zip_from_s3(folder)
except:
self.logger.debug(
"Could not download in zip format in S3. Downloading from GitHub repository."
)
if self.force_from_s3:
raise S3DownloaderError(model_id=self.model_id)
else:
self._copy_from_github(folder)
self._change_py_version_in_dockerfile_if_necessary()
self._remove_sudo_if_root()


#  TODO: work outside GIT LFS
Expand Down Expand Up @@ -158,6 +180,7 @@ def __init__(
config_json=config_json,
force_from_github=force_from_gihtub,
force_from_s3=force_from_s3,
repo_path=repo_path
)
self.mpg = ModelParametersGetter(model_id=model_id, config_json=config_json)

Expand All @@ -167,17 +190,10 @@ def _get_repository(self):
def _get_model_parameters(self):
self.mpg.get()

def _copy_from_repo_path(self):
dst = self._model_path(self.model_id)
src = self.repo_path
shutil.copytree(src, dst)

@throw_ersilia_exception
def get(self):
self._get_repository()
if self.repo_path is None:
self._get_repository()
self._get_model_parameters()
else:
self._copy_from_repo_path()
if not os.path.exists(self._model_path(self.model_id)):
raise FolderNotFoundError(os.path.exists(self._model_path(self.model_id)))

0 comments on commit 2436539

Please sign in to comment.