diff --git a/requirements.txt b/requirements.txt index a7e5fa7..0d61bf8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Experimaestro -experimaestro>=1.5.0 +experimaestro>=1.5.1 datamaestro>=1.0.1 datamaestro_text>=2024.2.29 ir_datasets diff --git a/src/xpmir/conversation/learning/__init__.py b/src/xpmir/conversation/learning/__init__.py index 1cbe147..23eee83 100644 --- a/src/xpmir/conversation/learning/__init__.py +++ b/src/xpmir/conversation/learning/__init__.py @@ -1,3 +1,4 @@ +from functools import cached_property import numpy as np from datamaestro_text.data.ir import TopicRecord from datamaestro_text.data.conversation import ( @@ -17,12 +18,16 @@ class DatasetConversationEntrySampler(BaseSampler): dataset: Param[ConversationDataset] """The conversation dataset""" + @cached_property + def conversations(self): + return list(self.dataset.__iter__()) + def __iter__(self) -> RandomSerializableIterator[TopicConversationRecord]: def generator(random: np.random.RandomState): while True: # Pick a random conversation - conversation_ix = random.randint(0, len(self.dataset)) - conversation = self.dataset[conversation_ix] + conversation_ix = random.randint(0, len(self.conversations)) + conversation = self.conversations[conversation_ix] # Pick a random topic record entry nodes = [ @@ -33,8 +38,10 @@ def generator(random: np.random.RandomState): node_ix = random.randint(len(nodes)) node = nodes[node_ix] - yield node.entry().add( + node = node.entry().add( ConversationHistoryItem(node.history()), no_check=True ) + yield node + return RandomSerializableIterator(self.random, generator) diff --git a/src/xpmir/conversation/models/cosplade.py b/src/xpmir/conversation/models/cosplade.py index 81e53d4..0121438 100644 --- a/src/xpmir/conversation/models/cosplade.py +++ b/src/xpmir/conversation/models/cosplade.py @@ -19,6 +19,9 @@ ) from xpmir.letor.trainers.alignment import AlignmentLoss from xpmir.neural.splade import SpladeTextEncoderV2 +from xpmir.utils.logging import easylog + +logger = easylog() @define @@ -109,7 +112,8 @@ def forward(self, records: List[TopicConversationRecord]): ) pair_origins.append(ix) elif isinstance(item, AnswerConversationRecord): - answer = item + if (answer := item.get(AnswerEntry)) is None: + logger.warning("Answer record has no answer entry") else: # Ignore anything which is not a pair topic-response answer = None diff --git a/src/xpmir/evaluation.py b/src/xpmir/evaluation.py index 4797eac..14b6b56 100644 --- a/src/xpmir/evaluation.py +++ b/src/xpmir/evaluation.py @@ -10,7 +10,6 @@ AdhocAssessments, Documents, AdhocResults, - TextItem, IDItem, ) from datamaestro_text.data.ir.trec import TrecAdhocRun, TrecAdhocResults @@ -73,7 +72,7 @@ def print_line(fp, measure, scope, value): def get_run(retriever: Retriever, dataset: Adhoc): """Returns the scored documents for each topic in a dataset""" results = retriever.retrieve_all( - {topic[IDItem].id: topic[TextItem].text for topic in dataset.topics.iter()} + {topic[IDItem].id: topic for topic in dataset.topics.iter()} ) return { qid: {sd.document[IDItem].id: sd.score for sd in scoredocs} diff --git a/src/xpmir/utils/iter.py b/src/xpmir/utils/iter.py index 5a25ba8..d837bfc 100644 --- a/src/xpmir/utils/iter.py +++ b/src/xpmir/utils/iter.py @@ -17,7 +17,6 @@ ) from xpmir.utils.utils import easylog import logging -import atexit logger = easylog() @@ -394,22 +393,7 @@ def detach(self): def __next__(self): # Start a process if needed self.start() - - # Get the next element - element = self.queue.get() - - # Last element - if isinstance(element, StopIterationClass): - atexit.unregister(self.kill_subprocess) - raise StopIteration() - - # An exception occurred - if isinstance(element, Exception): - atexit.unregister(self.kill_subprocess) - logging.warning("Got an exception in the iteration process") - raise RuntimeError("Error in iterator process") from element - - return element + return next(self.mp_iterator) class StatefullIterator(Iterator[Tuple[T, State]], Protocol[State]):