Skip to content

Commit

Permalink
Support forward compatibility (sony#81)
Browse files Browse the repository at this point in the history
Forward compatibility with 1.3.0 isn't tested because there's a bug in that version's tests
  • Loading branch information
elad-c authored Mar 6, 2024
1 parent 84e9b19 commit 4bbc18e
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 7 deletions.
57 changes: 57 additions & 0 deletions .github/workflows/forward_compatibility_keras_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
name: Run Keras Quantizers Forward Compatibility Tests

on:
workflow_call:
inputs:
load_version:
description: 'MCT Quantizers version to load models'
required: true
type: string
python_version:
description: 'Python version'
required: true
type: string
default: '3.10.*'
tf_version:
description: 'TF version'
required: true
type: string
default: '2.12.*'

jobs:
run-tensorflow-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install Python 3
uses: actions/setup-python@v1
with:
python-version: ${{ inputs.python_version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install tensorflow==${{ inputs.tf_version }}
- name: Run save model tests with latest version
run: |
cd tests
echo "Current directory: $PWD"
export PYTHONPATH="$PWD:${PYTHONPATH}"
echo "Updated PYTHONPATH: $PYTHONPATH"
cd ..
python tests/compatibility_tests/keras_comp_tests/compatibility_weights_save_model_test_suite.py ${{ inputs.load_version }}
python tests/compatibility_tests/keras_comp_tests/compatibility_activation_save_model_test_suite.py ${{ inputs.load_version }}
- name: Checkout to MCT Quantizers requested tag for loading test models
run: |
git checkout tags/${{ inputs.load_version }}
- name: Run load model tests with load_version
run: |
cd tests
echo "Current directory: $PWD"
export PYTHONPATH="$PWD:${PYTHONPATH}"
echo "Updated PYTHONPATH: $PYTHONPATH"
cd ..
python tests/compatibility_tests/keras_comp_tests/compatibility_weights_load_model_test_suite.py ${{ inputs.load_version }}
python tests/compatibility_tests/keras_comp_tests/compatibility_activation_load_model_test_suite.py ${{ inputs.load_version }}
57 changes: 57 additions & 0 deletions .github/workflows/forward_compatibility_torch_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
name: Run Torch Quantizers Forward Compatibility Tests

on:
workflow_call:
inputs:
load_version:
description: 'MCT Quantizers version to load models'
required: true
type: string
python_version:
description: 'Python version'
required: true
type: string
default: '3.10.*'
torch_version:
description: 'Torch version'
required: true
type: string
default: '2.0.*'

jobs:
run-torch-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install Python 3
uses: actions/setup-python@v1
with:
python-version: ${{ inputs.python_version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install torch==${{ inputs.torch_version }} onnx onnxruntime onnxruntime-extensions
- name: Run save model tests with latest version
run: |
cd tests
echo "Current directory: $PWD"
export PYTHONPATH="$PWD:${PYTHONPATH}"
echo "Updated PYTHONPATH: $PYTHONPATH"
cd ..
python tests/compatibility_tests/torch_comp_tests/compatibility_weights_save_model_test_suite.py ${{ inputs.load_version }}
python tests/compatibility_tests/torch_comp_tests/compatibility_activation_save_model_test_suite.py ${{ inputs.load_version }}
- name: Checkout to MCT Quantizers requested tag for loading test models
run: |
git checkout tags/${{ inputs.load_version }}
- name: Run load model tests with load_version
run: |
cd tests
echo "Current directory: $PWD"
export PYTHONPATH="$PWD:${PYTHONPATH}"
echo "Updated PYTHONPATH: $PYTHONPATH"
cd ..
python tests/compatibility_tests/torch_comp_tests/compatibility_weights_load_model_test_suite.py ${{ inputs.load_version }}
python tests/compatibility_tests/torch_comp_tests/compatibility_activation_load_model_test_suite.py ${{ inputs.load_version }}
13 changes: 13 additions & 0 deletions .github/workflows/run_forward_comp_test_tf212_v14.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: Run Forward Compatibility Test - Tensorflow 2.12 MCTQ v1.4.0
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 0 * * *

jobs:
run-forward-comp-tensorflow-2_12-v1_4:
uses: ./.github/workflows/forward_compatibility_keras_tests.yml
with:
load_version: "v1.4.0"
python_version: "3.10"
tf_version: "2.12.*"
13 changes: 13 additions & 0 deletions .github/workflows/run_forward_comp_test_tf213_v14.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: Run Forward Compatibility Test - Tensorflow 2.13 MCTQ v1.4.0
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 0 * * *

jobs:
run-forward-comp-tensorflow-2_13-v1_4:
uses: ./.github/workflows/forward_compatibility_keras_tests.yml
with:
load_version: "v1.4.0"
python_version: "3.10"
tf_version: "2.13.*"
13 changes: 13 additions & 0 deletions .github/workflows/run_forward_comp_test_tf214_v14.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: Run Forward Compatibility Test - Tensorflow 2.14 MCTQ v1.4.0
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 0 * * *

jobs:
run-forward-comp-tensorflow-2_14-v1_4:
uses: ./.github/workflows/forward_compatibility_keras_tests.yml
with:
load_version: "v1.4.0"
python_version: "3.10"
tf_version: "2.14.*"
13 changes: 13 additions & 0 deletions .github/workflows/run_forward_comp_test_torch20_v14.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: Run Forward Compatibility Test - Pytorch 2.0 MCTQ v1.4.0
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 0 * * *

jobs:
run-comp-torch-2_0-v1_4:
uses: ./.github/workflows/forward_compatibility_torch_tests.yml
with:
load_version: "v1.4.0"
python_version: "3.10"
torch_version: "2.0.*"
13 changes: 13 additions & 0 deletions .github/workflows/run_forward_comp_test_torch21_v14.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: Run Forward Compatibility Test - Pytorch 2.1 MCTQ v1.4.0
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 0 * * *

jobs:
run-comp-torch-2_1-v1_4:
uses: ./.github/workflows/forward_compatibility_torch_tests.yml
with:
load_version: "v1.4.0"
python_version: "3.10"
torch_version: "2.1.*"
14 changes: 8 additions & 6 deletions mct_quantizers/keras/quantize_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,15 @@ def get_config(self):
"""
base_config = super(KerasQuantizationWrapper, self).get_config()
config = {WEIGHTS_QUANTIZERS: {k: keras.utils.serialize_keras_object(v) for k, v in self.weights_quantizers.items()},
WEIGHTS_VALUES: {k: self.serialize_fn(v) for k, v in self.weight_values.items()}}

config = {WEIGHTS_QUANTIZERS: {k: keras.utils.serialize_keras_object(v) for k, v in self.weights_quantizers.items()}}
# Only create the wrapper attributes that handle positional weights if they exist, so the wrapper is forward
# compatible with older MCTQ versions (at least until MCT will start quantizing positional weights)
if len(self.weight_values) > 0:
config[WEIGHTS_VALUES] = {k: self.serialize_fn(v) for k, v in self.weight_values.items()}
config[OP_CALL_ARGS] = self.op_call_args
config[OP_CALL_KWARGS] = self.op_call_kwargs
config[IS_INPUT_AS_LIST] = self.is_inputs_as_list
return_config = {**base_config, **config}
return_config[OP_CALL_ARGS] = self.op_call_args
return_config[OP_CALL_KWARGS] = self.op_call_kwargs
return_config[IS_INPUT_AS_LIST] = self.is_inputs_as_list
return_config[MCTQ_VERSION] = self._mctq_version

return return_config
Expand Down
2 changes: 1 addition & 1 deletion mct_quantizers/pytorch/quantize_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self,
representing the input index in the function\layer's inputs.
Args:
layer: A pytorch module or as function.
module: A pytorch module or as function.
weights_quantizers: A dictionary between a weight's name or position to its quantizer.
weight_values: A dictionary between a weight's position to its value.
op_call_args: A list containing the layer's call arguments.
Expand Down

0 comments on commit 4bbc18e

Please sign in to comment.