Skip to content

Commit 259ea74

Browse files
Update opt_einsum_fx/_efficient_shape_prop.py
Co-authored-by: Alby M. <[email protected]>
1 parent 11857f7 commit 259ea74

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

opt_einsum_fx/_efficient_shape_prop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def run_node(self, n: Node) -> Any:
5858

5959
assert len({op.meta['tensor_meta'].dtype for op in operands}) == 1
6060
meta = SimpleMeta(einsum_shape(equation, *shapes), operands[0].meta['tensor_meta'].dtype)
61-
result = torch.zeros(meta.shape, dtype=meta.dtype, device='cpu')
61+
result = torch.zeros((1 for _ in meta.shape), dtype=meta.dtype, device='cpu').expand(meta.shape)
6262
elif n.op == "call_function" and n.target == torch.tensordot:
6363
shape_a, shape_b = [op.meta['tensor_meta'].shape for op in n.args]
6464
shape_a = [dim for i, dim in enumerate(shape_a) if i not in n.kwargs['dims'][0]]

0 commit comments

Comments
 (0)