From c0c3e4d0455f0ab943ff648ea54b8b517bc28492 Mon Sep 17 00:00:00 2001 From: Abena Laast Date: Tue, 18 Jul 2023 22:49:38 +0000 Subject: [PATCH 1/8] add policy_utils --- compiler_opt/es/policy_utils.py | 128 ++++++++++++++++ compiler_opt/es/policy_utils_test.py | 210 +++++++++++++++++++++++++++ 2 files changed, 338 insertions(+) create mode 100644 compiler_opt/es/policy_utils.py create mode 100644 compiler_opt/es/policy_utils_test.py diff --git a/compiler_opt/es/policy_utils.py b/compiler_opt/es/policy_utils.py new file mode 100644 index 00000000..fdff994d --- /dev/null +++ b/compiler_opt/es/policy_utils.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +############################################################################### +# +# +# This is a port of the code by Krzysztof Choromanski, Deepali Jain and Vikas +# Sindhwani, based on the portfolio of Blackbox optimization algorithms listed +# below: +# +# "On Blackbox Backpropagation and Jacobian Sensing"; K. Choromanski, +# V. Sindhwani, NeurIPS 2017 +# "Optimizing Simulations with Noise-Tolerant Structured Exploration"; K. +# Choromanski, A. Iscen, V. Sindhwani, J. Tan, E. Coumans, ICRA 2018 +# "Structured Evolution with Compact Architectures for Scalable Policy +# Optimization"; K. Choromanski, M. Rowland, V. Sindhwani, R. Turner, A. +# Weller, ICML 2018, https://arxiv.org/abs/1804.02395 +# "From Complexity to Simplicity: Adaptive ES-Active Subspaces for Blackbox +# Optimization"; K. Choromanski, A. Pacchiano, J. Parker-Holder, Y. Tang, V. +# Sindhwani, NeurIPS 2019 +# "i-Sim2Real: Reinforcement Learning on Robotic Policies in Tight Human-Robot +# Interaction Loops"; L. Graesser, D. D'Ambrosio, A. Singh, A. Bewley, D. Jain, +# K. Choromanski, P. Sanketi , CoRL 2022, https://arxiv.org/abs/2207.06572 +# "Agile Catching with Whole-Body MPC and Blackbox Policy Learning"; S. +# Abeyruwan, A. Bewley, N. Boffi, K. Choromanski, D. D'Ambrosio, D. Jain, P. +# Sanketi, A. Shankar, V. Sindhwani, S. Singh, J. Slotine, S. Tu, L4DC, +# https://arxiv.org/abs/2306.08205 +# "Robotic Table Tennis: A Case Study into a High Speed Learning System"; A. +# Bewley, A. Shankar, A. Iscen, A. Singh, C. Lynch, D. D'Ambrosio, D. Jain, +# E. Coumans, G. Versom, G. Kouretas, J. Abelian, J. Boyd, K. Oslund, +# K. Reymann, K. Choromanski, L. Graesser, M. Ahn, N. Jaitly, N. Lazic, +# P. Sanketi, P. Xu, P. Sermanet, R. Mahjourian, S. Abeyruwan, S. Kataoka, +# S. Moore, T. Nguyen, T. Ding, V. Sindhwani, V. Vanhoucke, W. Gao, Y. Kuang, +# to be presented at RSS 2023 +############################################################################### +"""Util functions to create and edit a tf_agent policy.""" + +import gin +import numpy as np +import numpy.typing as npt +import tensorflow as tf +from typing import Union + +from tf_agents.networks import network +from tf_agents.policies import actor_policy, greedy_policy, tf_policy +from compiler_opt.rl import policy_saver, registry + + +@gin.configurable(module='policy_utils') +def create_actor_policy(actor_network_ctor: network.DistributionNetwork, + greedy: bool = False) -> tf_policy.TFPolicy: + """Creates an actor policy.""" + problem_config = registry.get_configuration() + time_step_spec, action_spec = problem_config.get_signature_spec() + layers = tf.nest.map_structure( + problem_config.get_preprocessing_layer_creator(), + time_step_spec.observation) + + actor_network = actor_network_ctor( + input_tensor_spec=time_step_spec.observation, + output_tensor_spec=action_spec, + preprocessing_layers=layers) + + policy = actor_policy.ActorPolicy( + time_step_spec=time_step_spec, + action_spec=action_spec, + actor_network=actor_network) + + if greedy: + policy = greedy_policy.GreedyPolicy(policy) + + return policy + + +def get_vectorized_parameters_from_policy( + policy: Union[tf_policy.TFPolicy, tf.Module]) -> npt.NDArray[np.float32]: + if isinstance(policy, tf_policy.TFPolicy): + variables = policy.variables() + elif policy.model_variables: + variables = policy.model_variables + + parameters = [var.numpy().flatten() for var in variables] + parameters = np.concatenate(parameters, axis=0) + return parameters + + +def set_vectorized_parameters_for_policy( + policy: Union[tf_policy.TFPolicy, + tf.Module], parameters: npt.NDArray[np.float32]) -> None: + if isinstance(policy, tf_policy.TFPolicy): + variables = policy.variables() + else: + try: + getattr(policy, 'model_variables') + except AttributeError as e: + raise TypeError('policy must be a TFPolicy or a loaded SavedModel') from e + variables = policy.model_variables + + param_pos = 0 + for variable in variables: + shape = tf.shape(variable).numpy() + num_ele = np.prod(shape) + param = np.reshape(parameters[param_pos:param_pos + num_ele], shape) + variable.assign(param) + param_pos += num_ele + if param_pos != len(parameters): + raise ValueError( + f'Parameter dimensions are not matched! Expected {len(parameters)} ' + 'but only found {param_pos}.') + + +def save_policy(policy: tf_policy.TFPolicy, parameters: npt.NDArray[np.float32], + save_folder: str, policy_name: str) -> None: + set_vectorized_parameters_for_policy(policy, parameters) + saver = policy_saver.PolicySaver({policy_name: policy}) + saver.save(save_folder) diff --git a/compiler_opt/es/policy_utils_test.py b/compiler_opt/es/policy_utils_test.py new file mode 100644 index 00000000..a8cd2ead --- /dev/null +++ b/compiler_opt/es/policy_utils_test.py @@ -0,0 +1,210 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +############################################################################### +# +# +# This is a port of the code by Krzysztof Choromanski, Deepali Jain and Vikas +# Sindhwani, based on the portfolio of Blackbox optimization algorithms listed +# below: +# +# "On Blackbox Backpropagation and Jacobian Sensing"; K. Choromanski, +# V. Sindhwani, NeurIPS 2017 +# "Optimizing Simulations with Noise-Tolerant Structured Exploration"; K. +# Choromanski, A. Iscen, V. Sindhwani, J. Tan, E. Coumans, ICRA 2018 +# "Structured Evolution with Compact Architectures for Scalable Policy +# Optimization"; K. Choromanski, M. Rowland, V. Sindhwani, R. Turner, A. +# Weller, ICML 2018, https://arxiv.org/abs/1804.02395 +# "From Complexity to Simplicity: Adaptive ES-Active Subspaces for Blackbox +# Optimization"; K. Choromanski, A. Pacchiano, J. Parker-Holder, Y. Tang, V. +# Sindhwani, NeurIPS 2019 +# "i-Sim2Real: Reinforcement Learning on Robotic Policies in Tight Human-Robot +# Interaction Loops"; L. Graesser, D. D'Ambrosio, A. Singh, A. Bewley, D. Jain, +# K. Choromanski, P. Sanketi , CoRL 2022, https://arxiv.org/abs/2207.06572 +# "Agile Catching with Whole-Body MPC and Blackbox Policy Learning"; S. +# Abeyruwan, A. Bewley, N. Boffi, K. Choromanski, D. D'Ambrosio, D. Jain, P. +# Sanketi, A. Shankar, V. Sindhwani, S. Singh, J. Slotine, S. Tu, L4DC, +# https://arxiv.org/abs/2306.08205 +# "Robotic Table Tennis: A Case Study into a High Speed Learning System"; A. +# Bewley, A. Shankar, A. Iscen, A. Singh, C. Lynch, D. D'Ambrosio, D. Jain, +# E. Coumans, G. Versom, G. Kouretas, J. Abelian, J. Boyd, K. Oslund, +# K. Reymann, K. Choromanski, L. Graesser, M. Ahn, N. Jaitly, N. Lazic, +# P. Sanketi, P. Xu, P. Sermanet, R. Mahjourian, S. Abeyruwan, S. Kataoka, +# S. Moore, T. Nguyen, T. Ding, V. Sindhwani, V. Vanhoucke, W. Gao, Y. Kuang, +# to be presented at RSS 2023 +############################################################################### +"""Tests for policy_utils.""" + +from absl.testing import absltest +import numpy as np +import os +import tensorflow as tf +from tf_agents.networks import actor_distribution_network +from tf_agents.policies import actor_policy + +from compiler_opt.es import policy_utils +from compiler_opt.rl import policy_saver, registry +from compiler_opt.rl.inlining import InliningConfig +from compiler_opt.rl.inlining import config as inlining_config +from compiler_opt.rl.regalloc import config as regalloc_config +from compiler_opt.rl.regalloc import RegallocEvictionConfig, regalloc_network + + +class ConfigTest(absltest.TestCase): + + def test_inlining_config(self): + problem_config = registry.get_configuration(implementation=InliningConfig) + time_step_spec, action_spec = problem_config.get_signature_spec() + creator = inlining_config.get_observation_processing_layer_creator( + quantile_file_dir='compiler_opt/rl/inlining/vocab/', + with_sqrt=False, + with_z_score_normalization=False) + layers = tf.nest.map_structure(creator, time_step_spec.observation) + + actor_network = actor_distribution_network.ActorDistributionNetwork( + input_tensor_spec=time_step_spec.observation, + output_tensor_spec=action_spec, + preprocessing_layers=layers, + preprocessing_combiner=tf.keras.layers.Concatenate(), + fc_layer_params=(64, 64, 64, 64), + dropout_layer_params=None, + activation_fn=tf.keras.activations.relu) + + policy = actor_policy.ActorPolicy( + time_step_spec=time_step_spec, + action_spec=action_spec, + actor_network=actor_network) + + self.assertIsNotNone(policy) + self.assertIsInstance( + policy._actor_network, # pylint: disable=protected-access + actor_distribution_network.ActorDistributionNetwork) + + def test_regalloc_config(self): + problem_config = registry.get_configuration( + implementation=RegallocEvictionConfig) + time_step_spec, action_spec = problem_config.get_signature_spec() + creator = regalloc_config.get_observation_processing_layer_creator( + quantile_file_dir='compiler_opt/rl/regalloc/vocab', + with_sqrt=False, + with_z_score_normalization=False) + layers = tf.nest.map_structure(creator, time_step_spec.observation) + + actor_network = regalloc_network.RegAllocNetwork( + input_tensor_spec=time_step_spec.observation, + output_tensor_spec=action_spec, + preprocessing_layers=layers, + preprocessing_combiner=tf.keras.layers.Concatenate(), + fc_layer_params=(64, 64, 64, 64), + dropout_layer_params=None, + activation_fn=tf.keras.activations.relu) + + policy = actor_policy.ActorPolicy( + time_step_spec=time_step_spec, + action_spec=action_spec, + actor_network=actor_network) + + self.assertIsNotNone(policy) + self.assertIsInstance( + policy._actor_network, # pylint: disable=protected-access + regalloc_network.RegAllocNetwork) + + +class VectorTest(absltest.TestCase): + + def test_set_vectorized_parameters_for_policy(self): + # create a policy + problem_config = registry.get_configuration(implementation=InliningConfig) + time_step_spec, action_spec = problem_config.get_signature_spec() + creator = inlining_config.get_observation_processing_layer_creator( + quantile_file_dir='compiler_opt/rl/inlining/vocab/', + with_sqrt=False, + with_z_score_normalization=False) + layers = tf.nest.map_structure(creator, time_step_spec.observation) + + actor_network = actor_distribution_network.ActorDistributionNetwork( + input_tensor_spec=time_step_spec.observation, + output_tensor_spec=action_spec, + preprocessing_layers=layers, + preprocessing_combiner=tf.keras.layers.Concatenate(), + fc_layer_params=(64, 64, 64, 64), + dropout_layer_params=None, + activation_fn=tf.keras.activations.relu) + + policy = actor_policy.ActorPolicy( + time_step_spec=time_step_spec, + action_spec=action_spec, + actor_network=actor_network) + saver = policy_saver.PolicySaver({'policy': policy}) + + # save the policy + testing_path = self.create_tempdir() + policy_save_path = os.path.join(testing_path, 'temp_output/policy') + saver.save(policy_save_path) + + # set the values of the policy variables + length_of_a_perturbation = 17218 + params = np.arange(length_of_a_perturbation, dtype=np.float32) + policy_utils.set_vectorized_parameters_for_policy(policy, params) + # iterate through variables and check their values + idx = 0 + for variable in policy.variables(): # pylint: disable=not-callable + nums = variable.numpy().flatten() + for num in nums: + if idx != num: + raise AssertionError(f'values at index {idx} do not match') + idx += 1 + + def test_get_vectorized_parameters_from_policy(self): + # create a policy + problem_config = registry.get_configuration(implementation=InliningConfig) + time_step_spec, action_spec = problem_config.get_signature_spec() + creator = inlining_config.get_observation_processing_layer_creator( + quantile_file_dir='compiler_opt/rl/inlining/vocab/', + with_sqrt=False, + with_z_score_normalization=False) + layers = tf.nest.map_structure(creator, time_step_spec.observation) + + actor_network = actor_distribution_network.ActorDistributionNetwork( + input_tensor_spec=time_step_spec.observation, + output_tensor_spec=action_spec, + preprocessing_layers=layers, + preprocessing_combiner=tf.keras.layers.Concatenate(), + fc_layer_params=(64, 64, 64, 64), + dropout_layer_params=None, + activation_fn=tf.keras.activations.relu) + + policy = actor_policy.ActorPolicy( + time_step_spec=time_step_spec, + action_spec=action_spec, + actor_network=actor_network) + saver = policy_saver.PolicySaver({'policy': policy}) + + # save the policy + testing_path = self.create_tempdir() + policy_save_path = os.path.join(testing_path, 'temp_output/policy') + saver.save(policy_save_path) + + length_of_a_perturbation = 17218 + params = np.arange(length_of_a_perturbation, dtype=np.float32) + # functionality verified in previous test + policy_utils.set_vectorized_parameters_for_policy(policy, params) + # vectorize and check if the outcome is the same as the start + output = policy_utils.get_vectorized_parameters_from_policy(policy) + np.testing.assert_array_almost_equal(output, params) + + +if __name__ == '__main__': + absltest.main() From 35425e8c5f8730f3022995b8f11261a5a0f9bfca Mon Sep 17 00:00:00 2001 From: Abena Laast Date: Wed, 19 Jul 2023 18:12:45 +0000 Subject: [PATCH 2/8] add tests for loaded policies, revise error handling, add docstrings, edit type annotations, remove credit message --- compiler_opt/es/policy_utils.py | 71 ++++++++++------------------ compiler_opt/es/policy_utils_test.py | 64 ++++++++++++------------- 2 files changed, 53 insertions(+), 82 deletions(-) diff --git a/compiler_opt/es/policy_utils.py b/compiler_opt/es/policy_utils.py index fdff994d..6402596c 100644 --- a/compiler_opt/es/policy_utils.py +++ b/compiler_opt/es/policy_utils.py @@ -12,45 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -############################################################################### -# -# -# This is a port of the code by Krzysztof Choromanski, Deepali Jain and Vikas -# Sindhwani, based on the portfolio of Blackbox optimization algorithms listed -# below: -# -# "On Blackbox Backpropagation and Jacobian Sensing"; K. Choromanski, -# V. Sindhwani, NeurIPS 2017 -# "Optimizing Simulations with Noise-Tolerant Structured Exploration"; K. -# Choromanski, A. Iscen, V. Sindhwani, J. Tan, E. Coumans, ICRA 2018 -# "Structured Evolution with Compact Architectures for Scalable Policy -# Optimization"; K. Choromanski, M. Rowland, V. Sindhwani, R. Turner, A. -# Weller, ICML 2018, https://arxiv.org/abs/1804.02395 -# "From Complexity to Simplicity: Adaptive ES-Active Subspaces for Blackbox -# Optimization"; K. Choromanski, A. Pacchiano, J. Parker-Holder, Y. Tang, V. -# Sindhwani, NeurIPS 2019 -# "i-Sim2Real: Reinforcement Learning on Robotic Policies in Tight Human-Robot -# Interaction Loops"; L. Graesser, D. D'Ambrosio, A. Singh, A. Bewley, D. Jain, -# K. Choromanski, P. Sanketi , CoRL 2022, https://arxiv.org/abs/2207.06572 -# "Agile Catching with Whole-Body MPC and Blackbox Policy Learning"; S. -# Abeyruwan, A. Bewley, N. Boffi, K. Choromanski, D. D'Ambrosio, D. Jain, P. -# Sanketi, A. Shankar, V. Sindhwani, S. Singh, J. Slotine, S. Tu, L4DC, -# https://arxiv.org/abs/2306.08205 -# "Robotic Table Tennis: A Case Study into a High Speed Learning System"; A. -# Bewley, A. Shankar, A. Iscen, A. Singh, C. Lynch, D. D'Ambrosio, D. Jain, -# E. Coumans, G. Versom, G. Kouretas, J. Abelian, J. Boyd, K. Oslund, -# K. Reymann, K. Choromanski, L. Graesser, M. Ahn, N. Jaitly, N. Lazic, -# P. Sanketi, P. Xu, P. Sermanet, R. Mahjourian, S. Abeyruwan, S. Kataoka, -# S. Moore, T. Nguyen, T. Ding, V. Sindhwani, V. Vanhoucke, W. Gao, Y. Kuang, -# to be presented at RSS 2023 -############################################################################### """Util functions to create and edit a tf_agent policy.""" import gin import numpy as np import numpy.typing as npt import tensorflow as tf +from tensorflow.python.trackable import autotrackable from typing import Union from tf_agents.networks import network @@ -58,6 +26,7 @@ from compiler_opt.rl import policy_saver, registry +# TODO(abenalaast): Issue #280 @gin.configurable(module='policy_utils') def create_actor_policy(actor_network_ctor: network.DistributionNetwork, greedy: bool = False) -> tf_policy.TFPolicy: @@ -85,11 +54,15 @@ def create_actor_policy(actor_network_ctor: network.DistributionNetwork, def get_vectorized_parameters_from_policy( - policy: Union[tf_policy.TFPolicy, tf.Module]) -> npt.NDArray[np.float32]: + policy: Union[tf_policy.TFPolicy, autotrackable.AutoTrackable] +) -> npt.NDArray[np.float32]: + """Returns a policy's variable values as a single np array.""" if isinstance(policy, tf_policy.TFPolicy): variables = policy.variables() - elif policy.model_variables: + elif hasattr(policy, 'model_variables'): variables = policy.model_variables + else: + raise ValueError('policy must be a TFPolicy or a loaded SavedModel') parameters = [var.numpy().flatten() for var in variables] parameters = np.concatenate(parameters, axis=0) @@ -97,32 +70,36 @@ def get_vectorized_parameters_from_policy( def set_vectorized_parameters_for_policy( - policy: Union[tf_policy.TFPolicy, - tf.Module], parameters: npt.NDArray[np.float32]) -> None: + policy: Union[tf_policy.TFPolicy, autotrackable.AutoTrackable], + parameters: npt.NDArray[np.float32]) -> None: + """Separates values in parameters into the policy's shapes + and sets the policy variables to those values""" if isinstance(policy, tf_policy.TFPolicy): variables = policy.variables() - else: - try: - getattr(policy, 'model_variables') - except AttributeError as e: - raise TypeError('policy must be a TFPolicy or a loaded SavedModel') from e + elif hasattr(policy, 'model_variables'): variables = policy.model_variables + else: + raise ValueError('policy must be a TFPolicy or a loaded SavedModel') param_pos = 0 for variable in variables: shape = tf.shape(variable).numpy() - num_ele = np.prod(shape) - param = np.reshape(parameters[param_pos:param_pos + num_ele], shape) + num_elems = np.prod(shape) + param = np.reshape(parameters[param_pos:param_pos + num_elems], shape) variable.assign(param) - param_pos += num_ele + param_pos += num_elems if param_pos != len(parameters): raise ValueError( f'Parameter dimensions are not matched! Expected {len(parameters)} ' 'but only found {param_pos}.') -def save_policy(policy: tf_policy.TFPolicy, parameters: npt.NDArray[np.float32], - save_folder: str, policy_name: str) -> None: +def save_policy(policy: Union[tf_policy.TFPolicy, autotrackable.AutoTrackable], + parameters: npt.NDArray[np.float32], save_folder: str, + policy_name: str) -> None: + """Assigns a policy the name policy_name + and saves it to the directory of save_folder + with the values in parameters.""" set_vectorized_parameters_for_policy(policy, parameters) saver = policy_saver.PolicySaver({policy_name: policy}) saver.save(save_folder) diff --git a/compiler_opt/es/policy_utils_test.py b/compiler_opt/es/policy_utils_test.py index a8cd2ead..eb37722a 100644 --- a/compiler_opt/es/policy_utils_test.py +++ b/compiler_opt/es/policy_utils_test.py @@ -12,56 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -############################################################################### -# -# -# This is a port of the code by Krzysztof Choromanski, Deepali Jain and Vikas -# Sindhwani, based on the portfolio of Blackbox optimization algorithms listed -# below: -# -# "On Blackbox Backpropagation and Jacobian Sensing"; K. Choromanski, -# V. Sindhwani, NeurIPS 2017 -# "Optimizing Simulations with Noise-Tolerant Structured Exploration"; K. -# Choromanski, A. Iscen, V. Sindhwani, J. Tan, E. Coumans, ICRA 2018 -# "Structured Evolution with Compact Architectures for Scalable Policy -# Optimization"; K. Choromanski, M. Rowland, V. Sindhwani, R. Turner, A. -# Weller, ICML 2018, https://arxiv.org/abs/1804.02395 -# "From Complexity to Simplicity: Adaptive ES-Active Subspaces for Blackbox -# Optimization"; K. Choromanski, A. Pacchiano, J. Parker-Holder, Y. Tang, V. -# Sindhwani, NeurIPS 2019 -# "i-Sim2Real: Reinforcement Learning on Robotic Policies in Tight Human-Robot -# Interaction Loops"; L. Graesser, D. D'Ambrosio, A. Singh, A. Bewley, D. Jain, -# K. Choromanski, P. Sanketi , CoRL 2022, https://arxiv.org/abs/2207.06572 -# "Agile Catching with Whole-Body MPC and Blackbox Policy Learning"; S. -# Abeyruwan, A. Bewley, N. Boffi, K. Choromanski, D. D'Ambrosio, D. Jain, P. -# Sanketi, A. Shankar, V. Sindhwani, S. Singh, J. Slotine, S. Tu, L4DC, -# https://arxiv.org/abs/2306.08205 -# "Robotic Table Tennis: A Case Study into a High Speed Learning System"; A. -# Bewley, A. Shankar, A. Iscen, A. Singh, C. Lynch, D. D'Ambrosio, D. Jain, -# E. Coumans, G. Versom, G. Kouretas, J. Abelian, J. Boyd, K. Oslund, -# K. Reymann, K. Choromanski, L. Graesser, M. Ahn, N. Jaitly, N. Lazic, -# P. Sanketi, P. Xu, P. Sermanet, R. Mahjourian, S. Abeyruwan, S. Kataoka, -# S. Moore, T. Nguyen, T. Ding, V. Sindhwani, V. Vanhoucke, W. Gao, Y. Kuang, -# to be presented at RSS 2023 -############################################################################### """Tests for policy_utils.""" from absl.testing import absltest import numpy as np import os import tensorflow as tf +from tensorflow.python.trackable import autotrackable from tf_agents.networks import actor_distribution_network -from tf_agents.policies import actor_policy +from tf_agents.policies import actor_policy, tf_policy from compiler_opt.es import policy_utils from compiler_opt.rl import policy_saver, registry -from compiler_opt.rl.inlining import InliningConfig from compiler_opt.rl.inlining import config as inlining_config +from compiler_opt.rl.inlining import InliningConfig from compiler_opt.rl.regalloc import config as regalloc_config from compiler_opt.rl.regalloc import RegallocEvictionConfig, regalloc_network +# TODO(abenalaast): Issue #280 class ConfigTest(absltest.TestCase): def test_inlining_config(self): @@ -167,6 +136,21 @@ def test_set_vectorized_parameters_for_policy(self): raise AssertionError(f'values at index {idx} do not match') idx += 1 + # get saved model to test a loaded policy + sm = tf.saved_model.load(policy_save_path + '/policy') + self.assertIsInstance(sm, autotrackable.AutoTrackable) + self.assertNotIsInstance(sm, tf_policy.TFPolicy) + params = params[::-1] + policy_utils.set_vectorized_parameters_for_policy(sm, params) + val = length_of_a_perturbation - 1 + for variable in sm.model_variables: + nums = variable.numpy().flatten() + for num in nums: + if val != num: + raise AssertionError( + f'values at index {length_of_a_perturbation - val} do not match') + val -= 1 + def test_get_vectorized_parameters_from_policy(self): # create a policy problem_config = registry.get_configuration(implementation=InliningConfig) @@ -205,6 +189,16 @@ def test_get_vectorized_parameters_from_policy(self): output = policy_utils.get_vectorized_parameters_from_policy(policy) np.testing.assert_array_almost_equal(output, params) + # get saved model to test a loaded policy + sm = tf.saved_model.load(policy_save_path + '/policy') + self.assertIsInstance(sm, autotrackable.AutoTrackable) + self.assertNotIsInstance(sm, tf_policy.TFPolicy) + params = params[::-1] + policy_utils.set_vectorized_parameters_for_policy(sm, params) + # vectorize and check if the outcome is the same as the start + output = policy_utils.get_vectorized_parameters_from_policy(sm) + np.testing.assert_array_almost_equal(output, params) + if __name__ == '__main__': absltest.main() From abe22016d5094c16820758835c4f678a56a44166 Mon Sep 17 00:00:00 2001 From: Abena Laast Date: Wed, 19 Jul 2023 20:07:03 +0000 Subject: [PATCH 3/8] include passed object in ValueError message, add test to check that tfpolicy and loaded policy variable orders match --- compiler_opt/es/policy_utils.py | 8 ++-- compiler_opt/es/policy_utils_test.py | 68 ++++++++++++++++++++++++---- 2 files changed, 63 insertions(+), 13 deletions(-) diff --git a/compiler_opt/es/policy_utils.py b/compiler_opt/es/policy_utils.py index 6402596c..a391ed39 100644 --- a/compiler_opt/es/policy_utils.py +++ b/compiler_opt/es/policy_utils.py @@ -62,7 +62,8 @@ def get_vectorized_parameters_from_policy( elif hasattr(policy, 'model_variables'): variables = policy.model_variables else: - raise ValueError('policy must be a TFPolicy or a loaded SavedModel') + raise ValueError(f'Policy must be a TFPolicy or a loaded SavedModel. ' + f'Passed policy: {policy}') parameters = [var.numpy().flatten() for var in variables] parameters = np.concatenate(parameters, axis=0) @@ -79,7 +80,8 @@ def set_vectorized_parameters_for_policy( elif hasattr(policy, 'model_variables'): variables = policy.model_variables else: - raise ValueError('policy must be a TFPolicy or a loaded SavedModel') + raise ValueError(f'Policy must be a TFPolicy or a loaded SavedModel. ' + f'Passed policy: {policy}') param_pos = 0 for variable in variables: @@ -91,7 +93,7 @@ def set_vectorized_parameters_for_policy( if param_pos != len(parameters): raise ValueError( f'Parameter dimensions are not matched! Expected {len(parameters)} ' - 'but only found {param_pos}.') + f'but only found {param_pos}.') def save_policy(policy: Union[tf_policy.TFPolicy, autotrackable.AutoTrackable], diff --git a/compiler_opt/es/policy_utils_test.py b/compiler_opt/es/policy_utils_test.py index eb37722a..2d5dd4db 100644 --- a/compiler_opt/es/policy_utils_test.py +++ b/compiler_opt/es/policy_utils_test.py @@ -127,9 +127,12 @@ def test_set_vectorized_parameters_for_policy(self): length_of_a_perturbation = 17218 params = np.arange(length_of_a_perturbation, dtype=np.float32) policy_utils.set_vectorized_parameters_for_policy(policy, params) - # iterate through variables and check their values + expected_variable_shapes = [(71, 64), (64), (64, 64), (64), (64, 64), (64), + (64, 64), (64), (64, 2), (2)] + # iterate through variables and check their shapes and values idx = 0 - for variable in policy.variables(): # pylint: disable=not-callable + for i, variable in enumerate(policy.variables()): # pylint: disable=not-callable + self.assertEqual(variable.shape, expected_variable_shapes[i]) nums = variable.numpy().flatten() for num in nums: if idx != num: @@ -140,16 +143,15 @@ def test_set_vectorized_parameters_for_policy(self): sm = tf.saved_model.load(policy_save_path + '/policy') self.assertIsInstance(sm, autotrackable.AutoTrackable) self.assertNotIsInstance(sm, tf_policy.TFPolicy) - params = params[::-1] policy_utils.set_vectorized_parameters_for_policy(sm, params) - val = length_of_a_perturbation - 1 - for variable in sm.model_variables: + idx = 0 + for i, variable in enumerate(sm.model_variables): + self.assertEqual(variable.shape, expected_variable_shapes[i]) nums = variable.numpy().flatten() for num in nums: - if val != num: - raise AssertionError( - f'values at index {length_of_a_perturbation - val} do not match') - val -= 1 + if idx != num: + raise AssertionError(f'values at index {idx} do not match') + idx += 1 def test_get_vectorized_parameters_from_policy(self): # create a policy @@ -193,12 +195,58 @@ def test_get_vectorized_parameters_from_policy(self): sm = tf.saved_model.load(policy_save_path + '/policy') self.assertIsInstance(sm, autotrackable.AutoTrackable) self.assertNotIsInstance(sm, tf_policy.TFPolicy) - params = params[::-1] policy_utils.set_vectorized_parameters_for_policy(sm, params) # vectorize and check if the outcome is the same as the start output = policy_utils.get_vectorized_parameters_from_policy(sm) np.testing.assert_array_almost_equal(output, params) + def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self): + # create a policy + problem_config = registry.get_configuration(implementation=InliningConfig) + time_step_spec, action_spec = problem_config.get_signature_spec() + creator = inlining_config.get_observation_processing_layer_creator( + quantile_file_dir='compiler_opt/rl/inlining/vocab/', + with_sqrt=False, + with_z_score_normalization=False) + layers = tf.nest.map_structure(creator, time_step_spec.observation) + + actor_network = actor_distribution_network.ActorDistributionNetwork( + input_tensor_spec=time_step_spec.observation, + output_tensor_spec=action_spec, + preprocessing_layers=layers, + preprocessing_combiner=tf.keras.layers.Concatenate(), + fc_layer_params=(64, 64, 64, 64), + dropout_layer_params=None, + activation_fn=tf.keras.activations.relu) + + policy = actor_policy.ActorPolicy( + time_step_spec=time_step_spec, + action_spec=action_spec, + actor_network=actor_network) + saver = policy_saver.PolicySaver({'policy': policy}) + + # save the policy + testing_path = self.create_tempdir() + policy_save_path = os.path.join(testing_path, 'temp_output/policy') + saver.save(policy_save_path) + + length_of_a_perturbation = 17218 + params = np.arange(length_of_a_perturbation, dtype=np.float32) + # set the values of the variables + policy_utils.set_vectorized_parameters_for_policy(policy, params) + # save the changes + saver.save(policy_save_path) + # vectorize the tfpolicy + tf_params = policy_utils.get_vectorized_parameters_from_policy(policy) + + # get loaded policy + sm = tf.saved_model.load(policy_save_path + '/policy') + # vectorize the loaded policy + loaded_params = policy_utils.get_vectorized_parameters_from_policy(sm) + + # assert that they result in the same order of values + np.testing.assert_array_almost_equal(tf_params, loaded_params) + if __name__ == '__main__': absltest.main() From 96dcfe2f4ec39bcbc6bfc183c798d7234b8ccf5a Mon Sep 17 00:00:00 2001 From: Abena Laast Date: Wed, 19 Jul 2023 23:52:24 +0000 Subject: [PATCH 4/8] replace AutoTrackable type annotation with HasModelVariables Protocol --- compiler_opt/es/policy_utils.py | 16 +++++++++------- compiler_opt/es/policy_utils_test.py | 3 --- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/compiler_opt/es/policy_utils.py b/compiler_opt/es/policy_utils.py index a391ed39..0ec225c7 100644 --- a/compiler_opt/es/policy_utils.py +++ b/compiler_opt/es/policy_utils.py @@ -18,12 +18,15 @@ import numpy as np import numpy.typing as npt import tensorflow as tf -from tensorflow.python.trackable import autotrackable -from typing import Union +from typing import Protocol, Sequence +from compiler_opt.rl import policy_saver, registry from tf_agents.networks import network from tf_agents.policies import actor_policy, greedy_policy, tf_policy -from compiler_opt.rl import policy_saver, registry + + +class HasModelVariables(Protocol): + model_variables: Sequence[float] # TODO(abenalaast): Issue #280 @@ -54,8 +57,7 @@ def create_actor_policy(actor_network_ctor: network.DistributionNetwork, def get_vectorized_parameters_from_policy( - policy: Union[tf_policy.TFPolicy, autotrackable.AutoTrackable] -) -> npt.NDArray[np.float32]: + policy: tf_policy.TFPolicy | HasModelVariables) -> npt.NDArray[np.float32]: """Returns a policy's variable values as a single np array.""" if isinstance(policy, tf_policy.TFPolicy): variables = policy.variables() @@ -71,7 +73,7 @@ def get_vectorized_parameters_from_policy( def set_vectorized_parameters_for_policy( - policy: Union[tf_policy.TFPolicy, autotrackable.AutoTrackable], + policy: tf_policy.TFPolicy | HasModelVariables, parameters: npt.NDArray[np.float32]) -> None: """Separates values in parameters into the policy's shapes and sets the policy variables to those values""" @@ -96,7 +98,7 @@ def set_vectorized_parameters_for_policy( f'but only found {param_pos}.') -def save_policy(policy: Union[tf_policy.TFPolicy, autotrackable.AutoTrackable], +def save_policy(policy: tf_policy.TFPolicy | HasModelVariables, parameters: npt.NDArray[np.float32], save_folder: str, policy_name: str) -> None: """Assigns a policy the name policy_name diff --git a/compiler_opt/es/policy_utils_test.py b/compiler_opt/es/policy_utils_test.py index 2d5dd4db..d57051f8 100644 --- a/compiler_opt/es/policy_utils_test.py +++ b/compiler_opt/es/policy_utils_test.py @@ -18,7 +18,6 @@ import numpy as np import os import tensorflow as tf -from tensorflow.python.trackable import autotrackable from tf_agents.networks import actor_distribution_network from tf_agents.policies import actor_policy, tf_policy @@ -141,7 +140,6 @@ def test_set_vectorized_parameters_for_policy(self): # get saved model to test a loaded policy sm = tf.saved_model.load(policy_save_path + '/policy') - self.assertIsInstance(sm, autotrackable.AutoTrackable) self.assertNotIsInstance(sm, tf_policy.TFPolicy) policy_utils.set_vectorized_parameters_for_policy(sm, params) idx = 0 @@ -193,7 +191,6 @@ def test_get_vectorized_parameters_from_policy(self): # get saved model to test a loaded policy sm = tf.saved_model.load(policy_save_path + '/policy') - self.assertIsInstance(sm, autotrackable.AutoTrackable) self.assertNotIsInstance(sm, tf_policy.TFPolicy) policy_utils.set_vectorized_parameters_for_policy(sm, params) # vectorize and check if the outcome is the same as the start From f9d098e1d4368ac358c47d0ced3cde68f1b57664 Mon Sep 17 00:00:00 2001 From: Abena Laast Date: Thu, 20 Jul 2023 17:48:24 +0000 Subject: [PATCH 5/8] use Any type as placeholder for Protocol --- compiler_opt/es/policy_utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/compiler_opt/es/policy_utils.py b/compiler_opt/es/policy_utils.py index 0ec225c7..2b87c620 100644 --- a/compiler_opt/es/policy_utils.py +++ b/compiler_opt/es/policy_utils.py @@ -18,15 +18,16 @@ import numpy as np import numpy.typing as npt import tensorflow as tf -from typing import Protocol, Sequence +from typing import Any, Protocol, Sequence, Union from compiler_opt.rl import policy_saver, registry from tf_agents.networks import network from tf_agents.policies import actor_policy, greedy_policy, tf_policy +# TODO(abenalaast): Replace Any type annotation with HasModelVariables Protocol class HasModelVariables(Protocol): - model_variables: Sequence[float] + model_variables = Sequence[Sequence[float]] # TODO(abenalaast): Issue #280 @@ -57,7 +58,7 @@ def create_actor_policy(actor_network_ctor: network.DistributionNetwork, def get_vectorized_parameters_from_policy( - policy: tf_policy.TFPolicy | HasModelVariables) -> npt.NDArray[np.float32]: + policy: Union[tf_policy.TFPolicy, Any]) -> npt.NDArray[np.float32]: """Returns a policy's variable values as a single np array.""" if isinstance(policy, tf_policy.TFPolicy): variables = policy.variables() @@ -73,8 +74,8 @@ def get_vectorized_parameters_from_policy( def set_vectorized_parameters_for_policy( - policy: tf_policy.TFPolicy | HasModelVariables, - parameters: npt.NDArray[np.float32]) -> None: + policy: Union[tf_policy.TFPolicy, + Any], parameters: npt.NDArray[np.float32]) -> None: """Separates values in parameters into the policy's shapes and sets the policy variables to those values""" if isinstance(policy, tf_policy.TFPolicy): @@ -98,9 +99,9 @@ def set_vectorized_parameters_for_policy( f'but only found {param_pos}.') -def save_policy(policy: tf_policy.TFPolicy | HasModelVariables, - parameters: npt.NDArray[np.float32], save_folder: str, - policy_name: str) -> None: +def save_policy(policy: Union[tf_policy.TFPolicy, + Any], parameters: npt.NDArray[np.float32], + save_folder: str, policy_name: str) -> None: """Assigns a policy the name policy_name and saves it to the directory of save_folder with the values in parameters.""" From 97637374c1d506cfceffbc540c982c65103cfe14 Mon Sep 17 00:00:00 2001 From: Abena Laast Date: Thu, 20 Jul 2023 19:26:05 +0000 Subject: [PATCH 6/8] implement HasModelVariables Protocol --- compiler_opt/es/policy_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/compiler_opt/es/policy_utils.py b/compiler_opt/es/policy_utils.py index 2b87c620..8053d9aa 100644 --- a/compiler_opt/es/policy_utils.py +++ b/compiler_opt/es/policy_utils.py @@ -18,16 +18,15 @@ import numpy as np import numpy.typing as npt import tensorflow as tf -from typing import Any, Protocol, Sequence, Union +from typing import Protocol, Sequence from compiler_opt.rl import policy_saver, registry from tf_agents.networks import network from tf_agents.policies import actor_policy, greedy_policy, tf_policy -# TODO(abenalaast): Replace Any type annotation with HasModelVariables Protocol class HasModelVariables(Protocol): - model_variables = Sequence[Sequence[float]] + model_variables: Sequence[tf.Variable] # TODO(abenalaast): Issue #280 @@ -58,7 +57,8 @@ def create_actor_policy(actor_network_ctor: network.DistributionNetwork, def get_vectorized_parameters_from_policy( - policy: Union[tf_policy.TFPolicy, Any]) -> npt.NDArray[np.float32]: + policy: 'tf_policy.TFPolicy | HasModelVariables' +) -> npt.NDArray[np.float32]: """Returns a policy's variable values as a single np array.""" if isinstance(policy, tf_policy.TFPolicy): variables = policy.variables() @@ -74,8 +74,8 @@ def get_vectorized_parameters_from_policy( def set_vectorized_parameters_for_policy( - policy: Union[tf_policy.TFPolicy, - Any], parameters: npt.NDArray[np.float32]) -> None: + policy: 'tf_policy.TFPolicy | HasModelVariables', + parameters: npt.NDArray[np.float32]) -> None: """Separates values in parameters into the policy's shapes and sets the policy variables to those values""" if isinstance(policy, tf_policy.TFPolicy): @@ -99,9 +99,9 @@ def set_vectorized_parameters_for_policy( f'but only found {param_pos}.') -def save_policy(policy: Union[tf_policy.TFPolicy, - Any], parameters: npt.NDArray[np.float32], - save_folder: str, policy_name: str) -> None: +def save_policy(policy: 'tf_policy.TFPolicy | HasModelVariables', + parameters: npt.NDArray[np.float32], save_folder: str, + policy_name: str) -> None: """Assigns a policy the name policy_name and saves it to the directory of save_folder with the values in parameters.""" From 485b8b1a2e70a336152bcd75f06c41414b3b134e Mon Sep 17 00:00:00 2001 From: Abena Laast Date: Mon, 24 Jul 2023 22:23:06 +0000 Subject: [PATCH 7/8] Restructure value tests and rename variables for clarity --- compiler_opt/es/policy_utils_test.py | 62 ++++++++++++++-------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/compiler_opt/es/policy_utils_test.py b/compiler_opt/es/policy_utils_test.py index d57051f8..68a4eb45 100644 --- a/compiler_opt/es/policy_utils_test.py +++ b/compiler_opt/es/policy_utils_test.py @@ -92,6 +92,12 @@ def test_regalloc_config(self): class VectorTest(absltest.TestCase): + expected_variable_shapes = [(71, 64), (64), (64, 64), (64), (64, 64), (64), + (64, 64), (64), (64, 2), (2)] + expected_length_of_a_perturbation = sum( + np.prod(shape) for shape in expected_variable_shapes) + params = np.arange(expected_length_of_a_perturbation, dtype=np.float32) + def test_set_vectorized_parameters_for_policy(self): # create a policy problem_config = registry.get_configuration(implementation=InliningConfig) @@ -115,41 +121,37 @@ def test_set_vectorized_parameters_for_policy(self): time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_network) - saver = policy_saver.PolicySaver({'policy': policy}) # save the policy + saver = policy_saver.PolicySaver({'policy': policy}) testing_path = self.create_tempdir() policy_save_path = os.path.join(testing_path, 'temp_output/policy') saver.save(policy_save_path) # set the values of the policy variables - length_of_a_perturbation = 17218 - params = np.arange(length_of_a_perturbation, dtype=np.float32) - policy_utils.set_vectorized_parameters_for_policy(policy, params) - expected_variable_shapes = [(71, 64), (64), (64, 64), (64), (64, 64), (64), - (64, 64), (64), (64, 2), (2)] + policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params) # iterate through variables and check their shapes and values - idx = 0 + expected_values = [*VectorTest.params] for i, variable in enumerate(policy.variables()): # pylint: disable=not-callable - self.assertEqual(variable.shape, expected_variable_shapes[i]) - nums = variable.numpy().flatten() - for num in nums: - if idx != num: - raise AssertionError(f'values at index {idx} do not match') - idx += 1 + self.assertEqual(variable.shape, VectorTest.expected_variable_shapes[i]) + variable_values = variable.numpy().flatten() + np.testing.assert_array_almost_equal( + expected_values[:len(variable_values)], variable_values) + expected_values = expected_values[len(variable_values):] + self.assertEmpty(expected_values) # get saved model to test a loaded policy sm = tf.saved_model.load(policy_save_path + '/policy') self.assertNotIsInstance(sm, tf_policy.TFPolicy) - policy_utils.set_vectorized_parameters_for_policy(sm, params) - idx = 0 + policy_utils.set_vectorized_parameters_for_policy(sm, VectorTest.params) + expected_values = [*VectorTest.params] for i, variable in enumerate(sm.model_variables): - self.assertEqual(variable.shape, expected_variable_shapes[i]) - nums = variable.numpy().flatten() - for num in nums: - if idx != num: - raise AssertionError(f'values at index {idx} do not match') - idx += 1 + self.assertEqual(variable.shape, VectorTest.expected_variable_shapes[i]) + variable_values = variable.numpy().flatten() + np.testing.assert_array_almost_equal( + expected_values[:len(variable_values)], variable_values) + expected_values = expected_values[len(variable_values):] + self.assertEmpty(expected_values) def test_get_vectorized_parameters_from_policy(self): # create a policy @@ -174,28 +176,26 @@ def test_get_vectorized_parameters_from_policy(self): time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_network) - saver = policy_saver.PolicySaver({'policy': policy}) # save the policy + saver = policy_saver.PolicySaver({'policy': policy}) testing_path = self.create_tempdir() policy_save_path = os.path.join(testing_path, 'temp_output/policy') saver.save(policy_save_path) - length_of_a_perturbation = 17218 - params = np.arange(length_of_a_perturbation, dtype=np.float32) # functionality verified in previous test - policy_utils.set_vectorized_parameters_for_policy(policy, params) + policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params) # vectorize and check if the outcome is the same as the start output = policy_utils.get_vectorized_parameters_from_policy(policy) - np.testing.assert_array_almost_equal(output, params) + np.testing.assert_array_almost_equal(output, VectorTest.params) # get saved model to test a loaded policy sm = tf.saved_model.load(policy_save_path + '/policy') self.assertNotIsInstance(sm, tf_policy.TFPolicy) - policy_utils.set_vectorized_parameters_for_policy(sm, params) + policy_utils.set_vectorized_parameters_for_policy(sm, VectorTest.params) # vectorize and check if the outcome is the same as the start output = policy_utils.get_vectorized_parameters_from_policy(sm) - np.testing.assert_array_almost_equal(output, params) + np.testing.assert_array_almost_equal(output, VectorTest.params) def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self): # create a policy @@ -220,17 +220,15 @@ def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self): time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_network) - saver = policy_saver.PolicySaver({'policy': policy}) # save the policy + saver = policy_saver.PolicySaver({'policy': policy}) testing_path = self.create_tempdir() policy_save_path = os.path.join(testing_path, 'temp_output/policy') saver.save(policy_save_path) - length_of_a_perturbation = 17218 - params = np.arange(length_of_a_perturbation, dtype=np.float32) # set the values of the variables - policy_utils.set_vectorized_parameters_for_policy(policy, params) + policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params) # save the changes saver.save(policy_save_path) # vectorize the tfpolicy From 74b7605bd1c982a31576cbacd9232b4294456868 Mon Sep 17 00:00:00 2001 From: Abena Laast Date: Tue, 25 Jul 2023 22:31:56 +0000 Subject: [PATCH 8/8] Use os.path.join to form paths --- compiler_opt/es/policy_utils_test.py | 47 +++++++++++++++++++--------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/compiler_opt/es/policy_utils_test.py b/compiler_opt/es/policy_utils_test.py index 68a4eb45..e830f92d 100644 --- a/compiler_opt/es/policy_utils_test.py +++ b/compiler_opt/es/policy_utils_test.py @@ -29,14 +29,15 @@ from compiler_opt.rl.regalloc import RegallocEvictionConfig, regalloc_network -# TODO(abenalaast): Issue #280 class ConfigTest(absltest.TestCase): + # TODO(abenalaast): Issue #280 def test_inlining_config(self): problem_config = registry.get_configuration(implementation=InliningConfig) time_step_spec, action_spec = problem_config.get_signature_spec() + quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab') creator = inlining_config.get_observation_processing_layer_creator( - quantile_file_dir='compiler_opt/rl/inlining/vocab/', + quantile_file_dir=quantile_file_dir, with_sqrt=False, with_z_score_normalization=False) layers = tf.nest.map_structure(creator, time_step_spec.observation) @@ -60,12 +61,14 @@ def test_inlining_config(self): policy._actor_network, # pylint: disable=protected-access actor_distribution_network.ActorDistributionNetwork) + # TODO(abenalaast): Issue #280 def test_regalloc_config(self): problem_config = registry.get_configuration( implementation=RegallocEvictionConfig) time_step_spec, action_spec = problem_config.get_signature_spec() + quantile_file_dir = os.path.join('compiler_opt', 'rl', 'regalloc', 'vocab') creator = regalloc_config.get_observation_processing_layer_creator( - quantile_file_dir='compiler_opt/rl/regalloc/vocab', + quantile_file_dir=quantile_file_dir, with_sqrt=False, with_z_score_normalization=False) layers = tf.nest.map_structure(creator, time_step_spec.observation) @@ -97,13 +100,16 @@ class VectorTest(absltest.TestCase): expected_length_of_a_perturbation = sum( np.prod(shape) for shape in expected_variable_shapes) params = np.arange(expected_length_of_a_perturbation, dtype=np.float32) + POLICY_NAME = 'test_policy_name' + # TODO(abenalaast): Issue #280 def test_set_vectorized_parameters_for_policy(self): # create a policy problem_config = registry.get_configuration(implementation=InliningConfig) time_step_spec, action_spec = problem_config.get_signature_spec() + quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab') creator = inlining_config.get_observation_processing_layer_creator( - quantile_file_dir='compiler_opt/rl/inlining/vocab/', + quantile_file_dir=quantile_file_dir, with_sqrt=False, with_z_score_normalization=False) layers = tf.nest.map_structure(creator, time_step_spec.observation) @@ -123,14 +129,15 @@ def test_set_vectorized_parameters_for_policy(self): actor_network=actor_network) # save the policy - saver = policy_saver.PolicySaver({'policy': policy}) + saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy}) testing_path = self.create_tempdir() - policy_save_path = os.path.join(testing_path, 'temp_output/policy') + policy_save_path = os.path.join(testing_path, 'temp_output', 'policy') saver.save(policy_save_path) # set the values of the policy variables policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params) # iterate through variables and check their shapes and values + # deep copy params in order to destructively iterate over values expected_values = [*VectorTest.params] for i, variable in enumerate(policy.variables()): # pylint: disable=not-callable self.assertEqual(variable.shape, VectorTest.expected_variable_shapes[i]) @@ -138,12 +145,15 @@ def test_set_vectorized_parameters_for_policy(self): np.testing.assert_array_almost_equal( expected_values[:len(variable_values)], variable_values) expected_values = expected_values[len(variable_values):] + # all values in the copy should have been removed at this point self.assertEmpty(expected_values) # get saved model to test a loaded policy - sm = tf.saved_model.load(policy_save_path + '/policy') + load_path = os.path.join(policy_save_path, VectorTest.POLICY_NAME) + sm = tf.saved_model.load(load_path) self.assertNotIsInstance(sm, tf_policy.TFPolicy) policy_utils.set_vectorized_parameters_for_policy(sm, VectorTest.params) + # deep copy params in order to destructively iterate over values expected_values = [*VectorTest.params] for i, variable in enumerate(sm.model_variables): self.assertEqual(variable.shape, VectorTest.expected_variable_shapes[i]) @@ -151,14 +161,17 @@ def test_set_vectorized_parameters_for_policy(self): np.testing.assert_array_almost_equal( expected_values[:len(variable_values)], variable_values) expected_values = expected_values[len(variable_values):] + # all values in the copy should have been removed at this point self.assertEmpty(expected_values) + # TODO(abenalaast): Issue #280 def test_get_vectorized_parameters_from_policy(self): # create a policy problem_config = registry.get_configuration(implementation=InliningConfig) time_step_spec, action_spec = problem_config.get_signature_spec() + quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab') creator = inlining_config.get_observation_processing_layer_creator( - quantile_file_dir='compiler_opt/rl/inlining/vocab/', + quantile_file_dir=quantile_file_dir, with_sqrt=False, with_z_score_normalization=False) layers = tf.nest.map_structure(creator, time_step_spec.observation) @@ -178,9 +191,9 @@ def test_get_vectorized_parameters_from_policy(self): actor_network=actor_network) # save the policy - saver = policy_saver.PolicySaver({'policy': policy}) + saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy}) testing_path = self.create_tempdir() - policy_save_path = os.path.join(testing_path, 'temp_output/policy') + policy_save_path = os.path.join(testing_path, 'temp_output', 'policy') saver.save(policy_save_path) # functionality verified in previous test @@ -190,19 +203,22 @@ def test_get_vectorized_parameters_from_policy(self): np.testing.assert_array_almost_equal(output, VectorTest.params) # get saved model to test a loaded policy - sm = tf.saved_model.load(policy_save_path + '/policy') + load_path = os.path.join(policy_save_path, VectorTest.POLICY_NAME) + sm = tf.saved_model.load(load_path) self.assertNotIsInstance(sm, tf_policy.TFPolicy) policy_utils.set_vectorized_parameters_for_policy(sm, VectorTest.params) # vectorize and check if the outcome is the same as the start output = policy_utils.get_vectorized_parameters_from_policy(sm) np.testing.assert_array_almost_equal(output, VectorTest.params) + # TODO(abenalaast): Issue #280 def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self): # create a policy problem_config = registry.get_configuration(implementation=InliningConfig) time_step_spec, action_spec = problem_config.get_signature_spec() + quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab') creator = inlining_config.get_observation_processing_layer_creator( - quantile_file_dir='compiler_opt/rl/inlining/vocab/', + quantile_file_dir=quantile_file_dir, with_sqrt=False, with_z_score_normalization=False) layers = tf.nest.map_structure(creator, time_step_spec.observation) @@ -222,9 +238,9 @@ def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self): actor_network=actor_network) # save the policy - saver = policy_saver.PolicySaver({'policy': policy}) + saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy}) testing_path = self.create_tempdir() - policy_save_path = os.path.join(testing_path, 'temp_output/policy') + policy_save_path = os.path.join(testing_path, 'temp_output', 'policy') saver.save(policy_save_path) # set the values of the variables @@ -235,7 +251,8 @@ def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self): tf_params = policy_utils.get_vectorized_parameters_from_policy(policy) # get loaded policy - sm = tf.saved_model.load(policy_save_path + '/policy') + load_path = os.path.join(policy_save_path, VectorTest.POLICY_NAME) + sm = tf.saved_model.load(load_path) # vectorize the loaded policy loaded_params = policy_utils.get_vectorized_parameters_from_policy(sm)