Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifying the local Catalog functionality #1195

Merged
merged 7 commits into from
Jul 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions ersilia/hub/content/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...auth.auth import Auth
from ...db.hubdata.interfaces import AirtableInterface
import validators
from functools import lru_cache

try:
from validators import ValidationFailure
Expand Down Expand Up @@ -45,7 +46,7 @@
except:
Hdf5Explorer = None

from ...default import CARD_FILE, METADATA_JSON_FILE, SERVICE_CLASS_FILE
from ...default import EOS, INFORMATION_FILE, METADATA_JSON_FILE, SERVICE_CLASS_FILE


class BaseInformation(ErsiliaBase):
Expand All @@ -60,6 +61,7 @@ def __init__(self, config_json):
self._mode = None
self._task = None
self._input = None

self._input_shape = None
self._output = None
self._output_type = None
Expand Down Expand Up @@ -702,35 +704,48 @@ def get(self, model_id):


class LocalCard(ErsiliaBase):
"""
This class provides information on models that have been fetched and are available locally.
It retrieves and caches information about the models.
"""
def __init__(self, config_json):
ErsiliaBase.__init__(self, config_json=config_json)

def get(self, model_id):

@lru_cache(maxsize=32)
def _load_data(self, model_id):
"""
Loads the JSON data from the model's information file.
"""
model_path = self._model_path(model_id)
card_path = os.path.join(model_path, CARD_FILE)
if os.path.exists(card_path):
with open(card_path, "r") as f:
card = json.load(f)
return card
else:
return None

file_path = os.path.join(model_path, INFORMATION_FILE)

if os.path.exists(file_path):
try:
with open(file_path, "r") as f:
return json.load(f)
except json.JSONDecodeError:
return None
return None

def get(self, model_id):
"""
Returns the 'card' information for the specified model.
"""
data = self._load_data(model_id)
if data:
return data.get("card")
return None

def get_service_class(self, model_id):
"""
This method returns information about how the model was fetched by reading
the service class file located in the model's bundle directory. If the service
class file does not exist, it returns None.
Returns the 'service class' information for the specified model.
"""
service_class_path = os.path.join(
self._get_bundle_location(model_id), SERVICE_CLASS_FILE
)

if os.path.exists(service_class_path):
with open(service_class_path, "r") as f:
service_class = f.read().strip()
return service_class
else:
return None

data = self._load_data(model_id)
if data:
return data.get("service_class")
return None


class LakeCard(ErsiliaBase):
def __init__(self, config_json=None):
Expand Down