Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
Browse files Browse the repository at this point in the history
…wgpu
  • Loading branch information
xadupre committed Jun 30, 2023
2 parents e4f2fb0 + 2fd25de commit 3db9066
Show file tree
Hide file tree
Showing 55 changed files with 4,183 additions and 1,838 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"profiles": {
"WSL": {
"commandName": "WSL2",
"distributionName": ""
}
}
}
Original file line number Diff line number Diff line change
@@ -1,32 +1,41 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using Microsoft.ML.OnnxRuntime.Tensors;
using System.Linq;

namespace Microsoft.ML.OnnxRuntime.InferenceSample
{
public class InferenceSampleApi : IDisposable
{
public InferenceSampleApi()
{
model = LoadModelFromEmbeddedResource("TestData.squeezenet.onnx");
_model = LoadModelFromEmbeddedResource("TestData.squeezenet.onnx");

// this is the data for only one input tensor for this model
var inputTensor = LoadTensorFromEmbeddedResource("TestData.bench.in");
var inputData = LoadTensorFromEmbeddedResource("TestData.bench.in");

// create default session with default session options
// Creating an InferenceSession and loading the model is an expensive operation, so generally you would
// do this once. InferenceSession.Run can be called multiple times, and concurrently.
CreateInferenceSession();

// setup sample input data
inputData = new List<NamedOnnxValue>();
var inputMeta = inferenceSession.InputMetadata;
var inputMeta = _inferenceSession.InputMetadata;
_inputData = new List<OrtValue>(inputMeta.Count);
_orderedInputNames = new List<string>(inputMeta.Count);

foreach (var name in inputMeta.Keys)
{
// note: DenseTensor takes a copy of the provided data
var tensor = new DenseTensor<float>(inputTensor, inputMeta[name].Dimensions);
inputData.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor));
// We create an OrtValue in this case over the buffer of potentially different shapes.
// It is Okay as long as the specified shape does not exceed the actual length of the buffer
var shape = Array.ConvertAll<int, long>(inputMeta[name].Dimensions, Convert.ToInt64);
Debug.Assert(shape.Aggregate(1L, (a, v) => a * v) <= inputData.LongLength);

var ortValue = OrtValue.CreateTensorValueFromMemory(inputData, shape);
_inputData.Add(ortValue);

_orderedInputNames.Add(name);
}
}

Expand All @@ -40,30 +49,47 @@ public void CreateInferenceSession(SessionOptions options = null)
options = new SessionOptions { LogId = "Sample" };
}

inferenceSession = new InferenceSession(model, options);
_inferenceSession = new InferenceSession(_model, options);
}

public void Execute()
{
// Run the inference
// 'results' is an IDisposableReadOnlyCollection<DisposableNamedOnnxValue> container
using (var results = inferenceSession.Run(inputData))
// 'results' is an IDisposableReadOnlyCollection<OrtValue> container
using (var results = _inferenceSession.Run(null, _orderedInputNames, _inputData, _inferenceSession.OutputNames))
{
// dump the results
foreach (var r in results)
for (int i = 0; i < results.Count; ++i)
{
Console.WriteLine("Output for {0}", r.Name);
Console.WriteLine(r.AsTensor<float>().GetArrayString());
var name = _inferenceSession.OutputNames[i];
Console.WriteLine("Output for {0}", name);
// We can now access the native buffer directly from the OrtValue, no copy is involved.
// Spans are structs and are stack allocated. They do not add any GC pressure.
ReadOnlySpan<float> span = results[i].GetTensorDataAsSpan<float>();
Console.Write($"Input {i} results:");
for(int k = 0; k < span.Length; ++k)
{
Console.Write($" {span[k]}");
}
Console.WriteLine();
}
}
}

protected virtual void Dispose(bool disposing)
{
if (disposing && inferenceSession != null)
if (disposing && !_disposed)
{
inferenceSession.Dispose();
inferenceSession = null;
_inferenceSession?.Dispose();
_inferenceSession = null;

if (_inputData != null)
foreach(var v in _inputData)
{
v?.Dispose();
}

_disposed = true;
}
}

Expand Down Expand Up @@ -110,8 +136,10 @@ static byte[] LoadModelFromEmbeddedResource(string path)
return model;
}

private readonly byte[] model;
private readonly List<NamedOnnxValue> inputData;
private InferenceSession inferenceSession;
private bool _disposed = false;
private readonly byte[] _model;
private readonly List<string> _orderedInputNames;
private readonly List<OrtValue> _inputData;
private InferenceSession _inferenceSession;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup>
Expand Down
Loading

0 comments on commit 3db9066

Please sign in to comment.