@@ -53,20 +53,21 @@ class EfficientShapeProp(torch.fx.Interpreter):
5353
5454 def run_node (self , n : Node ) -> Any :
5555 if n .op == "call_function" and n .target in _EINSUM_FUNCS :
56- equation , * operands = n .args
57- shapes = [op .meta ['tensor_meta' ].shape for op in operands ]
56+ args , kwargs = self .fetch_args_kwargs_from_env (n )
57+ equation , * operands = args
58+ shapes = [op .shape for op in operands ]
5859
59- assert len ({op .meta [ 'tensor_meta' ]. dtype for op in operands }) == 1
60- meta = SimpleMeta (einsum_shape (equation , * shapes ), operands [0 ].meta [ 'tensor_meta' ]. dtype )
61- result = torch .zeros ((1 for _ in meta .shape ), dtype = meta .dtype , device = 'cpu' ).expand (meta .shape )
60+ assert len ({op .dtype for op in operands }) == 1
61+ meta = SimpleMeta (einsum_shape (equation , * shapes ), operands [0 ].dtype )
62+ result = torch .zeros ((1 ,) * len ( meta .shape ), dtype = meta .dtype , device = operands [ 0 ]. device ).expand (meta .shape )
6263 elif n .op == "call_function" and n .target == torch .tensordot :
63- shape_a , shape_b = [ op . meta [ 'tensor_meta' ]. shape for op in n . args ]
64- shape_a = [dim for i , dim in enumerate (shape_a ) if i not in n . kwargs ['dims' ][0 ]]
65- shape_b = [dim for i , dim in enumerate (shape_b ) if i not in n . kwargs ['dims' ][1 ]]
64+ args , kwargs = self . fetch_args_kwargs_from_env ( n )
65+ shape_a = [dim for i , dim in enumerate (args [ 0 ]. shape ) if i not in kwargs ['dims' ][0 ]]
66+ shape_b = [dim for i , dim in enumerate (args [ 1 ]. shape ) if i not in kwargs ['dims' ][1 ]]
6667
67- assert len ({op .meta [ 'tensor_meta' ]. dtype for op in n . args }) == 1
68- meta = SimpleMeta (shape_a + shape_b , n . args [0 ]. meta [ 'tensor_meta' ].dtype )
69- result = torch .zeros (meta .shape , dtype = meta .dtype , device = 'cpu' )
68+ assert len ({op .dtype for op in args }) == 1
69+ meta = SimpleMeta (shape_a + shape_b , args [0 ].dtype )
70+ result = torch .zeros (( 1 ,) * len ( meta .shape ) , dtype = meta .dtype , device = args [ 0 ]. device ). expand ( meta . shape )
7071 else :
7172 result = super ().run_node (n )
7273
0 commit comments