Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Sep 12, 2023
1 parent 8812fe0 commit 2f0691b
Show file tree
Hide file tree
Showing 21 changed files with 148 additions and 166 deletions.
1 change: 0 additions & 1 deletion docs/source/letor/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ scorers, some are have learnable parameters.
.. autoxpmconfig:: xpmir.rankers.RandomScorer
.. autoxpmconfig:: xpmir.rankers.AbstractLearnableScorer
.. autoxpmconfig:: xpmir.rankers.LearnableScorer
.. autoxpmconfig:: xpmir.neural.TorchLearnableScorer

.. autofunction:: xpmir.rankers.scorer_retriever

Expand Down
5 changes: 1 addition & 4 deletions docs/source/letor/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Trainers are responsible for defining the the way to train
a learnable scorer.

.. autoxpmconfig:: xpmir.letor.trainers.Trainer
.. autoxpmconfig:: xpmir.learning.trainers.multiple.MultipleTrainer

.. autoxpmconfig:: xpmir.letor.trainers.LossTrainer
:members: process_microbatch
Expand Down Expand Up @@ -78,10 +79,6 @@ Losses
.. autoxpmconfig:: xpmir.letor.trainers.batchwise.CrossEntropyLoss
.. autoxpmconfig:: xpmir.letor.trainers.batchwise.SoftmaxCrossEntropy

Other
*****

.. autoxpmconfig:: xpmir.letor.trainers.multiple.MultipleTrainer

Distillation: Pairwise
**********************
Expand Down
4 changes: 3 additions & 1 deletion src/xpmir/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def update(self, state: Context, model: nn.Module) -> nn.Module:
n_gpus = torch.cuda.device_count()
if n_gpus > 1:
logger.info(
"Setting up DataParallel for text encoder (%d GPUs)", n_gpus
"Setting up DataParallel on %d GPUs for model model %s",
n_gpus,
str(model.__class__.__qualname__),
)
return DataParallel(model)
else:
Expand Down
9 changes: 9 additions & 0 deletions src/xpmir/interfaces/anserini.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,15 @@ def retrieve(self, query: str) -> List[ScoredDocument]:
hits = self.searcher.search(query, k=self.k)
store = self.get_store()

# Batch retrieve documents
if store is not None:
return [
ScoredDocument(doc, hit.score)
for hit, doc in zip(
hits, store.documents_ext([hit.docid for hit in hits])
)
]

return [
ScoredDocument(
AnseriniDocument(hit.docid, hit.lucene_docid, hit.contents, hit.raw)
Expand Down
1 change: 1 addition & 0 deletions src/xpmir/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# flake8: noqa: F401
from .base import Random, Sampler
from .optim import Module
1 change: 0 additions & 1 deletion src/xpmir/learning/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ class Learner(Task, EasyLogger):
hooks: Param[List[Hook]] = []
"""Global learning hooks
:class:`Initialization hooks <xpmir.context.InitializationHook>` are called
before and after the initialization of the trainer and listeners.
"""
Expand Down
3 changes: 3 additions & 0 deletions src/xpmir/learning/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def __init__(self):
def __call__(self, *args, **kwargs):
return torch.nn.Module.__call__(self, *args, **kwargs)

def to(self, *args, **kwargs):
return torch.nn.Module.to(self, *args, **kwargs)


class ModuleLoader(PathSerializationLWTask):
def execute(self):
Expand Down
52 changes: 0 additions & 52 deletions src/xpmir/learning/trainers.py

This file was deleted.

17 changes: 12 additions & 5 deletions src/xpmir/letor/distillation/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from torch import nn
from torch.functional import Tensor
from experimaestro import Config, Param
from xpmir.letor.records import Document, PairwiseRecord, PairwiseRecords
from xpmir.letor.records import (
DocumentRecord,
TopicRecord,
PairwiseRecord,
PairwiseRecords,
)
from xpmir.learning.context import Loss
from xpmir.letor.trainers import TrainerContext, LossTrainer
from xpmir.utils.utils import batchiter, foreach
Expand Down Expand Up @@ -126,9 +131,9 @@ def train_batch(self, samples: List[PairwiseDistillationSample]):
for ix, sample in enumerate(samples):
records.add(
PairwiseRecord(
sample.query,
Document(None, sample.documents[0].content, None),
Document(None, sample.documents[1].content, None),
TopicRecord(sample.query),
DocumentRecord(sample.documents[0].document),
DocumentRecord(sample.documents[1].document),
)
)
teacher_scores[ix, 0] = sample.documents[0].score
Expand All @@ -138,7 +143,9 @@ def train_batch(self, samples: List[PairwiseDistillationSample]):
scores = self.ranker(records, self.context).reshape(2, len(records)).T

if torch.isnan(scores).any() or torch.isinf(scores).any():
self.logger.error("nan or inf relevance score detected. Aborting.")
self.logger.error(
"nan or inf relevance score detected. Aborting (pairwise distillation)."
)
sys.exit(1)

# Call the losses (distillation, pairwise and pointwise)
Expand Down
48 changes: 32 additions & 16 deletions src/xpmir/letor/distillation/samplers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
from typing import Iterable, Iterator, NamedTuple, Optional, Tuple

import numpy as np
from datamaestro.data import File
from experimaestro import Config, Meta, Param
from ir_datasets.formats import GenericDoc
from xpmir.letor.records import Query
from datamaestro_text.data.ir import DocumentStore
from xpmir.rankers import ScoredDocument
from datamaestro_text.data.ir.base import (
GenericTopic,
IDTopic,
TextTopic,
IDDocument,
TextDocument,
)
from experimaestro import Config, Meta, Param

from xpmir.datasets.adapters import TextStore
from xpmir.learning import Sampler
import numpy as np

from xpmir.letor.records import TopicRecord
from xpmir.rankers import ScoredDocument
from xpmir.utils.iter import SerializableIterator, SkippingIterator


class PairwiseDistillationSample(NamedTuple):
query: Query
query: TopicRecord
"""The query"""

documents: Tuple[ScoredDocument, ScoredDocument]
Expand Down Expand Up @@ -41,12 +48,21 @@ class PairwiseHydrator(PairwiseDistillationSamples):

def __iter__(self) -> Iterator[PairwiseDistillationSample]:
for sample in self.samples:
topic, documents = sample.query, sample.documents

if self.querystore is not None:
sample.query.text = self.querystore[sample.query.id]
topic = GenericTopic(
sample.query.get_id(), self.querystore[sample.query.get_id()]
)
if self.documentstore is not None:
for d in sample.documents:
d.content = self.documentstore.document_text(d.docid)
documents = tuple(
ScoredDocument(
self.documentstore.document_ext(d.document.get_id()), d.score
)
for d in sample.documents
)

sample = PairwiseDistillationSample(topic, documents)
yield sample


Expand All @@ -65,19 +81,19 @@ def iter(self) -> Iterator[PairwiseDistillationSample]:
with self.path.open("rt") as fp:
for row in csv.reader(fp, delimiter="\t"):
if self.with_queryid:
query = Query(row[2], None)
query = IDTopic(row[2])
else:
query = Query(None, row[2])
query = TextTopic(row[2])

if self.with_docid:
documents = (
ScoredDocument(GenericDoc(row[3], None), float(row[0])),
ScoredDocument(GenericDoc(row[4], None), float(row[1])),
ScoredDocument(IDDocument(row[3]), float(row[0])),
ScoredDocument(IDDocument(row[4]), float(row[1])),
)
else:
documents = (
ScoredDocument(GenericDoc(None, row[3]), float(row[0])),
ScoredDocument(GenericDoc(None, row[4]), float(row[1])),
ScoredDocument(TextDocument(row[3]), float(row[0])),
ScoredDocument(TextDocument(row[4]), float(row[1])),
)

yield PairwiseDistillationSample(query, documents)
Expand Down
10 changes: 4 additions & 6 deletions src/xpmir/letor/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,13 @@ def batchwise_iter(self, batch_size: int) -> SerializableIterator[BatchwiseRecor


class ModelBasedSampler(Sampler):
"""Base class for retriever-based sampler
Attributes:
dataset: The topics and assessments
retriever: The document retriever
"""
"""Base class for retriever-based sampler"""

dataset: Param[Adhoc]
"""The IR adhoc dataset"""

retriever: Param[Retriever]
"""A retriever to sample negative documents"""

_store: DocumentStore

Expand Down
46 changes: 0 additions & 46 deletions src/xpmir/letor/trainers/multiple.py

This file was deleted.

18 changes: 1 addition & 17 deletions src/xpmir/neural/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
import itertools
from typing import Iterable, List, Optional
import torch
import torch.nn as nn
from xpmir.learning.batchers import Sliceable

from xpmir.learning.context import TrainerContext
from xpmir.letor.records import BaseRecords
from xpmir.learning.optim import Module
from xpmir.rankers import LearnableScorer


class TorchLearnableScorer(LearnableScorer, Module):
"""Base class for torch-learnable scorers"""

def __init__(self):
nn.Module.__init__(self)
super().__init__()

__call__ = nn.Module.__call__
to = nn.Module.to

def train(self, mode=True):
return nn.Module.train(self, mode)


class DualRepresentationScorer(TorchLearnableScorer):
class DualRepresentationScorer(LearnableScorer):
"""Neural scorer based on (at least a partially) independent representation
of the document and the question.
Expand Down
4 changes: 2 additions & 2 deletions src/xpmir/neural/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
BaseRecords,
PairwiseRecords,
)
from xpmir.neural import TorchLearnableScorer
from xpmir.rankers import LearnableScorer
from xpmir.text.encoders import DualTextEncoder, TripletTextEncoder
from xpmir.rankers import (
DuoLearnableScorer,
Expand All @@ -16,7 +16,7 @@
)


class CrossScorer(TorchLearnableScorer, DistributableModel):
class CrossScorer(LearnableScorer, DistributableModel):
"""Query-Document Representation Classifier
Based on a query-document representation representation (e.g. BERT [CLS] token).
Expand Down
16 changes: 10 additions & 6 deletions src/xpmir/neural/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,14 @@ class FlopsRegularizer(DualVectorListener):
lambda_d: Param[float]
"""Lambda for documents"""

@staticmethod
def compute(x: torch.Tensor):
# Computes the mean for each term
y = x.abs().mean(0)
"""
:param x: term vectors (batch size x vocabulary dimension)
:returns: A couple (vocabulary dimension / FLOPS regularizations)
"""
# Computes the mean for each term (weights are positive)
y = x.mean(0)

# Returns the sum of squared means
return y, (y * y).sum()

Expand All @@ -175,9 +179,9 @@ def __call__(self, info: TrainerContext, queries, documents):
flops = self.lambda_d * flops_d + self.lambda_q * flops_q
info.add_loss(Loss("flops", flops, 1.0))

info.metrics.add(ScalarMetric("flops", flops.item(), len(q)))
info.metrics.add(ScalarMetric("flops_q", flops_q.item(), len(q)))
info.metrics.add(ScalarMetric("flops_d", flops_d.item(), len(d)))
info.metrics.add(ScalarMetric("flops", flops.item(), 1))
info.metrics.add(ScalarMetric("flops_q", flops_q.item(), 1))
info.metrics.add(ScalarMetric("flops_d", flops_d.item(), 1))

with torch.no_grad():
info.metrics.add(
Expand Down
Loading

0 comments on commit 2f0691b

Please sign in to comment.