diff --git a/src/xpmir/conversation/learning/__init__.py b/src/xpmir/conversation/learning/__init__.py index 23eee83..4b0dc17 100644 --- a/src/xpmir/conversation/learning/__init__.py +++ b/src/xpmir/conversation/learning/__init__.py @@ -1,4 +1,5 @@ from functools import cached_property +from datamaestro.record import RecordTypesCache import numpy as np from datamaestro_text.data.ir import TopicRecord from datamaestro_text.data.conversation import ( @@ -22,6 +23,11 @@ class DatasetConversationEntrySampler(BaseSampler): def conversations(self): return list(self.dataset.__iter__()) + def __post_init__(self): + super().__post_init__() + + self._recordtypes = RecordTypesCache() + def __iter__(self) -> RandomSerializableIterator[TopicConversationRecord]: def generator(random: np.random.RandomState): while True: @@ -38,8 +44,8 @@ def generator(random: np.random.RandomState): node_ix = random.randint(len(nodes)) node = nodes[node_ix] - node = node.entry().add( - ConversationHistoryItem(node.history()), no_check=True + node = self._recordtypes.update( + node.entry(), ConversationHistoryItem(node.history()) ) yield node diff --git a/src/xpmir/index/sparse.py b/src/xpmir/index/sparse.py index 18a6daf..7e7589b 100644 --- a/src/xpmir/index/sparse.py +++ b/src/xpmir/index/sparse.py @@ -236,6 +236,7 @@ def execute(self): ] # Cleanup the index before starting + # ENHANCE: recover index build when possible from shutil import rmtree if self.index_path.is_dir(): diff --git a/src/xpmir/interfaces/anserini.py b/src/xpmir/interfaces/anserini.py index 9710027..45176ea 100644 --- a/src/xpmir/interfaces/anserini.py +++ b/src/xpmir/interfaces/anserini.py @@ -12,7 +12,7 @@ from typing import List, Optional from experimaestro import tqdm as xpmtqdm, Task, Meta -from datamaestro_text.data.ir import DocumentStore, TextItem, IDItem +from datamaestro_text.data.ir import DocumentStore, TextItem, IDItem, TopicRecord import datamaestro_text.data.ir.csv as ir_csv from datamaestro_text.data.ir.trec import ( Documents, @@ -273,7 +273,7 @@ class AnseriniRetriever(Retriever): Attributes: index: The Anserini index - model: the model used to search. Only suupports BM25 so far. + model: the model used to search. Only supports BM25 so far. k: Number of results to retrieve """ @@ -306,10 +306,10 @@ def _get_store(self) -> Optional[Index]: if self.index.storeContents: return self.index - def retrieve(self, query: str) -> List[ScoredDocument]: + def retrieve(self, record: TopicRecord) -> List[ScoredDocument]: # see # https://github.com/castorini/anserini/blob/master/src/main/java/io/anserini/search/SimpleSearcher.java - hits = self.searcher.search(query, k=self.k) + hits = self.searcher.search(record[TextItem].text, k=self.k) store = self.get_store() # Batch retrieve documents diff --git a/src/xpmir/rankers/full.py b/src/xpmir/rankers/full.py index 2027fcf..2b70a37 100644 --- a/src/xpmir/rankers/full.py +++ b/src/xpmir/rankers/full.py @@ -20,7 +20,7 @@ class FullRetriever(Retriever): documents: Param[Documents] - def retrieve(self, query: str) -> List[ScoredDocument]: + def retrieve(self, record: TopicRecord) -> List[ScoredDocument]: return [ScoredDocument(doc, 0.0) for doc in self.documents] @@ -111,9 +111,9 @@ def score( # Add each result to the full document list scored_documents.extend(new_scores) - def retrieve(self, query: str): + def retrieve(self, record: TopicRecord) -> List[ScoredDocument]: # Only use retrieve_all - return self.retrieve_all({"_": query})["_"] + return self.retrieve_all({"_": record})["_"] def retrieve_all( self, queries: Dict[str, TopicRecord]