Skip to content

Commit

Permalink
update rule of into_shape
Browse files Browse the repository at this point in the history
- RSTSR now always returns owned tensor after calling `into_owned`, and added `change_shape` for previous behavior
  • Loading branch information
ajz34 committed Jan 3, 2025
1 parent 215fb6d commit cf3a9d0
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 30 deletions.
18 changes: 7 additions & 11 deletions listings/features-default/tests/arithmetics_and_broadcasting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ fn example_basic_arithmetics() {
// ANCHOR_END: basic_arithmetics_01

// ANCHOR: basic_arithmetics_02
let mat = rt::arange(12).into_shape([3, 4]).into_owned();
let vec = rt::arange(4).into_shape([4]).into_owned();
let mat = rt::arange(12).into_shape([3, 4]);
let vec = rt::arange(4).into_shape([4]);

// matrix multiplication
let res = &mat % mat.t();
Expand Down Expand Up @@ -69,7 +69,7 @@ fn example_basic_arithmetics() {
#[test]
fn example_op_percent() {
// ANCHOR: star_as_elem_mult
let mat = rt::arange(12).into_shape([3, 4]).into_owned();
let mat = rt::arange(12).into_shape([3, 4]);
let vec = rt::arange(4);

// element-wise matrix multiplication
Expand Down Expand Up @@ -135,9 +135,7 @@ fn example_lt_os_mp2() {
// ANCHOR: lt_os_mp2_01
// task definition
let (naux, nocc, nvir) = (8, 2, 4); // subscripts (P, i, a)
let y = rt::arange(naux * nocc * nvir)
.into_shape([naux, nocc, nvir])
.into_owned();
let y = rt::arange(naux * nocc * nvir).into_shape([naux, nocc, nvir]);
let ei = rt::arange(nocc);
let ea = rt::arange(nvir);
// ANCHOR_END: lt_os_mp2_01
Expand Down Expand Up @@ -167,11 +165,9 @@ fn example_ao2mo_vo() {
// ANCHOR: ao2mo_vo_01
// task definition
let (naux, nocc, nvir, nao, _) = (8, 2, 4, 6, 6); // subscripts (P, i, a, μ, ν)
let y_ao = rt::arange(naux * nao * nao)
.into_shape([naux, nao, nao])
.into_owned();
let c_occ = rt::arange(nao * nocc).into_shape([nao, nocc]).into_owned();
let c_vir = rt::arange(nao * nvir).into_shape([nao, nvir]).into_owned();
let y_ao = rt::arange(naux * nao * nao).into_shape([naux, nao, nao]);
let c_occ = rt::arange(nao * nocc).into_shape([nao, nocc]);
let c_vir = rt::arange(nao * nvir).into_shape([nao, nvir]);
// ANCHOR_END: ao2mo_vo_01

// ANCHOR: ao2mo_vo_02
Expand Down
20 changes: 10 additions & 10 deletions listings/features-default/tests/indexing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rstsr_core::prelude::*;
fn example_index_by_num() {
// ANCHOR: example_index_by_num_01
// generate 3-D tensor A_ijk
let a = rt::arange(24).into_shape([4, 3, 2]).into_owned();
let a = rt::arange(24).into_shape([4, 3, 2]);
println!("{:}", a);

// B_jk = A_ijk where i = 2
Expand Down Expand Up @@ -39,7 +39,7 @@ fn example_index_by_num() {
fn example_index_by_range() {
// ANCHOR: example_index_by_range_01
// generate 3-D tensor A_ijk
let a = rt::arange(24).into_shape([4, 3, 2]).into_owned();
let a = rt::arange(24).into_shape([4, 3, 2]);
println!("{:}", a);

// B_ijk = A_ijk where 1 <= i < 3
Expand Down Expand Up @@ -84,7 +84,7 @@ fn example_index_by_range() {
// ANCHOR_END: example_index_by_range_04

// ANCHOR: example_index_by_range_05
let a = rt::arange(24).into_shape([4, 3, 2]).into_owned();
let a = rt::arange(24).into_shape([4, 3, 2]);
let b = a.slice((.., 1..3, ..2)); // equivalently `a.slice(s![.., 1..3, ..2])`
println!("{:}", b);
// output:
Expand Down Expand Up @@ -138,7 +138,7 @@ fn example_slice_with_strides() {
#[test]
fn example_insert_axes() {
// ANCHOR: example_insert_axes_01
let a = rt::arange(24).into_shape([4, 3, 2]).into_owned();
let a = rt::arange(24).into_shape([4, 3, 2]);

// insert new axis at the beginning
let b = a.slice(NewAxis);
Expand All @@ -161,7 +161,7 @@ fn example_insert_axes() {
#[should_panic]
fn example_insert_axes_panic() {
// ANCHOR: example_insert_axes_02
let a = rt::arange(24).into_shape([4, 3, 2]).into_owned();
let a = rt::arange(24).into_shape([4, 3, 2]);

// insert new axis at the beginning
let b = a.slice(Some(2));
Expand All @@ -173,7 +173,7 @@ fn example_insert_axes_panic() {
#[test]
fn example_ellipsis() {
// ANCHOR: example_ellipsis_01
let a = rt::arange(24).into_shape([4, 3, 2]).into_owned();
let a = rt::arange(24).into_shape([4, 3, 2]);

// using ellipsis to select index from last dimension
// equivallently to `a.slice((.., .., 0))` for 3-D tensor
Expand Down Expand Up @@ -203,7 +203,7 @@ fn example_mixed_indexing() {
#[test]
fn example_elementwise_safe() {
// ANCHOR: example_elementwise_safe
let a = rt::arange(24).into_shape([4, 3, 2]).into_owned();
let a = rt::arange(24).into_shape([4, 3, 2]);

let val = a[[2, 2, 1]];
println!("{:}", val);
Expand Down Expand Up @@ -234,7 +234,7 @@ fn example_elementwise_safe() {
#[should_panic]
fn example_elementwise_safe_panic() {
// ANCHOR: example_elementwise_safe_panic
let a = rt::arange(24).into_shape([4, 3, 2]).into_owned();
let a = rt::arange(24).into_shape([4, 3, 2]);

let val = a[[2, 2, 3]];
println!("{:}", val);
Expand All @@ -245,7 +245,7 @@ fn example_elementwise_safe_panic() {
#[test]
fn example_elementwise_unchecked() {
// ANCHOR: example_elementwise_unchecked
let a = rt::arange(24).into_shape([4, 3, 2]).into_owned();
let a = rt::arange(24).into_shape([4, 3, 2]);

let val = unsafe { a.index_uncheck([2, 2, 1]) };
println!("{:}", val);
Expand All @@ -256,7 +256,7 @@ fn example_elementwise_unchecked() {
#[test]
fn example_elementwise_unchecked_not_desired() {
// ANCHOR: example_elementwise_unchecked_not_desired
let a = rt::arange(24).into_shape([4, 3, 2]).into_owned();
let a = rt::arange(24).into_shape([4, 3, 2]);

let val = unsafe { a.index_uncheck([2, 2, 3]) };
println!("{:}", val);
Expand Down
11 changes: 4 additions & 7 deletions listings/features-default/tests/structure_and_ownership.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@ fn example_tensor_ownership() {
let tensor = rt::arange(12);
let ptr_1 = tensor.rawvec().as_ptr();

// this will give cow tensor with 2-D shape
// this will give owned tensor with 2-D shape
// since previous tensor is contiguous, this will not copy memory
let tensor = tensor.into_shape([3, 4]);

// convert cow tensor to owned tensor
let mut tensor = tensor.into_owned();
let mut tensor = tensor.into_shape([3, 4]);
tensor += 1; // inplace operation
let ptr_2 = tensor.rawvec().as_ptr();

Expand Down Expand Up @@ -83,7 +80,7 @@ fn example_to_scalar() {
fn example_dim_conversion() {
// ANCHOR: dim_conversion
// fixed dimension
let a = rt::arange(12).into_shape([3, 4]).into_owned();
let a = rt::arange(12).into_shape([3, 4]);
println!("{:?}", a);
// output: 2-Dim, contiguous: Cc

Expand All @@ -99,7 +96,7 @@ fn example_dim_conversion() {
// ANCHOR_END: dim_conversion

// ANCHOR: dyn_dim_construct
let a = rt::arange(12).into_shape(vec![3, 4]).into_owned();
let a = rt::arange(12).into_shape(vec![3, 4]);
println!("{:?}", a);
// output: 2-Dim (dyn), contiguous: Cc
// ANCHOR_END: dyn_dim_construct
Expand Down
4 changes: 2 additions & 2 deletions listings/features-default/tests/tensor_creation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ fn example_03() {

// if you feel function `into_shape_assume_contig` ugly, following code also works
let vec = vec![1, 2, 3, 4, 5, 6];
let tensor = rt::asarray(vec).into_shape([2, 3]).into_owned();
let tensor = rt::asarray(vec).into_shape([2, 3]);
println!("{:}", tensor);

// and even more concise
Expand Down Expand Up @@ -183,7 +183,7 @@ fn example_diag() {
// [ 0 2 0]
// [ 0 0 3]]

let tensor = rt::arange(9).into_shape([3, 3]).into_owned();
let tensor = rt::arange(9).into_shape([3, 3]);
let diag = tensor.diag();
println!("{:}", diag);
// output: [ 0 4 8]
Expand Down

0 comments on commit cf3a9d0

Please sign in to comment.