Skip to content

Commit

Permalink
Merge pull request tensorflow#61184 from benbarsdell:dlpack-relax-str…
Browse files Browse the repository at this point in the history
…ides

PiperOrigin-RevId: 552672132
  • Loading branch information
tensorflower-gardener committed Aug 1, 2023
2 parents 5ea362c + cfb8642 commit 379184f
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 7 deletions.
19 changes: 19 additions & 0 deletions tensorflow/c/eager/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,25 @@ cc_library(
alwayslink = 1,
)

tf_cuda_cc_test(
name = "dlpack_test",
size = "small",
srcs = [
"dlpack_test.cc",
],
args = [],
tags = [],
deps = [
":c_api",
":dlpack",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/strings",
"@dlpack",
],
)

# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
# right now, remove this public rule when no longer needed (it should be
# replaced by TF Lite)
Expand Down
24 changes: 17 additions & 7 deletions tensorflow/c/eager/dlpack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,18 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
// data.
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
int ndim) {
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
return false;
}
for (int i = ndim - 2; i >= 0; --i) {
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
return false;
bool valid = true;
int64_t expected_stride = 1;
for (int i = ndim - 1; i >= 0; --i) {
// Empty tensors are always compact regardless of strides.
if (shape_arr[i] == 0) return true;
// Note that dimensions with size=1 can have any stride.
if (shape_arr[i] != 1 && stride_arr[i] != expected_stride) {
valid = false;
}
expected_stride *= shape_arr[i];
}
return true;
return valid;
}
} // namespace

Expand Down Expand Up @@ -350,6 +353,13 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
const int64_t* dims = dl_tensor->shape;
void* data = dl_tensor->data;

if (dl_tensor->byte_offset != 0) {
status->status = tensorflow::errors::InvalidArgument(
"Unsupported byte_offset (", dl_tensor->byte_offset,
") from DLPack, must be zero");
return nullptr;
}

size_t total_bytes = dl_tensor->dtype.bits / 8;
for (int i = 0; i < num_dims; i++) {
total_bytes *= dims[i];
Expand Down
113 changes: 113 additions & 0 deletions tensorflow/c/eager/dlpack_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/c/eager/dlpack.h"

#include <vector>

#include "absl/strings/str_join.h"
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/platform/test.h"

namespace tensorflow {
namespace {

void TestHandleFromDLPack(TF_Status* status, TFE_Context* ctx,
std::vector<int64_t> shape,
std::vector<int64_t> strides) {
size_t num_elements = 1;
for (int i = 0; i < static_cast<int32_t>(shape.size()); ++i) {
num_elements *= shape[i];
}
std::vector<float> data(num_elements);
for (size_t j = 0; j < num_elements; ++j) {
data[j] = j;
}
DLManagedTensor dlm_in = {};
DLTensor* dltensor_in = &dlm_in.dl_tensor;
dltensor_in->data = data.data();
dltensor_in->device = {kDLCPU, 0};
dltensor_in->ndim = static_cast<int32_t>(shape.size());
dltensor_in->dtype = {kDLFloat, 32, 1};
dltensor_in->shape = shape.data();
dltensor_in->strides = strides.data();
TFE_TensorHandle* handle = TFE_HandleFromDLPack(&dlm_in, status, ctx);
ASSERT_NE(handle, nullptr)
<< TF_Message(status) << " (shape=[" << absl::StrJoin(shape, ",")
<< "], strides=[" << absl::StrJoin(strides, ",") << "])";

auto* dlm_out =
static_cast<DLManagedTensor*>(TFE_HandleToDLPack(handle, status));
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const DLTensor* dltensor_out = &dlm_out->dl_tensor;
EXPECT_EQ(dltensor_out->device.device_type, dltensor_in->device.device_type);
EXPECT_EQ(dltensor_out->device.device_id, dltensor_in->device.device_id);
EXPECT_EQ(dltensor_out->ndim, dltensor_in->ndim);
EXPECT_EQ(dltensor_out->dtype.code, dltensor_in->dtype.code);
EXPECT_EQ(dltensor_out->dtype.bits, dltensor_in->dtype.bits);
EXPECT_EQ(dltensor_out->dtype.lanes, dltensor_in->dtype.lanes);
for (int i = 0; i < dltensor_in->ndim; ++i) {
EXPECT_EQ(dltensor_out->shape[i], dltensor_in->shape[i]);
if (dltensor_out->strides) {
if (i == dltensor_in->ndim - 1) {
EXPECT_EQ(dltensor_out->strides[i], 1);
} else {
EXPECT_EQ(dltensor_out->strides[i],
dltensor_out->shape[i + 1] * dltensor_out->strides[i + 1]);
}
}
}
const float* data_in = static_cast<const float*>(dltensor_in->data);
const float* data_out = static_cast<const float*>(dltensor_out->data);
for (size_t j = 0; j < num_elements; ++j) {
EXPECT_EQ(data_out[j], data_in[j]);
}

TFE_CallDLManagedTensorDeleter(dlm_out);
TFE_DeleteTensorHandle(handle);
}

TEST(DLPack, HandleFromDLPackStrides) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);

TestHandleFromDLPack(status, ctx, {}, {});
TestHandleFromDLPack(status, ctx, {4}, {});
TestHandleFromDLPack(status, ctx, {4}, {1});
TestHandleFromDLPack(status, ctx, {4, 3, 2}, {});
TestHandleFromDLPack(status, ctx, {4, 3, 2}, {6, 2, 1});
// Test that dims with size=1 can have any stride.
TestHandleFromDLPack(status, ctx, {1}, {1});
TestHandleFromDLPack(status, ctx, {1}, {0});
TestHandleFromDLPack(status, ctx, {4, 1, 2}, {2, 1, 1});
TestHandleFromDLPack(status, ctx, {4, 1, 2}, {2, 0, 1});
TestHandleFromDLPack(status, ctx, {4, 3, 1}, {3, 1, 1});
TestHandleFromDLPack(status, ctx, {4, 3, 1}, {3, 1, 0});
// Test that empty tensors can have any strides.
TestHandleFromDLPack(status, ctx, {4, 0, 2}, {0, 2, 1});
TestHandleFromDLPack(status, ctx, {4, 0, 2}, {0, 1, 1});
TestHandleFromDLPack(status, ctx, {4, 0, 2}, {0, 0, 1});
TestHandleFromDLPack(status, ctx, {4, 0, 2}, {0, 2, 0});

TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}

} // namespace
} // namespace tensorflow

0 comments on commit 379184f

Please sign in to comment.