Skip to content

Commit

Permalink
[js/rn] Support create boolean tensor (microsoft#17052)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

For some use case need to create boolean tensor.

I've tested on [this
project](https://github.com/hans00/react-native-transformers-example)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Add handle `ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL`

And it required microsoft#15556 (It seems not include in latest release
(v1.15.1))
  • Loading branch information
hans00 authored Sep 14, 2023
1 parent 32f5658 commit ad369a1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,34 @@ public void createInputTensor_double() throws Exception {
outputTensor.close();
}

@Test
public void createInputTensor_bool() throws Exception {
OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, new boolean[] {false, true});

JavaOnlyMap inputTensorMap = new JavaOnlyMap();

JavaOnlyArray dims = new JavaOnlyArray();
dims.pushInt(2);
inputTensorMap.putArray("dims", dims);

inputTensorMap.putString("type", TensorHelper.JsTensorTypeBool);

ByteBuffer dataByteBuffer = ByteBuffer.allocate(2);
dataByteBuffer.put((byte)0);
dataByteBuffer.put((byte)1);
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));

OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);

Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
Assert.assertEquals(inputTensor.toString(), outputTensor.toString());
Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array());

inputTensor.close();
outputTensor.close();
}

@Test
public void createOutputTensor_bool() throws Exception {
MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ private static OnnxTensor createInputTensor(TensorInfo.OnnxTensorType tensorType
tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.UINT8);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
ByteBuffer buffer = values;
tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.BOOL);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
Expand Down

0 comments on commit ad369a1

Please sign in to comment.