Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Policy combiner #357

Merged
merged 15 commits into from
Sep 10, 2024
47 changes: 47 additions & 0 deletions compiler_opt/tools/combine_tfa_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs the policy combiner."""
from absl import app

import tensorflow as tf

from compiler_opt.rl import policy_saver
from compiler_opt.tools import combine_tfa_policies_lib as cfa_lib


def main(_):
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
expected_signature = cfa_lib.get_input_signature()
action_spec = cfa_lib.get_action_spec()
policy1_name = input("First policy name: ")
policy1_path = input(policy1_name + " path: ")
policy2_name = input("Second policy name: ")
policy2_path = input(policy2_name + " path: ")
policy1 = tf.saved_model.load(policy1_path, tags=None, options=None)
policy2 = tf.saved_model.load(policy2_path, tags=None, options=None)
combined_policy = cfa_lib.CombinedTFPolicy(
tf_policies={
policy1_name: policy1,
policy2_name: policy2
},
time_step_spec=expected_signature,
action_spec=action_spec)
combined_policy_path = input("Save combined policy path: ")
policy_dict = {"combined_policy": combined_policy}
saver = policy_saver.PolicySaver(policy_dict=policy_dict)
saver.save(combined_policy_path)


if __name__ == "__main__":
app.run(main)
193 changes: 193 additions & 0 deletions compiler_opt/tools/combine_tfa_policies_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Combines two tf-agent policies with the given state and action spec."""
from typing import Dict, Optional

import gin
import tensorflow as tf
import hashlib

import tf_agents
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.trajectories import policy_step
from tf_agents.specs import tensor_spec
import tensorflow_probability as tfp


class CombinedTFPolicy(tf_agents.policies.TFPolicy):
"""Policy which combines two target policies."""

def __init__(self, *args, tf_policies: Dict[str, tf_agents.policies.TFPolicy],
**kwargs):
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
super(CombinedTFPolicy, self).__init__(*args, **kwargs)

self.tf_policies = []
self.tf_policy_names = []
for name, policy in tf_policies.items():
self.tf_policies.append(policy)
self.tf_policy_names.append(name)

self.expected_signature = self.time_step_spec
self.sorted_keys = sorted(self.expected_signature.observation.keys())

high_low_tensors = []
for name in self.tf_policy_names:
m = hashlib.md5()
m.update(name.encode("utf-8"))
high_low_tensors.append(
tf.stack([
tf.constant(
int.from_bytes(m.digest()[8:], "little"), dtype=tf.uint64),
tf.constant(
int.from_bytes(m.digest()[:8], "little"), dtype=tf.uint64)
]))
self.high_low_tensors = tf.stack(high_low_tensors)

m = hashlib.md5()
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
m.update(self.tf_policy_names[0].encode("utf-8"))
self.high = int.from_bytes(m.digest()[8:], "little")
self.low = int.from_bytes(m.digest()[:8], "little")
self.high_low_tensor = tf.constant([self.high, self.low], dtype=tf.uint64)

def _process_observation(self, observation):
for name in self.sorted_keys:
if name in ["model_selector"]:
switch_tensor = observation.pop(name)[0]
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
high_low_tensor = switch_tensor

tf.debugging.Assert(
tf.equal(
tf.reduce_any(
tf.reduce_all(
tf.equal(high_low_tensor, self.high_low_tensors),
axis=1)), True),
[high_low_tensor, self.high_low_tensors])
return observation, switch_tensor

def _create_distribution(self, inlining_prediction):
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
probs = [inlining_prediction, 1.0 - inlining_prediction]
logits = [[0.0, tf.math.log(probs[1] / (1.0 - probs[1]))]]
return tfp.distributions.Categorical(logits=logits)

def _action(self,
time_step: ts.TimeStep,
policy_state: types.NestedTensorSpec,
seed: Optional[types.Seed] = None) -> policy_step.PolicyStep:
new_observation = time_step.observation
new_observation, switch_tensor = self._process_observation(new_observation)
updated_step = ts.TimeStep(
step_type=time_step.step_type,
reward=time_step.reward,
discount=time_step.discount,
observation=new_observation)

def f0():
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
return tf.cast(
self.tf_policies[0].action(updated_step).action[0], dtype=tf.int64)

def f1():
return tf.cast(
self.tf_policies[1].action(updated_step).action[0], dtype=tf.int64)

action = tf.cond(
tf.math.reduce_all(tf.equal(switch_tensor, self.high_low_tensor)), f0,
f1)
return tf_agents.trajectories.PolicyStep(action=action, state=policy_state)

def _distribution(
self, time_step: ts.TimeStep,
policy_state: types.NestedTensorSpec) -> policy_step.PolicyStep:
new_observation = time_step.observation
new_observation, switch_tensor = self._process_observation(new_observation)
updated_step = ts.TimeStep(
step_type=time_step.step_type,
reward=time_step.reward,
discount=time_step.discount,
observation=new_observation)

def f0():
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
return tf.cast(
self.tf_policies[0].distribution(updated_step).action.cdf(0)[0],
dtype=tf.float32)

def f1():
return tf.cast(
self.tf_policies[1].distribution(updated_step).action.cdf(0)[0],
dtype=tf.float32)

distribution = tf.cond(
tf.math.reduce_all(tf.equal(switch_tensor, self.high_low_tensor)), f0,
f1)
return tf_agents.trajectories.PolicyStep(
action=self._create_distribution(distribution), state=policy_state)


@gin.configurable()
def get_input_signature():
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
"""Returns a list of inlining features to be used with the combined models."""
# int64 features
inputs = dict((key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key))
for key in [
"caller_basic_block_count",
"caller_conditionally_executed_blocks",
"caller_users",
"callee_basic_block_count",
"callee_conditionally_executed_blocks",
"callee_users",
"nr_ctant_params",
"node_count",
"edge_count",
"callsite_height",
"cost_estimate",
"inlining_default",
"sroa_savings",
"sroa_losses",
"load_elimination",
"call_penalty",
"call_argument_setup",
"load_relative_intrinsic",
"lowered_call_arg_setup",
"indirect_call_penalty",
"jump_table_penalty",
"case_cluster_penalty",
"switch_penalty",
"unsimplified_common_instructions",
"num_loops",
"dead_blocks",
"simplified_instructions",
"constant_args",
"constant_offset_ptr_args",
"callsite_cost",
"cold_cc_penalty",
"last_call_to_static_bonus",
"is_multiple_blocks",
"nested_inlines",
"nested_inline_cost_estimate",
"threshold",
"is_callee_avail_external",
"is_caller_avail_external",
])
inputs.update({
"model_selector":
tf.TensorSpec(shape=(2,), dtype=tf.uint64, name="model_selector")
})
return ts.time_step_spec(inputs)


@gin.configurable()
def get_action_spec():
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
return tensor_spec.BoundedTensorSpec(
dtype=tf.int64, shape=(), name="inlining_decision", minimum=0, maximum=1)
149 changes: 149 additions & 0 deletions compiler_opt/tools/combine_tfa_policies_lib_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the combine_tfa_policies_lib.py module"""

from absl.testing import absltest

import tensorflow as tf
from compiler_opt.tools import combine_tfa_policies_lib
from tf_agents.trajectories import time_step as ts
import tf_agents
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import policy_step
import hashlib
import numpy as np


class AddOnePolicy(tf_agents.policies.TFPolicy):
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
"""Test policy which adds one to obs feature."""
tvmarino marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
obs_spec = {
'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)
}
time_step_spec = ts.time_step_spec(obs_spec)

act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)

super(AddOnePolicy, self).__init__(
time_step_spec=time_step_spec, action_spec=act_spec)

def _distribution(self, time_step):
pass

def _variables(self):
return ()

def _action(self, time_step, policy_state, seed):
observation = time_step.observation['obs'][0]
action = tf.reshape(observation + 1, (1,))
return policy_step.PolicyStep(action, policy_state)


class SubtractOnePolicy(tf_agents.policies.TFPolicy):
"""Test policy which subtracts one to obs feature."""
tvmarino marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
obs_spec = {
'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)
}
time_step_spec = ts.time_step_spec(obs_spec)

act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)

super(SubtractOnePolicy, self).__init__(
time_step_spec=time_step_spec, action_spec=act_spec)

def _distribution(self, time_step):
pass

def _variables(self):
return ()

def _action(self, time_step, policy_state, seed):
observation = time_step.observation['obs'][0]
action = tf.reshape(observation - 1, (1,))
return policy_step.PolicyStep(action, policy_state)


observation_spec = ts.time_step_spec({
'obs':
tf.TensorSpec(dtype=tf.int32, shape=(), name='obs'),
'model_selector':
tf.TensorSpec(shape=(2,), dtype=tf.uint64, name='model_selector')
})

action_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)


class FeatureImportanceTest(absltest.TestCase):
tvmarino marked this conversation as resolved.
Show resolved Hide resolved

def test_select_add_policy(self):
policy1 = AddOnePolicy()
policy2 = SubtractOnePolicy()
combined_policy = combine_tfa_policies_lib.CombinedTFPolicy(
tf_policies={
'add_one': policy1,
'subtract_one': policy2
},
time_step_spec=observation_spec,
action_spec=action_spec)

m = hashlib.md5()
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
m.update('add_one'.encode('utf-8'))
high = int.from_bytes(m.digest()[8:], 'little')
low = int.from_bytes(m.digest()[:8], 'little')
model_selector = tf.constant([[high, low]], dtype=tf.uint64)

state = tf_agents.trajectories.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
tvmarino marked this conversation as resolved.
Show resolved Hide resolved
observation={
'obs': tf.constant(np.array([0]), dtype=tf.int64),
'model_selector': model_selector
},
reward=tf.constant(np.array([0]), dtype=tf.float64),
step_type=tf.constant(np.array([0]), dtype=tf.int64))

self.assertEqual(
combined_policy.action(state).action, tf.constant(1, dtype=tf.int64))

def test_select_subtract_policy(self):
policy1 = AddOnePolicy()
policy2 = SubtractOnePolicy()
combined_policy = combine_tfa_policies_lib.CombinedTFPolicy(
tf_policies={
'add_one': policy1,
'subtract_one': policy2
},
time_step_spec=observation_spec,
action_spec=action_spec)

m = hashlib.md5()
m.update('subtract_one'.encode('utf-8'))
high = int.from_bytes(m.digest()[8:], 'little')
low = int.from_bytes(m.digest()[:8], 'little')
model_selector = tf.constant([[high, low]], dtype=tf.uint64)

state = tf_agents.trajectories.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'obs': tf.constant(np.array([0]), dtype=tf.int64),
'model_selector': model_selector
},
reward=tf.constant(np.array([0]), dtype=tf.float64),
step_type=tf.constant(np.array([0]), dtype=tf.int64))

self.assertEqual(
combined_policy.action(state).action, tf.constant(-1, dtype=tf.int64))
Loading