Skip to content

Commit

Permalink
Resolved _distribution and common.gin comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
tvmarino committed Sep 9, 2024
1 parent 6d8c0c7 commit 3b0cefd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 37 deletions.
2 changes: 0 additions & 2 deletions compiler_opt/tools/combine_tfa_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def main(_):
'Length of policies_names: %d must equal length of policies_paths: %d.',
len(_COMBINE_POLICIES_NAMES.value), len(_COMBINE_POLICIES_PATHS.value))
sys.exit(1)
gin.add_config_file_search_path(
'compiler_opt/rl/inlining/gin_configs/common.gin')
gin.parse_config_files_and_bindings(
_GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=False)

Expand Down
37 changes: 2 additions & 35 deletions compiler_opt/tools/combine_tfa_policies_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,6 @@ def _process_observation(

return observation, high_low_tensor

def _create_distribution(self, inlining_prediction):
"""Ensures that even deterministic policies return a distribution.
This will not change the behavior of the action function which is
what is used at inference time. The change for the distribution
function is so that we can always support sampling even for
deterministic policies."""
probs = [inlining_prediction, 1.0 - inlining_prediction]
logits = [[0.0, tf.math.log(probs[1] / (1.0 - probs[1]))]]
return tfp.distributions.Categorical(logits=logits)

def _action(self,
time_step: ts.TimeStep,
policy_state: types.NestedTensorSpec,
Expand Down Expand Up @@ -122,28 +111,6 @@ def f1():
def _distribution(
self, time_step: ts.TimeStep,
policy_state: types.NestedTensorSpec) -> policy_step.PolicyStep:
new_observation = time_step.observation
new_observation, switch_tensor = self._process_observation(new_observation)
updated_step = ts.TimeStep(
step_type=time_step.step_type,
reward=time_step.reward,
discount=time_step.discount,
observation=new_observation)

# TODO(359): We only support combining two policies.Generalize this to
# handle multiple policies.
def f0():
return tf.cast(
self.tf_policies[0].distribution(updated_step).action.cdf(0)[0],
dtype=tf.float32)

def f1():
return tf.cast(
self.tf_policies[1].distribution(updated_step).action.cdf(0)[0],
dtype=tf.float32)

distribution = tf.cond(
tf.math.reduce_all(tf.equal(switch_tensor, self.high_low_tensor)), f0,
f1)
"""Placeholder for distribution as every TFPolicy requires it."""
return policy_step.PolicyStep(
action=self._create_distribution(distribution), state=policy_state)
action=tfp.distributions.Deterministic(2.), state=policy_state)

0 comments on commit 3b0cefd

Please sign in to comment.