Skip to content

Commit

Permalink
Fix branched caching
Browse files Browse the repository at this point in the history
  • Loading branch information
lxsmnsyc committed Feb 24, 2024
1 parent be396f6 commit fbe9ffb
Showing 1 changed file with 111 additions and 47 deletions.
158 changes: 111 additions & 47 deletions packages/forgetti/src/core/optimizer-scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ export default class OptimizerScope {
this.isInLoop = isInLoop;
}

createHeader(type: 'memo' | 'ref' = 'memo'): t.Identifier {
createHeader(type: 'memo' | 'ref'): t.Identifier {
if (type === 'ref') {
if (!this.ref) {
this.ref = this.path.scope.generateUidIdentifier('ref');
Expand All @@ -86,18 +86,18 @@ export default class OptimizerScope {
}

getMemoDeclarations(): t.VariableDeclaration[] | undefined {
if (this.memo || this.ref) {
if (this.memo) {
// This is for generating branched caching.
// Parent means that we want to create the cache
// from the parent (or root)
if (this.parent) {
const header = this.parent.createHeader();
const header = this.parent.createHeader('memo');
const index = this.parent.createIndex('memo');

return [
t.variableDeclaration('let', [
t.variableDeclarator(
this.createHeader(),
this.createHeader('memo'),
t.callExpression(
getImportIdentifier(this.ctx, this.path, RUNTIME_BRANCH),
[header, index, t.numericLiteral(this.indecesMemo)],
Expand All @@ -108,48 +108,70 @@ export default class OptimizerScope {
}

const outputDeclarations = [];

if (this.memo) {
outputDeclarations.push(
t.variableDeclaration('let', [
t.variableDeclarator(
this.memo,
t.callExpression(
getImportIdentifier(this.ctx, this.path, RUNTIME_CACHE),
[
getImportIdentifier(
this.ctx,
this.path,
this.ctx.preset.runtime.useMemo,
),
t.numericLiteral(this.indecesMemo),
],
),
outputDeclarations.push(
t.variableDeclaration('let', [
t.variableDeclarator(
this.memo,
t.callExpression(
getImportIdentifier(this.ctx, this.path, RUNTIME_CACHE),
[
getImportIdentifier(
this.ctx,
this.path,
this.ctx.preset.runtime.useMemo,
),
t.numericLiteral(this.indecesMemo),
],
),
]),
);
}
if (this.ref) {
outputDeclarations.push(
),
]),
);
return outputDeclarations;
}
return undefined;
}

getRefDeclarations(): t.VariableDeclaration[] | undefined {
if (this.ref) {
// This is for generating branched caching.
// Parent means that we want to create the cache
// from the parent (or root)
if (this.parent) {
const header = this.parent.createHeader('ref');
const index = this.parent.createIndex('ref');

return [
t.variableDeclaration('let', [
t.variableDeclarator(
this.ref,
this.createHeader('ref'),
t.callExpression(
getImportIdentifier(this.ctx, this.path, RUNTIME_REF),
[
getImportIdentifier(
this.ctx,
this.path,
this.ctx.preset.runtime.useRef,
),
t.numericLiteral(this.indecesRef),
],
getImportIdentifier(this.ctx, this.path, RUNTIME_BRANCH),
[header, index, t.numericLiteral(this.indecesRef)],
),
),
]),
);
];
}

const outputDeclarations = [];
outputDeclarations.push(
t.variableDeclaration('let', [
t.variableDeclarator(
this.ref,
t.callExpression(
getImportIdentifier(this.ctx, this.path, RUNTIME_REF),
[
getImportIdentifier(
this.ctx,
this.path,
this.ctx.preset.runtime.useRef,
),
t.numericLiteral(this.indecesRef),
],
),
),
]),
);
return outputDeclarations;
}
return undefined;
Expand Down Expand Up @@ -177,13 +199,13 @@ export default class OptimizerScope {
if (!this.parent) {
return undefined;
}
const header = this.parent.createHeader();
const header = this.parent.createHeader('memo');
const index = this.parent.createIndex('memo');
const id = this.createLoopIndex();

return t.variableDeclaration('let', [
t.variableDeclarator(
this.createHeader(),
this.createHeader('memo'),
t.callExpression(
getImportIdentifier(this.ctx, this.path, RUNTIME_BRANCH),
// Looped branches cannot be statically analyzed
Expand All @@ -195,8 +217,30 @@ export default class OptimizerScope {
]);
}

getLoopDeclaration(): t.VariableDeclaration {
const header = this.createHeader();
getLoopRefDeclaration(): t.VariableDeclaration | undefined {
if (!this.parent) {
return undefined;
}
const header = this.parent.createHeader('ref');
const index = this.parent.createIndex('ref');
const id = this.createLoopIndex();

return t.variableDeclaration('let', [
t.variableDeclarator(
this.createHeader('ref'),
t.callExpression(
getImportIdentifier(this.ctx, this.path, RUNTIME_BRANCH),
// Looped branches cannot be statically analyzed
[header, index, t.numericLiteral(0)],
),
),
// This is for tracking the dynamic size
t.variableDeclarator(id, t.numericLiteral(0)),
]);
}

getLoopMemoHeaderDeclaration(): t.VariableDeclaration {
const header = this.createHeader('memo');
const index = this.createLoopIndex();
const localIndex = this.path.scope.generateUidIdentifier('loopId');
return t.variableDeclaration('let', [
Expand All @@ -211,15 +255,35 @@ export default class OptimizerScope {
]);
}

getLoopRefHeaderDeclaration(): t.VariableDeclaration {
const header = this.createHeader('ref');
const index = this.createLoopIndex();
const localIndex = this.path.scope.generateUidIdentifier('loopId');
return t.variableDeclaration('let', [
t.variableDeclarator(localIndex, t.updateExpression('++', index)),
t.variableDeclarator(
this.createLoopHeader(),
t.callExpression(
getImportIdentifier(this.ctx, this.path, RUNTIME_BRANCH),
[header, localIndex, t.numericLiteral(this.indecesRef)],
),
),
]);
}

getStatements(): t.Statement[] {
const result = [...this.statements];
const header = this.isInLoop
? [this.getLoopDeclaration()]
const memoHeader = this.isInLoop
? [this.getLoopMemoHeaderDeclaration()]
: this.getMemoDeclarations();
if (header) {
return mergeVariableDeclaration([...header, ...result]);
}
return mergeVariableDeclaration(result);
const refHeader = this.isInLoop
? [this.getLoopRefHeaderDeclaration()]
: this.getRefDeclarations();
return mergeVariableDeclaration([
...(memoHeader || []),
...(refHeader || []),
...result,
]);
}

push(...statements: t.Statement[]): void {
Expand Down

0 comments on commit fbe9ffb

Please sign in to comment.