Skip to content

Commit

Permalink
fix: Retriever uses TopicRecord now
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Mar 5, 2024
1 parent d0286bf commit 17936da
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
10 changes: 8 additions & 2 deletions src/xpmir/conversation/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import cached_property
from datamaestro.record import RecordTypesCache
import numpy as np
from datamaestro_text.data.ir import TopicRecord
from datamaestro_text.data.conversation import (
Expand All @@ -22,6 +23,11 @@ class DatasetConversationEntrySampler(BaseSampler):
def conversations(self):
return list(self.dataset.__iter__())

def __post_init__(self):
super().__post_init__()

self._recordtypes = RecordTypesCache()

def __iter__(self) -> RandomSerializableIterator[TopicConversationRecord]:
def generator(random: np.random.RandomState):
while True:
Expand All @@ -38,8 +44,8 @@ def generator(random: np.random.RandomState):
node_ix = random.randint(len(nodes))
node = nodes[node_ix]

node = node.entry().add(
ConversationHistoryItem(node.history()), no_check=True
node = self._recordtypes.update(
node.entry(), ConversationHistoryItem(node.history())
)

yield node
Expand Down
1 change: 1 addition & 0 deletions src/xpmir/index/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def execute(self):
]

# Cleanup the index before starting
# ENHANCE: recover index build when possible
from shutil import rmtree

if self.index_path.is_dir():
Expand Down
8 changes: 4 additions & 4 deletions src/xpmir/interfaces/anserini.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import List, Optional
from experimaestro import tqdm as xpmtqdm, Task, Meta

from datamaestro_text.data.ir import DocumentStore, TextItem, IDItem
from datamaestro_text.data.ir import DocumentStore, TextItem, IDItem, TopicRecord
import datamaestro_text.data.ir.csv as ir_csv
from datamaestro_text.data.ir.trec import (
Documents,
Expand Down Expand Up @@ -273,7 +273,7 @@ class AnseriniRetriever(Retriever):
Attributes:
index: The Anserini index
model: the model used to search. Only suupports BM25 so far.
model: the model used to search. Only supports BM25 so far.
k: Number of results to retrieve
"""

Expand Down Expand Up @@ -306,10 +306,10 @@ def _get_store(self) -> Optional[Index]:
if self.index.storeContents:
return self.index

def retrieve(self, query: str) -> List[ScoredDocument]:
def retrieve(self, record: TopicRecord) -> List[ScoredDocument]:
# see
# https://github.com/castorini/anserini/blob/master/src/main/java/io/anserini/search/SimpleSearcher.java
hits = self.searcher.search(query, k=self.k)
hits = self.searcher.search(record[TextItem].text, k=self.k)
store = self.get_store()

# Batch retrieve documents
Expand Down
6 changes: 3 additions & 3 deletions src/xpmir/rankers/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class FullRetriever(Retriever):

documents: Param[Documents]

def retrieve(self, query: str) -> List[ScoredDocument]:
def retrieve(self, record: TopicRecord) -> List[ScoredDocument]:
return [ScoredDocument(doc, 0.0) for doc in self.documents]


Expand Down Expand Up @@ -111,9 +111,9 @@ def score(
# Add each result to the full document list
scored_documents.extend(new_scores)

def retrieve(self, query: str):
def retrieve(self, record: TopicRecord) -> List[ScoredDocument]:
# Only use retrieve_all
return self.retrieve_all({"_": query})["_"]
return self.retrieve_all({"_": record})["_"]

def retrieve_all(
self, queries: Dict[str, TopicRecord]
Expand Down

0 comments on commit 17936da

Please sign in to comment.