diff --git a/modules/devices.py b/modules/devices.py index 67165bf6..93d82bbc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -44,8 +44,18 @@ def get_optimal_device(): def torch_gc(): if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + from modules import shared + + device_id = shared.cmd_opts.device_id + + if device_id is not None: + cuda_device = f"cuda:{device_id}" + else: + cuda_device = "cuda" + + with torch.cuda.device(cuda_device): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def enable_tf32():