From d910b8aefdb3a247f441557ba413d6a352456904 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 21 Jun 2024 09:17:51 -0700 Subject: [PATCH] Fix the soft_actor_critic model (#2326) Summary: Unfortunately, https://github.com/pytorch/benchmark/pull/2318 has a bug that breaks the `soft_actor_critic` model. Pull Request resolved: https://github.com/pytorch/benchmark/pull/2326 Reviewed By: aaronenyeshi Differential Revision: D58871386 Pulled By: xuzhao9 fbshipit-source-id: 5f8b5fbe00722ccb647b08a8089fd52a7719208c --- torchbenchmark/models/soft_actor_critic/nets.py | 12 ++++++------ torchbenchmark/models/soft_actor_critic/sac.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchbenchmark/models/soft_actor_critic/nets.py b/torchbenchmark/models/soft_actor_critic/nets.py index d938904c42..fd2f1690c4 100644 --- a/torchbenchmark/models/soft_actor_critic/nets.py +++ b/torchbenchmark/models/soft_actor_critic/nets.py @@ -4,7 +4,7 @@ from torch import distributions as pyd from torch import nn -from . import utils +from . import sac_utils from torchbenchmark.util.distribution import SquashedNormal def weight_init(m): @@ -30,11 +30,11 @@ def __init__(self, obs_shape, out_dim=50): self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1) self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1) - output_height, output_width = utils.compute_conv_output( + output_height, output_width = sac_utils.compute_conv_output( obs_shape[1:], kernel_size=(3, 3), stride=(2, 2) ) for _ in range(3): - output_height, output_width = utils.compute_conv_output( + output_height, output_width = sac_utils.compute_conv_output( (output_height, output_width), kernel_size=(3, 3), stride=(1, 1) ) @@ -63,15 +63,15 @@ def __init__(self, obs_shape, out_dim=50): self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) - output_height, output_width = utils.compute_conv_output( + output_height, output_width = sac_utils.compute_conv_output( obs_shape[1:], kernel_size=(8, 8), stride=(4, 4) ) - output_height, output_width = utils.compute_conv_output( + output_height, output_width = sac_utils.compute_conv_output( (output_height, output_width), kernel_size=(4, 4), stride=(2, 2) ) - output_height, output_width = utils.compute_conv_output( + output_height, output_width = sac_utils.compute_conv_output( (output_height, output_width), kernel_size=(3, 3), stride=(1, 1) ) diff --git a/torchbenchmark/models/soft_actor_critic/sac.py b/torchbenchmark/models/soft_actor_critic/sac.py index dd590a8d02..c67da4fca5 100644 --- a/torchbenchmark/models/soft_actor_critic/sac.py +++ b/torchbenchmark/models/soft_actor_critic/sac.py @@ -3,7 +3,7 @@ import numpy as np import torch -from . import envs, nets, replay, utils +from . import envs, nets, replay, sac_utils class SACAgent: