From 7001bffe0247804793dfabb69ac96d832572ccd0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 15:43:25 +0300 Subject: [PATCH] fix AND broken for long prompts --- modules/prompt_parser.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f00256f2..15666073 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -239,6 +239,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch) + # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes + # and won't be able to torch.stack them. So this fixes that. + token_count = max([x.shape[0] for x in tensors]) + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)