From 237d4910970b3f3658ad11120676a523053b3261 Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Thu, 12 Sep 2024 22:50:41 +0000 Subject: [PATCH] Add a whole corpus sampler This patch adds a whole corpus sampler. This is intended to be used for trace data collection where we need to compile an entire corpus subset that is extracted using separate tooling in order to generate a reward. --- compiler_opt/rl/corpus.py | 18 ++++++++++++++++++ compiler_opt/rl/corpus_test.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py index c351695e..87742442 100644 --- a/compiler_opt/rl/corpus.py +++ b/compiler_opt/rl/corpus.py @@ -224,6 +224,24 @@ def __call__(self, k: int, n: int = 10) -> List[ModuleSpec]: return list(results) +class WholeCorpusSampler(Sampler): + """Returns the entire corpus every time a sample is requested.""" + + def __init__(self, module_specs: Tuple[ModuleSpec]): + super().__init__(module_specs) + + def reset(self): + pass + + def __call__(self, k: int) -> List[ModuleSpec]: + """Returns the entire corpus a list of module specs.""" + if len(self._module_specs) != k: + raise ValueError( + f'The number of modules requested {k} is not equal to ' + f'the number of modules in the corpus, {len(self._module_specs)}') + return list(self._module_specs) + + class Corpus: """Represents a corpus. diff --git a/compiler_opt/rl/corpus_test.py b/compiler_opt/rl/corpus_test.py index 846a0fc7..fc7f41f0 100644 --- a/compiler_opt/rl/corpus_test.py +++ b/compiler_opt/rl/corpus_test.py @@ -267,6 +267,34 @@ def test_sample_without_replacement(self): self.assertEqual(samples[2].name, 'small') self.assertEqual(samples[3].name, 'smol') + def test_whole_corpus_sampler(self): + cps = corpus.create_corpus_for_testing( + location=self.create_tempdir(), + elements=[ + corpus.ModuleSpec(name='xsmall', size=1), + corpus.ModuleSpec(name='small', size=5), + corpus.ModuleSpec(name='middle', size=20), + corpus.ModuleSpec(name='large', size=100) + ], + sampler_type=corpus.WholeCorpusSampler) + sample = cps.sample(4, sort=True) + self.assertLen(sample, 4) + self.assertEqual(sample[0].name, 'large') + self.assertEqual(sample[1].name, 'middle') + self.assertEqual(sample[2].name, 'small') + self.assertEqual(sample[3].name, 'xsmall') + + def test_whole_corpus_sampler_invalid_count(self): + cps = corpus.create_corpus_for_testing( + location=self.create_tempdir(), + elements=[ + corpus.ModuleSpec(name='small', size=1), + corpus.ModuleSpec(name='middle', size=2), + ], + sampler_type=corpus.WholeCorpusSampler) + with self.assertRaises(ValueError): + cps.sample(1) + def test_filter(self): cps = corpus.create_corpus_for_testing( location=self.create_tempdir(),