Skip to content

Commit

Permalink
Add DeviceTensorTrait to export annotations (#240)
Browse files Browse the repository at this point in the history
Exported globals can now be parked to belong to a specific device. This
works with sharding tooling to ensure globals are not misassigned to the
default affinity device
  • Loading branch information
rsuderman authored Oct 24, 2024
1 parent 50e17a5 commit b81c6df
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
6 changes: 6 additions & 0 deletions iree/turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@

from ..tensor_traits import (
DeviceAffinity,
DeviceTensorTrait,
ExternalTensorTrait,
)

Expand Down Expand Up @@ -272,6 +273,7 @@ def create_tensor_global(
) -> Tuple[str, Operation, IrType]:
element_type = self.torch_dtype_to_iree_type(t.dtype)
external, external_scope, external_name = attrs.infer_external_from_tensor(t)
device = DeviceTensorTrait.get(t)

# Always create globals at the top. Then after created, if there was
# a prior one, move the new one to after it to maintain declaration
Expand All @@ -287,6 +289,10 @@ def create_tensor_global(
ir_attrs["noinline"] = UnitAttr.get()
if attrs.mutable:
ir_attrs["is_mutable"] = UnitAttr.get()
if device:
ir_attrs["iree.abi.affinity"] = Attribute.parse(
f"#hal.device.promise<@__device_{device.ordinal}>"
)
if external:
# Emit named external reference.
external_scope_attr = StringAttr.get(external_scope or "model")
Expand Down
21 changes: 21 additions & 0 deletions iree/turbine/aot/tensor_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

__all__ = [
"DeviceAffinity",
"DeviceTensorTrait",
"ExternalTensorTrait",
]

Expand All @@ -31,6 +32,26 @@ def __repr__(self) -> str:
return f"DeviceAffinity({self.ordinal})"


@dataclass
class DeviceTensorTrait:
"""Represents a 'trait' that can be applied to a Tensor to signal that
it is to be loaded to a speific device at execution time.
"""

ordinal: int

@staticmethod
def get(from_tensor: torch.Tensor) -> Optional["DeviceTensorTrait"]:
existing = getattr(from_tensor, "_turbine_device_tensor_trait", None)
if existing is None:
return None
assert isinstance(existing, DeviceTensorTrait)
return existing

def set(self, to_tensor: torch.Tensor):
to_tensor._turbine_device_tensor_trait = self # type: ignore


@dataclass
class ExternalTensorTrait:
"""Represents a 'trait' that can be applied to a Tensor to signal that
Expand Down
10 changes: 10 additions & 0 deletions tests/aot/params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
export,
externalize_module_parameters,
save_module_parameters,
DeviceTensorTrait,
ExternalTensorTrait,
ParameterArchive,
ParameterArchiveBuilder,
Expand Down Expand Up @@ -119,6 +120,15 @@ def testExternalTensorTrait(self):
self.assertIs(ExternalTensorTrait.get(t), trait)


class DeviceTensorTest(unittest.TestCase):
def testDeviceTensorTrait(self):
t = torch.ones([2, 3], dtype=torch.float32)
trait = DeviceTensorTrait(ordinal=7)
self.assertIsNone(trait.get(t))
trait.set(t)
self.assertIs(DeviceTensorTrait.get(t), trait)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

0 comments on commit b81c6df

Please sign in to comment.