From b0254eba8ed95494ca470cf99e4f4fab4b26c6e2 Mon Sep 17 00:00:00 2001 From: Paul Wohlhart Date: Mon, 29 Jan 2024 08:58:04 -0800 Subject: [PATCH] Use tree.flatten_with_path on structures in assert_same_structure exception message. This sorts the structure by key alphabetically and thus makes it much easier to compare. PiperOrigin-RevId: 602402622 Change-Id: I71122ec670b8661253589a6a6d8cb867588a50db --- tf_agents/utils/nest_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tf_agents/utils/nest_utils.py b/tf_agents/utils/nest_utils.py index 8eb0c3779..f7543e5b5 100644 --- a/tf_agents/utils/nest_utils.py +++ b/tf_agents/utils/nest_utils.py @@ -29,6 +29,7 @@ import tensorflow as tf from tf_agents.typing import types from tf_agents.utils import composite +import tree import wrapt # TODO(b/128613858): Update to a public facing API. @@ -122,11 +123,15 @@ def assert_same_structure( exception = type(e) if exception: - str1 = tf.nest.map_structure( - lambda _: _DOT, nest1, expand_composites=expand_composites + str1 = tree.flatten_with_path( + tf.nest.map_structure( + lambda _: _DOT, nest1, expand_composites=expand_composites + ) ) - str2 = tf.nest.map_structure( - lambda _: _DOT, nest2, expand_composites=expand_composites + str2 = tree.flatten_with_path( + tf.nest.map_structure( + lambda _: _DOT, nest2, expand_composites=expand_composites + ) ) raise exception( '{}:\n {}\nvs.\n {}\nValues:\n {}\nvs.\n {}.'.format(