Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add msmarco v2.1 trec rag #269

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions ir_datasets/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from . import mr_tydi
from . import msmarco_document
from . import msmarco_document_v2
from . import msmarco_document_v2_1
from . import msmarco_segment_v2_1
from . import msmarco_passage
from . import msmarco_passage_v2
from . import msmarco_qna
Expand Down
25 changes: 14 additions & 11 deletions ir_datasets/datasets/msmarco_document_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ def default_text(self):


class MsMarcoV2Docs(BaseDocs):
def __init__(self, dlc):
def __init__(self, dlc, docid_prefix='msmarco_doc_', docstore_size_hint=66500029281, name=NAME):
super().__init__()
self._dlc = dlc
self._docid_prefix = docid_prefix
self._docstore_size_hint = docstore_size_hint
self._name = name

@ir_datasets.util.use_docstore
def docs_iter(self):
with self._dlc.stream() as stream, \
tarfile.open(fileobj=stream, mode='r|') as tarf:
for record in tarf:
if not record.name.endswith('.gz'):
continue
with tarfile.open(self._dlc.path(), mode='r:') as tarf:
# since there's no compression, it's fast to scan all records and sort them.
# The sorting has no effect on v2, but in v2.1, the files are out-of-sequence, so this
# addressed that problem.
records = sorted([r for r in tarf if r.name.endswith('.gz')], key=lambda x: x.name)
for record in records:
file = tarf.extractfile(record)
with gzip.open(file) as file:
for line in file:
Expand Down Expand Up @@ -84,18 +88,17 @@ def docs_store(self, field='doc_id'):
data_cls=self.docs_cls(),
lookup_field=field,
index_fields=['doc_id'],
key_field_prefix='msmarco_doc_', # cut down on storage by removing prefix in lookup structure
size_hint=66500029281,
count_hint=ir_datasets.util.count_hint(NAME),
key_field_prefix=self._docid_prefix, # cut down on storage by removing prefix in lookup structure
size_hint=self._docstore_size_hint,
count_hint=ir_datasets.util.count_hint(self._name),
)
# return MsMArcoV2DocStore(self)

def docs_count(self):
if self.docs_store().built():
return self.docs_store().count()

def docs_namespace(self):
return NAME
return self._name

def docs_lang(self):
return 'en'
Expand Down
31 changes: 31 additions & 0 deletions ir_datasets/datasets/msmarco_document_v2_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import ir_datasets
from ir_datasets.util import DownloadConfig
from ir_datasets.datasets.base import Dataset, YamlDocumentation
from ir_datasets.formats import TsvQueries
from ir_datasets.datasets.msmarco_passage import DUA
from ir_datasets.datasets.msmarco_document_v2 import MsMarcoV2Docs

_logger = ir_datasets.log.easy()

NAME = 'msmarco-document-v2.1'

def _init():
base_path = ir_datasets.util.home_path()/NAME
documentation = YamlDocumentation(f'docs/{NAME}.yaml')
dlc = DownloadConfig.context(NAME, base_path, dua=DUA)
# we can re-use MsMarcoV2Docs, just with a few modifications directly
collection = MsMarcoV2Docs(dlc['docs'], docid_prefix='msmarco_v2.1_doc_', docstore_size_hint=59680176084, name=NAME)
subsets = {}

subsets['trec-rag-2024'] = Dataset(
collection,
TsvQueries(dlc['rag-2024-test-topics'], namespace=NAME, lang='en'),
)

ir_datasets.registry.register(NAME, Dataset(collection, documentation('_')))
for s in sorted(subsets):
ir_datasets.registry.register(f'{NAME}/{s}', Dataset(subsets[s], documentation(s)))

return collection, subsets

collection, subsets = _init()
68 changes: 39 additions & 29 deletions ir_datasets/datasets/msmarco_passage_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,22 @@ def parse_msmarco_passage(line):
data['docid'])


def passage_bundle_pos_from_key(key):
(string1, string2, bundlenum, position) = key.split('_')
assert string1 == 'msmarco' and string2 == 'passage'
return f'msmarco_passage_{bundlenum}', position

class MsMarcoV2Passages(BaseDocs):
def __init__(self, dlc, pos_dlc=None):
def __init__(self, dlc, pos_dlc=None, cls=MsMarcoV2Passage, parse_passage=parse_msmarco_passage, name=NAME, docstore_size_hint=60880127751, bundle_pos_from_key=passage_bundle_pos_from_key, count=138_364_198):
super().__init__()
self._dlc = dlc
self._pos_dlc = pos_dlc
self._cls = cls
self._parse_passage = parse_passage
self._name = name
self._docstore_size_hint = docstore_size_hint
self._bundle_pos_from_key = bundle_pos_from_key
self._count = count

@ir_datasets.util.use_docstore
def docs_iter(self):
Expand All @@ -59,30 +70,31 @@ def docs_iter(self):
# files are used (i.e., no filtering is applied)
yield from self.docs_store()
else:
with self._dlc.stream() as stream, \
tarfile.open(fileobj=stream, mode='r|') as tarf:
for record in tarf:
if not record.name.endswith('.gz'):
continue
with tarfile.open(self._dlc.path(), mode='r:') as tarf:
# since there's no compression, it's fast to scan all records and sort them.
# The sorting has no effect on v2, but in v2.1, the files are out-of-sequence, so this
# addressed that problem.
records = sorted([r for r in tarf if r.name.endswith('.gz')], key=lambda x: x.name)
for record in records:
file = tarf.extractfile(record)
with gzip.open(file) as file:
for line in file:
yield parse_msmarco_passage(line)
yield self._parse_passage(line)

def docs_cls(self):
return MsMarcoV2Passage
return self._cls

def docs_store(self, field='doc_id'):
assert field == 'doc_id'
# Unlike for msmarco-document-v2, using the docstore actually hurts performance.
return MsMarcoV2DocStore(self)
return MsMarcoV2DocStore(self, size_hint=self._docstore_size_hint, count=self._count)

def docs_count(self):
if self.docs_store().built():
return self.docs_store().count()

def docs_namespace(self):
return NAME
return self._name

def docs_lang(self):
return 'en'
Expand All @@ -92,7 +104,7 @@ def docs_path(self, force=True):


class MsMarcoV2DocStore(ir_datasets.indices.Docstore):
def __init__(self, docs_handler):
def __init__(self, docs_handler, size_hint=60880127751, count=138_364_198):
super().__init__(docs_handler.docs_cls(), 'doc_id')
self.np = ir_datasets.lazy_libs.numpy()
self.docs_handler = docs_handler
Expand All @@ -101,37 +113,38 @@ def __init__(self, docs_handler):
self.base_path = docs_handler.docs_path(force=False) + '.extracted'
if not os.path.exists(self.base_path):
os.makedirs(self.base_path)
self.size_hint = 60880127751
self.size_hint = size_hint
self._count = count

def get_many_iter(self, keys):
self.build()
# adapted from <https://microsoft.github.io/msmarco/TREC-Deep-Learning.html>
bundles = {}
for key in keys:
if not key.count('_') == 3:
try:
bundlenum, position = self.docs_handler._bundle_pos_from_key(key)
except:
continue
(string1, string2, bundlenum, position) = key.split('_')
assert string1 == 'msmarco' and string2 == 'passage'
if bundlenum not in bundles:
bundles[bundlenum] = []
bundles[bundlenum].append(int(position))
for bundlenum, positions in bundles.items():
positions = sorted(positions)
file = f'{self.base_path}/msmarco_passage_{bundlenum}'
file = f'{self.base_path}/{bundlenum}'
if not os.path.exists(file):
# invalid doc_id -- doesn't point to a real bundle
continue
if self.docs_handler._pos_dlc is not None:
# check the positions are valid for these doc_ids -- only return valid ones
mmp = self.np.memmap(os.path.join(self.pos_dlc.path(), f'msmarco_passage_{bundlenum}.pos'), dtype='<u4')
mmp = self.np.memmap(os.path.join(self.pos_dlc.path(), f'{bundlenum}.pos'), dtype='<u4')
positions = self.np.array(positions, dtype='<u4')
positions = positions[self.np.isin(positions, mmp)].tolist()
del mmp
with open(file, 'rt', encoding='utf8') as in_fh:
for position in positions:
in_fh.seek(position)
try:
yield parse_msmarco_passage(in_fh.readline())
yield self.docs_handler._parse_passage(in_fh.readline())
except json.JSONDecodeError:
# invalid doc_id -- pointed to a wrong position
pass
Expand All @@ -141,12 +154,9 @@ def build(self):
return
np = ir_datasets.lazy_libs.numpy()
ir_datasets.util.check_disk_free(self.base_path, self.size_hint)
with _logger.pbar_raw('extracting source documents', total=70, unit='file') as pbar, \
self.dlc.stream() as stream, \
tarfile.open(fileobj=stream, mode='r|') as tarf:
for record in tarf:
if not record.name.endswith('.gz'):
continue
with tarfile.open(self.dlc.path(), mode='r:') as tarf:
records = sorted([r for r in tarf if r.name.endswith('.gz')], key=lambda x: x.name)
for record in _logger.pbar(records, desc='extracting source documents'):
file = tarf.extractfile(record)
fname = record.name.split('/')[-1][:-len('.gz')]
positions = []
Expand All @@ -158,7 +168,6 @@ def build(self):
# keep track of the positions for efficient slicing
with open(os.path.join(self.base_path, f'{fname}.pos'), 'wb') as posout:
posout.write(np.array(positions, dtype='<u4').tobytes())
pbar.update(1)
(Path(self.base_path) / '_built').touch()

def built(self):
Expand All @@ -169,14 +178,15 @@ def __iter__(self):
return MsMarcoV2PassageIter(self, slice(0, self.count()))

def _iter_source_files(self):
for i in range(70):
yield os.path.join(self.base_path, f'msmarco_passage_{i:02d}')
for path in sorted(os.listdir(self.base_path)):
if path.startswith('msmarco_') and not path.endswith('.pos'):
yield os.path.join(self.base_path, path)

def count(self):
if self.docs_handler._pos_dlc is not None:
base_path = self.pos_dlc.path()
return sum(os.path.getsize(os.path.join(base_path, f)) for f in os.listdir(base_path)) // 4
return 138_364_198
return self._count


class MsMarcoV2PassageIter:
Expand Down Expand Up @@ -218,7 +228,7 @@ def __next__(self):
pos = self.current_pos_mmap[self.slice.start - self.current_file_start_idx]
self.current_file.seek(pos)
self.next_index = self.slice.start
result = parse_msmarco_passage(self.current_file.readline())
result = self.docstore.docs_handler._parse_passage(self.current_file.readline())
self.next_index += 1
self.slice = slice(self.slice.start + (self.slice.step or 1), self.slice.stop, self.slice.step)
return result
Expand Down
78 changes: 78 additions & 0 deletions ir_datasets/datasets/msmarco_segment_v2_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import json
from typing import NamedTuple
import ir_datasets
from ir_datasets.util import DownloadConfig
from ir_datasets.datasets.base import Dataset, YamlDocumentation
from ir_datasets.formats import TsvQueries
from ir_datasets.datasets.msmarco_passage import DUA
from ir_datasets.datasets.msmarco_passage_v2 import MsMarcoV2Passages

_logger = ir_datasets.log.easy()

NAME = 'msmarco-segment-v2.1'


class MsMarcoV21SegmentedDoc(NamedTuple):
doc_id: str
url: str
title: str
headings: str
segment: str
start_char: int
end_char: int
msmarco_document_id: str
msmarco_document_segment_idx: int
def default_text(self):
"""
title + headings + segment
This is consistent with the MsMarcoV21Document that returns the full text alternative of this: title + headings + body
Please note that Anserini additionaly returns the url. I.e., anserini returns url + title + headings + segment
E.g., https://github.com/castorini/anserini/blob/b8ce19f56bc4e85056ef703322f76646804ec640/src/main/java/io/anserini/collection/MsMarcoV2DocCollection.java#L169
"""
return f'{self.title} {self.headings} {self.segment}'


def parse_msmarco_segment(line):
data = json.loads(line)
msmarco_document_id, segment_info = data['docid'].split('#')
segment_idx, segment_file_offset = segment_info.split('_')
return MsMarcoV21SegmentedDoc(
data['docid'],
data['url'],
data['title'],
data['headings'],
data['segment'],
data['start_char'],
data['end_char'],
msmarco_document_id,
int(segment_idx),
)


def segment_bundle_pos_from_key(key):
# key like: msmarco_v2.1_doc_00_0#4_5974
first, second = key.split('#')
(string1, string2, string3, bundle, doc_pos) = first.split('_')
(segment_num, segment_pos) = second.split('_')
assert string1 == 'msmarco' and string2 == 'v2.1' and string3 == 'doc'
return f'msmarco_v2.1_doc_segmented_{bundle}.json', segment_pos


def _init():
base_path = ir_datasets.util.home_path()/NAME
documentation = YamlDocumentation(f'docs/{NAME}.yaml')
dlc = DownloadConfig.context(NAME, base_path, dua=DUA)
collection = MsMarcoV2Passages(dlc['docs'], cls=MsMarcoV21SegmentedDoc, parse_passage=parse_msmarco_segment, name=NAME, bundle_pos_from_key=segment_bundle_pos_from_key, count=113_520_750, docstore_size_hint=205178702472)
subsets = {}
subsets['trec-rag-2024'] = Dataset(
collection,
TsvQueries(dlc['rag-2024-test-topics'], namespace=NAME, lang='en'),
)

ir_datasets.registry.register(NAME, Dataset(collection, documentation('_')))
for s in sorted(subsets):
ir_datasets.registry.register(f'{NAME}/{s}', Dataset(subsets[s], documentation(s)))

return collection, subsets

collection, subsets = _init()
14 changes: 14 additions & 0 deletions ir_datasets/docs/msmarco-document-v2.1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_:
pretty_name: 'MSMARCO (document, version 2.1)'
desc: '
<p>
Version 2.1 of the MS MARCO document ranking dataset used in TREC RAG 2024.
</p>
<ul>
<li>Version 1 of dataset: <a class="ds-ref">msmarco-document</a></li>
<li>Documents: Text extracted from web pages</li>
<li>Queries: Natural language questions (from query log)</li>
<li> TODO: add paper describing the dataset.</li>
</ul>'
bibtex_ids: []

32 changes: 32 additions & 0 deletions ir_datasets/etc/downloads.json
Original file line number Diff line number Diff line change
Expand Up @@ -4714,6 +4714,38 @@
}
},

"msmarco-document-v2.1": {
"docs": {
"url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco_v2.1_doc.tar",
"size_hint": 30844989440,
"expected_md5": "a5950665d6448d3dbaf7135645f1e074",
"cache_path": "msmarco_v2.1_doc.tar",
"download_args": {"headers": {"X-Ms-Version": "2024-07-10"}}
},
"rag-2024-test-topics": {
"url": "https://trec-rag.github.io/assets/txt/topics.rag24.test.txt",
"size_hint": 19517,
"expected_md5": "5bd6c8fa0e1300233fe139bae8288d09",
"cache_path": "trec-rag-2024-topics-test.txt"
}
},

"msmarco-segment-v2.1": {
"docs": {
"url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco_v2.1_doc_segmented.tar",
"size_hint": 26918768640,
"expected_md5": "3799e7611efffd8daeb257e9ccca4d60",
"cache_path": "msmarco_v2.1_doc_segmented.tar",
"download_args": {"headers": {"X-Ms-Version": "2024-07-10"}}
},
"rag-2024-test-topics": {
"url": "https://trec-rag.github.io/assets/txt/topics.rag24.test.txt",
"size_hint": 19517,
"expected_md5": "5bd6c8fa0e1300233fe139bae8288d09",
"cache_path": "trec-rag-2024-topics-test.txt"
}
},

"msmarco-passage": {
"collectionandqueries": {
"url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/collectionandqueries.tar.gz",
Expand Down
Loading
Loading