Skip to content

Commit

Permalink
feat: multi-gpu sparse index building
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Feb 28, 2024
1 parent f13f521 commit a8fecac
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 47 deletions.
198 changes: 167 additions & 31 deletions src/xpmir/index/sparse.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Index for sparse models"""

import heapq
import torch
from queue import Empty
import torch.multiprocessing as mp
import numpy as np
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Generic
from typing import Dict, List, Tuple, Generic, Iterator, Union
from attrs import define
from experimaestro import (
Annotated,
Config,
Expand All @@ -19,10 +23,11 @@
from xpmir.learning import ModuleInitMode
from xpmir.learning.batchers import Batcher
from xpmir.utils.utils import batchiter, easylog
from xpmir.letor import Device, DEFAULT_DEVICE
from xpmir.letor import Device, DeviceInformation, DEFAULT_DEVICE
from xpmir.text.encoders import TextEncoderBase, TextsRepresentationOutput, InputType
from xpmir.rankers import Retriever, TopicRecord, ScoredDocument
from xpmir.utils.iter import MultiprocessIterator
from xpmir.utils.multiprocessing import StoppableQueue
import xpmir_rust

logger = easylog()
Expand Down Expand Up @@ -123,6 +128,22 @@ def retrieve(self, query: TopicRecord, top_k=None) -> List[ScoredDocument]:
return self.index.retrieve(query, top_k or self.topk)


@define(frozen=True)
class EncodedDocument:
docid: int
value: torch.Tensor


@define(frozen=True)
class DocumentRange:
rank: int
start: int
end: int

def __lt__(self, other: "DocumentRange"):
return self.start < other.start


class SparseRetrieverIndexBuilder(Task, Generic[InputType]):
"""Builds an index from a sparse representation
Expand All @@ -147,6 +168,7 @@ class SparseRetrieverIndexBuilder(Task, Generic[InputType]):
fast top-k strategies"""

device: Meta[Device] = DEFAULT_DEVICE
"""The device for building the index"""

max_postings: Meta[int] = 16384
"""Maximum number of postings (per term) before flushing to disk"""
Expand All @@ -172,52 +194,166 @@ def task_outputs(self, dep):
)

def execute(self):
# Encode all documents
logger.info(
f"Load the encoder and transfer to the target device {self.device.value}"
max_docs = (
self.documents.documentcount
if self.max_docs == 0
else min(self.max_docs, self.documents.documentcount)
)

self.encoder.initialize(ModuleInitMode.DEFAULT.to_options(None))
self.encoder.to(self.device.value).eval()

batcher = self.batcher.initialize(self.batch_size)

doc_iter = tqdm(
zip(
range(sys.maxsize if self.max_docs == 0 else self.max_docs),
MultiprocessIterator(self.documents.iter_documents()),
iter_batches = tqdm(
MultiprocessIterator(
batchiter(
self.batch_size,
zip(
range(sys.maxsize if self.max_docs == 0 else self.max_docs),
MultiprocessIterator(self.documents.iter_documents()).start(),
),
)
),
total=self.documents.documentcount
if self.max_docs == 0
else min(self.max_docs, self.documents.documentcount),
total=max_docs // self.batch_size,
unit_scale=self.batch_size,
unit="documents",
desc="Building the index",
)

# Create the index builder
self.encoder.initialize(ModuleInitMode.DEFAULT.to_options(None))

closed = mp.Event()
queues = [
StoppableQueue(2 * self.batch_size + 1, closed)
for _ in range(self.device.n_processes)
]

# Cleanup the index before starting
from shutil import rmtree
import xpmir_rust

if self.index_path.is_dir():
rmtree(self.index_path)
self.index_path.mkdir(parents=True)

self.indexer = xpmir_rust.index.SparseIndexer(str(self.index_path))
# Start the index process
index_process = mp.Process(
target=self.index,
args=(queues,),
daemon=True,
)
index_process.start()

# Waiting for the encoder process to end
logger.info(f"Starting to index {max_docs} documents")

try:
self.device.execute(self.device_execute, iter_batches, queues)
finally:
logger.info("Waiting for the index process to stop")
index_process.join()
if index_process.exitcode != 0:
logger.warning(
"Indexer process has finished with exit code %d",
index_process.exitcode,
)
raise RuntimeError("Failure")

# Index
logger.info(f"Starting to index {self.documents.documentcount} documents")
def index(
self, queues: List[StoppableQueue[Union[DocumentRange, EncodedDocument]]]
):
"""Index encoded documents
with torch.no_grad():
for batch in batchiter(self.batch_size, doc_iter):
batcher.process(batch, self.encode_documents)
:param queues: Queues are used to send tensors
"""
try:
# Get ranges
logger.info(
"Starting the indexing process (%d queues) in %s",
len(queues),
self.index_path,
)
indexer = xpmir_rust.index.SparseIndexer(str(self.index_path))
heap = [queue.get() for queue in queues]
heapq.heapify(queues)

# Loop over them
while heap:
# Process current range
current = heap[0]
logger.debug("Handling range: %s", current)
for docid in range(current.start, current.end + 1):
encoded = queues[current.rank].get()
assert (
encoded.docid == docid
), f"Mismatch in document IDs ({encoded.docid} vs {docid})"

(nonzero_ix,) = encoded.value.nonzero()
indexer.add(
docid, nonzero_ix.astype(np.uint64), encoded.value[nonzero_ix]
)

# Get next range
next_range = queues[current.rank].get() # type: DocumentRange
if next_range:
heapq.heappushpop(heap, next_range)
else:
logger.info("Iterator %d is over", current.rank)
heapq.heappop(heap)

logger.info("Building the index")
indexer.build(self.in_memory)
except Empty:
logger.warning("One encoder got a problem... stopping")
raise
except Exception:
# Close all the queues
logger.exception(
"Got an exception in the indexing process, closing the queues"
)
queues[0].stop()
raise

def device_execute(
self,
device_information: DeviceInformation,
iter_batches: Iterator[List[Tuple[int, DocumentRecord]]],
queues: List[StoppableQueue],
):
try:
# Encode all documents
logger.info(
"Load the encoder and "
f"transfer to the target device {self.device.value}"
)

# Build the index
self.indexer.build(self.in_memory)
encoder = self.encoder.to(self.device.value).eval()
queue = queues[device_information.rank]
batcher = self.batcher.initialize(self.batch_size)

def encode_documents(self, batch: List[Tuple[int, DocumentRecord]]):
# Index
with torch.no_grad():
for batch in iter_batches:
# Signals the output range
queue.put(
DocumentRange(
device_information.rank, batch[0][0], batch[-1][0]
)
)
# Outputs the documents
batcher.process(batch, self.encode_documents, encoder, queue)

# Build the index
logger.info("Closing queue %d", device_information.rank)
queue.put(None)
except Exception:
queue.stop()
raise

def encode_documents(
self,
batch: List[Tuple[int, DocumentRecord]],
encoder: TextEncoderBase[InputType, TextsRepresentationOutput],
queue: "mp.Queue[EncodedDocument]",
):
# Assumes for now dense vectors
vectors = (
self.encoder([d[TextItem].text for _, d in batch]).value.cpu().numpy()
encoder([d[TextItem].text for _, d in batch]).value.cpu().numpy()
) # bs * vocab
for vector, (docid, _) in zip(vectors, batch):
(nonzero_ix,) = vector.nonzero()
self.indexer.add(docid, nonzero_ix.astype(np.uint64), vector[nonzero_ix])
queue.put(EncodedDocument(docid, vector))
36 changes: 29 additions & 7 deletions src/xpmir/learning/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class DeviceInformation:
main: bool
"""Flag for the main process (all other are slaves)"""

count: int = 1
"""Number of processes"""

rank: int = 0
"""Rank when using multiple processes"""


class ComputationContext(Context):
device_information: DeviceInformation
Expand All @@ -40,19 +46,27 @@ def value(self):

return torch.device("cpu")

def execute(self, callback):
return callback(DeviceInformation(self.value, True))
n_processes = 1
"""Number of processes"""

def execute(self, callback, *args, **kwargs):
callback(DeviceInformation(self.value, True), *args, **kwargs)


def mp_launcher(rank, path, world_size, device, callback, taskenv):
def mp_launcher(rank, path, world_size, callback, taskenv, args, kwargs):
logger.warning("Launcher of rank %d [%s]", rank, path)
TaskEnv._instance = taskenv
taskenv.slave = rank == 0

dist.init_process_group(
"gloo", init_method=f"file://{path}", rank=rank, world_size=world_size
)
callback(DistributedDeviceInformation(device, rank == 0, rank))
device = torch.device(f"cuda:{rank}")
callback(
DistributedDeviceInformation(device, rank == 0, rank, count=world_size),
*args,
**kwargs,
)

# Cleanup
dist.destroy_process_group()
Expand Down Expand Up @@ -94,12 +108,19 @@ def value(self):

return torch.device("cuda")

def execute(self, callback):
@cached_property
def n_processes(self):
"""Number of processes"""
if self.distributed:
return torch.cuda.device_count()
return 1

def execute(self, callback, *args, **kwargs):
# Setup distributed computation
# Seehttps://pytorch.org/tutorials/intermediate/ddp_tutorial.html
n_gpus = torch.cuda.device_count()
if n_gpus == 1 or not self.distributed:
callback(DeviceInformation(self.value, True))
callback(DeviceInformation(self.value, True), *args, **kwargs)
else:
with tempfile.NamedTemporaryFile() as temporary:
logger.info("Setting up distributed CUDA computing (%d GPUs)", n_gpus)
Expand All @@ -108,9 +129,10 @@ def execute(self, callback):
args=(
temporary.name,
n_gpus,
self.value,
callback,
TaskEnv.instance(),
args,
kwargs,
),
nprocs=n_gpus,
join=True,
Expand Down
9 changes: 4 additions & 5 deletions src/xpmir/neural/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from abc import abstractmethod
import itertools
from typing import Iterable, Union, List, Optional, TypeVar, Generic
from typing import Iterable, Union, List, Optional, TypeVar, Generic, Sequence
import torch
from datamaestro_text.data.ir import TextItem
from xpmir.learning.batchers import Sliceable
from xpmir.learning.context import TrainerContext
from xpmir.letor.records import BaseRecords, ProductRecords, TopicRecord, DocumentRecord
from xpmir.rankers import LearnableScorer

QueriesRep = TypeVar("QueriesRep", bound=Sliceable["QueriesRep"])
DocsRep = TypeVar("DocsRep", bound=Sliceable["DocsRep"])
QueriesRep = TypeVar("QueriesRep", bound=Sequence)
DocsRep = TypeVar("DocsRep", bound=Sequence)


class DualRepresentationScorer(LearnableScorer, Generic[QueriesRep, DocsRep]):
Expand Down Expand Up @@ -57,7 +56,7 @@ def encode_documents(self, records: Iterable[DocumentRecord]) -> DocsRep:
def encode_queries(self, records: Iterable[TopicRecord]) -> QueriesRep:
"""Encode a list of texts (document or query)
The return value is model dependent, but should be sliceable
The return value is model dependent, but should be sequence
By default, uses `merge`
"""
Expand Down
6 changes: 2 additions & 4 deletions src/xpmir/neural/interaction/common.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from abc import ABC, abstractmethod
from typing import List, Union
from typing import List, Union, Sequence
from attrs import evolve

import torch
from attrs import define
from experimaestro import Config

from xpmir.learning.batchers import Sliceable


@define
class SimilarityInput(Sliceable["SimilarityInput"]):
class SimilarityInput(Sequence["SimilarityInput"]):
value: torch.Tensor
"""A 3D tensor (batch x max_length x dim)"""

Expand Down
Loading

0 comments on commit a8fecac

Please sign in to comment.