Skip to content

Commit

Permalink
Fix many bugs due to structure change
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Jul 2, 2024
1 parent c94cfbb commit 944700c
Show file tree
Hide file tree
Showing 13 changed files with 82 additions and 26 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ datamaestro>=1.0.1
datamaestro_text>=2024.2.29
ir_datasets
docstring_parser
xpmir_rust == 0.21.*
impact-index == 0.23.*
omegaconf>=2.2
attrs

Expand Down
2 changes: 1 addition & 1 deletion src/xpmir/datasets/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def execute(self):
# don't need to worry about the threshold here
for retriever in self.retrievers:
docids.update(
sd.document[IDItem].id for sd in retriever.retrieve(topic.text)
sd.document[IDItem].id for sd in retriever.retrieve(topic)
)

# Write the document IDs
Expand Down
17 changes: 10 additions & 7 deletions src/xpmir/index/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from xpmir.rankers import Retriever, TopicRecord, ScoredDocument
from xpmir.utils.iter import MultiprocessIterator
from xpmir.utils.multiprocessing import StoppableQueue, available_cpus
import xpmir_rust
import impact_index

logger = easylog()

Expand All @@ -42,13 +42,11 @@ class SparseRetrieverIndex(Config):
index_path: Meta[Path]
documents: Param[DocumentStore]

index: xpmir_rust.index.SparseBuilderIndex
index: impact_index.Index
ordered = False

def initialize(self, in_memory: bool):
self.index = xpmir_rust.index.SparseBuilderIndex.load(
str(self.index_path.absolute()), in_memory
)
self.index = impact_index.Index.load(str(self.index_path.absolute()), in_memory)

def retrieve(self, query: Dict[int, float], top_k: int) -> List[ScoredDocument]:
results = []
Expand Down Expand Up @@ -125,7 +123,12 @@ async def reducer(
batch: List[Tuple[str, InputType]],
queue: asyncio.Queue,
):
x = self.encoder([text for _, text in batch]).value.cpu().detach().numpy()
x = (
self.encoder([topic[TextItem].text for _, topic in batch])
.value.cpu()
.detach()
.numpy()
)
assert len(x) == len(batch), (
f"Discrepancy between counts of vectors ({len(x)})"
f" and number queries ({len(batch)})"
Expand Down Expand Up @@ -351,7 +354,7 @@ def index(
len(queues),
self.index_path,
)
indexer = xpmir_rust.index.SparseIndexer(str(self.index_path))
indexer = impact_index.IndexBuilder(str(self.index_path))
heap = [queue.get() for queue in queues]
heapq.heapify(heap)

Expand Down
5 changes: 4 additions & 1 deletion src/xpmir/interfaces/anserini.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def _get_store(self) -> Optional[Index]:
"""Returns the associated index (if any)"""
if self.index.storeContents:
return self.index
return self.store

def retrieve(self, record: TopicRecord) -> List[ScoredDocument]:
# see
Expand All @@ -317,7 +318,9 @@ def retrieve(self, record: TopicRecord) -> List[ScoredDocument]:
return [
ScoredDocument(
create_record(
InternalIDItem(hit.lucene_docid), id=hit.docid, text=hit.contents
InternalIDItem(hit.lucene_docid),
id=hit.docid,
text=getattr(hit, "contents", None),
),
hit.score,
)
Expand Down
12 changes: 9 additions & 3 deletions src/xpmir/neural/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,25 @@ def merge_queries(self, queries: QueriesRep):
By default, uses `merge`
"""
return self.merge(list)
return self.merge(queries)

def merge_documents(self, documents: DocsRep):
"""Merge query batches encoded with `encode_documents`"""
return self.merge(list)
return self.merge(documents)

def merge(self, objects: Union[DocsRep, QueriesRep]):
"""Merge objects
- for tensors, uses torch.cat
- for lists, concatenate all of them
"""
assert isinstance(objects, List), "Merging can only be done with lists"
assert isinstance(
objects, List
), f"Merging can only be done with lists, got {type(objects)}"

# Just returns the only object to merge
if len(objects) == 1:
return objects[0]

if isinstance(objects[0], torch.Tensor):
return torch.cat(objects)
Expand Down
8 changes: 7 additions & 1 deletion src/xpmir/neural/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PairwiseRecords,
)
from xpmir.rankers import LearnableScorer
from xpmir.text import TokenizerOptions
from xpmir.text.encoders import TextEncoderBase, TripletTextEncoder
from xpmir.rankers import (
DuoLearnableScorer,
Expand All @@ -28,6 +29,9 @@ class CrossScorer(LearnableScorer, DistributableModel):
AKA Cross-Encoder
"""

max_length: Param[int]
"""Maximum length (in tokens) for the query-document string"""

encoder: Param[TextEncoderBase[Tuple[str, str], torch.Tensor]]
"""an encoder for encoding the concatenated query-document tokens which
doesn't contains the final linear layer"""
Expand All @@ -40,14 +44,16 @@ def __initialize__(self, options):
super().__initialize__(options)
self.encoder.initialize(options)
self.classifier = torch.nn.Linear(self.encoder.dimension, 1)
self.tokenizer_options = TokenizerOptions(max_length=self.max_length)

def forward(self, inputs: BaseRecords, info: TrainerContext = None):
# Encode queries and documents
pairs = self.encoder(
[
(tr[TextItem].text, dr[TextItem].text)
for tr, dr in zip(inputs.topics, inputs.documents)
]
],
options=self.tokenizer_options,
) # shape (batch_size * dimension)
return self.classifier(pairs.value).squeeze(1)

Expand Down
5 changes: 2 additions & 3 deletions src/xpmir/neural/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
from experimaestro import Param
from datamaestro_text.data.ir import TextItem
from xpmir.distributed import DistributableModel
from xpmir.learning.batchers import Batcher
from xpmir.letor.records import TopicRecord, DocumentRecord
from xpmir.neural import DualRepresentationScorer, QueriesRep, DocsRep
Expand Down Expand Up @@ -69,7 +68,7 @@ class Dense(DualVectorScorer[QueriesRep, DocsRep]):
"""A scorer based on a pair of (query, document) dense vectors"""

def score_product(self, queries, documents, info: Optional[TrainerContext] = None):
return queries @ documents.T
return queries.value @ documents.value.T

def score_pairs(self, queries, documents, info: Optional[TrainerContext] = None):
scores = (queries.unsqueeze(1) @ documents.unsqueeze(2)).squeeze(-1).squeeze(-1)
Expand Down Expand Up @@ -111,7 +110,7 @@ def encode_documents(self, records: List[DocumentRecord]):
return documents / documents.norm(dim=1, keepdim=True)


class DotDense(Dense, DistributableModel):
class DotDense(Dense):
"""Dual model based on inner product."""

def __validate__(self):
Expand Down
4 changes: 2 additions & 2 deletions src/xpmir/neural/splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class Aggregation(Config):
"""The aggregation function for Splade"""

def get_output_module(self, linear: nn.Module):
def get_output_module(self, linear: nn.Module) -> nn.Module:
return AggregationModule(linear, self)


Expand Down Expand Up @@ -106,7 +106,7 @@ def forward(self, texts: List[str]) -> torch.Tensor:
"""Returns a batch x vocab tensor"""
tokenized = self.encoder.batch_tokenize(texts, mask=True, maxlen=self.maxlen)
out = self.model(tokenized)
return out
return TextsRepresentationOutput(out, tokenized)

@property
def dimension(self):
Expand Down
19 changes: 18 additions & 1 deletion src/xpmir/text/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,18 @@ class TripletTextEncoder(TextEncoderBase[Tuple[str, str, str], torch.Tensor]):
@define
class RepresentationOutput:
value: torch.Tensor
"""An arbitrary representation"""
"""An arbitrary representation (by default, the batch dimension is the
first)"""

def __len__(self):
return len(self.value)

def __getitem__(self, ix: Union[slice, int]):
return self.__class__(self.value[ix])

@property
def device(self):
return self.value.device


@define
Expand All @@ -189,6 +200,12 @@ class TextsRepresentationOutput(RepresentationOutput):
tokenized: TokenizedTexts
"""Tokenized texts"""

def to(self, device):
return self.__class__(self.value.to(device), self.tokenized.to(device))

def __getitem__(self, ix: Union[slice, int]):
return self.__class__(self.value[ix], self.tokenized[ix])


class TokenizedEncoder(Encoder, Generic[EncoderOutput, TokenizerOutput]):
"""Encodes a tokenized text into a vector"""
Expand Down
5 changes: 4 additions & 1 deletion src/xpmir/text/huggingface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,13 @@ def contextual_model(self) -> nn.Module:

def forward(self, tokenized: TokenizedTexts):
tokenized = tokenized.to(self.model.device)
kwargs = {}
if tokenized.token_type_ids is not None:
kwargs["token_type_ids"] = tokenized.token_type_ids

return self.model(
input_ids=tokenized.ids,
attention_mask=tokenized.mask,
token_type_ids=tokenized.token_type_ids,
)


Expand Down
4 changes: 0 additions & 4 deletions src/xpmir/text/huggingface/encoders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
from typing import Optional
from experimaestro import Param
from xpmir.learning import Module
from xpmir.text.encoders import (
Expand Down Expand Up @@ -71,9 +70,6 @@ class HFCLSEncoder(
):
"""Encodes a text using the [CLS] token"""

maxlen: Param[Optional[int]] = None
"""Limit the text to be encoded"""

def forward(self, tokenized: TokenizedTexts) -> TextsRepresentationOutput:
tokenized = tokenized.to(self.device)
y = self.model.contextual_model(
Expand Down
15 changes: 14 additions & 1 deletion src/xpmir/text/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from xpmir.text.utils import lengthToMask
from xpmir.learning.optim import ModuleInitOptions
from xpmir.utils.utils import Initializable
from xpmir.utils.misc import opt_slice
from xpmir.utils.torch import to_device


class TokenizedTexts(NamedTuple):
"""Tokenized texts output"""

tokens: List[List[str]]
tokens: Optional[List[List[str]]]
"""The list of tokens"""

ids: torch.LongTensor
Expand All @@ -30,6 +31,18 @@ class TokenizedTexts(NamedTuple):
token_type_ids: Optional[torch.LongTensor] = None
"""Type of each token"""

def __len__(self):
return len(self.ids)

def __getitem__(self, ix):
return TokenizedTexts(
opt_slice(self.tokens, ix),
self.ids[ix],
self.lens[ix],
opt_slice(self.mask, ix),
opt_slice(self.token_type_ids, ix),
)

def to(self, device: torch.device):
if device is self.ids.device:
return self
Expand Down
10 changes: 10 additions & 0 deletions src/xpmir/utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Optional, Sequence, TypeVar, Union


T = TypeVar("T")


def opt_slice(x: Optional[Sequence[T]], ix: Union[int, slice]) -> Optional[Sequence[T]]:
if x is None:
return None
return x[ix]

0 comments on commit 944700c

Please sign in to comment.