Skip to content

Latest commit

 

History

History

pytorch_convert

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 

Model to TorchScript

Introduction

This directory contains Python code to patch and save the Segment Anything Model (SAM) as TorchScript to a new file. The implementation is meant to be more of a proof of concept and a place where to start.

This TorchScript model takes 1 image as input and outputs the mask in the center of the image.

We used the segment_anything Python package provided by the Segment Anything Model repository.

Installation & Usage

Create a virtual environment and install the dependencies:

python3 -m venv .venv

source .venv/bin/activate

pip install -r requirements.txt

Run the entry-point script:

python sam_convert.py

How to Convert your Own Models to TorchScript

Tracing vs Scripting

The PyTorch JIT api exposes two ways to convert your model to TorchScript:

The advantage of tracing is create a smaller and more optimized model. The disadvantage is that it only works for a subset of Python's features and does not work with conditional branches.

Scripting on the other hand, is more flexible and can handle more Python features. However, it is less optimized and can be slower than tracing.

The first step is to evaluate whether you can trace your model. You can read more about the limitations here.

As a summary of the limitations, you cannot trace a model that:

  • Has control flow (e.g. if statements) or loops
  • Has data structures that are not tensors or tuples/lists/dicts of tensors
  • Has function or modules that are data dependent
  • Has untracked external dependencies

Wrap the Model

If you would like to add a wrapper to the model, you can create a new model class that extends torch.nn.Module and add the wrapper logic in the forward method.

class MyModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x: torch.Tensor):
        # Add wrapper logic here
        return self.model(x)

You can have a look at sam_predict_base_model.py for an example of how to wrap the SAM model.

Convert the Model to TorchScript

Next, you can try to convert your model to TorchScript. In this example, we used the script method as the SAM model has control flow in it's layers.

You will have to include example inputs to the script method. The TorchScript compiler will use these inputs to infer the types of the variables.

model = MyModel(model)
# Set the model to evaluation mode
model.eval()

example_inputs = [
    (torch.rand(3, 256, 256),) # The tuple is the input to the forward method
] # You can add more example inputs

# Convert the model to TorchScript
scripted_model = torch.jit.script(model, example_inputs={model: example_inputs})

Doing for the first time will likely result in an error if you are using a model with Python specific logic. You will have to modify/patch your model to remove the unsupported features which brings us to the next step. If you do not have any errors, you can skip the next step.

Modify/Patch the Model

In our case, the SAM model had a lot of custom logic in the forward method. This made it difficult to script the model. Therefore, we had to patch the model to remove the custom logic.

You will have to modify or patch a model. Modifying is the preferred approach as it is easier to maintain and keep track of changes. However, if you would like to do minimal changes without needing to maintain a fork of the model, you can patch the model.

In our case, we patched the model by using the mock library. The patches are located under the patches/ directory. You can read more about the mock library here.

Each file is named after the package's (segment_anything) original file name and exports a patches variable containing a tuple of patches.

The patches/__init__.py files imports these patches and applies them.

You can patch a class function the following way:

import mock

example_patch = mock.patch.object(<class>, "<method name to patch>", <patched function>),

with example_patch as mock_example:
    # Do something
    pass

In order to avoid nested with statements, you can directly call the __enter__() method:

example_patch.__enter__()

# Do something

Tips

Fixing the SAM model required a lot of trial and error. Here are some tips that might help you:

  • Start converting the least amount of code, and gradually add more in. This will be easier to trace down errors.
  • The error messages are not always very helpful and can sometimes be misleading. You will have to debug the code to find the error.
  • Avoid any 'Pythonic' code. For example, the SAM model used lists with multiple types of data.
  • Use type hints to initialize variables to None. For example, x: torch.Tensor = None. This will help the TorchScript compiler infer the type of the variable.

Save the TorchScript Model

Once you have successfully converted your model to TorchScript, you can save it to a file:

scripted_model.save("model.pt")

Resources