Add some error handling for VRAM monitor

This commit is contained in:
EyeDeck 2022-09-18 05:20:33 -04:00
parent 7e77938230
commit fabaf4bddb
2 changed files with 29 additions and 17 deletions

View file

@ -22,6 +22,13 @@ class MemUsageMonitor(threading.Thread):
self.run_flag = threading.Event() self.run_flag = threading.Event()
self.data = defaultdict(int) self.data = defaultdict(int)
try:
torch.cuda.mem_get_info()
torch.cuda.memory_stats(self.device)
except Exception as e: # AMD or whatever
print(f"Warning: caught exception '{e}', memory monitor disabled")
self.disabled = True
def run(self): def run(self):
if self.disabled: if self.disabled:
return return
@ -62,13 +69,14 @@ class MemUsageMonitor(threading.Thread):
self.run_flag.set() self.run_flag.set()
def read(self): def read(self):
free, total = torch.cuda.mem_get_info() if not self.disabled:
self.data["total"] = total free, total = torch.cuda.mem_get_info()
self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device) torch_stats = torch.cuda.memory_stats(self.device)
self.data["active_peak"] = torch_stats["active_bytes.all.peak"] self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"] self.data["system_peak"] = total - self.data["min_free"]
return self.data return self.data

View file

@ -119,7 +119,8 @@ def save_files(js_data, images, index):
def wrap_gradio_call(func): def wrap_gradio_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
shared.mem_mon.monitor() if opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled:
shared.mem_mon.monitor()
t = time.perf_counter() t = time.perf_counter()
try: try:
@ -136,17 +137,20 @@ def wrap_gradio_call(func):
elapsed = time.perf_counter() - t elapsed = time.perf_counter() - t
mem_stats = {k: -(v//-(1024*1024)) for k,v in shared.mem_mon.stop().items()} if opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled:
active_peak = mem_stats['active_peak'] mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
reserved_peak = mem_stats['reserved_peak'] active_peak = mem_stats['active_peak']
sys_peak = '?' if opts.memmon_poll_rate <= 0 else mem_stats['system_peak'] reserved_peak = mem_stats['reserved_peak']
sys_total = mem_stats['total'] sys_peak = mem_stats['system_peak']
sys_pct = '?' if opts.memmon_poll_rate <= 0 else round(sys_peak/sys_total * 100, 2) sys_total = mem_stats['total']
vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.&#013;" \ sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
"Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.&#013;" \ vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.&#013;" \
"Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)." "Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.&#013;" \
"Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)."
vram_html = '' if opts.memmon_poll_rate == 0 else f"<p class='vram' title='{vram_tooltip}'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>" vram_html = f"<p class='vram' title='{vram_tooltip}'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
else:
vram_html = ''
# last item is always HTML # last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>" res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"