Skip to content

Commit

Permalink
fix: CoSPLADE fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed May 11, 2024
1 parent 931f599 commit 758c4db
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
40 changes: 25 additions & 15 deletions src/xpmir/conversation/models/cosplade.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datamaestro_text.data.conversation.base import EntryType
import torch
import sys
from experimaestro import Param
from experimaestro import Param, Constant
from datamaestro.record import Record
from datamaestro_text.data.ir import TextItem
from datamaestro_text.data.conversation import (
Expand Down Expand Up @@ -35,6 +35,9 @@ class AsymetricMSEContextualizedRepresentationLoss(
):
"""Computes the asymetric loss for CoSPLADE"""

version: Constant[int] = 2
"""Current version"""

def __call__(self, input: CoSPLADEOutput, target: TextsRepresentationOutput):
# Builds up the list of tokens in the gold output
ids = target.tokenized.ids.cpu()
Expand All @@ -47,19 +50,12 @@ def __call__(self, input: CoSPLADEOutput, target: TextsRepresentationOutput):
sources.append(ix)
tokens.append(token_id)

# Compute difference on selected tokens
difference = torch.nn.functional.mse_loss(
input.value[sources, tokens],
target.value[sources, tokens],
reduction="none",
)
loss = torch.zeros(
len(target.value), dtype=target.value.dtype, device=target.value.device
# Compute the loss
delta = (
torch.relu(target.value[sources, tokens] - input.value[sources, tokens])
** 2
)

# Aggregate
sources_pt = torch.tensor(sources, device=target.value.device, dtype=torch.long)
return loss.scatter_add(0, sources_pt, difference).mean()
return torch.sum(delta) / input.value.numel()


class CoSPLADE(ConversationRepresentationEncoder):
Expand All @@ -74,6 +70,9 @@ class CoSPLADE(ConversationRepresentationEncoder):
history_encoder: Param[SpladeTextEncoderV2[Tuple[str, str]]]
"""Encoder for (query, answer) pairs"""

version: Constant[int] = 2
"""Current version"""

def __initialize__(self, options):
super().__initialize__(options)

Expand All @@ -91,8 +90,12 @@ def forward(self, records: List[Record]):
history_size = self.history_size or sys.maxsize

# Process each topic record

#: History size for normalization
history_sizes = torch.zeros((len(records), 1))

for ix, c_record in enumerate(records):
# Adds q_n, q_1, ..., q_{n-1}
# Adds q_n, q_{n-1}, ..., q_{1}
queries.append(
[c_record[TextItem].text]
+ [
Expand All @@ -104,10 +107,12 @@ def forward(self, records: List[Record]):

# List of query/answer couples
answer: Optional[AnswerEntry] = None
count = 0
for item in c_record[ConversationHistoryItem].history:
entry_type = item[EntryType]
if entry_type == EntryType.USER_QUERY and answer is not None:
query_answer_pairs.append((item[TextItem].text, answer.answer))
count += 1
query_answer_pairs.append((c_record[TextItem].text, answer.answer))
pair_origins.append(ix)
if len(pair_origins) >= history_size:
break
Expand All @@ -118,6 +123,8 @@ def forward(self, records: List[Record]):
# Ignore anything which is not a pair topic-response
answer = None

history_sizes[ix, 0] = max(count, 1)

# (1) encodes the queries
q_queries = self.queries_encoder(queries).value

Expand All @@ -129,4 +136,7 @@ def forward(self, records: List[Record]):
q_ix = q_ix.unsqueeze(-1).expand(x_pairs.shape)
q_answers.scatter_add_(0, q_ix, x_pairs)

# Normalize by number of pairs
q_answers /= history_sizes.to(q_queries.device)

return CoSPLADEOutput(q_queries + q_answers, q_queries, q_answers)
14 changes: 7 additions & 7 deletions src/xpmir/index/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
from functools import cached_property
import logging
import threading
import heapq
import torch
Expand Down Expand Up @@ -98,10 +97,10 @@ class SparseRetriever(Retriever, Generic[InputType]):

def initialize(self):
super().initialize()
logging.info("Initializing the encoder")
logger.info("Initializing the encoder")
self.encoder.initialize(ModuleInitMode.DEFAULT.to_options(None))
self.encoder.to(self.device.value)
logging.info("Initializing the index")
logger.info("Initializing the index")
self.index.initialize(self.in_memory)

def retrieve_all(
Expand All @@ -120,7 +119,7 @@ async def aio_search_worker(progress, results: Dict, queue: asyncio.Queue):
# Just stopped
pass
except Exception:
logging.exception("Error in worker thread")
logger.exception("Error in worker thread")

async def reducer(
batch: List[Tuple[str, InputType]],
Expand All @@ -132,9 +131,9 @@ async def reducer(
):
(ix,) = vector.nonzero()
query = {ix: float(v) for ix, v in zip(ix, vector[ix])}
logging.debug("Adding topic %s to the queue", key)
logger.debug("Adding topic %s to the queue", key)
await queue.put((key, query, self.topk))
logging.debug("[done] Adding topic %s to the queue", key)
logger.debug("[done] Adding topic %s to the queue", key)

async def aio_process():
workers = []
Expand Down Expand Up @@ -164,6 +163,7 @@ async def aio_process():
worker.cancel()
return results

logger.info("Retrieve all with %d CPUs", available_cpus())
results = asyncio.run(aio_process())
return results

Expand Down Expand Up @@ -209,7 +209,7 @@ def iterator(self):
return batchiter(
self.batch_size,
zip(
range(sys.maxsize if self.max_docs == 0 else self.max_docs),
range(self.max_docs or sys.maxsize),
self.documents.iter_documents(),
),
)
Expand Down
8 changes: 7 additions & 1 deletion src/xpmir/learning/devices.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from dataclasses import dataclass
from pathlib import Path
from experimaestro import Config, Param
Expand Down Expand Up @@ -127,7 +128,12 @@ def execute(self, callback, *args, **kwargs):
if n_gpus == 1 or not self.distributed:
callback(DeviceInformation(self.value, True), *args, **kwargs)
else:
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as directory:
if sys.version_info.major == 3 and sys.version_info.minor < 10:
tmp_directory = tempfile.TemporaryDirectory()
else:
tmp_directory = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)

with tmp_directory as directory:
logger.info("Setting up distributed CUDA computing (%d GPUs)", n_gpus)
return mp.start_processes(
mp_launcher,
Expand Down

0 comments on commit 758c4db

Please sign in to comment.