Skip to content

Commit

Permalink
Clean up mju_sqrMatTDSparse
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711486739
Change-Id: I5f9145dcd8522a9cb372cbdf60611d7d05cb524e
  • Loading branch information
yuvaltassa authored and copybara-github committed Jan 2, 2025
1 parent 4510c6d commit 6c880eb
Showing 1 changed file with 34 additions and 20 deletions.
54 changes: 34 additions & 20 deletions src/engine/engine_util_sparse.c
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,8 @@ void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
int* markers = mjSTACKALLOC(d, nc, int);

for (int i=0; i < nc; i++) {
int* cols = res_colind+res_rowadr[i];
int rowadr_i = res_rowadr[i];
int* cols = res_colind + rowadr_i;

res_rownnz[i] = 0;
buffer[i] = 0;
Expand All @@ -755,18 +756,26 @@ void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
}

// iterate through each row of M'
int end = rowadrT[i] + rownnzT[i];
for (int r = rowadrT[i]; r < end; r++) {
int adrT = rowadrT[i];
int end_r = adrT + rownnzT[i];
for (int r = adrT; r < end_r; r++) {
int t = colindT[r];
mjtNum v = diag ? matT[r] * diag[t] : matT[r];
for (int c=rowadr[t]; c < rowadr[t]+rownnz[t]; c++) {
int adr = rowadr[t];
int end_c = adr + rownnz[t];
for (int c=adr; c < end_c; c++) {
int cc = colind[c];

// ignore upper triangle
if (cc > i) {
break;
}

buffer[cc] += v*mat[c];
// add value to buffer
if (diag) {
buffer[cc] += matT[r] * diag[t] * mat[c];
} else {
buffer[cc] += matT[r] * mat[c];
}

// only need to insert nnz if not marked
if (!markers[cc]) {
Expand Down Expand Up @@ -810,31 +819,36 @@ void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
}
}

end = res_rownnz[i];
end_r = res_rownnz[i];

// rowsuperT: reuse sparsity, copy into res
if (rowsuperT && rowsuperT[i]) {
for (int r=0; r < end; r++) {
res[res_rowadr[i] + r] = buffer[cols[r]];
buffer[cols[r]] = 0;
for (int r=0; r < end_r; r++) {
int c = cols[r];
res[rowadr_i + r] = buffer[c];
buffer[c] = 0;
}
} else {
// clear out buffers since sparsity cannot be reused
for (int r=0; r < end; r++) {
int cc = cols[r];
res[res_rowadr[i] + r] = buffer[cc];
res_colind[res_rowadr[i] + r] = cc;
buffer[cc] = 0;
markers[cc] = 0;
}

// clear out buffers, sparsity cannot be reused
else {
for (int r=0; r < end_r; r++) {
int c = cols[r];
int adr = rowadr_i + r;
res[adr] = buffer[c];
res_colind[adr] = c;
buffer[c] = 0;
markers[c] = 0;
}
}
}


// fill upper triangle
for (int i=0; i < nc; i++) {
int end = res_rowadr[i] + res_rownnz[i] - 1;
for (int j=res_rowadr[i]; j < end; j++) {
int start = res_rowadr[i];
int end = start + res_rownnz[i] - 1;
for (int j=start; j < end; j++) {
int adr = res_rowadr[res_colind[j]] + res_rownnz[res_colind[j]]++;
res[adr] = res[j];
res_colind[adr] = i;
Expand Down

0 comments on commit 6c880eb

Please sign in to comment.