Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590298665
Change-Id: Ie389bec30e47c40cf0da7d86f91a29c56c290725
  • Loading branch information
TF-Agents Team authored and copybara-github committed Dec 12, 2023
1 parent 06705c6 commit 20476b4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
3 changes: 2 additions & 1 deletion tf_agents/policies/py_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

"""Converts TensorFlow Policies into Python Policies."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(
)

self._tf_policy = policy
self.session = None
self.session: tf.compat.v1.Session = None

self._policy_state_spec = tensor_spec.to_nest_array_spec(
self._tf_policy.policy_state_spec
Expand Down
32 changes: 22 additions & 10 deletions tf_agents/replay_buffers/reverb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ def __init__(
self._overflow_episode = False
self._writer_has_data = False

def _get_writer(self):
if self._writer is None:
raise ValueError("Could not obtain writer from py_client.")
return self._writer

def update_priority(self, priority: Union[float, int]) -> None:
"""Update the table priority.
Expand Down Expand Up @@ -181,7 +186,7 @@ def __call__(self, trajectory: trajectory_lib.Trajectory) -> None:
return

if not self._overflow_episode:
self._writer.append(trajectory)
self._get_writer().append(trajectory)
self._writer_has_data = True
self._cached_steps += 1

Expand All @@ -197,9 +202,11 @@ def _write_cached_steps(self) -> None:
# Only writes to Reverb when the writer has cached trajectories.
if self._writer_has_data:
# No need to truncate since the truncation is done in the class.
trajectory = tf.nest.map_structure(lambda h: h[:], self._writer.history)
trajectory = tf.nest.map_structure(
lambda h: h[:], self._get_writer().history
)
for table_name in self._table_names:
self._writer.create_item(
self._get_writer().create_item(
table=table_name, trajectory=trajectory, priority=self._priority
)
self._writer_has_data = False
Expand All @@ -214,7 +221,7 @@ def flush(self):
By calling this method before the `learner.run()`, we ensure that there is
enough data to be consumed.
"""
self._writer.flush()
self._get_writer().flush()

def reset(self, write_cached_steps: bool = True) -> None:
"""Resets the state of the observer.
Expand Down Expand Up @@ -347,6 +354,11 @@ def __init__(
self._cached_steps = 0
self._last_trajectory = None

def _get_writer(self):
if self._writer is None:
raise ValueError("Could not obtain writer from py_client.")
return self._writer

@property
def py_client(self) -> types.ReverbClient:
return self._py_client
Expand All @@ -367,7 +379,7 @@ def __call__(self, trajectory: trajectory_lib.Trajectory) -> None:
there is *no* batch dimension.
"""
self._last_trajectory = trajectory
self._writer.append(trajectory)
self._get_writer().append(trajectory)
self._cached_steps += 1

# If the fixed sequence length is reached, write the sequence.
Expand Down Expand Up @@ -395,10 +407,10 @@ def _write_cached_steps(self) -> None:

if self._sequence_lengths_reached():
trajectory = tf.nest.map_structure(
lambda d: d[-self._sequence_length :], self._writer.history
lambda d: d[-self._sequence_length :], self._get_writer().history
)
for table_name in self._table_names:
self._writer.create_item(
self._get_writer().create_item(
table=table_name, trajectory=trajectory, priority=self._priority
)

Expand All @@ -412,7 +424,7 @@ def flush(self):
By calling this method before the `learner.run()`, we ensure that there is
enough data to be consumed.
"""
self._writer.flush()
self._get_writer().flush()

def _get_padding_step(
self, example_trajectory: trajectory_lib.Trajectory
Expand Down Expand Up @@ -452,7 +464,7 @@ def reset(self, write_cached_steps: bool = True) -> None:
pad_range = range(self._sequence_length - self._cached_steps)

for _ in pad_range:
self._writer.append(zero_step)
self._get_writer().append(zero_step)
self._cached_steps += 1
self._write_cached_steps()
# Write the cached trajectories without padding, if the cache contains
Expand Down Expand Up @@ -522,7 +534,7 @@ def __call__(self, trajectory: trajectory_lib.Trajectory) -> None:
# We record the last step called to generate the padding step when padding
# is enabled.
self._last_trajectory = trajectory
self._writer.append(trajectory)
self._get_writer().append(trajectory)
self._cached_steps += 1

self._write_cached_steps()
2 changes: 1 addition & 1 deletion tf_agents/utils/session_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def session(self, session):
"""

@property
def session(self):
def session(self) -> tf.compat.v1.Session:
"""Returns the TensorFlow session-like object used by this object.
Returns:
Expand Down

0 comments on commit 20476b4

Please sign in to comment.