Skip to content

Commit

Permalink
Refactored model loading in BaseTask class
Browse files Browse the repository at this point in the history
  • Loading branch information
chakravarthik27 committed Nov 3, 2023
1 parent c98d84d commit 5806513
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions langtest/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,21 @@ def load_model(cls, model_path: str, model_hub: str, *args, **kwargs):
if "user_prompt" in kwargs:
cls.user_prompt = kwargs.get("user_prompt")
kwargs.pop("user_prompt")

if cls._name in models[model_hub] or cls._name in models["llm"]:
try:
if model_hub in LANGCHAIN_HUBS:
# LLM models
cls.model = models["llm"][cls._name].load_model(
hub=model_hub, path=model_path, *args, **kwargs
)
else:
# JSL, Huggingface, and Spacy models
cls.model = models[model_hub][cls._name].load_model(
path=model_path, *args, **kwargs
)
return cls.model
except TypeError:
raise ValueError(Errors.E081.format(hub=model_hub))

if model_hub in LANGCHAIN_HUBS:
# LLM models
cls.model = models["llm"][cls._name].load_model(
hub=model_hub, path=model_path, *args, **kwargs
)
else:
# JSL, Huggingface, and Spacy models
cls.model = models[model_hub][cls._name].load_model(
path=model_path, *args, **kwargs
)
return cls.model

@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
Expand Down

0 comments on commit 5806513

Please sign in to comment.