Make SwinIR interruptible
This commit is contained in:
parent
5c1cb9263f
commit
f993525820
1 changed files with 7 additions and 1 deletions
|
@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader, devices, script_callbacks, shared
|
from modules import modelloader, devices, script_callbacks, shared
|
||||||
from modules.shared import cmd_opts, opts
|
from modules.shared import cmd_opts, opts, state
|
||||||
from swinir_model_arch import SwinIR as net
|
from swinir_model_arch import SwinIR as net
|
||||||
from swinir_model_arch_v2 import Swin2SR as net2
|
from swinir_model_arch_v2 import Swin2SR as net2
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||||
|
|
||||||
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
||||||
for h_idx in h_idx_list:
|
for h_idx in h_idx_list:
|
||||||
|
if state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
for w_idx in w_idx_list:
|
for w_idx in w_idx_list:
|
||||||
|
if state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
out_patch = model(in_patch)
|
out_patch = model(in_patch)
|
||||||
out_patch_mask = torch.ones_like(out_patch)
|
out_patch_mask = torch.ones_like(out_patch)
|
||||||
|
|
Loading…
Reference in a new issue