Skip to content

Commit 63099f9

Browse files
authored
Fix memory analysis crash when there is multiple free with same address (#560)
1 parent 742260c commit 63099f9

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

tb_plugin/torch_tb_profiler/profiler/memory_parser.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ def get_preprocessed_records(self):
305305
memory_records = sorted(self.memory_records, key=lambda r: r.ts)
306306

307307
alloc = {} # allocation events may or may not have paired free event
308-
free = {} # free events that does not have paired alloc event
309308
prev_ts = float('-inf') # ensure ordered memory records is ordered
310309
for i, r in enumerate(memory_records):
311310
if r.addr is None:
@@ -326,10 +325,4 @@ def get_preprocessed_records(self):
326325
r.op_name = alloc_r.op_name
327326
r.parent_op_name = alloc_r.parent_op_name
328327
del alloc[addr]
329-
else:
330-
assert addr not in free
331-
free[addr] = i
332-
333-
if free:
334-
logger.debug(f'{len(free)} memory records do not have associated operator.')
335328
return memory_records

tb_plugin/torch_tb_profiler/run.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from collections import defaultdict
55
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
66

7-
from . import consts
7+
from . import consts, utils
88
from .profiler.diffrun import compare_op_tree, diff_summary
99
from .profiler.memory_parser import MemoryMetrics, MemoryRecord, MemorySnapshot
1010
from .profiler.module_op import Stats
1111
from .profiler.node import OperatorNode
1212
from .utils import Canonicalizer, DisplayRounder
1313

14+
logger = utils.get_logger()
15+
1416

1517
class Run(object):
1618
""" A profiler run. For visualization purpose only.
@@ -341,7 +343,7 @@ def get_op_name_or_ctx(record: MemoryRecord):
341343
# profile json data prior to pytorch 1.10 do not have addr
342344
# we should ignore them
343345
continue
344-
assert prev_ts < r.ts
346+
assert prev_ts <= r.ts
345347
prev_ts = r.ts
346348
addr = r.addr
347349
size = r.bytes
@@ -362,7 +364,8 @@ def get_op_name_or_ctx(record: MemoryRecord):
362364
])
363365
del alloc[addr]
364366
else:
365-
assert addr not in free
367+
if addr in free:
368+
logger.warning(f'Address {addr} is freed multiple times')
366369
free[addr] = i
367370

368371
for i in alloc.values():

0 commit comments

Comments
 (0)