Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Policy combiner #357

Merged
merged 15 commits into from
Sep 10, 2024
89 changes: 89 additions & 0 deletions compiler_opt/tools/combine_tfa_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs the policy combiner."""
from absl import app
from absl import flags
from absl import logging

import sys

import gin

import tensorflow as tf

from compiler_opt.rl import policy_saver
from compiler_opt.rl import registry
from compiler_opt.tools import combine_tfa_policies_lib as cfa_lib

_COMBINE_POLICIES_NAMES = flags.DEFINE_multi_string(
'policies_names', [], 'List in order of policy names for combined policies.'
'Order must match that of policies_paths.')
_COMBINE_POLICIES_PATHS = flags.DEFINE_multi_string(
'policies_paths', [], 'List in order of policy paths for combined policies.'
'Order must match that of policies_names.')
_COMBINED_POLICY_PATH = flags.DEFINE_string(
'combined_policy_path', '', 'Path to save the combined policy.')
_GIN_FILES = flags.DEFINE_multi_string(
'gin_files', [], 'List of paths to gin configuration files.')
_GIN_BINDINGS = flags.DEFINE_multi_string(
'gin_bindings', [],
'Gin bindings to override the values set in the config files.')


def main(_):
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
flags.mark_flag_as_required('policies_names')
flags.mark_flag_as_required('policies_paths')
flags.mark_flag_as_required('combined_policy_path')
if len(_COMBINE_POLICIES_NAMES.value) != len(_COMBINE_POLICIES_PATHS.value):
logging.error(
'Length of policies_names: %d must equal length of policies_paths: %d.',
len(_COMBINE_POLICIES_NAMES.value), len(_COMBINE_POLICIES_PATHS.value))
sys.exit(1)
gin.parse_config_files_and_bindings(
_GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=False)

problem_config = registry.get_configuration()
expected_signature, action_spec = problem_config.get_signature_spec()
expected_signature.observation.update({
'model_selector':
tf.TensorSpec(shape=(2,), dtype=tf.uint64, name='model_selector')
})
# TODO(359): We only support combining two policies.Generalize this to handle
# multiple policies.
if len(_COMBINE_POLICIES_NAMES.value) != 2:
logging.error('Policy combiner only supports two policies, %d given.',
len(_COMBINE_POLICIES_NAMES.value))
sys.exit(1)
policy1_name = _COMBINE_POLICIES_NAMES.value[0]
policy1_path = _COMBINE_POLICIES_PATHS.value[0]
policy2_name = _COMBINE_POLICIES_NAMES.value[1]
policy2_path = _COMBINE_POLICIES_PATHS.value[1]
policy1 = tf.saved_model.load(policy1_path, tags=None, options=None)
policy2 = tf.saved_model.load(policy2_path, tags=None, options=None)
combined_policy = cfa_lib.CombinedTFPolicy(
tf_policies={
policy1_name: policy1,
policy2_name: policy2
},
time_step_spec=expected_signature,
action_spec=action_spec)
combined_policy_path = _COMBINED_POLICY_PATH.value
policy_dict = {'combined_policy': combined_policy}
saver = policy_saver.PolicySaver(policy_dict=policy_dict)
saver.save(combined_policy_path)


if __name__ == '__main__':
app.run(main)
116 changes: 116 additions & 0 deletions compiler_opt/tools/combine_tfa_policies_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Combines two tf-agent policies with the given state and action spec."""
from typing import Dict, Optional, Tuple

import tensorflow as tf
import hashlib

import tf_agents
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.trajectories import policy_step
import tensorflow_probability as tfp


class CombinedTFPolicy(tf_agents.policies.TFPolicy):
"""Policy which combines two target policies."""

def __init__(self, *args, tf_policies: Dict[str, tf_agents.policies.TFPolicy],
**kwargs):
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*args, **kwargs)

self.tf_policies = []
self.tf_policy_names = []
for name, policy in tf_policies.items():
self.tf_policies.append(policy)
self.tf_policy_names.append(name)

self.expected_signature = self.time_step_spec
self.sorted_keys = sorted(self.expected_signature.observation.keys())

high_low_tensors = []
for name in self.tf_policy_names:
m = hashlib.md5()
m.update(name.encode("utf-8"))
high_low_tensors.append(
tf.stack([
tf.constant(
int.from_bytes(m.digest()[8:], "little"), dtype=tf.uint64),
tf.constant(
int.from_bytes(m.digest()[:8], "little"), dtype=tf.uint64)
]))
self.high_low_tensors = tf.stack(high_low_tensors)
# Related LLVM commit: https://github.com/llvm/llvm-project/pull/96276
m = hashlib.md5()
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
m.update(self.tf_policy_names[0].encode("utf-8"))
self.high = int.from_bytes(m.digest()[8:], "little")
self.low = int.from_bytes(m.digest()[:8], "little")
self.high_low_tensor = tf.constant([self.high, self.low], dtype=tf.uint64)

def _process_observation(
self, observation: types.NestedSpecTensorOrArray
) -> Tuple[types.NestedSpecTensorOrArray, types.TensorOrArray]:
assert "model_selector" in self.sorted_keys
high_low_tensor = self.high_low_tensor
for name in self.sorted_keys:
if name in ["model_selector"]:
# model_selector is a Tensor of shape (1,) which requires indexing [0]
switch_tensor = observation.pop(name)[0]
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
high_low_tensor = switch_tensor

tf.debugging.Assert(
tf.equal(
tf.reduce_any(
tf.reduce_all(
tf.equal(high_low_tensor, self.high_low_tensors),
axis=1)), True),
[high_low_tensor, self.high_low_tensors])

return observation, high_low_tensor

def _action(self,
time_step: ts.TimeStep,
policy_state: types.NestedTensorSpec,
seed: Optional[types.Seed] = None) -> policy_step.PolicyStep:
new_observation = time_step.observation
new_observation, switch_tensor = self._process_observation(new_observation)
updated_step = ts.TimeStep(
step_type=time_step.step_type,
reward=time_step.reward,
discount=time_step.discount,
observation=new_observation)

# TODO(359): We only support combining two policies. Generalize this to
# handle multiple policies.
def f0():
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
return tf.cast(
self.tf_policies[0].action(updated_step).action[0], dtype=tf.int64)

def f1():
return tf.cast(
self.tf_policies[1].action(updated_step).action[0], dtype=tf.int64)

action = tf.cond(
tf.math.reduce_all(tf.equal(switch_tensor, self.high_low_tensor)), f0,
f1)
return policy_step.PolicyStep(action=action, state=policy_state)

def _distribution(
self, time_step: ts.TimeStep,
policy_state: types.NestedTensorSpec) -> policy_step.PolicyStep:
"""Placeholder for distribution as every TFPolicy requires it."""
return policy_step.PolicyStep(
action=tfp.distributions.Deterministic(2.), state=policy_state)
152 changes: 152 additions & 0 deletions compiler_opt/tools/combine_tfa_policies_lib_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the combine_tfa_policies_lib.py module"""

from absl.testing import absltest

import tensorflow as tf
from compiler_opt.tools import combine_tfa_policies_lib
from tf_agents.trajectories import time_step as ts
import tf_agents
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import policy_step
from tf_agents.typing import types
import hashlib
import numpy as np


def client_side_model_selector_calculation(policy_name: str) -> types.Tensor:
m = hashlib.md5()
m.update(policy_name.encode('utf-8'))
high = int.from_bytes(m.digest()[8:], 'little')
low = int.from_bytes(m.digest()[:8], 'little')
model_selector = tf.constant([[high, low]], dtype=tf.uint64)
return model_selector


class AddOnePolicy(tf_agents.policies.TFPolicy):
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
"""Test policy which increments the obs feature."""

def __init__(self):
obs_spec = {'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)}
time_step_spec = ts.time_step_spec(obs_spec)

act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)

super().__init__(time_step_spec=time_step_spec, action_spec=act_spec)

def _distribution(self, time_step):
"""Boilerplate function for TFPolicy."""
pass

def _variables(self):
"""Boilerplate function for TFPolicy."""
return ()

def _action(self, time_step, policy_state, seed):
"""Boilerplate function for TFPolicy."""
observation = time_step.observation['obs'][0]
action = tf.reshape(observation + 1, (1,))
return policy_step.PolicyStep(action, policy_state)


class SubtractOnePolicy(tf_agents.policies.TFPolicy):
"""Test policy which decrements the obs feature."""

def __init__(self):
obs_spec = {'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)}
time_step_spec = ts.time_step_spec(obs_spec)

act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)

super().__init__(time_step_spec=time_step_spec, action_spec=act_spec)

def _distribution(self, time_step):
"""Boilerplate function for TFPolicy."""
pass

def _variables(self):
"""Boilerplate function for TFPolicy."""
return ()

def _action(self, time_step, policy_state, seed):
"""Boilerplate function for TFPolicy."""
observation = time_step.observation['obs'][0]
action = tf.reshape(observation - 1, (1,))
return policy_step.PolicyStep(action, policy_state)


observation_spec = ts.time_step_spec({
'obs':
tf.TensorSpec(dtype=tf.int32, shape=(), name='obs'),
'model_selector':
tf.TensorSpec(shape=(2,), dtype=tf.uint64, name='model_selector')
})

action_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)


class CombinedTFPolicyTest(absltest.TestCase):
"""Test for CombinedTFPolicy."""

def test_select_add_policy(self):
policy1 = AddOnePolicy()
policy2 = SubtractOnePolicy()
combined_policy = combine_tfa_policies_lib.CombinedTFPolicy(
tf_policies={
'add_one': policy1,
'subtract_one': policy2
},
time_step_spec=observation_spec,
action_spec=action_spec)

model_selector = client_side_model_selector_calculation('add_one')

state = ts.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
observation={
'obs': tf.constant(np.array([42]), dtype=tf.int64),
'model_selector': model_selector
},
reward=tf.constant(np.array([0]), dtype=tf.float64),
step_type=tf.constant(np.array([0]), dtype=tf.int64))

self.assertEqual(
combined_policy.action(state).action, tf.constant(43, dtype=tf.int64))

def test_select_subtract_policy(self):
policy1 = AddOnePolicy()
policy2 = SubtractOnePolicy()
combined_policy = combine_tfa_policies_lib.CombinedTFPolicy(
tf_policies={
'add_one': policy1,
'subtract_one': policy2
},
time_step_spec=observation_spec,
action_spec=action_spec)

model_selector = client_side_model_selector_calculation('subtract_one')

state = ts.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'obs': tf.constant(np.array([42]), dtype=tf.int64),
'model_selector': model_selector
},
reward=tf.constant(np.array([0]), dtype=tf.float64),
step_type=tf.constant(np.array([0]), dtype=tf.int64))

self.assertEqual(
combined_policy.action(state).action, tf.constant(41, dtype=tf.int64))
Loading