Skip to content

Commit

Permalink
KT unflatten issue with torch.export
Browse files Browse the repository at this point in the history
Summary:
# context

current error:
```
  1) torchrec.fb.ir.tests.test_serializer.TestSerializer: test_deserialized_device_vle
    1) RuntimeError: Node ir_dynamic_batch_emb_lookup_default referenced nonexistent value id_list_features__values! Run Graph.lint() to diagnose such issues

    While executing %ir_dynamic_batch_emb_lookup_default : [num_users=1] = call_function[target=torch.ops.torchrec.ir_dynamic_batch_emb_lookup.default](args = ([%id_list_features__values, None, %id_list_features__lengths, None], %floordiv, [4, 5]), kwargs = {})
    Original traceback:
    File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/009ebbab256a7e75/torchrec/fb/ir/tests/__test_serializer__/test_serializer#link-tree/torchrec/fb/ir/tests/test_serializer.py", line 142, in forward
        return self.sparse_arch(id_list_features)
      File "torchrec/fb/ir/tests/test_serializer.py", line 446, in test_deserialized_device_vle
        output = deserialized_model(features_batch_3.to(device))
      File "torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "torch/nn/modules/module.py", line 1747, in _call_impl
        return forward_call(*args, **kwargs)
      File "torch/export/unflatten.py", line 482, in forward
        tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
      File "torch/fx/interpreter.py", line 146, in run
        self.env[node] = self.run_node(node)
      File "torch/fx/interpreter.py", line 200, in run_node
        args, kwargs = self.fetch_args_kwargs_from_env(n)
      File "torch/fx/interpreter.py", line 372, in fetch_args_kwargs_from_env
        args = self.map_nodes_to_values(n.args, n)
      File "torch/fx/interpreter.py", line 394, in map_nodes_to_values
        return map_arg(args, load_arg)
      File "torch/fx/node.py", line 760, in map_arg
        return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
      File "torch/fx/node.py", line 768, in map_aggregate
        t = tuple(map_aggregate(elem, fn) for elem in a)
      File "torch/fx/node.py", line 768, in <genexpr>
        t = tuple(map_aggregate(elem, fn) for elem in a)
      File "torch/fx/node.py", line 772, in map_aggregate
        return immutable_list(map_aggregate(elem, fn) for elem in a)
      File "torch/fx/node.py", line 772, in <genexpr>
        return immutable_list(map_aggregate(elem, fn) for elem in a)
      File "torch/fx/node.py", line 778, in map_aggregate
        return fn(a)
      File "torch/fx/node.py", line 760, in <lambda>
        return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
      File "torch/fx/interpreter.py", line 391, in load_arg
        raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() '
```

Differential Revision: D59238744
  • Loading branch information
Huanyu He authored and facebook-github-bot committed Sep 12, 2024
1 parent ff6dc0a commit 75982ff
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2943,7 +2943,9 @@ def _kt_unflatten(


def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]:
return _kt_flatten(kt)[0]
_keys, _length_per_key = spec.context
res = KeyedTensor.regroup([kt], [_keys])
return [res[0]]


# The assumption here in torch.exporting KeyedTensor is that _length_per_key is static
Expand Down

0 comments on commit 75982ff

Please sign in to comment.