Skip to content

Commit e159c75

Browse files
committed
expand trick and propagate device
1 parent 259ea74 commit e159c75

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

opt_einsum_fx/_efficient_shape_prop.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,21 @@ class EfficientShapeProp(torch.fx.Interpreter):
5353

5454
def run_node(self, n: Node) -> Any:
5555
if n.op == "call_function" and n.target in _EINSUM_FUNCS:
56-
equation, *operands = n.args
57-
shapes = [op.meta['tensor_meta'].shape for op in operands]
56+
args, kwargs = self.fetch_args_kwargs_from_env(n)
57+
equation, *operands = args
58+
shapes = [op.shape for op in operands]
5859

59-
assert len({op.meta['tensor_meta'].dtype for op in operands}) == 1
60-
meta = SimpleMeta(einsum_shape(equation, *shapes), operands[0].meta['tensor_meta'].dtype)
61-
result = torch.zeros((1 for _ in meta.shape), dtype=meta.dtype, device='cpu').expand(meta.shape)
60+
assert len({op.dtype for op in operands}) == 1
61+
meta = SimpleMeta(einsum_shape(equation, *shapes), operands[0].dtype)
62+
result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=operands[0].device).expand(meta.shape)
6263
elif n.op == "call_function" and n.target == torch.tensordot:
63-
shape_a, shape_b = [op.meta['tensor_meta'].shape for op in n.args]
64-
shape_a = [dim for i, dim in enumerate(shape_a) if i not in n.kwargs['dims'][0]]
65-
shape_b = [dim for i, dim in enumerate(shape_b) if i not in n.kwargs['dims'][1]]
64+
args, kwargs = self.fetch_args_kwargs_from_env(n)
65+
shape_a = [dim for i, dim in enumerate(args[0].shape) if i not in kwargs['dims'][0]]
66+
shape_b = [dim for i, dim in enumerate(args[1].shape) if i not in kwargs['dims'][1]]
6667

67-
assert len({op.meta['tensor_meta'].dtype for op in n.args}) == 1
68-
meta = SimpleMeta(shape_a + shape_b, n.args[0].meta['tensor_meta'].dtype)
69-
result = torch.zeros(meta.shape, dtype=meta.dtype, device='cpu')
68+
assert len({op.dtype for op in args}) == 1
69+
meta = SimpleMeta(shape_a + shape_b, args[0].dtype)
70+
result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=args[0].device).expand(meta.shape)
7071
else:
7172
result = super().run_node(n)
7273

0 commit comments

Comments
 (0)