Skip to content

Commit

Permalink
Add device affinities for arguments in AOT (#231)
Browse files Browse the repository at this point in the history
We don't have support for providing device affinities for function
arguments, which need to end up as MLIR function argument attributes.

This change adds a class DeviceAffinity and provides the ability to
supply affinities when exporting Torch functions/modules or when tracing
in IREE-Trubine itself.

Signed-off-by: Boian Petkantchin <boian.petkantchin@amd.com>
  • Loading branch information
sogartar authored Oct 23, 2024
1 parent 67b253a commit ae9a51c
Show file tree
Hide file tree
Showing 12 changed files with 315 additions and 19 deletions.
87 changes: 83 additions & 4 deletions iree/turbine/aot/compiled_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
ModuleBuilderOptions,
)

from .tensor_traits import DeviceAffinity


__all__ = [
"CompiledModule",
Expand Down Expand Up @@ -107,12 +109,27 @@ def __call__(self, *args, **kwargs):
return self.py_value(*args, **kwargs)


class ExportTargetDef:
def __init__(
self,
target: Union[Callable, ExportedProgram],
*,
arg_device: dict[int, DeviceAffinity] | None = None,
):
self.target = target
self.arg_device = arg_device

def __call__(self, *args, **kwargs):
return self.target(*args, **kwargs)


class ExportProcDef:
__slots__ = [
"callable",
"export_name",
"signature",
"file_line_loc",
"arg_device",
]

def __init__(
Expand All @@ -122,14 +139,22 @@ def __init__(
*,
signature,
file_line_loc: Optional[Tuple[str, int]] = None,
arg_device: dict[int, DeviceAffinity] | None = None,
):
self.export_name = export_name
self.callable = callable
self.signature = signature
self.file_line_loc = file_line_loc
self.arg_device = arg_device

def copy(self) -> "ExportProcDef":
return ExportProcDef(self.export_name, self.callable, signature=self.signature)
return ExportProcDef(
self.export_name,
self.callable,
signature=self.signature,
file_line_loc=self.file_line_loc,
arg_device=self.arg_device,
)

def __repr__(self):
return f"<def {self.export_name}({self.signature})>"
Expand All @@ -142,14 +167,19 @@ def __init__(
*,
export_name: Optional[str] = None,
public: bool = False,
arg_device: dict[int, DeviceAffinity] | None = None,
):
self.export_name = export_name
self.exported_program = ep
self.public = public
self.arg_device = arg_device

def copy(self) -> "ExportedProgramDef":
return ExportedProgramDef(
self.exported_program, export_name=self.export_name, public=self.public
self.exported_program,
export_name=self.export_name,
public=self.public,
arg_device=self.arg_device,
)

def __repr__(self):
Expand Down Expand Up @@ -207,6 +237,19 @@ def globals_defs(self) -> Generator[Tuple[str, GlobalsDef], None, None]:
) # type: ignore

def def_attribute(self, key, value):
if isinstance(value, ExportTargetDef):
if not isinstance(value.target, ExportedProgram):
# We expect exported function.
assert callable(value.target) and inspect.isfunction(value.target)
return self.def_export_proc(key, value.target, value.arg_device)

value = ExportedProgramDef(
value.target,
export_name=key,
public=not key.startswith("_"),
arg_device=value.arg_device,
)

# Some decorators, the only thing we do is convert them to PyOnlyDef.
# Do that first so the generic descriptor code below handles them.
if isinstance(value, builtins.jittable):
Expand All @@ -233,6 +276,15 @@ def def_attribute(self, key, value):
logging.debug("DEFINE PY_ONLY: %s = %r", key, value)
self.add_export(key, value)
return value
if isinstance(value, ExportTargetDef) and isinstance(
value.target, ExportedProgram
):
value = ExportedProgramDef(
value.target,
export_name=key,
public=not key.startswith("_"),
arg_device=value.arg_device,
)
if isinstance(value, ExportedProgramDef):
if value.export_name is None:
value = value.copy()
Expand All @@ -250,7 +302,12 @@ def def_attribute(self, key, value):
f"compiled module: {value!r}"
)

def def_export_proc(self, name, f) -> ExportProcDef:
def def_export_proc(
self,
name,
f,
arg_device: dict[int, DeviceAffinity] | None = None,
) -> ExportProcDef:
logging.debug("DEFINE EXPORT: %s = %r", name, f)
# Get a reasonable location.
file_line_loc = None
Expand Down Expand Up @@ -292,7 +349,13 @@ def def_export_proc(self, name, f) -> ExportProcDef:
)
input_sig.append(param_desc)

info = ExportProcDef(name, f, signature=input_sig, file_line_loc=file_line_loc)
info = ExportProcDef(
name,
f,
signature=input_sig,
file_line_loc=file_line_loc,
arg_device=arg_device,
)
self.add_export(name, info)
return info

Expand Down Expand Up @@ -568,6 +631,20 @@ def save_mlir(inst: "CompiledModule", path: Union[Path, str]):

jittable = staticmethod(builtins.jittable)

@staticmethod
def signature_info(
*,
arg_device: dict[int, DeviceAffinity] | None = None,
) -> Callable:
"""Annotate an export target function.
This annotation is only required when additional information needs to be
provided."""

def _decorator(f: Callable):
return ExportTargetDef(f, arg_device=arg_device)

return _decorator

def __getattr__(self, name):
info = CompiledModule.get_info(self)
try:
Expand Down Expand Up @@ -633,6 +710,7 @@ def __new__(
ep_def.exported_program,
symbol_name=ep_def.export_name or "main",
symbol_visibility=None if ep_def.public else "private",
arg_device=ep_def.arg_device,
)

# Instantiate procs.
Expand Down Expand Up @@ -661,6 +739,7 @@ def invoke_with_self(*args, **kwargs):
posargs=proc_def.signature,
kwargs={}, # TODO(#128): kwargs
loc=loc,
arg_device=proc_def.arg_device,
)
trace.trace_py_func(invoke_with_self)
info.shadow_dict[key] = _uncallable_public_export
Expand Down
20 changes: 18 additions & 2 deletions iree/turbine/aot/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from .fx_programs import FxPrograms
from . import decompositions

from .tensor_traits import DeviceAffinity

__all__ = [
"export",
"ExportOutput",
Expand Down Expand Up @@ -177,6 +179,7 @@ def export(
function_name: Optional[str] = None,
strict_export: bool = True,
import_symbolic_shape_expressions: bool = False,
arg_device: dict[int, DeviceAffinity] | None = None,
) -> ExportOutput:
"""Exports a torch.nn.Module.
Expand All @@ -199,6 +202,7 @@ def export(
*,
module_name: Optional[str] = None,
function_name: Optional[str] = None,
arg_device: dict[int, DeviceAffinity] | None = None,
) -> ExportOutput:
"""Exports a single entry-point module consisting of an ExportedProgram."""
...
Expand Down Expand Up @@ -226,6 +230,7 @@ def export(
function_name: Optional[str] = None,
strict_export: bool = True,
import_symbolic_shape_expressions: bool = False,
arg_device: dict[int, DeviceAffinity] | None = None,
) -> ExportOutput:
"""Generic export of supported entities.
Expand All @@ -247,6 +252,10 @@ def export(
must be empty.
kwargs: Example keyword arguments.
dynamic_shapes: Dynamic shape specs to pass to torch.export.
arg_device: device affinities for the exported function
arguments. On what devices should the program expect its arguments.
It is a mapping of argument index to device affinity of the flattened
arguments.
Returns:
An ExportOutput object that wraps the compilation and provides
Expand All @@ -266,12 +275,14 @@ def export(
"This is an experimental feature in PyTorch that the IREE Turbine project is still evaluating. Please report issues or experiences."
)

from .compiled_module import ExportTargetDef

TransformedModule: Any
current_decomps = decompositions.current_aot_decompositions()
if isinstance(mdl, torch.export.ExportedProgram):
TransformedModule = CompiledModule.create_from_dict(
"LambdaCompiledModule",
{(function_name or "main"): mdl},
{(function_name or "main"): ExportTargetDef(mdl, arg_device=arg_device)},
export_name=module_name or "module",
options=ModuleBuilderOptions(
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
Expand Down Expand Up @@ -311,7 +322,12 @@ def export(

TransformedModule = CompiledModule.create_from_dict(
"LambdaCompiledModule",
{(function_name or "main"): exported_program},
{
(function_name or "main"): ExportTargetDef(
exported_program,
arg_device=arg_device,
)
},
export_name=module_name or "module",
options=ModuleBuilderOptions(
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
Expand Down
22 changes: 18 additions & 4 deletions iree/turbine/aot/fx_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@
import os
from pathlib import Path
from typing import Any, Optional, Union
from .compiled_module import ExportTargetDef

import functools

import torch
import torch.nn as nn

from .decompositions import current_aot_decompositions
from .tensor_traits import DeviceAffinity

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .compiled_module import ExportTargetDef

# The dynamic_shapes support showed up in the Torch 2.3 timeframe.
_supports_dynamic_shapes = hasattr(torch.export, "Dim")
Expand Down Expand Up @@ -61,7 +68,7 @@ class FxPrograms:
"""

def __init__(self):
self.programs: dict[str, torch.export.ExportedProgram] = {}
self.programs: dict[str, ExportTargetDef] = {}

def save(self, path: Union[str, os.PathLike]) -> int:
"""Saves the set of exported programs to a descriptor file.
Expand All @@ -86,7 +93,9 @@ def permute_path(name):
count_deduped = 0

# Save each.
for program_name, ep in self.programs.items():
for program_name, export_def in self.programs.items():
ep = export_def.target
assert isinstance(ep, torch.export.ExportedProgram)
# First validate the ep with normal rules, which we will then
# disable since we are violating the spec.
ep._validate()
Expand Down Expand Up @@ -129,7 +138,7 @@ def load(path: Union[str, os.PathLike]) -> "FxPrograms":
ep = torch.export.load(path.parent / program_file_name)
_unsharify_state_dict(shared_state_dict, ep.state_dict)
_unsharify_state_dict(shared_constants, _get_optional_constants(ep))
instance.programs[program_name] = ep
instance.programs[program_name] = ExportTargetDef(ep)
return instance


Expand Down Expand Up @@ -169,6 +178,7 @@ def export_program(
dynamic_shapes=None,
strict: bool = True,
name: Optional[str] = None,
arg_device: dict[int, DeviceAffinity] | None = None,
):
if f is None:
return functools.partial(
Expand All @@ -178,6 +188,7 @@ def export_program(
strict=strict,
dynamic_shapes=dynamic_shapes,
name=name,
arg_device=arg_device,
)

if name is None:
Expand Down Expand Up @@ -234,7 +245,10 @@ def new_forward(self, *forward_args, **forward_kwargs):

_patch_op_dispatch_for_export()
program = program.run_decompositions(current_decomps)
fx_builder.programs[name] = program
fx_builder.programs[name] = ExportTargetDef(
program,
arg_device=arg_device,
)
return program


Expand Down
Loading

0 comments on commit ae9a51c

Please sign in to comment.