Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Python inference runs faster than C++ #22328

Open
ashwin-999 opened this issue Oct 5, 2024 · 0 comments
Open

[Performance] Python inference runs faster than C++ #22328

ashwin-999 opened this issue Oct 5, 2024 · 0 comments
Labels
performance issues related to performance regressions platform:jetson issues related to the NVIDIA Jetson platform

Comments

@ashwin-999
Copy link

ashwin-999 commented Oct 5, 2024

Describe the issue

Observe that ONNX model (FP32) executed in C++ runs slower than Python. it's much much worse with TensorRT execution provider.

I've tried exporting a F16 model with keep_io_types set to True. That isn't better either, infact worse.

But first want to talk about the F32 model in question.

To reproduce

Python script:

import onnxruntime
import numpy as np
import time

def test_run_onnx_model(model_path, input_shape, num_iterations=1):
    print(f"Available providers: {onnxruntime.get_available_providers()}")
    sess_options = onnxruntime.SessionOptions()
    # sess_options.log_severity_level = 0  # 0 = Verbose
    # sess_options.log_verbosity_level = 1  # Increase verbosity
    
    providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
    
    session = onnxruntime.InferenceSession(model_path, sess_options, providers=providers)

    input_name = session.get_inputs()[0].name

    dummy_input = np.ones(input_shape, dtype=np.float32)

    session.run(None, {input_name: dummy_input})

    total_time = 0

    for _ in range(num_iterations):
        start_time = time.time()
        output = session.run(None, {input_name: dummy_input})
        end_time = time.time()
        total_time += (end_time - start_time)

    average_time = total_time / num_iterations
    print(f"Average inference time over {num_iterations} iterations: {average_time:.4f} seconds")

if __name__ == "__main__":
    model_path = "epoch103.nms.fp32.simplified.onnx"
    input_shape = (1, 6, 640, 640) 

    test_run_onnx_model(model_path, input_shape, 1000)

C++:

Model::Model(Ort::Env* env, const std::string& model_path){
    loadModel(env, model_path);
}

void Model::loadModel(Ort::Env* env, const std::string& model_path) {

    Ort::SessionOptions session_options;

    session_options.SetIntraOpNumThreads(1);
    session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); 
    session_options.EnableMemPattern();
    session_options.EnableCpuMemArena();

    // CUDA Execution Provider
    OrtCUDAProviderOptions cuda_options{};
    cuda_options.device_id = 0; 
    cuda_options.arena_extend_strategy = 0;
    cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchExhaustive;
    cuda_options.do_copy_in_default_stream = 1; 
    session_options.AppendExecutionProvider_CUDA(cuda_options);

    // // TensorRT Execution Provider
    // OrtTensorRTProviderOptions trt_options{};
    // trt_options.device_id = 0;
    // trt_options.trt_max_workspace_size = 1 << 30;
    // trt_options.trt_max_partition_iterations = 1000;
    // trt_options.trt_min_subgraph_size = 1;
    // trt_options.trt_fp16_enable = true;
    // trt_options.trt_int8_enable = false;
    // trt_options.trt_engine_cache_enable = true;
    // trt_options.trt_engine_cache_path = "/dummy/path";
    // session_options.AppendExecutionProvider_TensorRT(trt_options);

    auto providers = Ort::GetAvailableProviders();
    std::cout << "Available providers: ";
    copy(providers.begin(), providers.end(), std::ostream_iterator<std::string>(std::cout, " "));
    std::cout << std::endl;
    
    session_ = std::make_unique<Ort::Session>(*env, model_path.c_str(), session_options);
    
    // Get input names and dimensions
    Ort::AllocatorWithDefaultOptions allocator;
    size_t num_input_nodes = session_->GetInputCount();
    input_node_names_.resize(num_input_nodes);
    input_node_dims_.resize(num_input_nodes);

    for (size_t i = 0; i < num_input_nodes; i++) {
        auto input_name = session_->GetInputNameAllocated(i, allocator);
		input_node_name_allocated_strings_.push_back(std::move(input_name));
        input_node_names_[i] = input_node_name_allocated_strings_.back().get();

        Ort::TypeInfo input_node_info = session_->GetInputTypeInfo( i );
        auto tensor_info = input_node_info.GetTensorTypeAndShapeInfo();
        input_node_dims_[i] = tensor_info.GetShape();

    }

    auto input_shape = input_node_dims_[0];
    for (auto dim : input_shape) {
        input_tensor_size *= dim;
    }

    
    // Get output names and dimensions
    size_t num_output_nodes = session_->GetOutputCount();
    output_node_names_.resize(num_output_nodes);
    output_node_dims_.resize(num_output_nodes);

    for (size_t i = 0; i < num_output_nodes; i++) {
        auto output_name = session_->GetOutputNameAllocated(i, allocator);
		output_node_name_allocated_strings_.push_back(std::move(output_name));
        output_node_names_[i] = output_node_name_allocated_strings_.back().get();
        
        Ort::TypeInfo output_node_info = session_->GetOutputTypeInfo( i );
        auto tensor_info = output_node_info.GetTensorTypeAndShapeInfo();
        output_node_dims_[i] = tensor_info.GetShape();
    }

    warmUpModel(10);

}

void Model::runInference(const std::vector<float>& input, std::vector<SomeObject>& detections) {

    Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info_, 
                                                        const_cast<float*>(input.data()), 
                                                        input.size(), 
                                                        input_node_dims_[0].data(), 
                                                        input_node_dims_[0].size());

    auto output_tensors = session_->Run(Ort::RunOptions{nullptr}, 
                                        input_node_names_.data(), &input_tensor, 1, 
                                        output_node_names_.data(), 1);
    
    const float* output_data = output_tensors[0].GetTensorData<float>();
    auto output_shape = output_tensors[0].GetTensorTypeAndShapeInfo().GetShape();
    
    num_detections_ = output_shape[0];

    detections.reserve(num_detections_);

   // code to populate detection hidden. (confirmed this piece isn't slow.. a for loop that populates detections) 

}

void Model::warmUpModel(int num_iterations) {
    
    std::vector<float> dummy_input(input_tensor_size, 0.0f);
    std::vector<SomeObject> dummy_detections;
    auto start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < num_iterations; ++i) {
        runInference(dummy_input, dummy_detections);
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
    // print duration
}

Both Python and C++ run the same model, with inference being run 1000 times in C++ just as in Python.

  • Python: 42.5ms/iter on average
  • C++ (with OrtTensorRTProviderOptions commented out): 47.8ms/iter on average
  • C++ (with OrtCUDAProviderOptions commented out): 86.0ms/iter on average

I suspect it's the large input copying between CPU <-> GPU that's slowing things down. I hate the fact that I am initializing Ort::Value input_tensor every single time. Can someone help me with binding I/O? I searched quite extensively online, and there's quite a few examples for Python and the ones I run into for C++ either don't compile or segfaults or gives me 0 detections. The output node is a dynamic node.. shape/size is not pre-determined.

Here are a few questions that I have in particular:

  1. Would you say I/O is the primary suspect?
  2. Beyond I/O what else could I be doing to improve the runtime?
  3. Why is TensorRT being so slow?

for some context, all of the runtimes provided are on Jetson ORIN, being run from inside a docker container.

Urgency

No response

Platform

Linux

OS Version

22.04

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.20.0, commit id: e91ff94

ONNX Runtime API

Python

Architecture

ARM64

Execution Provider

Default CPU, CUDA, TensorRT

Execution Provider Library Version

CUDA Version: 12.2

Model File

No response

Is this a quantized model?

No

@ashwin-999 ashwin-999 added the performance issues related to performance regressions label Oct 5, 2024
@github-actions github-actions bot added the platform:jetson issues related to the NVIDIA Jetson platform label Oct 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance issues related to performance regressions platform:jetson issues related to the NVIDIA Jetson platform
Projects
None yet
Development

No branches or pull requests

1 participant