Skip to content

Commit

Permalink
Fix for CoSPLADE 1st stage
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Feb 29, 2024
1 parent 96f9932 commit d0286bf
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 24 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Experimaestro

experimaestro>=1.5.0
experimaestro>=1.5.1
datamaestro>=1.0.1
datamaestro_text>=2024.2.29
ir_datasets
Expand Down
13 changes: 10 additions & 3 deletions src/xpmir/conversation/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import cached_property
import numpy as np
from datamaestro_text.data.ir import TopicRecord
from datamaestro_text.data.conversation import (
Expand All @@ -17,12 +18,16 @@ class DatasetConversationEntrySampler(BaseSampler):
dataset: Param[ConversationDataset]
"""The conversation dataset"""

@cached_property
def conversations(self):
return list(self.dataset.__iter__())

def __iter__(self) -> RandomSerializableIterator[TopicConversationRecord]:
def generator(random: np.random.RandomState):
while True:
# Pick a random conversation
conversation_ix = random.randint(0, len(self.dataset))
conversation = self.dataset[conversation_ix]
conversation_ix = random.randint(0, len(self.conversations))
conversation = self.conversations[conversation_ix]

# Pick a random topic record entry
nodes = [
Expand All @@ -33,8 +38,10 @@ def generator(random: np.random.RandomState):
node_ix = random.randint(len(nodes))
node = nodes[node_ix]

yield node.entry().add(
node = node.entry().add(
ConversationHistoryItem(node.history()), no_check=True
)

yield node

return RandomSerializableIterator(self.random, generator)
6 changes: 5 additions & 1 deletion src/xpmir/conversation/models/cosplade.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
)
from xpmir.letor.trainers.alignment import AlignmentLoss
from xpmir.neural.splade import SpladeTextEncoderV2
from xpmir.utils.logging import easylog

logger = easylog()


@define
Expand Down Expand Up @@ -109,7 +112,8 @@ def forward(self, records: List[TopicConversationRecord]):
)
pair_origins.append(ix)
elif isinstance(item, AnswerConversationRecord):
answer = item
if (answer := item.get(AnswerEntry)) is None:
logger.warning("Answer record has no answer entry")
else:
# Ignore anything which is not a pair topic-response
answer = None
Expand Down
3 changes: 1 addition & 2 deletions src/xpmir/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
AdhocAssessments,
Documents,
AdhocResults,
TextItem,
IDItem,
)
from datamaestro_text.data.ir.trec import TrecAdhocRun, TrecAdhocResults
Expand Down Expand Up @@ -73,7 +72,7 @@ def print_line(fp, measure, scope, value):
def get_run(retriever: Retriever, dataset: Adhoc):
"""Returns the scored documents for each topic in a dataset"""
results = retriever.retrieve_all(
{topic[IDItem].id: topic[TextItem].text for topic in dataset.topics.iter()}
{topic[IDItem].id: topic for topic in dataset.topics.iter()}
)
return {
qid: {sd.document[IDItem].id: sd.score for sd in scoredocs}
Expand Down
18 changes: 1 addition & 17 deletions src/xpmir/utils/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)
from xpmir.utils.utils import easylog
import logging
import atexit

logger = easylog()

Expand Down Expand Up @@ -394,22 +393,7 @@ def detach(self):
def __next__(self):
# Start a process if needed
self.start()

# Get the next element
element = self.queue.get()

# Last element
if isinstance(element, StopIterationClass):
atexit.unregister(self.kill_subprocess)
raise StopIteration()

# An exception occurred
if isinstance(element, Exception):
atexit.unregister(self.kill_subprocess)
logging.warning("Got an exception in the iteration process")
raise RuntimeError("Error in iterator process") from element

return element
return next(self.mp_iterator)


class StatefullIterator(Iterator[Tuple[T, State]], Protocol[State]):
Expand Down

0 comments on commit d0286bf

Please sign in to comment.