diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index dd00146cd..e920331bc 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,9 +9,12 @@ import tensorflow as tf -from .custom_tf_addons import rotate_img -from .custom_tf_addons import transform -from .custom_tf_addons import translate +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + rotate_img +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + transform +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + translate # This signifies the max integer that the controller RNN could predict for the # augmentation scheme.