This repository is a Maven project that wraps the Segment Anything Model. You can read more about the converting process under the /pytorch_convert/README.md
directory.
To interface with the TorchScript model, we used the DJL framework. DJL is a deep learning framework for Java that supports PyTorch, TensorFlow, and MXNet. It also provides a Java API to load and run TorchScript models.
The project is structured as follows:
/pytorch_convert
: Python code to patch and save the Segment Anything Model (SAM) as TorchScript to a new file./src/main/java/djlsam/Sam.java
: Java code to load the TorchScript model and run inference./src/main/java/djlsam/translators
: Java classes to convert the input/output tensors to/from the TorchScript model./src/main/test/djlsam/SamTest.java
: Java code to test the model./src/resources/images
: Test images./src/resources/pytorch_models
: TorchScript models.
It is recommended to use an IDE such as IntelliJ IDEA to run the project.
To install the dependencies, run the following command:
mvn clean install
To run the tests, run the following command:
mvn test
Before implementing a model with the DJL
framework, you should first convert your model to TorchScript.
You can also find example from the DJL
documentation here.
Add the following dependencies to your pom.xml
file:
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.21.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.22.0</version>
<scope>runtime</scope>
</dependency>
Note: You can find the latest version of the dependencies here.
Create a new class for your model. Within the class you can load the TorchScript model and run inference. You can find an example here.
The main idea is to create the following objects:
Translator<Image, SamRawOutput> translator;
Criteria<Image, SamRawOutput> criteria;
ZooModel<Image, SamRawOutput> model;
Predictor<Image, SamRawOutput> predictor;
Each object has an input and output type which should match the input and output types of the translator
object.
DJL has many input/output types as well as translators already implemented. You can find them here.
The translator
object is used to convert the input/output tensors to/from the TorchScript model. You can find an example here.
It overrides the following methods:
processInput(TranslatorContext ctx, Image input)
to convert the input image to aNDList
object.processOutput(TranslatorContext ctx, NDList list)
to convert the outputNDList
object to aSamRawOutput
object.
The SamRawOutput
object is a custom class wrapper that contains the output tensors of the model. You can find an example here.
The criteria
object is used to specify the input and output types of the model. You can find an example here.
Note: The path of the TorchScript model must be a directory that contains the
.pt
file and it must have the same name as the directory.
By calling the method criteria.loadModel();
, the model
object is created. You can find an example here.
Finally, the predictor
object is created by calling the method model.newPredictor();
. You can find an example here.
- You can use the
NDManager
object to createNDArray
objects.