Skip to content

Commit

Permalink
Add a whole corpus sampler
Browse files Browse the repository at this point in the history
This patch adds a whole corpus sampler. This is intended to be used for
trace data collection where we need to compile an entire corpus subset
that is extracted using separate tooling in order to generate a reward.
  • Loading branch information
boomanaiden154 committed Sep 12, 2024
1 parent 2878b51 commit 237d491
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
18 changes: 18 additions & 0 deletions compiler_opt/rl/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,24 @@ def __call__(self, k: int, n: int = 10) -> List[ModuleSpec]:
return list(results)


class WholeCorpusSampler(Sampler):
"""Returns the entire corpus every time a sample is requested."""

def __init__(self, module_specs: Tuple[ModuleSpec]):
super().__init__(module_specs)

def reset(self):
pass

def __call__(self, k: int) -> List[ModuleSpec]:
"""Returns the entire corpus a list of module specs."""
if len(self._module_specs) != k:
raise ValueError(
f'The number of modules requested {k} is not equal to '
f'the number of modules in the corpus, {len(self._module_specs)}')
return list(self._module_specs)


class Corpus:
"""Represents a corpus.
Expand Down
28 changes: 28 additions & 0 deletions compiler_opt/rl/corpus_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,34 @@ def test_sample_without_replacement(self):
self.assertEqual(samples[2].name, 'small')
self.assertEqual(samples[3].name, 'smol')

def test_whole_corpus_sampler(self):
cps = corpus.create_corpus_for_testing(
location=self.create_tempdir(),
elements=[
corpus.ModuleSpec(name='xsmall', size=1),
corpus.ModuleSpec(name='small', size=5),
corpus.ModuleSpec(name='middle', size=20),
corpus.ModuleSpec(name='large', size=100)
],
sampler_type=corpus.WholeCorpusSampler)
sample = cps.sample(4, sort=True)
self.assertLen(sample, 4)
self.assertEqual(sample[0].name, 'large')
self.assertEqual(sample[1].name, 'middle')
self.assertEqual(sample[2].name, 'small')
self.assertEqual(sample[3].name, 'xsmall')

def test_whole_corpus_sampler_invalid_count(self):
cps = corpus.create_corpus_for_testing(
location=self.create_tempdir(),
elements=[
corpus.ModuleSpec(name='small', size=1),
corpus.ModuleSpec(name='middle', size=2),
],
sampler_type=corpus.WholeCorpusSampler)
with self.assertRaises(ValueError):
cps.sample(1)

def test_filter(self):
cps = corpus.create_corpus_for_testing(
location=self.create_tempdir(),
Expand Down

0 comments on commit 237d491

Please sign in to comment.