fix for unet hijack breaking the train tab
This commit is contained in:
parent
789d47f832
commit
15e89ef0f6
1 changed files with 5 additions and 2 deletions
|
@ -36,8 +36,11 @@ th = TorchHijackForUnet()
|
||||||
|
|
||||||
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||||
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||||
for y in cond.keys():
|
|
||||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
if isinstance(cond, dict):
|
||||||
|
for y in cond.keys():
|
||||||
|
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue