Skip to content

Commit

Permalink
fix nanobind differences (and others) (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental authored Dec 25, 2024
1 parent 8984cf8 commit e7592ab
Show file tree
Hide file tree
Showing 17 changed files with 183 additions and 75 deletions.
12 changes: 3 additions & 9 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,12 @@ jobs:
fail-fast: false
matrix:
os: [ ubuntu-22.04, macos-13, macos-14, windows-2022 ]
py_version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
py_version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ]

exclude:
- os: macos-13
py_version: "3.8"

- os: macos-13
py_version: "3.9"

- os: macos-14
py_version: "3.8"

- os: macos-14
py_version: "3.9"

Expand Down Expand Up @@ -174,7 +168,7 @@ jobs:
fail-fast: false
matrix:
os: [ ubuntu-22.04 ]
py_version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
py_version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ]

steps:
- name: Checkout
Expand All @@ -189,7 +183,7 @@ jobs:
install: |
apt-get update -q -y
apt-get install -y wget build-essential
apt-get install -y wget build-essential git
mkdir -p ~/miniconda3
wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh -O miniconda.sh
Expand Down
2 changes: 2 additions & 0 deletions examples/mwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def pats():
.finalize_memref_to_llvm()
# Convert Func to LLVM (always needed).
.convert_func_to_llvm()
.convert_arith_to_llvm()
.convert_cf_to_llvm()
# Convert Index to LLVM (always needed).
.convert_index_to_llvm()
# Convert remaining unrealized_casts (always needed).
Expand Down
2 changes: 2 additions & 0 deletions examples/vectorization_e2e.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@
" .finalize_memref_to_llvm()\n",
" # Convert Func to LLVM (always needed).\n",
" .convert_func_to_llvm()\n",
" .convert_arith_to_llvm()\n",
" .convert_cf_to_llvm()\n",
" # Convert Index to LLVM (always needed).\n",
" .convert_index_to_llvm()\n",
" # Convert remaining unrealized_casts (always needed).\n",
Expand Down
2 changes: 1 addition & 1 deletion mlir/extras/ast/canonicalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def transform_ast(
module_code_o = compile(module, f.__code__.co_filename, "exec")
new_f_code_o = find_func_in_code_object(module_code_o, f.__name__)
n_lines = len(inspect.getsource(f).splitlines())
line_starts = list(findlinestarts(new_f_code_o))
line_starts = list(filter(lambda el: el[1], findlinestarts(new_f_code_o)))
if (
max([l for _, l in line_starts]) - min([l for _, l in line_starts]) + 1
> n_lines
Expand Down
21 changes: 20 additions & 1 deletion mlir/extras/dialects/ext/_shaped_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


# mixin that requires `is_constant`
class ShapedValue:
def ShapedValue(cls):
@cached_property
def literal_value(self) -> np.ndarray:
if not self.is_constant:
Expand Down Expand Up @@ -42,3 +42,22 @@ def n_elements(self) -> int:
@cached_property
def dtype(self) -> Type:
return self._shaped_type.element_type

setattr(cls, "literal_value", literal_value)
cls.literal_value.__set_name__(None, "literal_value")
setattr(cls, "_shaped_type", _shaped_type)
cls._shaped_type.__set_name__(None, "_shaped_type")

setattr(cls, "has_static_shape", has_static_shape)
setattr(cls, "has_rank", has_rank)

setattr(cls, "rank", rank)
cls.rank.__set_name__(None, "rank")
setattr(cls, "shape", shape)
cls.shape.__set_name__(None, "shape")
setattr(cls, "n_elements", n_elements)
cls.n_elements.__set_name__(None, "n_elements")
setattr(cls, "dtype", dtype)
cls.dtype.__set_name__(None, "dtype")

return cls
9 changes: 8 additions & 1 deletion mlir/extras/dialects/ext/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Tuple, Union

from bytecode import ConcreteBytecode
from einspect.structs import PyTypeObject

from ...ast.canonicalize import StrictTransformer, Canonicalizer, BytecodePatcher
from ...ast.util import ast_call
Expand Down Expand Up @@ -138,7 +139,13 @@ def index_cast(
)


class ArithValueMeta(type(Value)):
nb_meta_cls = type(Value)

_Py_TPFLAGS_BASETYPE = 1 << 10
PyTypeObject.from_object(nb_meta_cls).tp_flags |= _Py_TPFLAGS_BASETYPE


class ArithValueMeta(nb_meta_cls):
"""Metaclass that orchestrates the Python object protocol
(i.e., calling __new__ and __init__) for Indexing dialect extension values
(created using `mlir_value_subclass`).
Expand Down
4 changes: 3 additions & 1 deletion mlir/extras/dialects/ext/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def __init__(

def _is_decl(self):
# magic constant found from looking at the code for an empty fn
if sys.version_info.minor == 12:
if sys.version_info.minor == 13:
return self.body_builder.__code__.co_code == b"\x95\x00g\x00"
elif sys.version_info.minor == 12:
return self.body_builder.__code__.co_code == b"\x97\x00y\x00"
elif sys.version_info.minor == 11:
return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00"
Expand Down
3 changes: 2 additions & 1 deletion mlir/extras/dialects/ext/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def store(


@register_value_caster(MemRefType.static_typeid)
class MemRef(Value, ShapedValue):
@ShapedValue
class MemRef(Value):
def __str__(self):
return f"{self.__class__.__name__}({self.get_name()}, {self.type})"

Expand Down
3 changes: 2 additions & 1 deletion mlir/extras/dialects/ext/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def insert_slice(

# TODO(max): unify vector/memref/tensor
@register_value_caster(RankedTensorType.static_typeid)
class Tensor(ShapedValue, ArithValue):
@ShapedValue
class Tensor(ArithValue):
def __getitem__(self, idx: tuple) -> "Tensor":
loc = get_user_code_loc()

Expand Down
5 changes: 3 additions & 2 deletions mlir/extras/dialects/ext/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@


@register_value_caster(VectorType.static_typeid)
class Vector(ShapedValue, ArithValue):
@ShapedValue
class Vector(ArithValue):
def __getitem__(self, idx: tuple) -> "Vector":
loc = get_user_code_loc()

Expand Down Expand Up @@ -105,7 +106,7 @@ def transfer_read(
if isinstance(padding, int):
padding = constant(padding, type=source.type.element_type)
if in_bounds is None:
in_bounds = [None] * len(permutation_map.results)
raise ValueError("in_bounds cannot be None")

return _transfer_read(
vector=vector_t,
Expand Down
71 changes: 44 additions & 27 deletions mlir/extras/runtime/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ def affine_expand_index_ops(self):
self.add_pass("affine-expand-index-ops")
return self

def affine_expand_index_ops_as_affine(self):
"""Lower affine operations operating on indices into affine.apply operations"""
self.add_pass("affine-expand-index-ops-as-affine")
return self

def affine_loop_coalescing(self):
"""Coalesce nested loops with independent bounds into a single loop"""
self.add_pass("affine-loop-coalescing")
Expand Down Expand Up @@ -1363,10 +1368,6 @@ def convert_func_to_llvm(
returns are updated accordingly. Block argument types are updated to use
LLVM IR types.
Note that until https://github.com/llvm/llvm-project/issues/70982 is resolved,
this pass includes patterns that lower `arith` and `cf` to LLVM. This is legacy
code due to when they were all converted in the same pass.
Args:
use-bare-ptr-memref-call-conv: Replace FuncOp's MemRef arguments with bare pointers to the MemRef element types
index-bitwidth: Bitwidth of the index type, 0 to use size of machine word
Expand Down Expand Up @@ -1398,12 +1399,12 @@ def convert_gpu_launch_to_vulkan_launch(self):
self.add_pass("convert-gpu-launch-to-vulkan-launch")
return self

def convert_gpu_to_llvm_spv(self, index_bitwidth: int = None):
def convert_gpu_to_llvm_spv(self, use_64bit_index: bool = None):
"""Generate LLVM operations to be ingested by a SPIR-V backend for gpu operations
Args:
index-bitwidth: Bitwidth of the index type, 0 to use size of machine word
use-64bit-index: Use 64-bit integers to convert index types
"""
self.add_pass("convert-gpu-to-llvm-spv", index_bitwidth=index_bitwidth)
self.add_pass("convert-gpu-to-llvm-spv", use_64bit_index=use_64bit_index)
return self

def convert_gpu_to_nvvm(
Expand Down Expand Up @@ -1597,6 +1598,20 @@ def convert_memref_to_spirv(
)
return self

def convert_mesh_to_mpi(self):
"""Convert Mesh dialect to MPI dialect.
This pass converts communication operations from the Mesh dialect to the
MPI dialect.
If it finds a global named "static_mpi_rank" it will use that splat value
instead of calling MPI_Comm_rank. This allows optimizations like constant
shape propagation and fusion because shard/partition sizes depend on the
rank.
"""
self.add_pass("convert-mesh-to-mpi")
return self

def convert_nvgpu_to_nvvm(self):
"""Convert NVGPU dialect to NVVM dialect
Expand Down Expand Up @@ -1715,17 +1730,26 @@ def convert_tensor_to_spirv(self, emulate_lt_32_bit_scalar_types: bool = None):
)
return self

def convert_to_llvm(self, filter_dialects: List[str] = None):
def convert_to_llvm(self, filter_dialects: List[str] = None, dynamic: bool = None):
"""Convert to LLVM via dialect interfaces found in the input IR
This is a generic pass to convert to LLVM, it uses the
`ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
the injection of conversion patterns.
If `dynamic` is set to `true`, the pass will look for
`ConvertToLLVMAttrInterface` attributes and use them to further configure
the conversion process. This option also uses the `DataLayoutAnalysis`
analysis to configure the type converter. Enabling this option incurs in
extra overhead.
Args:
filter-dialects: Test conversion patterns of only the specified dialects
dynamic: Use op conversion attributes to configure the conversion
"""
self.add_pass("convert-to-llvm", filter_dialects=filter_dialects)
self.add_pass(
"convert-to-llvm", filter_dialects=filter_dialects, dynamic=dynamic
)
return self

def convert_to_spirv(
Expand Down Expand Up @@ -2082,23 +2106,6 @@ def finalize_memref_to_llvm(
)
return self

def finalizing_bufferize(self):
"""Finalize a partial bufferization
A bufferize pass that finalizes a partial bufferization by removing
remaining `bufferization.to_tensor` and `bufferization.to_buffer` operations.
The removal of those operations is only possible if the operations only
exist in pairs, i.e., all uses of `bufferization.to_tensor` operations are
`bufferization.to_buffer` operations.
This pass will fail if not all operations can be removed or if any operation
with tensor typed operands remains.
"""
self.add_pass("finalizing-bufferize")
return self

def fold_memref_alias_ops(self):
"""Fold memref alias ops into consumer load/store ops
Expand Down Expand Up @@ -2201,6 +2208,7 @@ def gpu_module_to_binary(
l: List[str] = None,
opts: str = None,
format: str = None,
section: str = None,
):
"""Transforms a GPU module into a GPU binary.
Expand All @@ -2219,9 +2227,15 @@ def gpu_module_to_binary(
l: Extra files to link to.
opts: Command line options to pass to the tools.
format: The target representation of the compilation process.
section: ELF section where binary is to be located.
"""
self.add_pass(
"gpu-module-to-binary", toolkit=toolkit, l=l, opts=opts, format=format
"gpu-module-to-binary",
toolkit=toolkit,
l=l,
opts=opts,
format=format,
section=section,
)
return self

Expand Down Expand Up @@ -2893,6 +2907,7 @@ def one_shot_bufferize(
no_analysis_func_filter: List[str] = None,
function_boundary_type_conversion: str = None,
must_infer_memory_space: bool = None,
use_encoding_for_memory_space: bool = None,
test_analysis_only: bool = None,
print_conflicts: bool = None,
unknown_type_conversion: str = None,
Expand Down Expand Up @@ -3017,6 +3032,7 @@ def one_shot_bufferize(
no-analysis-func-filter: Skip analysis of functions with these symbol names.Set copyBeforeWrite to true when bufferizing them.
function-boundary-type-conversion: Controls layout maps when bufferizing function signatures.
must-infer-memory-space: The memory space of an memref types must always be inferred. If unset, a default memory space of 0 is used otherwise.
use-encoding-for-memory-space: Use the Tensor encoding attribute for the memory space. Exclusive to the 'must-infer-memory-space' option
test-analysis-only: Test only: Only run inplaceability analysis and annotate IR
print-conflicts: Test only: Annotate IR with RaW conflicts. Requires test-analysis-only.
unknown-type-conversion: Controls layout maps for non-inferrable memref types.
Expand All @@ -3036,6 +3052,7 @@ def one_shot_bufferize(
no_analysis_func_filter=no_analysis_func_filter,
function_boundary_type_conversion=function_boundary_type_conversion,
must_infer_memory_space=must_infer_memory_space,
use_encoding_for_memory_space=use_encoding_for_memory_space,
test_analysis_only=test_analysis_only,
print_conflicts=print_conflicts,
unknown_type_conversion=unknown_type_conversion,
Expand Down
11 changes: 6 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
PyYAML>=6.0.2
astpretty
PyYAML
astunparse
black
bytecode
inflection
numpy~=1.0
astunparse
cloudpickle>=3.0.0
einspect==0.5.16
einspect @ git+https://github.com/makslevental/einspect@makslevental/bump-py.3.13
inflection
ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16
numpy>=1.19.5, <=2.1.2
1 change: 1 addition & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_simple_parfor(ctx: MLIRContext, backend: LLVMJITBackend):
.convert_arith_to_llvm()
.finalize_memref_to_llvm()
.convert_func_to_llvm()
.convert_cf_to_llvm()
.reconcile_unrealized_casts(),
generate_kernel_wrapper=True,
generate_return_consumer=True,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def demo_fun1():
def test_declare_byte_rep(ctx: MLIRContext):
def demo_fun1(): ...

if sys.version_info.minor == 12:
if sys.version_info.minor == 13:
assert demo_fun1.__code__.co_code == b"\x95\x00g\x00"
elif sys.version_info.minor == 12:
assert demo_fun1.__code__.co_code == b"\x97\x00y\x00"
elif sys.version_info.minor == 11:
assert demo_fun1.__code__.co_code == b"\x97\x00d\x00S\x00"
Expand Down
Loading

0 comments on commit e7592ab

Please sign in to comment.