Skip to content

Commit

Permalink
Fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Apr 19, 2024
1 parent 0a3c076 commit 8be4575
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2778,7 +2778,10 @@ def _export_quantized_model(

device = "cpu"
example_input_shape = next(iter(quantization_result.calibration_dataloader))[0].size()
input_shape_with_explicit_batch = tuple(export_params.batch_size, *example_input_shape[1:])
input_shape_with_explicit_batch = tuple([export_params.batch_size, *example_input_shape[1:]])
if export_params.input_image_shape is not None:
input_shape_with_explicit_batch = input_shape_with_explicit_batch[: -len(export_params.input_image_shape)] + export_params.input_image_shape

onnx_input = torch.randn(input_shape_with_explicit_batch).to(device=device)
onnx_export_kwargs = export_params.onnx_export_kwargs or {}
model_to_export = quantization_result.quantized_model
Expand Down
7 changes: 6 additions & 1 deletion tests/recipe_training_tests/coded_qat_launch_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import unittest

from super_gradients.import_utils import import_pytorch_quantization_or_install
from torchvision.transforms import Normalize, ToTensor, RandomHorizontalFlip, RandomCrop

from super_gradients import Trainer
from super_gradients.training import modify_params_for_qat

from super_gradients.training.dataloaders.dataloaders import cifar10_train, cifar10_val
from super_gradients.training.metrics import Accuracy, Top5
from super_gradients.training.models import ResNet18

from super_gradients.training.utils.quantization.tensorrt.functional import modify_params_for_qat

import_pytorch_quantization_or_install()


class CodedQATLuanchTest(unittest.TestCase):
def test_qat_launch(self):
Expand Down

0 comments on commit 8be4575

Please sign in to comment.