Skip to content

Commit

Permalink
Add scuba logging to edge API's
Browse files Browse the repository at this point in the history
Differential Revision: D66385141

Pull Request resolved: #7103
  • Loading branch information
tarun292 authored Dec 2, 2024
1 parent c9d7b6e commit 2326fff
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 87 deletions.
66 changes: 15 additions & 51 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
exir.print_program.pretty_print(program)

deboxed_int_list = []
for item in program.execution_plan[0].values[5].val.items: # pyre-ignore[16]
deboxed_int_list.append(
program.execution_plan[0].values[item].val.int_val # pyre-ignore[16]
)
for item in program.execution_plan[0].values[5].val.items:
deboxed_int_list.append(program.execution_plan[0].values[item].val.int_val)

self.assertEqual(IntList(deboxed_int_list), IntList([2, 0, 1]))

Expand Down Expand Up @@ -459,11 +457,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Check the mul operator's stack trace contains f -> g -> h
self.assertTrue(
"return torch.mul(x, torch.randn(3, 2))"
in program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.stacktrace[1]
.items[-1]
.context
in program.execution_plan[0].chains[0].stacktrace[1].items[-1].context
)
self.assertEqual(
program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "f"
Expand Down Expand Up @@ -616,11 +610,7 @@ def false_fn(y: torch.Tensor) -> torch.Tensor:
if not isinstance(inst.instr_args, KernelCall):
continue

op = (
program.execution_plan[0]
.operators[inst.instr_args.op_index] # pyre-ignore[16]
.name
)
op = program.execution_plan[0].operators[inst.instr_args.op_index].name

if "mm" in op:
num_mm += 1
Expand Down Expand Up @@ -657,19 +647,13 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# generate the tensor on which this iteration will operate on.
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[0]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[0].instr_args.op_index
].name,
"aten::sym_size",
)
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[1]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[1].instr_args.op_index
].name,
"aten::select_copy",
)
Expand All @@ -681,28 +665,19 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# We check here that both of these have been generated.
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[-5]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[-5].instr_args.op_index
].name,
"executorch_prim::et_copy_index",
)
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[-4]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[-4].instr_args.op_index
].name,
"executorch_prim::add",
)
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[-3]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[-3].instr_args.op_index
].name,
"executorch_prim::eq",
)
Expand All @@ -716,10 +691,7 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
)
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[-1]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[-1].instr_args.op_index
].name,
"executorch_prim::sub",
)
Expand Down Expand Up @@ -1300,9 +1272,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# this triggers the actual emission of the graph
program = program_mul._emitter_output.program
node = None
program.execution_plan[0].chains[0].instructions[ # pyre-ignore[16]
0
].instr_args.op_index
program.execution_plan[0].chains[0].instructions[0].instr_args.op_index

# Find the multiplication node in the graph that was emitted.
for node in program_mul.exported_program().graph.nodes:
Expand All @@ -1314,7 +1284,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Find the multiplication instruction in the program that was emitted.
for idx in range(len(program.execution_plan[0].chains[0].instructions)):
instruction = program.execution_plan[0].chains[0].instructions[idx]
op_index = instruction.instr_args.op_index # pyre-ignore[16]
op_index = instruction.instr_args.op_index
if "mul" in program.execution_plan[0].operators[op_index].name:
break

Expand Down Expand Up @@ -1453,9 +1423,7 @@ def forward(self, x, y):
exec_prog._emitter_output.program
self.assertIsNotNone(exec_prog.delegate_map)
self.assertIsNotNone(exec_prog.delegate_map.get("forward"))
self.assertIsNotNone(
exec_prog.delegate_map.get("forward").get(0) # pyre-ignore[16]
)
self.assertIsNotNone(exec_prog.delegate_map.get("forward").get(0))
self.assertEqual(
exec_prog.delegate_map.get("forward").get(0).get("name"),
"BackendWithCompilerExample",
Expand Down Expand Up @@ -1568,9 +1536,7 @@ def forward(self, x):
model = model.to_executorch()
model.dump_executorch_program(True)
self.assertTrue(
model.executorch_program.execution_plan[0] # pyre-ignore[16]
.values[0]
.val.allocation_info
model.executorch_program.execution_plan[0].values[0].val.allocation_info
is not None
)
executorch_module = _load_for_executorch_from_buffer(model.buffer)
Expand Down Expand Up @@ -1611,9 +1577,7 @@ def forward(self, x):
)
model.dump_executorch_program(True)
self.assertTrue(
model.executorch_program.execution_plan[0] # pyre-ignore[16]
.values[0]
.val.allocation_info
model.executorch_program.execution_plan[0].values[0].val.allocation_info
is not None
)
executorch_module = _load_for_executorch_from_buffer(model.buffer)
Expand Down
3 changes: 2 additions & 1 deletion exir/program/TARGETS
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

Expand Down Expand Up @@ -43,7 +44,7 @@ python_library(
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/passes:weights_to_outputs_pass",
"//executorch/exir/verification:verifier",
],
] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else [])
)

python_library(
Expand Down
22 changes: 22 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,24 @@

Val = Any

from typing import Any, Callable

from torch.library import Library

try:
from executorch.exir.program.fb.logger import et_logger
except ImportError:
# Define a stub decorator that does nothing
def et_logger(api_name: str) -> Callable[[Any], Any]:
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
return func(self, *args, **kwargs)

return wrapper

return decorator


# This is the reserved namespace that is used to register ops to that will
# be prevented from being decomposed during to_edge_transform_and_lower.
edge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP"
Expand Down Expand Up @@ -957,6 +973,7 @@ def _gen_edge_manager_for_partitioners(
return edge_manager


@et_logger("to_edge_transform_and_lower")
def to_edge_transform_and_lower(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
transform_passes: Optional[
Expand Down Expand Up @@ -1110,6 +1127,7 @@ def to_edge_with_preserved_ops(
)


@et_logger("to_edge")
def to_edge(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
constant_methods: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -1204,8 +1222,10 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram:
"""
Returns the ExportedProgram specified by 'method_name'.
"""

return self._edge_programs[method_name]

@et_logger("transform")
def transform(
self,
passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]],
Expand Down Expand Up @@ -1253,6 +1273,7 @@ def transform(
new_programs, copy.deepcopy(self._config_methods), compile_config
)

@et_logger("to_backend")
def to_backend(
self, partitioner: Union[Partitioner, Dict[str, Partitioner]]
) -> "EdgeProgramManager":
Expand Down Expand Up @@ -1296,6 +1317,7 @@ def to_backend(
new_edge_programs, copy.deepcopy(self._config_methods), config
)

@et_logger("to_executorch")
def to_executorch(
self,
config: Optional[ExecutorchBackendConfig] = None,
Expand Down
26 changes: 8 additions & 18 deletions exir/tests/test_joint_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,21 @@ def forward(self, x, y):

# assert that the weight and bias have proper data_buffer_idx and allocation_info
self.assertEqual(
et.executorch_program.execution_plan[0] # pyre-ignore
.values[0]
.val.data_buffer_idx,
et.executorch_program.execution_plan[0].values[0].val.data_buffer_idx,
1,
)
self.assertEqual(
et.executorch_program.execution_plan[0] # pyre-ignore
.values[1]
.val.data_buffer_idx,
et.executorch_program.execution_plan[0].values[1].val.data_buffer_idx,
2,
)
self.assertEqual(
et.executorch_program.execution_plan[0] # pyre-ignore
et.executorch_program.execution_plan[0]
.values[0]
.val.allocation_info.memory_offset_low,
0,
)
self.assertEqual(
et.executorch_program.execution_plan[0] # pyre-ignore
et.executorch_program.execution_plan[0]
.values[1]
.val.allocation_info.memory_offset_low,
48,
Expand All @@ -106,7 +102,7 @@ def forward(self, x, y):

self.assertTrue(torch.allclose(loss, et_outputs[0]))
self.assertTrue(
torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore[6]
torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore
)
self.assertTrue(torch.allclose(m.linear.bias.grad, et_outputs[2]))
self.assertTrue(torch.allclose(m.linear.weight, et_outputs[3]))
Expand All @@ -118,23 +114,17 @@ def forward(self, x, y):

# gradient outputs start at index 1
self.assertEqual(
et.executorch_program.execution_plan[1] # pyre-ignore
.values[0]
.val.int_val,
et.executorch_program.execution_plan[1].values[0].val.int_val,
1,
)

self.assertEqual(
et.executorch_program.execution_plan[2] # pyre-ignore
.values[0]
.val.string_val,
et.executorch_program.execution_plan[2].values[0].val.string_val,
"linear.weight",
)

# parameter outputs start at index 3
self.assertEqual(
et.executorch_program.execution_plan[3] # pyre-ignore
.values[0]
.val.int_val,
et.executorch_program.execution_plan[3].values[0].val.int_val,
3,
)
24 changes: 7 additions & 17 deletions exir/tests/test_remove_view_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,24 +196,14 @@ def test_spec(self) -> None:
instructions = plan.chains[0].instructions
self.assertEqual(len(instructions), 7)

self.assertEqual(instructions[0].instr_args.op_index, 0) # view @ idx2
self.assertEqual(instructions[1].instr_args.op_index, 0) # view @ idx3
self.assertEqual(instructions[2].instr_args.op_index, 1) # aten:mul @ idx6
self.assertEqual(instructions[3].instr_args.op_index, 0) # view @ idx7
self.assertEqual(instructions[4].instr_args.op_index, 1) # aten:mul @ idx9
self.assertEqual(
instructions[0].instr_args.op_index, 0 # pyre-ignore
) # view @ idx2
self.assertEqual(
instructions[1].instr_args.op_index, 0 # pyre-ignore
) # view @ idx3
self.assertEqual(
instructions[2].instr_args.op_index, 1 # pyre-ignore
) # aten:mul @ idx6
self.assertEqual(
instructions[3].instr_args.op_index, 0 # pyre-ignore
) # view @ idx7
self.assertEqual(
instructions[4].instr_args.op_index, 1 # pyre-ignore
) # aten:mul @ idx9
self.assertEqual(
instructions[5].instr_args.op_index, 2 # pyre-ignore
instructions[5].instr_args.op_index, 2
) # aten:view_copy @ idx11
self.assertEqual(
instructions[6].instr_args.op_index, 2 # pyre-ignore
instructions[6].instr_args.op_index, 2
) # aten:view_copy @ idx11

0 comments on commit 2326fff

Please sign in to comment.