Skip to content

Commit

Permalink
Addressed comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
tvmarino committed Sep 9, 2024
1 parent 59d3677 commit 6d8c0c7
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 14 deletions.
31 changes: 22 additions & 9 deletions compiler_opt/tools/combine_tfa_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"""Runs the policy combiner."""
from absl import app
from absl import flags
from absl import logging

import sys

import gin

Expand All @@ -25,11 +28,15 @@
from compiler_opt.tools import combine_tfa_policies_lib as cfa_lib

_COMBINE_POLICIES_NAMES = flags.DEFINE_multi_string(
'policies_names', [],
'List in order of policy names for combined policies.')
'policies_names',
[],
'List in order of policy names for combined policies. Order must match that of policies_paths.' # pylint: disable=line-too-long
)
_COMBINE_POLICIES_PATHS = flags.DEFINE_multi_string(
'policies_paths', [],
'List in order of policy paths for combined policies.')
'policies_paths',
[],
'List in order of policy paths for combined policies. Order must match that of policies_names.' # pylint: disable=line-too-long
)
_COMBINED_POLICY_PATH = flags.DEFINE_string(
'combined_policy_path', '', 'Path to save the combined policy.')
_GIN_FILES = flags.DEFINE_multi_string(
Expand All @@ -43,8 +50,11 @@ def main(_):
flags.mark_flag_as_required('policies_names')
flags.mark_flag_as_required('policies_paths')
flags.mark_flag_as_required('combined_policy_path')
assert len(_COMBINE_POLICIES_NAMES.value) == len(
_COMBINE_POLICIES_PATHS.value)
if len(_COMBINE_POLICIES_NAMES.value) != len(_COMBINE_POLICIES_PATHS.value):
logging.error(
'Length of policies_names: %d must equal length of policies_paths: %d.',
len(_COMBINE_POLICIES_NAMES.value), len(_COMBINE_POLICIES_PATHS.value))
sys.exit(1)
gin.add_config_file_search_path(
'compiler_opt/rl/inlining/gin_configs/common.gin')
gin.parse_config_files_and_bindings(
Expand All @@ -56,9 +66,12 @@ def main(_):
'model_selector':
tf.TensorSpec(shape=(2,), dtype=tf.uint64, name='model_selector')
})
assert len(_COMBINE_POLICIES_NAMES.value
) == 2, 'Combiner supports only two policies.'

# TODO(359): We only support combining two policies.Generalize this to handle
# multiple policies.
if len(_COMBINE_POLICIES_NAMES.value) != 2:
logging.error('Policy combiner only supports two policies, %d given.',
len(_COMBINE_POLICIES_NAMES.value))
sys.exit(1)
policy1_name = _COMBINE_POLICIES_NAMES.value[0]
policy1_path = _COMBINE_POLICIES_PATHS.value[0]
policy2_name = _COMBINE_POLICIES_NAMES.value[1]
Expand Down
11 changes: 11 additions & 0 deletions compiler_opt/tools/combine_tfa_policies_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _process_observation(
high_low_tensor = self.high_low_tensor
for name in self.sorted_keys:
if name in ["model_selector"]:
# model_selector is a Tensor of shape (1,) which requires indexing [0]
switch_tensor = observation.pop(name)[0]
high_low_tensor = switch_tensor

Expand All @@ -81,6 +82,12 @@ def _process_observation(
return observation, high_low_tensor

def _create_distribution(self, inlining_prediction):
"""Ensures that even deterministic policies return a distribution.
This will not change the behavior of the action function which is
what is used at inference time. The change for the distribution
function is so that we can always support sampling even for
deterministic policies."""
probs = [inlining_prediction, 1.0 - inlining_prediction]
logits = [[0.0, tf.math.log(probs[1] / (1.0 - probs[1]))]]
return tfp.distributions.Categorical(logits=logits)
Expand All @@ -97,6 +104,8 @@ def _action(self,
discount=time_step.discount,
observation=new_observation)

# TODO(359): We only support combining two policies.Generalize this to
# handle multiple policies.
def f0():
return tf.cast(
self.tf_policies[0].action(updated_step).action[0], dtype=tf.int64)
Expand All @@ -121,6 +130,8 @@ def _distribution(
discount=time_step.discount,
observation=new_observation)

# TODO(359): We only support combining two policies.Generalize this to
# handle multiple policies.
def f0():
return tf.cast(
self.tf_policies[0].distribution(updated_step).action.cdf(0)[0],
Expand Down
11 changes: 6 additions & 5 deletions compiler_opt/tools/combine_tfa_policies_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def _action(self, time_step, policy_state, seed):
action_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)


class FeatureImportanceTest(absltest.TestCase):
class CombinedTFPolicyTest(absltest.TestCase):
"""Test for CombinedTFPolicy."""

def test_select_add_policy(self):
policy1 = AddOnePolicy()
Expand All @@ -116,14 +117,14 @@ def test_select_add_policy(self):
state = ts.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'obs': tf.constant(np.array([0]), dtype=tf.int64),
'obs': tf.constant(np.array([42]), dtype=tf.int64),
'model_selector': model_selector
},
reward=tf.constant(np.array([0]), dtype=tf.float64),
step_type=tf.constant(np.array([0]), dtype=tf.int64))

self.assertEqual(
combined_policy.action(state).action, tf.constant(1, dtype=tf.int64))
combined_policy.action(state).action, tf.constant(43, dtype=tf.int64))

def test_select_subtract_policy(self):
policy1 = AddOnePolicy()
Expand All @@ -141,11 +142,11 @@ def test_select_subtract_policy(self):
state = ts.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'obs': tf.constant(np.array([0]), dtype=tf.int64),
'obs': tf.constant(np.array([42]), dtype=tf.int64),
'model_selector': model_selector
},
reward=tf.constant(np.array([0]), dtype=tf.float64),
step_type=tf.constant(np.array([0]), dtype=tf.int64))

self.assertEqual(
combined_policy.action(state).action, tf.constant(-1, dtype=tf.int64))
combined_policy.action(state).action, tf.constant(41, dtype=tf.int64))

0 comments on commit 6d8c0c7

Please sign in to comment.