-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Valentina <valentina-kustikova@users.noreply.github.com> Co-authored-by: Кустикова Валентина Дмитриевна <kustikova.v@itmm.unn.net> Co-authored-by: valentina-kustikova <valentina.kustikova@gmail.com>
- Loading branch information
1 parent
d944b0b
commit 40c3036
Showing
7 changed files
with
216 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Conversion to the PaddlePaddle format | ||
|
||
PaddlePaddle converter supports conversion to the PaddlePaddle format | ||
from PyTorch and ONNX formats. | ||
|
||
## PaddlePaddle converter usage | ||
|
||
Usage of the script: | ||
|
||
```bash | ||
python srcf2paddle.py -m <path/to/input/model> -f <source_framework> \ | ||
-p <PyTorch/module/name> -d <output_directory> | ||
``` | ||
|
||
### Paddle converter parameters | ||
|
||
- `-m / --model_path` is a path to an .onnx or .pth file with the original model. | ||
- `-f / --framework` is a source framework for convertion to the PaddlePaddle format. | ||
- `-p / --pytorch_module_name` is a module name for the PyTorch model (it is required | ||
if source framework is PyTorch). | ||
- `-d / --save_dir` is a directory for converted model to be saved to. | ||
|
||
### Examples of usage | ||
|
||
```bash | ||
python srcf2paddle.py -m .\public\googlenet-v3-pytorch\inception_v3_google-1a9a5a14.pth \ | ||
-f pytorch -p InceptionV3 -d pd | ||
``` | ||
|
||
```bash | ||
python srcf2paddle.py -m .\public\ctdet_coco_dlav0_512\ctdet_coco_dlav0_512.onnx \ | ||
-f onnx -d pd | ||
``` | ||
|
||
# Conversion from the PaddlePaddle to the ONNX format | ||
|
||
paddle2onnx converter supports conversion to the ONNX format from the PaddlePaddle | ||
format. | ||
|
||
## PaddlePaddle converter usage | ||
|
||
Usage of the script: | ||
|
||
```bash | ||
python paddle2onnx.py -d .\pd_pth\inference_model -f model.pdmodel \ | ||
-p model.pdiparams -m inference.onnx -o 11 | ||
``` | ||
|
||
### Converter parameters | ||
|
||
- `-d / --model_dir` is a path to the directory with the original model. | ||
- `-f / --model_filename` is a model file name. | ||
- `-p / --params_filename` is a parameters file name. | ||
- `-m / --model_path` is a path to the resulting .onnx file. | ||
- `-o / --opset_version` is a desired opset version of the ONNX model. | ||
|
||
### Examples of usage | ||
|
||
```bash | ||
python paddle2onnx.py -d .\pd_pth\inference_model -f model.pdmodel \ | ||
-p model.pdiparams -m inference.onnx -o 11 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import argparse | ||
import logging as log | ||
import sys | ||
from pathlib import Path | ||
import os | ||
|
||
sys.path.append(str(Path(__file__).parent.parent.parent.parent)) | ||
|
||
|
||
def cli_argument_parser(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('-d', '--model_dir', | ||
help='Directory to save model in.', | ||
required=True, | ||
type=str, | ||
dest='model_dir') | ||
parser.add_argument('-f', '--model_filename', | ||
help='Name of the model file name.', | ||
required=True, | ||
type=str, | ||
dest='model_filename') | ||
parser.add_argument('-p', '--params_filename', | ||
help='Name of the parameters file name.', | ||
required=True, | ||
type=str, | ||
dest='params_filename') | ||
parser.add_argument('-m', '--model_path', | ||
help='Path to an .onnx file.', | ||
required=True, | ||
type=str, | ||
dest='model_path') | ||
parser.add_argument('-o', '--opset_version', | ||
help='', | ||
required=True, | ||
type=str, | ||
dest='opset_version') | ||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
def convert_paddle_to_onnx(model_dir: str, model_filename: str, params_filename: str, | ||
model_path: str, opset_version: str): | ||
os.system(f"""paddle2onnx --model_dir {model_dir} --model_filename {model_filename} | ||
--params_filename {params_filename} | ||
--save_file {model_path} | ||
--opset_version {opset_version} | ||
--enable_onnx_checker True""") | ||
|
||
|
||
def main(): | ||
log.basicConfig(format='[ %(levelname)s ] %(message)s', | ||
level=log.INFO, stream=sys.stdout) | ||
args = cli_argument_parser() | ||
convert_paddle_to_onnx(model_dir=args.model_dir, model_filename=args.model_filename, | ||
params_filename=args.params_filename, model_path=args.model_path, | ||
opset_version=args.opset_version) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
x2paddle | ||
paddle2onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import torch | ||
import numpy as np | ||
import torchvision.models as models | ||
from x2paddle.convert import pytorch2paddle | ||
import argparse | ||
import logging as log | ||
import sys | ||
from pathlib import Path | ||
import os | ||
|
||
sys.path.append(str(Path(__file__).parent.parent.parent.parent)) | ||
|
||
|
||
def get_model_by_name(model_name): | ||
try: | ||
model_constructor = getattr(models, model_name) | ||
model = model_constructor() | ||
return model | ||
except AttributeError: | ||
raise ValueError(f'Model {model_name} is not found in torchvision.models') | ||
|
||
|
||
def cli_argument_parser(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('-m', '--model_path', | ||
help='Path to an .onnx or .pth file.', | ||
required=True, | ||
type=str, | ||
dest='model_path') | ||
parser.add_argument('-f', '--framework', | ||
help='Original model framework (ONNX or PyTorch)', | ||
required=True, | ||
type=str, | ||
choices=['onnx', 'pytorch'], | ||
dest='framework') | ||
parser.add_argument('-p', '--pytorch_module_name', | ||
help='Module name for PyTorch model.', | ||
required=False, | ||
type=str, | ||
choices=['AlexNet', 'VGG', 'ResNet', 'SqueezeNet', 'DenseNet', | ||
'InceptionV3', 'GoogLeNet', 'ShuffleNetV2', 'MobileNetV2', | ||
'MobileNetV3', 'MNASNet', 'EfficientNet'], | ||
dest='module_name') | ||
parser.add_argument('-d', '--save_dir', | ||
help='Directory for converted model to be saved to.', | ||
required=True, | ||
type=str, | ||
dest='save_dir') | ||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
def convert_pytorch_to_paddle(model_path: str, module_name, save_dir: str): | ||
|
||
model = get_model_by_name(module_name) | ||
model.load_state_dict(torch.load(model_path)) | ||
model.eval() | ||
|
||
input_data = np.random.rand(1, 3, 224, 224).astype('float32') | ||
pytorch2paddle(model, | ||
save_dir=save_dir, | ||
jit_type='trace', | ||
input_examples=[torch.tensor(input_data)]) | ||
|
||
|
||
def convert_onnx_to_paddle(model_path: str, save_dir: str): | ||
print(f'x2paddle --framework=onnx --model={model_path} --save_dir={save_dir}') | ||
os.system(f'x2paddle --framework=onnx --model={model_path} --save_dir={save_dir}') | ||
|
||
|
||
def main(): | ||
log.basicConfig(format='[ %(levelname)s ] %(message)s', | ||
level=log.INFO, stream=sys.stdout) | ||
args = cli_argument_parser() | ||
if args.framework == 'pytorch' and not args.module_name: | ||
raise ValueError('Module name for pytorch is not specified') | ||
elif args.framework == 'pytorch' and args.module_name: | ||
convert_pytorch_to_paddle(model_path=args.model_path, | ||
module_name=args.module_name, save_dir=args.save_dir) | ||
elif args.framework == 'onnx': | ||
convert_onnx_to_paddle(model_path=args.model_path, save_dir=args.save_dir) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |