Skip to content

Commit

Permalink
Update for new record interface
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Mar 6, 2024
1 parent 80de710 commit ac7afff
Show file tree
Hide file tree
Showing 16 changed files with 78 additions and 107 deletions.
7 changes: 1 addition & 6 deletions src/xpmir/conversation/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import cached_property
from datamaestro.record import RecordTypesCache
import numpy as np
from datamaestro_text.data.ir import TopicRecord
from datamaestro_text.data.conversation import (
Expand All @@ -26,8 +25,6 @@ def conversations(self):
def __post_init__(self):
super().__post_init__()

self._recordtypes = RecordTypesCache("Conversation", ConversationHistoryItem)

def __iter__(self) -> RandomSerializableIterator[TopicConversationRecord]:
def generator(random: np.random.RandomState):
while True:
Expand All @@ -44,9 +41,7 @@ def generator(random: np.random.RandomState):
node_ix = random.randint(len(nodes))
node = nodes[node_ix]

node = self._recordtypes.update(
node.entry(), ConversationHistoryItem(node.history())
)
node = node.entry().update(ConversationHistoryItem(node.history()))

yield node

Expand Down
1 change: 1 addition & 0 deletions src/xpmir/conversation/models/cosplade.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __initialize__(self, options):
self.queries_encoder.initialize(options)
self.history_encoder.initialize(options)

@property
def dimension(self):
return self.queries_encoder.dimension

Expand Down
16 changes: 6 additions & 10 deletions src/xpmir/documents/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
from experimaestro import Param, Config
import torch
import numpy as np
from datamaestro_text.data.ir import DocumentStore, TextItem
from datamaestro_text.data.ir.base import (
SimpleTextTopicRecord,
SimpleTextDocumentRecord,
)
from datamaestro_text.data.ir import DocumentStore, TextItem, create_record
from xpmir.letor import Random
from xpmir.letor.records import DocumentRecord, PairwiseRecord, ProductRecords
from xpmir.letor.samplers import BatchwiseSampler, PairwiseSampler
Expand Down Expand Up @@ -150,9 +146,9 @@ def iter(random: np.random.RandomState):
continue

yield PairwiseRecord(
SimpleTextTopicRecord.from_text(spans_pos_qry[0]),
SimpleTextDocumentRecord.from_text(spans_pos_qry[1]),
SimpleTextDocumentRecord.from_text(spans_neg[random.randint(0, 2)]),
create_record(text=spans_pos_qry[0]),
create_record(text=spans_pos_qry[1]),
create_record(text=spans_neg[random.randint(0, 2)]),
)

return RandomSerializableIterator(self.random, iter)
Expand All @@ -174,8 +170,8 @@ def iterator(random: np.random.RandomState):
res = self.get_text_span(text, random)
if not res:
continue
batch.add_topics(SimpleTextTopicRecord.from_text(res[0]))
batch.add_documents(SimpleTextDocumentRecord.from_text(res[1]))
batch.add_topics(create_record(text=res[0]))
batch.add_documents(create_record(text=res[1]))
batch.set_relevances(relevances)
yield batch

Expand Down
2 changes: 1 addition & 1 deletion src/xpmir/index/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def train(
index.train(sample)

def execute(self):
self.device.execute(self._execute, None)
self.device.execute(self._execute)

def _execute(self, device_information: DeviceInformation):
# Initialization hooks
Expand Down
3 changes: 2 additions & 1 deletion src/xpmir/index/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def task_outputs(self, dep):
)

def execute(self):
mp.set_start_method("spawn")
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")

max_docs = (
self.documents.documentcount
Expand Down
5 changes: 3 additions & 2 deletions src/xpmir/letor/distillation/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ScoredItem,
SimpleTextItem,
IDItem,
create_record,
)
from experimaestro import Config, Meta, Param
from xpmir.learning import Sampler
Expand Down Expand Up @@ -85,9 +86,9 @@ def iterate():
with self.path.open("rt") as fp:
for row in csv.reader(fp, delimiter="\t"):
if self.with_queryid:
query = TopicRecord.from_id(row[2])
query = create_record(id=row[2])
else:
query = TopicRecord.from_text(row[2])
query = create_record(text=row[2])

if self.with_docid:
documents = (
Expand Down
9 changes: 3 additions & 6 deletions src/xpmir/letor/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
TopicRecord,
DocumentRecord,
TextItem,
SimpleTextTopicRecord,
SimpleTextDocumentRecord,
create_record,
)
from typing import (
Iterable,
Expand Down Expand Up @@ -145,10 +144,8 @@ def from_texts(
relevances: Optional[List[float]] = None,
):
records = PointwiseRecords()
records.topics = list(map(lambda t: SimpleTextTopicRecord.from_text(t), topics))
records.documents = list(
map(lambda t: SimpleTextDocumentRecord.from_text(t), documents)
)
records.topics = list(map(lambda t: create_record(text=t), topics))
records.documents = list(map(lambda t: create_record(text=t), documents))
records.relevances = relevances
return records

Expand Down
29 changes: 11 additions & 18 deletions src/xpmir/letor/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from typing import Iterator, List, Tuple, Dict, Any
import numpy as np
from datamaestro.record import recordtypes
from datamaestro.record import Record
from datamaestro_text.data.ir import (
Adhoc,
TrainingTriplets,
Expand All @@ -12,10 +12,8 @@
DocumentStore,
TextItem,
SimpleTextItem,
IDDocumentRecord,
SimpleTextTopicRecord,
create_record,
DocumentRecord,
IDTopicRecord,
IDItem,
)
from experimaestro import Param, tqdm, Task, Annotated, pathgenerator
Expand Down Expand Up @@ -344,7 +342,7 @@ def iter(random):
random.randint(0, len(self.topics))
]
yield PairwiseRecord(
SimpleTextTopicRecord.from_text(title),
create_record(text=title),
self.sample(positives),
self.sample(negatives),
)
Expand Down Expand Up @@ -453,7 +451,7 @@ def __next__(self):
)
if neg_id != pos.id:
break
neg = IDDocumentRecord.from_id(neg_id)
neg = create_record(id=neg_id)
else:
negatives = sample.negatives[self.negative_algo]
neg = negatives[self.random.randint(len(negatives))]
Expand All @@ -469,7 +467,7 @@ def __next__(self):

# --- Dataloader

# FIXME: need to fix the change where there is a list of queries and type of return

class TSVPairwiseSampleDataset(PairwiseSampleDataset):
"""Read the pairwise sample dataset from a tsv file"""

Expand Down Expand Up @@ -516,23 +514,18 @@ def iter(self) -> Iterator[PairwiseSample]:
positives = []
negatives = {}
for topic_text in sample["queries"]:
topics.append(SimpleTextTopicRecord.from_text(topic_text))
topics.append(create_record(text=topic_text))
for pos_id in sample["pos_ids"]:
positives.append(IDDocumentRecord.from_id(pos_id))
positives.append(create_record(id=pos_id))
for algo in sample["neg_ids"].keys():
negatives[algo] = []
for neg_id in sample["neg_ids"][algo]:
negatives[algo].append(IDDocumentRecord.from_id(neg_id))
negatives[algo].append(create_record(id=neg_id))
yield PairwiseSample(
topics=topics, positives=positives, negatives=negatives
)


@recordtypes(ScoredItem)
class ScoredIDDocumentRecord(IDDocumentRecord):
pass


# A class for loading the data, need to move the other places.
class PairwiseSamplerFromTSV(PairwiseSampler):

Expand All @@ -544,9 +537,9 @@ def iter() -> Iterator[PairwiseSample]:
for triplet in read_tsv(self.pairwise_samples_path):
q_id, pos_id, pos_score, neg_id, neg_score = triplet
yield PairwiseRecord(
IDTopicRecord.from_id(q_id),
ScoredIDDocumentRecord(IDItem(pos_id), ScoredItem(pos_score)),
ScoredIDDocumentRecord(IDItem(neg_id), ScoredItem(neg_score)),
Record(IDItem(q_id)),
Record(IDItem(pos_id), ScoredItem(pos_score)),
Record(IDItem(neg_id), ScoredItem(neg_score)),
)

return SkippingIterator(iter)
Expand Down
4 changes: 2 additions & 2 deletions src/xpmir/letor/samplers/hydrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def transform_topics(self, topics: List[ir.TopicRecord]):
if self.querystore is None:
return None
return [
ir.GenericTopicRecord.create(
topic[IDItem].id, self.querystore[topic[IDItem].id]
ir.create_record(
id=topic[IDItem].id, text=self.querystore[topic[IDItem].id]
)
for topic in topics
]
Expand Down
8 changes: 4 additions & 4 deletions src/xpmir/rankers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from datamaestro_text.data.ir import (
Documents,
DocumentStore,
SimpleTextTopicRecord,
create_record,
IDItem,
)
from datamaestro_text.data.ir.base import DocumentRecord
Expand Down Expand Up @@ -98,16 +98,16 @@ def rsv(
) -> List[ScoredDocument]:
# Convert into document records
if isinstance(documents, str):
documents = [ScoredDocument(DocumentRecord.from_text(documents), None)]
documents = [ScoredDocument(create_record(text=documents), None)]
elif isinstance(documents[0], str):
documents = [
ScoredDocument(DocumentRecord.from_text(scored_document), None)
ScoredDocument(create_record(text=scored_document), None)
for scored_document in documents
]

# Convert into topic record
if isinstance(topic, str):
topic = SimpleTextTopicRecord.from_text(topic)
topic = create_record(text=topic)

return self.compute(topic, documents)

Expand Down
27 changes: 9 additions & 18 deletions src/xpmir/test/letor/test_samplers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datamaestro.record import record_type
import pytest
import numpy as np
from typing import Iterator, Tuple
Expand All @@ -17,24 +18,16 @@
class MyTrainingTriplets(TrainingTriplets):
def iter(
self,
) -> Iterator[
Tuple[
ir.SimpleTextTopicRecord, ir.GenericDocumentRecord, ir.GenericDocumentRecord
]
]:
) -> Iterator[Tuple[ir.TopicRecord, ir.DocumentRecord, ir.DocumentRecord]]:
count = 0

while True:
yield ir.SimpleTextTopicRecord.from_text(
f"q{count}"
), ir.GenericDocumentRecord.create(
1, f"doc+{count}"
), ir.GenericDocumentRecord.create(
2, f"doc-{count}"
)
yield ir.create_record(text=f"q{count}"), ir.create_record(
id=1, text=f"doc+{count}"
), ir.create_record(id=2, text=f"doc-{count}")

topic_recordtype = ir.SimpleTextTopicRecord
document_recordtype = ir.GenericDocumentRecord
topic_recordtype = record_type(ir.IDItem, ir.SimpleTextItem)
document_recordtype = record_type(ir.SimpleTextItem)


def test_serializing_tripletbasedsampler():
Expand Down Expand Up @@ -108,10 +101,8 @@ class FakeDocumentStore(ir.DocumentStore):
def documentcount(self):
return 10

def document_int(self, internal_docid: int) -> ir.GenericDocumentRecord:
return ir.GenericDocumentRecord.create(
str(internal_docid), f"D{internal_docid} " * 10
)
def document_int(self, internal_docid: int) -> ir.DocumentRecord:
return ir.create_record(id=str(internal_docid), text=f"D{internal_docid} " * 10)


def test_pairwise_randomspansampler():
Expand Down
22 changes: 12 additions & 10 deletions src/xpmir/test/letor/test_samplers_hydrator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from functools import cached_property
import itertools
from experimaestro import Param
from typing import Iterator, Tuple
from datamaestro.record import record_type
import datamaestro_text.data.ir as ir
from xpmir.letor.samplers import (
TrainingTriplets,
Expand All @@ -17,22 +19,22 @@
class TripletIterator(TrainingTriplets):
def iter(
self,
) -> Iterator[Tuple[ir.IDTopicRecord, ir.IDDocumentRecord, ir.IDDocumentRecord]]:
) -> Iterator[Tuple[ir.TopicRecord, ir.DocumentRecord, ir.DocumentRecord]]:
count = 0

while True:
yield ir.IDTopicRecord.from_id(str(count)), ir.IDDocumentRecord.from_id(
str(2 * count)
), ir.IDDocumentRecord.from_id(str(2 * count + 1))
yield ir.create_record(id=str(count)), ir.create_record(
id=str(2 * count)
), ir.create_record(id=str(2 * count + 1))
count += 1

@property
@cached_property
def topic_recordtype(self):
return ir.IDTopicRecord
return record_type(ir.IDItem)

@property
@cached_property
def document_recordtype(self):
return ir.IDDocumentRecord
return record_type(ir.IDItem)


class FakeTextStore(TextStore):
Expand All @@ -43,8 +45,8 @@ def __getitem__(self, key: str) -> str:
class FakeDocumentStore(ir.DocumentStore):
id: Param[str] = ""

def document_ext(self, docid: str) -> ir.GenericDocumentRecord:
return ir.GenericDocumentRecord.create(docid, f"D{docid}")
def document_ext(self, docid: str) -> ir.DocumentRecord:
return ir.create_record(id=docid, text=f"D{docid}")


def test_pairwise_hydrator():
Expand Down
Loading

0 comments on commit ac7afff

Please sign in to comment.