diff --git a/tf_agents/utils/nest_utils.py b/tf_agents/utils/nest_utils.py index 8eb0c3779..f7543e5b5 100644 --- a/tf_agents/utils/nest_utils.py +++ b/tf_agents/utils/nest_utils.py @@ -29,6 +29,7 @@ import tensorflow as tf from tf_agents.typing import types from tf_agents.utils import composite +import tree import wrapt # TODO(b/128613858): Update to a public facing API. @@ -122,11 +123,15 @@ def assert_same_structure( exception = type(e) if exception: - str1 = tf.nest.map_structure( - lambda _: _DOT, nest1, expand_composites=expand_composites + str1 = tree.flatten_with_path( + tf.nest.map_structure( + lambda _: _DOT, nest1, expand_composites=expand_composites + ) ) - str2 = tf.nest.map_structure( - lambda _: _DOT, nest2, expand_composites=expand_composites + str2 = tree.flatten_with_path( + tf.nest.map_structure( + lambda _: _DOT, nest2, expand_composites=expand_composites + ) ) raise exception( '{}:\n {}\nvs.\n {}\nValues:\n {}\nvs.\n {}.'.format(