Skip to content

Commit

Permalink
Merge pull request #926 from NiklasGustafsson/missing
Browse files Browse the repository at this point in the history
Missing methods implemented
  • Loading branch information
NiklasGustafsson authored Feb 21, 2023
2 parents 64b6999 + 35dbb1b commit 17a35ca
Show file tree
Hide file tree
Showing 39 changed files with 3,394 additions and 2,040 deletions.
5 changes: 5 additions & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ Adding allow_tf32<br/>
Adding overloads of Module.save() and Module.load() taking a 'Stream' argument.<br/>
Adding torch.softmax() and Tensor.softmax() as aliases for torch.special.softmax()<br/>
Adding torch.from_file()<br/>
Adding a number of missing pointwise Tensor operations.<br/>
Adding select_scatter, diagonal_scatter, and slice_scatter<br/>
Adding torch.set_printoptions<br/>
Adding torch.cartesian_prod, combinations, and cov.<br/>
Adding torch.cdist, diag_embed, rot90, triu_indices, tril_indices<br/>

__Fixed Bugs__:

Expand Down
17 changes: 17 additions & 0 deletions src/Native/LibTorchSharp/THSLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ Tensor THSLinalg_det(const Tensor tensor)
CATCH_TENSOR(torch::linalg::det(*tensor));
}

Tensor THSTensor_logdet(const Tensor tensor)
{
CATCH_TENSOR(torch::logdet(*tensor));
}

Tensor THSLinalg_slogdet(const Tensor tensor, Tensor* logabsdet)
{
std::tuple<at::Tensor, at::Tensor> res;
Expand All @@ -63,6 +68,13 @@ Tensor THSLinalg_eig(const Tensor tensor, Tensor* eigenvectors)
return ResultTensor(std::get<0>(res));
}

Tensor THSTensor_geqrf(const Tensor tensor, Tensor* tau)
{
std::tuple<at::Tensor, at::Tensor> res;
CATCH(res = torch::geqrf(*tensor);)
*tau = ResultTensor(std::get<1>(res));
return ResultTensor(std::get<0>(res));
}

#if 0
Tensor THSTensor_eig(const Tensor tensor, bool vectors, Tensor* eigenvectors)
Expand Down Expand Up @@ -98,6 +110,11 @@ Tensor THSLinalg_eigvalsh(const Tensor tensor, const char UPLO)
CATCH_TENSOR(torch::linalg::eigvalsh(*tensor, _uplo));
}

Tensor THSLinalg_householder_product(const Tensor tensor, const Tensor tau)
{
CATCH_TENSOR(torch::linalg::householder_product(*tensor, *tau));
}

Tensor THSLinalg_inv(const Tensor tensor)
{
CATCH_TENSOR(torch::linalg::inv(*tensor));
Expand Down
102 changes: 95 additions & 7 deletions src/Native/LibTorchSharp/THSTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ Tensor THSTensor_any_along_dimension(const Tensor tensor, const int64_t dim, boo
{
CATCH_TENSOR(tensor->any(dim, keepdim));
}

Tensor THSTensor_adjoint(const Tensor tensor)
{
CATCH_TENSOR(tensor->adjoint());
}

Tensor THSTensor_argmax(const Tensor tensor)
{
CATCH_TENSOR(tensor->argmax());
Expand All @@ -86,6 +92,11 @@ Tensor THSTensor_argmin_along_dimension(const Tensor tensor, const int64_t dim,
CATCH_TENSOR(tensor->argmin(dim, keepdim));
}

Tensor THSTensor_argwhere(const Tensor tensor)
{
CATCH_TENSOR(tensor->argwhere());
}

Tensor THSTensor_atleast_1d(const Tensor tensor)
{
CATCH_TENSOR(torch::atleast_1d(*tensor));
Expand Down Expand Up @@ -159,6 +170,11 @@ void THSTensor_vector_to_parameters(const Tensor vec, const Tensor* tensors, con
CATCH(torch::nn::utils::vector_to_parameters(*vec, toTensors<at::Tensor>((torch::Tensor**)tensors, length)););
}

Tensor THSTensor_cartesian_prod(const Tensor* tensors, const int length)
{
CATCH_TENSOR(torch::cartesian_prod(toTensors<at::Tensor>((torch::Tensor**)tensors, length)));
}

double THSTensor_clip_grad_norm_(const Tensor* tensors, const int length, const double max_norm, const double norm_type)
{
double res = 0.0;
Expand Down Expand Up @@ -258,6 +274,11 @@ Tensor THSTensor_clone(const Tensor tensor)
CATCH_TENSOR(tensor->clone());
}

Tensor THSTensor_combinations(const Tensor tensor, const int r, const bool with_replacement)
{
CATCH_TENSOR(torch::combinations(*tensor, r, with_replacement));
}

Tensor THSTensor_copy_(const Tensor input, const Tensor other, const bool non_blocking)
{
CATCH_TENSOR(input->copy_(*other, non_blocking));
Expand Down Expand Up @@ -285,6 +306,13 @@ int THSTensor_is_contiguous(const Tensor tensor)
return result;
}

int64_t THSTensor_is_nonzero(const Tensor tensor)
{
bool result = false;
CATCH(result = tensor->is_nonzero();)
return result;
}

Tensor THSTensor_copysign(const Tensor input, const Tensor other)
{
CATCH_TENSOR(input->copysign(*other));
Expand All @@ -295,13 +323,6 @@ Tensor THSTensor_corrcoef(const Tensor tensor)
CATCH_TENSOR(tensor->corrcoef());
}

Tensor THSTensor_cov(const Tensor input, int64_t correction, const Tensor fweights, const Tensor aweights)
{
c10::optional<at::Tensor> fw = (fweights == nullptr) ? c10::optional<at::Tensor>() : *fweights;
c10::optional<at::Tensor> aw = (aweights == nullptr) ? c10::optional<at::Tensor>() : *aweights;
CATCH_TENSOR(input->cov(correction, fw, aw));
}

bool THSTensor_is_cpu(const Tensor tensor)
{
bool result = true;
Expand Down Expand Up @@ -402,6 +423,11 @@ int THSTensor_device_type(const Tensor tensor)
return (int)device.type();
}

Tensor THSTensor_diag_embed(const Tensor tensor, const int64_t offset, const int64_t dim1, const int64_t dim2)
{
CATCH_TENSOR(tensor->diag_embed(offset, dim1, dim2));
}

Tensor THSTensor_diff(const Tensor tensor, const int64_t n, const int64_t dim, const Tensor prepend, const Tensor append)
{
c10::optional<at::Tensor> prep = prepend != nullptr ? *prepend : c10::optional<at::Tensor>(c10::nullopt);
Expand Down Expand Up @@ -473,6 +499,11 @@ Tensor THSTensor_repeat_interleave_int64(const Tensor tensor, const int64_t repe
CATCH_TENSOR(tensor->repeat_interleave(repeats, _dim, _output_size));
}

int THSTensor_result_type(const Tensor left, const Tensor right)
{
CATCH_RETURN_RES(int, -1, res = (int)torch::result_type(*left, *right));
}

Tensor THSTensor_movedim(const Tensor tensor, const int64_t* src, const int src_len, const int64_t* dst, const int dst_len)
{
CATCH_TENSOR(tensor->movedim(at::ArrayRef<int64_t>(src, src_len), at::ArrayRef<int64_t>(dst, dst_len)));
Expand Down Expand Up @@ -1070,6 +1101,11 @@ Tensor THSTensor_outer(const Tensor left, const Tensor right)
CATCH_TENSOR(left->outer(*right));
}

Tensor THSTensor_ormqr(const Tensor input, const Tensor tau, const Tensor other, bool left, bool transpose)
{
CATCH_TENSOR(torch::ormqr(*input, *tau, *other, left, transpose));
}

Tensor THSTensor_mH(const Tensor tensor)
{
CATCH_TENSOR(tensor->mH());
Expand Down Expand Up @@ -1161,6 +1197,11 @@ Tensor THSTensor_reshape(const Tensor tensor, const int64_t* shape, const int le
CATCH_TENSOR(tensor->reshape(at::ArrayRef<int64_t>(shape, length)));
}

Tensor THSTensor_rot90(const Tensor tensor, const int64_t k, const int64_t dim1, const int64_t dim2)
{
CATCH_TENSOR(tensor->rot90(k, { dim1, dim2 }));
}

Tensor THSTensor_roll(const Tensor tensor, const int64_t* shifts, const int shLength, const int64_t* dims, const int dimLength)
{
CATCH_TENSOR(
Expand Down Expand Up @@ -1194,6 +1235,36 @@ Tensor THSTensor_scatter_(
CATCH_TENSOR(tensor->scatter_(dim, *index, *source));
}

Tensor THSTensor_select_scatter(
const Tensor tensor,
const Tensor source,
const int64_t dim,
const int64_t index)
{
CATCH_TENSOR(torch::select_scatter(*tensor, *source, dim, index));
}

Tensor THSTensor_diagonal_scatter(
const Tensor tensor,
const Tensor source,
const int64_t offset,
const int64_t dim1,
const int64_t dim2)
{
CATCH_TENSOR(torch::diagonal_scatter(*tensor, *source, offset, dim1, dim2));
}

Tensor THSTensor_slice_scatter(
const Tensor tensor,
const Tensor source,
const int64_t dim,
const int64_t *start,
const int64_t *end,
const int64_t step)
{
CATCH_TENSOR(torch::slice_scatter(*tensor, *source, dim, start == nullptr ? c10::optional<int64_t>() : c10::optional<int64_t>(*start), end == nullptr ? c10::optional<int64_t>() : c10::optional<int64_t>(*end), step));
}

Tensor THSTensor_scatter_add(
const Tensor tensor,
const int64_t dim,
Expand Down Expand Up @@ -1762,6 +1833,23 @@ Tensor THSTensor_tril(const Tensor tensor, const int64_t diagonal)
CATCH_TENSOR(tensor->tril(diagonal));
}

Tensor THSTensor_tril_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index)
{
auto options = at::TensorOptions()
.dtype(at::ScalarType(scalar_type))
.device(c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index));
CATCH_TENSOR(torch::tril_indices(row, col, offset, options));
}

Tensor THSTensor_triu_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index)
{
auto options = at::TensorOptions()
.dtype(at::ScalarType(scalar_type))
.device(c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index));
CATCH_TENSOR(torch::triu_indices(row, col, offset, options));
}


Tensor THSTensor_transpose(const Tensor tensor, const int64_t dim1, const int64_t dim2)
{
CATCH_TENSOR(tensor->transpose(dim1, dim2));
Expand Down
Loading

0 comments on commit 17a35ca

Please sign in to comment.