Skip to content

Commit

Permalink
Add CI minimal build with all options disabled. Fix python binding co…
Browse files Browse the repository at this point in the history
…de if sparse tensors are disabled. (microsoft#9898)

* Add 2 builds to validate the cmake defines for excluding optional components work in both full and minimal builds.

* Create empty config for no-ops build

* Create empty config for no-ops build - attempt #2

* Create empty config for no-ops build - attempt microsoft#3

* Update python binding code to work when sparse tensors are disabled.
  • Loading branch information
skottmckay authored Dec 2, 2021
1 parent 3f5c1e1 commit 912e50f
Show file tree
Hide file tree
Showing 6 changed files with 438 additions and 324 deletions.
52 changes: 33 additions & 19 deletions onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ void addOrtValueMethods(pybind11::module& m) {
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA
CreateGenericMLValue(nullptr, GetCudaAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToCudaMemCpy);
#elif USE_ROCM
if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}
if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}

// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA
CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy);
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA
CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy);

#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
Expand Down Expand Up @@ -97,9 +97,9 @@ void addOrtValueMethods(pybind11::module& m) {
}
allocator = GetCudaAllocator(device.Id());
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
Expand All @@ -111,6 +111,7 @@ void addOrtValueMethods(pybind11::module& m) {
return ml_value;
})

#if !defined(DISABLE_SPARSE_TENSORS)
.def_static("ort_value_from_sparse_tensor", [](const PySparseTensor* py_sparse_tensor) -> std::unique_ptr<OrtValue> {
return py_sparse_tensor->AsOrtValue();
})
Expand All @@ -121,6 +122,7 @@ void addOrtValueMethods(pybind11::module& m) {
}
return std::make_unique<PySparseTensor>(*ort_value);
})
#endif
// Get a pointer to Tensor data
.def("data_ptr", [](OrtValue* ml_value) -> int64_t {
// TODO: Assumes that the OrtValue is a Tensor, make this generic to handle non-Tensors
Expand All @@ -138,21 +140,31 @@ void addOrtValueMethods(pybind11::module& m) {
.def("device_name", [](const OrtValue* ort_value) -> std::string {
if (ort_value->IsTensor()) {
return std::string(GetDeviceName(ort_value->Get<Tensor>().Location().device));
} else if (ort_value->IsSparseTensor()) {
}
#if !defined(DISABLE_SPARSE_TENSORS)
else if (ort_value->IsSparseTensor()) {
return std::string(GetDeviceName(ort_value->Get<SparseTensor>().Location().device));
} else {
ORT_THROW("Only OrtValues that are Tensors/SparseTensors are currently supported");
}

ORT_THROW("Only OrtValues that are Tensors/SparseTensors are currently supported");
#else
ORT_THROW("Only OrtValues that are Tensors are supported in this build");
#endif
})
.def("shape", [](const OrtValue* ort_value) -> py::list {
py::list shape_arr;
#if !defined(DISABLE_SPARSE_TENSORS)
// OrtValue can only be a Tensor/SparseTensor, make this generic to handle non-Tensors
ORT_ENFORCE(ort_value->IsTensor() || ort_value->IsSparseTensor(),
"Only OrtValues that are Tensors/SpareTensors are currently supported");

py::list shape_arr;
const auto& dims = (ort_value->IsTensor())
? ort_value->Get<Tensor>().Shape().GetDims()
: ort_value->Get<SparseTensor>().DenseShape().GetDims();
#else
ORT_ENFORCE(ort_value->IsTensor(), "Only OrtValues that are Tensors are supported in this build");
const auto& dims = ort_value->Get<Tensor>().Shape().GetDims();
#endif

for (auto dim : dims) {
// For sequence tensors - we would append a list of dims to the outermost list
Expand All @@ -168,9 +180,11 @@ void addOrtValueMethods(pybind11::module& m) {
if (ort_value->IsTensor()) {
auto elem_type = ort_value->Get<Tensor>().GetElementType();
type_proto = DataTypeImpl::TensorTypeFromONNXEnum(elem_type)->GetTypeProto();
#if !defined(DISABLE_SPARSE_TENSORS)
} else if (ort_value->IsSparseTensor()) {
auto elem_type = ort_value->Get<SparseTensor>().GetElementType();
type_proto = DataTypeImpl::SparseTensorTypeFromONNXEnum(elem_type)->GetTypeProto();
#endif
} else if (ort_value->IsTensorSequence()) {
auto elem_type = ort_value->Get<TensorSeq>().DataType()->AsPrimitiveDataType()->GetDataType();
type_proto = DataTypeImpl::SequenceTensorTypeFromONNXEnum(elem_type)->GetTypeProto();
Expand Down Expand Up @@ -204,9 +218,9 @@ void addOrtValueMethods(pybind11::module& m) {
#ifdef USE_CUDA
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetCudaToHostMemCpyFunction());
#elif USE_ROCM
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetRocmToHostMemCpyFunction());
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetRocmToHostMemCpyFunction());
#else
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, nullptr);
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, nullptr);
#endif
return obj;
})
Expand Down
Loading

0 comments on commit 912e50f

Please sign in to comment.