From 6d8c0c77a2185dc3bbc41df2ee85c5494a3c35b7 Mon Sep 17 00:00:00 2001 From: "Teodor V. Marinov" Date: Mon, 9 Sep 2024 13:46:27 +0000 Subject: [PATCH] Addressed comments. --- compiler_opt/tools/combine_tfa_policies.py | 31 +++++++++++++------ .../tools/combine_tfa_policies_lib.py | 11 +++++++ .../tools/combine_tfa_policies_lib_test.py | 11 ++++--- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/compiler_opt/tools/combine_tfa_policies.py b/compiler_opt/tools/combine_tfa_policies.py index fb1a8f25..0e758a8f 100755 --- a/compiler_opt/tools/combine_tfa_policies.py +++ b/compiler_opt/tools/combine_tfa_policies.py @@ -15,6 +15,9 @@ """Runs the policy combiner.""" from absl import app from absl import flags +from absl import logging + +import sys import gin @@ -25,11 +28,15 @@ from compiler_opt.tools import combine_tfa_policies_lib as cfa_lib _COMBINE_POLICIES_NAMES = flags.DEFINE_multi_string( - 'policies_names', [], - 'List in order of policy names for combined policies.') + 'policies_names', + [], + 'List in order of policy names for combined policies. Order must match that of policies_paths.' # pylint: disable=line-too-long +) _COMBINE_POLICIES_PATHS = flags.DEFINE_multi_string( - 'policies_paths', [], - 'List in order of policy paths for combined policies.') + 'policies_paths', + [], + 'List in order of policy paths for combined policies. Order must match that of policies_names.' # pylint: disable=line-too-long +) _COMBINED_POLICY_PATH = flags.DEFINE_string( 'combined_policy_path', '', 'Path to save the combined policy.') _GIN_FILES = flags.DEFINE_multi_string( @@ -43,8 +50,11 @@ def main(_): flags.mark_flag_as_required('policies_names') flags.mark_flag_as_required('policies_paths') flags.mark_flag_as_required('combined_policy_path') - assert len(_COMBINE_POLICIES_NAMES.value) == len( - _COMBINE_POLICIES_PATHS.value) + if len(_COMBINE_POLICIES_NAMES.value) != len(_COMBINE_POLICIES_PATHS.value): + logging.error( + 'Length of policies_names: %d must equal length of policies_paths: %d.', + len(_COMBINE_POLICIES_NAMES.value), len(_COMBINE_POLICIES_PATHS.value)) + sys.exit(1) gin.add_config_file_search_path( 'compiler_opt/rl/inlining/gin_configs/common.gin') gin.parse_config_files_and_bindings( @@ -56,9 +66,12 @@ def main(_): 'model_selector': tf.TensorSpec(shape=(2,), dtype=tf.uint64, name='model_selector') }) - assert len(_COMBINE_POLICIES_NAMES.value - ) == 2, 'Combiner supports only two policies.' - + # TODO(359): We only support combining two policies.Generalize this to handle + # multiple policies. + if len(_COMBINE_POLICIES_NAMES.value) != 2: + logging.error('Policy combiner only supports two policies, %d given.', + len(_COMBINE_POLICIES_NAMES.value)) + sys.exit(1) policy1_name = _COMBINE_POLICIES_NAMES.value[0] policy1_path = _COMBINE_POLICIES_PATHS.value[0] policy2_name = _COMBINE_POLICIES_NAMES.value[1] diff --git a/compiler_opt/tools/combine_tfa_policies_lib.py b/compiler_opt/tools/combine_tfa_policies_lib.py index a2303bf0..c8d09b6a 100644 --- a/compiler_opt/tools/combine_tfa_policies_lib.py +++ b/compiler_opt/tools/combine_tfa_policies_lib.py @@ -67,6 +67,7 @@ def _process_observation( high_low_tensor = self.high_low_tensor for name in self.sorted_keys: if name in ["model_selector"]: + # model_selector is a Tensor of shape (1,) which requires indexing [0] switch_tensor = observation.pop(name)[0] high_low_tensor = switch_tensor @@ -81,6 +82,12 @@ def _process_observation( return observation, high_low_tensor def _create_distribution(self, inlining_prediction): + """Ensures that even deterministic policies return a distribution. + + This will not change the behavior of the action function which is + what is used at inference time. The change for the distribution + function is so that we can always support sampling even for + deterministic policies.""" probs = [inlining_prediction, 1.0 - inlining_prediction] logits = [[0.0, tf.math.log(probs[1] / (1.0 - probs[1]))]] return tfp.distributions.Categorical(logits=logits) @@ -97,6 +104,8 @@ def _action(self, discount=time_step.discount, observation=new_observation) + # TODO(359): We only support combining two policies.Generalize this to + # handle multiple policies. def f0(): return tf.cast( self.tf_policies[0].action(updated_step).action[0], dtype=tf.int64) @@ -121,6 +130,8 @@ def _distribution( discount=time_step.discount, observation=new_observation) + # TODO(359): We only support combining two policies.Generalize this to + # handle multiple policies. def f0(): return tf.cast( self.tf_policies[0].distribution(updated_step).action.cdf(0)[0], diff --git a/compiler_opt/tools/combine_tfa_policies_lib_test.py b/compiler_opt/tools/combine_tfa_policies_lib_test.py index 2404bff1..7cd873c5 100644 --- a/compiler_opt/tools/combine_tfa_policies_lib_test.py +++ b/compiler_opt/tools/combine_tfa_policies_lib_test.py @@ -98,7 +98,8 @@ def _action(self, time_step, policy_state, seed): action_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64) -class FeatureImportanceTest(absltest.TestCase): +class CombinedTFPolicyTest(absltest.TestCase): + """Test for CombinedTFPolicy.""" def test_select_add_policy(self): policy1 = AddOnePolicy() @@ -116,14 +117,14 @@ def test_select_add_policy(self): state = ts.TimeStep( discount=tf.constant(np.array([0.]), dtype=tf.float32), observation={ - 'obs': tf.constant(np.array([0]), dtype=tf.int64), + 'obs': tf.constant(np.array([42]), dtype=tf.int64), 'model_selector': model_selector }, reward=tf.constant(np.array([0]), dtype=tf.float64), step_type=tf.constant(np.array([0]), dtype=tf.int64)) self.assertEqual( - combined_policy.action(state).action, tf.constant(1, dtype=tf.int64)) + combined_policy.action(state).action, tf.constant(43, dtype=tf.int64)) def test_select_subtract_policy(self): policy1 = AddOnePolicy() @@ -141,11 +142,11 @@ def test_select_subtract_policy(self): state = ts.TimeStep( discount=tf.constant(np.array([0.]), dtype=tf.float32), observation={ - 'obs': tf.constant(np.array([0]), dtype=tf.int64), + 'obs': tf.constant(np.array([42]), dtype=tf.int64), 'model_selector': model_selector }, reward=tf.constant(np.array([0]), dtype=tf.float64), step_type=tf.constant(np.array([0]), dtype=tf.int64)) self.assertEqual( - combined_policy.action(state).action, tf.constant(-1, dtype=tf.int64)) + combined_policy.action(state).action, tf.constant(41, dtype=tf.int64))