From 2f0691b55c7f0bf90716a789a3ac20e560d9d5e9 Mon Sep 17 00:00:00 2001 From: Benjamin Piwowarski Date: Tue, 12 Sep 2023 11:25:34 +0200 Subject: [PATCH] updates --- docs/source/letor/index.rst | 1 - docs/source/letor/trainers.rst | 5 +-- src/xpmir/distributed.py | 4 +- src/xpmir/interfaces/anserini.py | 9 ++++ src/xpmir/learning/__init__.py | 1 + src/xpmir/learning/learner.py | 1 - src/xpmir/learning/optim.py | 3 ++ src/xpmir/learning/trainers.py | 52 ------------------------ src/xpmir/letor/distillation/pairwise.py | 17 +++++--- src/xpmir/letor/distillation/samplers.py | 48 ++++++++++++++-------- src/xpmir/letor/samplers.py | 10 ++--- src/xpmir/letor/trainers/multiple.py | 46 --------------------- src/xpmir/neural/__init__.py | 18 +------- src/xpmir/neural/cross.py | 4 +- src/xpmir/neural/dual.py | 16 +++++--- src/xpmir/neural/huggingface.py | 4 +- src/xpmir/neural/interaction/__init__.py | 4 +- src/xpmir/neural/splade.py | 22 ++++++++-- src/xpmir/rankers/__init__.py | 8 +++- src/xpmir/test/letor/test_samplers.py | 38 ++++++++++++++++- src/xpmir/utils/utils.py | 3 ++ 21 files changed, 148 insertions(+), 166 deletions(-) delete mode 100644 src/xpmir/learning/trainers.py delete mode 100644 src/xpmir/letor/trainers/multiple.py diff --git a/docs/source/letor/index.rst b/docs/source/letor/index.rst index 24a51827..240e0d65 100644 --- a/docs/source/letor/index.rst +++ b/docs/source/letor/index.rst @@ -42,7 +42,6 @@ scorers, some are have learnable parameters. .. autoxpmconfig:: xpmir.rankers.RandomScorer .. autoxpmconfig:: xpmir.rankers.AbstractLearnableScorer .. autoxpmconfig:: xpmir.rankers.LearnableScorer -.. autoxpmconfig:: xpmir.neural.TorchLearnableScorer .. autofunction:: xpmir.rankers.scorer_retriever diff --git a/docs/source/letor/trainers.rst b/docs/source/letor/trainers.rst index 0d766c7f..9840f618 100644 --- a/docs/source/letor/trainers.rst +++ b/docs/source/letor/trainers.rst @@ -5,6 +5,7 @@ Trainers are responsible for defining the the way to train a learnable scorer. .. autoxpmconfig:: xpmir.letor.trainers.Trainer +.. autoxpmconfig:: xpmir.learning.trainers.multiple.MultipleTrainer .. autoxpmconfig:: xpmir.letor.trainers.LossTrainer :members: process_microbatch @@ -78,10 +79,6 @@ Losses .. autoxpmconfig:: xpmir.letor.trainers.batchwise.CrossEntropyLoss .. autoxpmconfig:: xpmir.letor.trainers.batchwise.SoftmaxCrossEntropy -Other -***** - -.. autoxpmconfig:: xpmir.letor.trainers.multiple.MultipleTrainer Distillation: Pairwise ********************** diff --git a/src/xpmir/distributed.py b/src/xpmir/distributed.py index 00187e63..c8592cb9 100644 --- a/src/xpmir/distributed.py +++ b/src/xpmir/distributed.py @@ -47,7 +47,9 @@ def update(self, state: Context, model: nn.Module) -> nn.Module: n_gpus = torch.cuda.device_count() if n_gpus > 1: logger.info( - "Setting up DataParallel for text encoder (%d GPUs)", n_gpus + "Setting up DataParallel on %d GPUs for model model %s", + n_gpus, + str(model.__class__.__qualname__), ) return DataParallel(model) else: diff --git a/src/xpmir/interfaces/anserini.py b/src/xpmir/interfaces/anserini.py index cf8b6de8..56724d28 100644 --- a/src/xpmir/interfaces/anserini.py +++ b/src/xpmir/interfaces/anserini.py @@ -305,6 +305,15 @@ def retrieve(self, query: str) -> List[ScoredDocument]: hits = self.searcher.search(query, k=self.k) store = self.get_store() + # Batch retrieve documents + if store is not None: + return [ + ScoredDocument(doc, hit.score) + for hit, doc in zip( + hits, store.documents_ext([hit.docid for hit in hits]) + ) + ] + return [ ScoredDocument( AnseriniDocument(hit.docid, hit.lucene_docid, hit.contents, hit.raw) diff --git a/src/xpmir/learning/__init__.py b/src/xpmir/learning/__init__.py index 9a98ea3a..40b50007 100644 --- a/src/xpmir/learning/__init__.py +++ b/src/xpmir/learning/__init__.py @@ -1,2 +1,3 @@ # flake8: noqa: F401 from .base import Random, Sampler +from .optim import Module diff --git a/src/xpmir/learning/learner.py b/src/xpmir/learning/learner.py index 2e310632..89130a27 100644 --- a/src/xpmir/learning/learner.py +++ b/src/xpmir/learning/learner.py @@ -126,7 +126,6 @@ class Learner(Task, EasyLogger): hooks: Param[List[Hook]] = [] """Global learning hooks - :class:`Initialization hooks ` are called before and after the initialization of the trainer and listeners. """ diff --git a/src/xpmir/learning/optim.py b/src/xpmir/learning/optim.py index 12bf6965..9e03c30c 100644 --- a/src/xpmir/learning/optim.py +++ b/src/xpmir/learning/optim.py @@ -78,6 +78,9 @@ def __init__(self): def __call__(self, *args, **kwargs): return torch.nn.Module.__call__(self, *args, **kwargs) + def to(self, *args, **kwargs): + return torch.nn.Module.to(self, *args, **kwargs) + class ModuleLoader(PathSerializationLWTask): def execute(self): diff --git a/src/xpmir/learning/trainers.py b/src/xpmir/learning/trainers.py deleted file mode 100644 index f692991b..00000000 --- a/src/xpmir/learning/trainers.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Dict, Iterator, List -from experimaestro import Config, Param -import torch.nn as nn -import numpy as np -from xpmir.utils.utils import EasyLogger -from xpmir.learning.context import ( - TrainingHook, - TrainerContext, -) - -from xpmir.utils.utils import foreach - - -class Trainer(Config, EasyLogger): - """Generic trainer""" - - hooks: Param[List[TrainingHook]] = [] - """Hooks for this trainer: this includes the losses, but can be adapted for - other uses - - The specific list of hooks depends on the specific trainer - """ - - def initialize( - self, - random: np.random.RandomState, - context: TrainerContext, - ): - self.random = random - # Old style (to be deprecated) - self.ranker = context.state.model - # Generic style - self.model = context.state.model - self.context = context - - foreach(self.hooks, self.context.add_hook) - - def to(self, device): - """Change the computing device (if this is needed)""" - foreach(self.context.hooks(nn.Module), lambda hook: hook.to(device)) - - def iter_batches(self) -> Iterator: - raise NotImplementedError - - def process_batch(self, batch): - raise NotImplementedError() - - def load_state_dict(self, state: Dict): - raise NotImplementedError() - - def state_dict(self): - raise NotImplementedError() diff --git a/src/xpmir/letor/distillation/pairwise.py b/src/xpmir/letor/distillation/pairwise.py index bd508241..1ae4162d 100644 --- a/src/xpmir/letor/distillation/pairwise.py +++ b/src/xpmir/letor/distillation/pairwise.py @@ -4,7 +4,12 @@ from torch import nn from torch.functional import Tensor from experimaestro import Config, Param -from xpmir.letor.records import Document, PairwiseRecord, PairwiseRecords +from xpmir.letor.records import ( + DocumentRecord, + TopicRecord, + PairwiseRecord, + PairwiseRecords, +) from xpmir.learning.context import Loss from xpmir.letor.trainers import TrainerContext, LossTrainer from xpmir.utils.utils import batchiter, foreach @@ -126,9 +131,9 @@ def train_batch(self, samples: List[PairwiseDistillationSample]): for ix, sample in enumerate(samples): records.add( PairwiseRecord( - sample.query, - Document(None, sample.documents[0].content, None), - Document(None, sample.documents[1].content, None), + TopicRecord(sample.query), + DocumentRecord(sample.documents[0].document), + DocumentRecord(sample.documents[1].document), ) ) teacher_scores[ix, 0] = sample.documents[0].score @@ -138,7 +143,9 @@ def train_batch(self, samples: List[PairwiseDistillationSample]): scores = self.ranker(records, self.context).reshape(2, len(records)).T if torch.isnan(scores).any() or torch.isinf(scores).any(): - self.logger.error("nan or inf relevance score detected. Aborting.") + self.logger.error( + "nan or inf relevance score detected. Aborting (pairwise distillation)." + ) sys.exit(1) # Call the losses (distillation, pairwise and pointwise) diff --git a/src/xpmir/letor/distillation/samplers.py b/src/xpmir/letor/distillation/samplers.py index cd437bc9..1c0e7c00 100644 --- a/src/xpmir/letor/distillation/samplers.py +++ b/src/xpmir/letor/distillation/samplers.py @@ -1,19 +1,26 @@ from typing import Iterable, Iterator, NamedTuple, Optional, Tuple + +import numpy as np from datamaestro.data import File -from experimaestro import Config, Meta, Param -from ir_datasets.formats import GenericDoc -from xpmir.letor.records import Query from datamaestro_text.data.ir import DocumentStore -from xpmir.rankers import ScoredDocument +from datamaestro_text.data.ir.base import ( + GenericTopic, + IDTopic, + TextTopic, + IDDocument, + TextDocument, +) +from experimaestro import Config, Meta, Param + from xpmir.datasets.adapters import TextStore from xpmir.learning import Sampler -import numpy as np - +from xpmir.letor.records import TopicRecord +from xpmir.rankers import ScoredDocument from xpmir.utils.iter import SerializableIterator, SkippingIterator class PairwiseDistillationSample(NamedTuple): - query: Query + query: TopicRecord """The query""" documents: Tuple[ScoredDocument, ScoredDocument] @@ -41,12 +48,21 @@ class PairwiseHydrator(PairwiseDistillationSamples): def __iter__(self) -> Iterator[PairwiseDistillationSample]: for sample in self.samples: + topic, documents = sample.query, sample.documents + if self.querystore is not None: - sample.query.text = self.querystore[sample.query.id] + topic = GenericTopic( + sample.query.get_id(), self.querystore[sample.query.get_id()] + ) if self.documentstore is not None: - for d in sample.documents: - d.content = self.documentstore.document_text(d.docid) + documents = tuple( + ScoredDocument( + self.documentstore.document_ext(d.document.get_id()), d.score + ) + for d in sample.documents + ) + sample = PairwiseDistillationSample(topic, documents) yield sample @@ -65,19 +81,19 @@ def iter(self) -> Iterator[PairwiseDistillationSample]: with self.path.open("rt") as fp: for row in csv.reader(fp, delimiter="\t"): if self.with_queryid: - query = Query(row[2], None) + query = IDTopic(row[2]) else: - query = Query(None, row[2]) + query = TextTopic(row[2]) if self.with_docid: documents = ( - ScoredDocument(GenericDoc(row[3], None), float(row[0])), - ScoredDocument(GenericDoc(row[4], None), float(row[1])), + ScoredDocument(IDDocument(row[3]), float(row[0])), + ScoredDocument(IDDocument(row[4]), float(row[1])), ) else: documents = ( - ScoredDocument(GenericDoc(None, row[3]), float(row[0])), - ScoredDocument(GenericDoc(None, row[4]), float(row[1])), + ScoredDocument(TextDocument(row[3]), float(row[0])), + ScoredDocument(TextDocument(row[4]), float(row[1])), ) yield PairwiseDistillationSample(query, documents) diff --git a/src/xpmir/letor/samplers.py b/src/xpmir/letor/samplers.py index 6c312a35..c0c09700 100644 --- a/src/xpmir/letor/samplers.py +++ b/src/xpmir/letor/samplers.py @@ -77,15 +77,13 @@ def batchwise_iter(self, batch_size: int) -> SerializableIterator[BatchwiseRecor class ModelBasedSampler(Sampler): - """Base class for retriever-based sampler - - Attributes: - dataset: The topics and assessments - retriever: The document retriever - """ + """Base class for retriever-based sampler""" dataset: Param[Adhoc] + """The IR adhoc dataset""" + retriever: Param[Retriever] + """A retriever to sample negative documents""" _store: DocumentStore diff --git a/src/xpmir/letor/trainers/multiple.py b/src/xpmir/letor/trainers/multiple.py deleted file mode 100644 index 2e5352e8..00000000 --- a/src/xpmir/letor/trainers/multiple.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Dict, Iterator -from experimaestro import Param -import numpy as np -from xpmir.learning.context import ( - TrainerContext, -) -from . import Trainer - - -class MultipleTrainer(Trainer): - """This trainer can be used to combine various trainers""" - - trainers: Param[Dict[str, Trainer]] - """The trainers""" - - def initialize( - self, - random: np.random.RandomState, - context: TrainerContext, - ): - super().initialize(random, context) - for trainer in self.trainers.values(): - trainer.initialize(random, context) - - def load_state_dict(self, state: Dict): - for key, trainer in self.trainers.items(): - trainer.load_state_dict(state[key]) - - def state_dict(self): - return {key: trainer.state_dict() for key, trainer in self.trainers.items()} - - def to(self, device): - """Change the computing device (if this is needed)""" - super().to(device) - for trainer in self.trainers.values(): - trainer.to(device) - - def iter_batches(self) -> Iterator: - iters = {key: trainer.iter_batches() for key, trainer in self.trainers.items()} - while True: - yield {key: next(iter) for key, iter in iters.items()} - - def process_batch(self, batch): - for key, trainer in self.trainers.items(): - with self.context.scope(key): - trainer.process_batch(batch[key]) diff --git a/src/xpmir/neural/__init__.py b/src/xpmir/neural/__init__.py index 4425ecc1..d645b112 100644 --- a/src/xpmir/neural/__init__.py +++ b/src/xpmir/neural/__init__.py @@ -1,30 +1,14 @@ import itertools from typing import Iterable, List, Optional import torch -import torch.nn as nn from xpmir.learning.batchers import Sliceable from xpmir.learning.context import TrainerContext from xpmir.letor.records import BaseRecords -from xpmir.learning.optim import Module from xpmir.rankers import LearnableScorer -class TorchLearnableScorer(LearnableScorer, Module): - """Base class for torch-learnable scorers""" - - def __init__(self): - nn.Module.__init__(self) - super().__init__() - - __call__ = nn.Module.__call__ - to = nn.Module.to - - def train(self, mode=True): - return nn.Module.train(self, mode) - - -class DualRepresentationScorer(TorchLearnableScorer): +class DualRepresentationScorer(LearnableScorer): """Neural scorer based on (at least a partially) independent representation of the document and the question. diff --git a/src/xpmir/neural/cross.py b/src/xpmir/neural/cross.py index 16388b85..731e755b 100644 --- a/src/xpmir/neural/cross.py +++ b/src/xpmir/neural/cross.py @@ -7,7 +7,7 @@ BaseRecords, PairwiseRecords, ) -from xpmir.neural import TorchLearnableScorer +from xpmir.rankers import LearnableScorer from xpmir.text.encoders import DualTextEncoder, TripletTextEncoder from xpmir.rankers import ( DuoLearnableScorer, @@ -16,7 +16,7 @@ ) -class CrossScorer(TorchLearnableScorer, DistributableModel): +class CrossScorer(LearnableScorer, DistributableModel): """Query-Document Representation Classifier Based on a query-document representation representation (e.g. BERT [CLS] token). diff --git a/src/xpmir/neural/dual.py b/src/xpmir/neural/dual.py index aefce10f..7e68cc35 100644 --- a/src/xpmir/neural/dual.py +++ b/src/xpmir/neural/dual.py @@ -156,10 +156,14 @@ class FlopsRegularizer(DualVectorListener): lambda_d: Param[float] """Lambda for documents""" - @staticmethod def compute(x: torch.Tensor): - # Computes the mean for each term - y = x.abs().mean(0) + """ + :param x: term vectors (batch size x vocabulary dimension) + :returns: A couple (vocabulary dimension / FLOPS regularizations) + """ + # Computes the mean for each term (weights are positive) + y = x.mean(0) + # Returns the sum of squared means return y, (y * y).sum() @@ -175,9 +179,9 @@ def __call__(self, info: TrainerContext, queries, documents): flops = self.lambda_d * flops_d + self.lambda_q * flops_q info.add_loss(Loss("flops", flops, 1.0)) - info.metrics.add(ScalarMetric("flops", flops.item(), len(q))) - info.metrics.add(ScalarMetric("flops_q", flops_q.item(), len(q))) - info.metrics.add(ScalarMetric("flops_d", flops_d.item(), len(d))) + info.metrics.add(ScalarMetric("flops", flops.item(), 1)) + info.metrics.add(ScalarMetric("flops_q", flops_q.item(), 1)) + info.metrics.add(ScalarMetric("flops_d", flops_d.item(), 1)) with torch.no_grad(): info.metrics.add( diff --git a/src/xpmir/neural/huggingface.py b/src/xpmir/neural/huggingface.py index fa664202..e7aa5ffe 100644 --- a/src/xpmir/neural/huggingface.py +++ b/src/xpmir/neural/huggingface.py @@ -1,6 +1,6 @@ from xpmir.learning.context import TrainerContext from xpmir.letor.records import BaseRecords -from xpmir.neural import TorchLearnableScorer +from xpmir.rankers import LearnableScorer from experimaestro import Param from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig from xpmir.letor.records import TokenizedTexts @@ -9,7 +9,7 @@ import torch -class HFCrossScorer(TorchLearnableScorer, DistributableModel): +class HFCrossScorer(LearnableScorer, DistributableModel): """Load a cross scorer model from the huggingface""" hf_id: Param[str] diff --git a/src/xpmir/neural/interaction/__init__.py b/src/xpmir/neural/interaction/__init__.py index 060605f6..64acabb4 100644 --- a/src/xpmir/neural/interaction/__init__.py +++ b/src/xpmir/neural/interaction/__init__.py @@ -1,12 +1,12 @@ import torch from experimaestro import Param -from xpmir.neural import TorchLearnableScorer +from xpmir.rankers import LearnableScorer from xpmir.text import TokensEncoder from xpmir.letor.records import BaseRecords from xpmir.learning.context import TrainerContext -class InteractionScorer(TorchLearnableScorer): +class InteractionScorer(LearnableScorer): """Interaction-based neural scorer This is the base class for all scorers that depend on a map diff --git a/src/xpmir/neural/splade.py b/src/xpmir/neural/splade.py index ae5b2793..f331ea41 100644 --- a/src/xpmir/neural/splade.py +++ b/src/xpmir/neural/splade.py @@ -1,5 +1,6 @@ from typing import List, Optional from experimaestro import Config, Param +import torch.nn.functional as F import torch.nn as nn import torch from xpmir.distributed import DistributableModel @@ -15,18 +16,33 @@ class Aggregation(Config): """The aggregation function for Splade""" - pass + def with_linear(self, logits, mask, weight, bias=None): + """Project before aggregating using a linear transformation + + Can be optimized by further operators + + :param logits: The logits output by the sequence representation model (B + x L x D) + :param mask: The mask (B x L) where 0 when the element should be masked + out + :param weight: The linear transformation (D' x D) + """ + projection = F.linear(logits, weight, bias) + return self(projection, mask) class MaxAggregation(Aggregation): """Aggregate using a max""" def __call__(self, logits, mask): + # Get the maximum (masking the values) values, _ = torch.max( - torch.log1p(torch.relu(logits) * mask.to(logits.device).unsqueeze(-1)), + torch.relu(logits) * mask.to(logits.device).unsqueeze(-1), dim=1, ) - return values + + # Computes log(1+x) + return torch.log1p(values.clamp(min=0)) class SumAggregation(Aggregation): diff --git a/src/xpmir/rankers/__init__.py b/src/xpmir/rankers/__init__.py index d81ec348..278ae431 100644 --- a/src/xpmir/rankers/__init__.py +++ b/src/xpmir/rankers/__init__.py @@ -175,14 +175,20 @@ def rsv( class AbstractLearnableScorer(Scorer, Module): """Base class for all learnable scorer""" + # Ensures basic operations are redirected to torch.nn.Module methods __call__ = nn.Module.__call__ to = nn.Module.to + train = nn.Module.train def __init__(self): + self.logger.info("Initializing %s", self) nn.Module.__init__(self) super().__init__() self._initialized = False + def __str__(self): + return f"scorer {self.__class__.__qualname__}" + def _initialize(self, random): raise NotImplementedError(f"_initialize in {self.__class__}") @@ -223,7 +229,7 @@ class LearnableScorer(AbstractLearnableScorer): A scorer with parameters that can be learnt""" - def __call__(self, inputs: "BaseRecords", info: Optional[TrainerContext]): + def forward(self, inputs: "BaseRecords", info: Optional[TrainerContext]): """Computes the score of all (query, document) pairs Different subclasses can process the input more or diff --git a/src/xpmir/test/letor/test_samplers.py b/src/xpmir/test/letor/test_samplers.py index d32b3355..42d08eb2 100644 --- a/src/xpmir/test/letor/test_samplers.py +++ b/src/xpmir/test/letor/test_samplers.py @@ -1,6 +1,11 @@ from typing import Iterator, Tuple +import datamaestro_text.data.ir as ir from datamaestro_text.data.ir.base import GenericTopic, GenericDocument -from xpmir.letor.samplers import TrainingTriplets, TripletBasedSampler +from xpmir.letor.samplers import ( + TrainingTriplets, + TripletBasedSampler, + ModelBasedSampler, +) # ---- Serialization @@ -45,3 +50,34 @@ def test_serializing_tripletbasedsampler(): assert ( expected.negative.document.get_text() == record.negative.document.get_text() ) + + +class GeneratedDocuments(ir.Documents): + pass + + +class GeneratedTopics(ir.Topics): + pass + + +class GeneratedAssessments(ir.AdhocAssessments): + pass + + +def adhoc_synthetic_dataset(): + """Creates a random dataset""" + return ir.Adhoc( + documents=GeneratedDocuments(), + topics=GeneratedTopics(), + assessments=GeneratedAssessments(), + ) + + +def test_modelbasedsampler(): + dataset = adhoc_synthetic_dataset() + sampler = ModelBasedSampler( + dataset=dataset, retriever=RandomRetriever(dataset=dataset) + ).instance() + + for a in sampler._itertopics(): + print(a) diff --git a/src/xpmir/utils/utils.py b/src/xpmir/utils/utils.py index 517a8958..d3fdd684 100644 --- a/src/xpmir/utils/utils.py +++ b/src/xpmir/utils/utils.py @@ -168,6 +168,9 @@ def find_java_home(min_version: int = 6) -> str: paths = [] # (1) Use environment variable + if java_home := os.environ.get("FORCE_JAVA_HOME", None): + return java_home + if java_home := os.environ.get("JAVA_HOME", None): paths.append(Path(java_home) / "bin" / "java")