add hypernetwork multipliers

This commit is contained in:
AUTOMATIC 2022-10-13 20:12:37 +03:00
parent a10b0e11fc
commit 354ef0da3b
6 changed files with 27 additions and 5 deletions

View file

@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
def __init__(self, dim, state_dict=None): def __init__(self, dim, state_dict=None):
super().__init__() super().__init__()
@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module):
self.to(devices.device) self.to(devices.device)
def forward(self, x): def forward(self, x):
return x + (self.linear2(self.linear1(x))) return x + (self.linear2(self.linear1(x))) * self.multiplier
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
class Hypernetwork: class Hypernetwork:

View file

@ -238,7 +238,8 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
@ -348,6 +349,8 @@ class Options:
item = self.data_labels.get(key) item = self.data_labels.get(key)
item.onchange = func item.onchange = func
func()
def dumpjson(self): def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d) return json.dumps(d)

View file

@ -1244,7 +1244,10 @@ def create_ui(wrap_gradio_gpu_call):
def refresh(): def refresh():
info.refresh() info.refresh()
refreshed_args = info.component_args() if callable(info.component_args) else info.component_args refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
res.choices = refreshed_args["choices"]
for k, v in refreshed_args.items():
setattr(res, k, v)
return gr.update(**(refreshed_args or {})) return gr.update(**(refreshed_args or {}))
refresh_button.click( refresh_button.click(

View file

@ -107,6 +107,10 @@ def apply_hypernetwork(p, x, xs):
hypernetwork.load_hypernetwork(name) hypernetwork.load_hypernetwork(name)
def apply_hypernetwork_strength(p, x, xs):
hypernetwork.apply_strength(x)
def confirm_hypernetworks(p, xs): def confirm_hypernetworks(p, xs):
for x in xs: for x in xs:
if x.lower() in ["", "none"]: if x.lower() in ["", "none"]:
@ -165,6 +169,7 @@ axis_options = [
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers), AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints), AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks), AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
@ -250,7 +255,7 @@ class Script(scripts.Script):
y_values = gr.Textbox(label="Y values", visible=False, lines=1) y_values = gr.Textbox(label="Y values", visible=False, lines=1)
draw_legend = gr.Checkbox(label='Draw legend', value=True) draw_legend = gr.Checkbox(label='Draw legend', value=True)
include_lone_images = gr.Checkbox(label='Include Separate Images', value=True) include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False)
return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
@ -377,6 +382,8 @@ class Script(scripts.Script):
modules.sd_models.reload_model_weights(shared.sd_model) modules.sd_models.reload_model_weights(shared.sd_model)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork) hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers

View file

@ -522,6 +522,9 @@ canvas[key="mask"] {
z-index: 200; z-index: 200;
width: 8em; width: 8em;
} }
#quicksettings .gr-box > div > div > input.gr-text-input {
top: -1.12em;
}
.row.gr-compact{ .row.gr-compact{
overflow: visible; overflow: visible;

View file

@ -72,7 +72,6 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
def initialize(): def initialize():
modelloader.cleanup_models() modelloader.cleanup_models()
modules.sd_models.setup_model() modules.sd_models.setup_model()
@ -86,6 +85,7 @@ def initialize():
shared.sd_model = modules.sd_models.load_model() shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
def webui(): def webui():