Skip to content

Commit

Permalink
Modifying the local Catalog functionality (#1195)
Browse files Browse the repository at this point in the history
* Add the path to information.json to local card

* change file path

* Remove local_card_file to avoid duplication

Update card.py

* Implementing rwview changes

---------

Co-authored-by: Dhanshree Arora <DhanshreeA@users.noreply.github.com>
  • Loading branch information
Malikbadmus and DhanshreeA authored Jul 15, 2024
1 parent bd31202 commit 3b89bf0
Showing 1 changed file with 39 additions and 24 deletions.
63 changes: 39 additions & 24 deletions ersilia/hub/content/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...auth.auth import Auth
from ...db.hubdata.interfaces import AirtableInterface
import validators
from functools import lru_cache

try:
from validators import ValidationFailure
Expand Down Expand Up @@ -45,7 +46,7 @@
except:
Hdf5Explorer = None

from ...default import CARD_FILE, METADATA_JSON_FILE, SERVICE_CLASS_FILE
from ...default import EOS, INFORMATION_FILE, METADATA_JSON_FILE, SERVICE_CLASS_FILE


class BaseInformation(ErsiliaBase):
Expand All @@ -60,6 +61,7 @@ def __init__(self, config_json):
self._mode = None
self._task = None
self._input = None

self._input_shape = None
self._output = None
self._output_type = None
Expand Down Expand Up @@ -702,35 +704,48 @@ def get(self, model_id):


class LocalCard(ErsiliaBase):
"""
This class provides information on models that have been fetched and are available locally.
It retrieves and caches information about the models.
"""
def __init__(self, config_json):
ErsiliaBase.__init__(self, config_json=config_json)

def get(self, model_id):

@lru_cache(maxsize=32)
def _load_data(self, model_id):
"""
Loads the JSON data from the model's information file.
"""
model_path = self._model_path(model_id)
card_path = os.path.join(model_path, CARD_FILE)
if os.path.exists(card_path):
with open(card_path, "r") as f:
card = json.load(f)
return card
else:
return None

file_path = os.path.join(model_path, INFORMATION_FILE)

if os.path.exists(file_path):
try:
with open(file_path, "r") as f:
return json.load(f)
except json.JSONDecodeError:
return None
return None

def get(self, model_id):
"""
Returns the 'card' information for the specified model.
"""
data = self._load_data(model_id)
if data:
return data.get("card")
return None

def get_service_class(self, model_id):
"""
This method returns information about how the model was fetched by reading
the service class file located in the model's bundle directory. If the service
class file does not exist, it returns None.
Returns the 'service class' information for the specified model.
"""
service_class_path = os.path.join(
self._get_bundle_location(model_id), SERVICE_CLASS_FILE
)

if os.path.exists(service_class_path):
with open(service_class_path, "r") as f:
service_class = f.read().strip()
return service_class
else:
return None

data = self._load_data(model_id)
if data:
return data.get("service_class")
return None


class LakeCard(ErsiliaBase):
def __init__(self, config_json=None):
Expand Down

0 comments on commit 3b89bf0

Please sign in to comment.