diff --git a/torchbenchmark/models/soft_actor_critic/nets.py b/torchbenchmark/models/soft_actor_critic/nets.py index d938904c4..fd2f1690c 100644 --- a/torchbenchmark/models/soft_actor_critic/nets.py +++ b/torchbenchmark/models/soft_actor_critic/nets.py @@ -4,7 +4,7 @@ from torch import distributions as pyd from torch import nn -from . import utils +from . import sac_utils from torchbenchmark.util.distribution import SquashedNormal def weight_init(m): @@ -30,11 +30,11 @@ def __init__(self, obs_shape, out_dim=50): self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1) self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1) - output_height, output_width = utils.compute_conv_output( + output_height, output_width = sac_utils.compute_conv_output( obs_shape[1:], kernel_size=(3, 3), stride=(2, 2) ) for _ in range(3): - output_height, output_width = utils.compute_conv_output( + output_height, output_width = sac_utils.compute_conv_output( (output_height, output_width), kernel_size=(3, 3), stride=(1, 1) ) @@ -63,15 +63,15 @@ def __init__(self, obs_shape, out_dim=50): self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) - output_height, output_width = utils.compute_conv_output( + output_height, output_width = sac_utils.compute_conv_output( obs_shape[1:], kernel_size=(8, 8), stride=(4, 4) ) - output_height, output_width = utils.compute_conv_output( + output_height, output_width = sac_utils.compute_conv_output( (output_height, output_width), kernel_size=(4, 4), stride=(2, 2) ) - output_height, output_width = utils.compute_conv_output( + output_height, output_width = sac_utils.compute_conv_output( (output_height, output_width), kernel_size=(3, 3), stride=(1, 1) ) diff --git a/torchbenchmark/models/soft_actor_critic/sac.py b/torchbenchmark/models/soft_actor_critic/sac.py index dd590a8d0..c67da4fca 100644 --- a/torchbenchmark/models/soft_actor_critic/sac.py +++ b/torchbenchmark/models/soft_actor_critic/sac.py @@ -3,7 +3,7 @@ import numpy as np import torch -from . import envs, nets, replay, utils +from . import envs, nets, replay, sac_utils class SACAgent: