From 758c4dbf986b74216470f92c85d52392517691f8 Mon Sep 17 00:00:00 2001 From: Benjamin Piwowarski Date: Sat, 11 May 2024 13:00:57 +0200 Subject: [PATCH] fix: CoSPLADE fixes --- src/xpmir/conversation/models/cosplade.py | 40 ++++++++++++++--------- src/xpmir/index/sparse.py | 14 ++++---- src/xpmir/learning/devices.py | 8 ++++- 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/src/xpmir/conversation/models/cosplade.py b/src/xpmir/conversation/models/cosplade.py index a0b34ae..cabed24 100644 --- a/src/xpmir/conversation/models/cosplade.py +++ b/src/xpmir/conversation/models/cosplade.py @@ -3,7 +3,7 @@ from datamaestro_text.data.conversation.base import EntryType import torch import sys -from experimaestro import Param +from experimaestro import Param, Constant from datamaestro.record import Record from datamaestro_text.data.ir import TextItem from datamaestro_text.data.conversation import ( @@ -35,6 +35,9 @@ class AsymetricMSEContextualizedRepresentationLoss( ): """Computes the asymetric loss for CoSPLADE""" + version: Constant[int] = 2 + """Current version""" + def __call__(self, input: CoSPLADEOutput, target: TextsRepresentationOutput): # Builds up the list of tokens in the gold output ids = target.tokenized.ids.cpu() @@ -47,19 +50,12 @@ def __call__(self, input: CoSPLADEOutput, target: TextsRepresentationOutput): sources.append(ix) tokens.append(token_id) - # Compute difference on selected tokens - difference = torch.nn.functional.mse_loss( - input.value[sources, tokens], - target.value[sources, tokens], - reduction="none", - ) - loss = torch.zeros( - len(target.value), dtype=target.value.dtype, device=target.value.device + # Compute the loss + delta = ( + torch.relu(target.value[sources, tokens] - input.value[sources, tokens]) + ** 2 ) - - # Aggregate - sources_pt = torch.tensor(sources, device=target.value.device, dtype=torch.long) - return loss.scatter_add(0, sources_pt, difference).mean() + return torch.sum(delta) / input.value.numel() class CoSPLADE(ConversationRepresentationEncoder): @@ -74,6 +70,9 @@ class CoSPLADE(ConversationRepresentationEncoder): history_encoder: Param[SpladeTextEncoderV2[Tuple[str, str]]] """Encoder for (query, answer) pairs""" + version: Constant[int] = 2 + """Current version""" + def __initialize__(self, options): super().__initialize__(options) @@ -91,8 +90,12 @@ def forward(self, records: List[Record]): history_size = self.history_size or sys.maxsize # Process each topic record + + #: History size for normalization + history_sizes = torch.zeros((len(records), 1)) + for ix, c_record in enumerate(records): - # Adds q_n, q_1, ..., q_{n-1} + # Adds q_n, q_{n-1}, ..., q_{1} queries.append( [c_record[TextItem].text] + [ @@ -104,10 +107,12 @@ def forward(self, records: List[Record]): # List of query/answer couples answer: Optional[AnswerEntry] = None + count = 0 for item in c_record[ConversationHistoryItem].history: entry_type = item[EntryType] if entry_type == EntryType.USER_QUERY and answer is not None: - query_answer_pairs.append((item[TextItem].text, answer.answer)) + count += 1 + query_answer_pairs.append((c_record[TextItem].text, answer.answer)) pair_origins.append(ix) if len(pair_origins) >= history_size: break @@ -118,6 +123,8 @@ def forward(self, records: List[Record]): # Ignore anything which is not a pair topic-response answer = None + history_sizes[ix, 0] = max(count, 1) + # (1) encodes the queries q_queries = self.queries_encoder(queries).value @@ -129,4 +136,7 @@ def forward(self, records: List[Record]): q_ix = q_ix.unsqueeze(-1).expand(x_pairs.shape) q_answers.scatter_add_(0, q_ix, x_pairs) + # Normalize by number of pairs + q_answers /= history_sizes.to(q_queries.device) + return CoSPLADEOutput(q_queries + q_answers, q_queries, q_answers) diff --git a/src/xpmir/index/sparse.py b/src/xpmir/index/sparse.py index 3b2bdbd..5845626 100644 --- a/src/xpmir/index/sparse.py +++ b/src/xpmir/index/sparse.py @@ -2,7 +2,6 @@ import asyncio from functools import cached_property -import logging import threading import heapq import torch @@ -98,10 +97,10 @@ class SparseRetriever(Retriever, Generic[InputType]): def initialize(self): super().initialize() - logging.info("Initializing the encoder") + logger.info("Initializing the encoder") self.encoder.initialize(ModuleInitMode.DEFAULT.to_options(None)) self.encoder.to(self.device.value) - logging.info("Initializing the index") + logger.info("Initializing the index") self.index.initialize(self.in_memory) def retrieve_all( @@ -120,7 +119,7 @@ async def aio_search_worker(progress, results: Dict, queue: asyncio.Queue): # Just stopped pass except Exception: - logging.exception("Error in worker thread") + logger.exception("Error in worker thread") async def reducer( batch: List[Tuple[str, InputType]], @@ -132,9 +131,9 @@ async def reducer( ): (ix,) = vector.nonzero() query = {ix: float(v) for ix, v in zip(ix, vector[ix])} - logging.debug("Adding topic %s to the queue", key) + logger.debug("Adding topic %s to the queue", key) await queue.put((key, query, self.topk)) - logging.debug("[done] Adding topic %s to the queue", key) + logger.debug("[done] Adding topic %s to the queue", key) async def aio_process(): workers = [] @@ -164,6 +163,7 @@ async def aio_process(): worker.cancel() return results + logger.info("Retrieve all with %d CPUs", available_cpus()) results = asyncio.run(aio_process()) return results @@ -209,7 +209,7 @@ def iterator(self): return batchiter( self.batch_size, zip( - range(sys.maxsize if self.max_docs == 0 else self.max_docs), + range(self.max_docs or sys.maxsize), self.documents.iter_documents(), ), ) diff --git a/src/xpmir/learning/devices.py b/src/xpmir/learning/devices.py index c2bd7a8..a0f68d0 100644 --- a/src/xpmir/learning/devices.py +++ b/src/xpmir/learning/devices.py @@ -1,3 +1,4 @@ +import sys from dataclasses import dataclass from pathlib import Path from experimaestro import Config, Param @@ -127,7 +128,12 @@ def execute(self, callback, *args, **kwargs): if n_gpus == 1 or not self.distributed: callback(DeviceInformation(self.value, True), *args, **kwargs) else: - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as directory: + if sys.version_info.major == 3 and sys.version_info.minor < 10: + tmp_directory = tempfile.TemporaryDirectory() + else: + tmp_directory = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + + with tmp_directory as directory: logger.info("Setting up distributed CUDA computing (%d GPUs)", n_gpus) return mp.start_processes( mp_launcher,