diff --git a/README.md b/README.md index 83faca5..4da490c 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-resea ## Documentation * [Draft Scaled Arithmetics design document](docs/design.md); +* [Scaled operators coverage](docs/operators.md) ## Development diff --git a/docs/operators.md b/docs/operators.md new file mode 100644 index 0000000..1cf2ba2 --- /dev/null +++ b/docs/operators.md @@ -0,0 +1,122 @@ +# JAX Scaled Operators coverage + +Summary of JAX LAX operators supported in `autoscale` graph transformation. + +## [JAX LAX operations](https://jax.readthedocs.io/en/latest/jax.lax.html) + +| Operation | Supported | Remarks | +| ---------------------- | ------------------ |-------- | +| `abs` | :x: | | +| `add` | :white_check_mark: | | +| `acos` | :x: | | +| `approx_max_k` | :x: | | +| `approx_min_k` | :x: | | +| `argmax` | :x: | | +| `argmin` | :x: | | +| `asin` | :x: | | +| `atan` | :x: | | +| `atan2` | :x: | | +| `batch_matmul` | :x: | | +| `bessel_i0e` | :x: | | +| `bessel_i1e` | :x: | | +| `betainc` | :x: | | +| `bitcast_convert_type` | :white_check_mark: | | +| `bitwise_not` | :x: | | +| `bitwise_and` | :x: | | +| `bitwise_or` | :x: | | +| `bitwise_xor` | :x: | | +| `population_count` | :x: | | +| `broadcast` | :white_check_mark: | | +| `broadcast_in_dim` | :white_check_mark: | | +| `cbrt` | :x: | | +| `ceil` | :x: | | +| `clamp` | :x: | | +| `collapse` | :white_check_mark: | | +| `complex` | :x: | | +| `concatenate` | :white_check_mark: | | +| `conj` | :x: | | +| `conv` | :x: | | +| `convert_element_type` | :white_check_mark: | | +| `conv_general_dilated` | :x: | | +| `conv_transpose` | :x: | | +| `cos` | :white_check_mark: | | +| `cosh` | :x: | | +| `cummax` | :x: | | +| `cummin` | :x: | | +| `cumprod` | :x: | | +| `cumsum` | :x: | | +| `digamma` | :x: | | +| `div` | :white_check_mark: | | +| `dot` | :white_check_mark: | | +| `dot_general` | :white_check_mark: | Limited set of configurations. See below. | +| `dynamic_slice` | :x: | | +| `dynamic_update_slice` | :x: | | +| `eq` | :white_check_mark: | | +| `erf` | :x: | | +| `erfc` | :x: | | +| `erf_inv` | :x: | | +| `exp` | :white_check_mark: | | +| `expand_dims` | :white_check_mark: | | +| `expm1` | :x: | | +| `fft` | :x: | | +| `floor` | :x: | | +| `full` | :question: | | +| `full_like` | :question: | | +| `gather` | :x: | | +| `ge` | :white_check_mark: | | +| `gt` | :white_check_mark: | | +| `igamma` | :x: | | +| `igammac` | :x: | | +| `imag` | :x: | | +| `index_in_dim` | :x: | | +| `index_take` | :x: | | +| `iota` | :white_check_mark: | | +| `is_finite` | :white_check_mark: | | +| `le` | :white_check_mark: | | +| `lt` | :white_check_mark: | | +| `lgamma` | :x: | | +| `log` | :white_check_mark: | | +| `log1p` | :x: | | +| `logistic` | :x: | | +| `max` | :white_check_mark: | | +| `min` | :white_check_mark: | | +| `mul` | :white_check_mark: | | +| `ne` | :white_check_mark: | | +| `neg` | :white_check_mark: | | +| `nextafter` | :x: | | +| `pad` | :x: | | +| `polygamma` | :x: | | +| `pow` | :x: | | +| `real` | :x: | | +| `reciprocal` | :x: | | +| `reduce` | :white_check_mark: | | +| `reshape` | :white_check_mark: | | +| `rem` | :x: | | +| `rev` | :x: | | +| `round` | :x: | | +| `rsqrt` | :x: | | +| `scatter` | :x: | | +| `scatter_add` | :x: | | +| `scatter_max` | :x: | | +| `scatter_min` | :x: | | +| `scatter_mul` | :x: | | +| `select` | :white_check_mark: | | +| `shift_left` | :x: | | +| `shift_right_arithmetic`| :x: | | +| `shift_right_logical` | :x: | | +| `slice` | :white_check_mark: | | +| `slice_in_dim` | :white_check_mark: | | +| `sign` | :x: | | +| `sin` | :white_check_mark: | | +| `sinh` | :x: | | +| `sort` | :x: | | +| `sort_key_val` | :x: | | +| `sqrt` | :x: | | +| `square` | :x: | | +| `squeeze` | :x: | | +| `sub` | :white_check_mark: | | +| `tan` | :x: | | +| `tie_in` | :x: | Deprecated in JAX | +| `top_k` | :x: | | +| `transpose` | :white_check_mark: | | +| `zeta` | :x: | |