Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mixed normal/scaled graph in autoscale #25

Merged
merged 1 commit into from
Nov 21, 2023

Conversation

balancap
Copy link
Contributor

The AutoScale interpreter needs to be generalized to support mixed graph, where some tensors are still using normal JAX arrays.

It means we need some form of rules + promotions related to:

  • When to use scaled primitives => when at least one input is a ScaledArray;
  • When to automatically promote simple arrays to ScaledArray;

@balancap balancap force-pushed the support-mixed-jax-scaled-graph-in-autoscale branch 5 times, most recently from 44f190e to c9d4433 Compare November 21, 2023 12:17
The `AutoScale` interpreter needs to be generalized to support mixed graph, where some tensors are still using normal JAX arrays.

It means we need some form of rules + promotions related to:
* When to use scaled primitives. By default, a `FORWARD` type rule, i.e. using scaled op if any input is scaled;
* When to automatically promote simple arrays to ScaledArray. By default, just promoting scalars (maybe numpy constants in future?);
@balancap balancap force-pushed the support-mixed-jax-scaled-graph-in-autoscale branch from c9d4433 to af7ade8 Compare November 21, 2023 12:19
@balancap balancap merged commit 642a58e into main Nov 21, 2023
2 checks passed
@balancap balancap deleted the support-mixed-jax-scaled-graph-in-autoscale branch November 21, 2023 13:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant