diff --git a/zshot/evaluation/zshot_evaluate.py b/zshot/evaluation/zshot_evaluate.py index 93dc9d1..5028248 100644 --- a/zshot/evaluation/zshot_evaluate.py +++ b/zshot/evaluation/zshot_evaluate.py @@ -78,8 +78,10 @@ def make_table(data: Dict, title: str = ""): return table tables = [] - mode = evaluation.pop('evaluation_mode') + mode = evaluation.get('evaluation_mode') for component in evaluation: + if component == 'evaluation_mode': + continue # General evaluation t_repr = make_table(evaluation[component], f"{component} - {name} \n General - {mode}-based").get_string() tables.append(fix_table_title(t_repr)) diff --git a/zshot/linker/linker_gliner.py b/zshot/linker/linker_gliner.py index addac6a..8277705 100644 --- a/zshot/linker/linker_gliner.py +++ b/zshot/linker/linker_gliner.py @@ -35,6 +35,7 @@ def load_models(self): """ Load GLINER model """ if self.model is None: self.model = GLiNER.from_pretrained(self.model_name, cache_dir=MODELS_CACHE_PATH).to(self.device) + self.model.eval() def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: """ diff --git a/zshot/linker/linker_smxm.py b/zshot/linker/linker_smxm.py index 10fa147..cf36347 100644 --- a/zshot/linker/linker_smxm.py +++ b/zshot/linker/linker_smxm.py @@ -39,6 +39,7 @@ def load_models(self): self.model = BertTaggerMultiClass.from_pretrained( self.model_name, output_hidden_states=True, cache_dir=MODELS_CACHE_PATH ).to(self.device) + self.model.eval() def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: """ diff --git a/zshot/mentions_extractor/mentions_extractor_gliner.py b/zshot/mentions_extractor/mentions_extractor_gliner.py index 120e2a9..81690e0 100644 --- a/zshot/mentions_extractor/mentions_extractor_gliner.py +++ b/zshot/mentions_extractor/mentions_extractor_gliner.py @@ -30,6 +30,7 @@ def load_models(self): """ Load GLINER model """ if self.model is None: self.model = GLiNER.from_pretrained(self.model_name, cache_dir=MODELS_CACHE_PATH).to(self.device) + self.model.eval() def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: """ diff --git a/zshot/mentions_extractor/mentions_extractor_smxm.py b/zshot/mentions_extractor/mentions_extractor_smxm.py index 60c1a7e..6bac5ce 100644 --- a/zshot/mentions_extractor/mentions_extractor_smxm.py +++ b/zshot/mentions_extractor/mentions_extractor_smxm.py @@ -34,6 +34,7 @@ def load_models(self): self.model = BertTaggerMultiClass.from_pretrained( self.model_name, output_hidden_states=True, cache_dir=MODELS_CACHE_PATH ).to(self.device) + self.model.eval() def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: """