add hypernetwork multipliers
This commit is contained in:
parent
a10b0e11fc
commit
354ef0da3b
6 changed files with 27 additions and 5 deletions
|
@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
|
multiplier = 1.0
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None):
|
def __init__(self, dim, state_dict=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
self.to(devices.device)
|
self.to(devices.device)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x + (self.linear2(self.linear1(x)))
|
return x + (self.linear2(self.linear1(x))) * self.multiplier
|
||||||
|
|
||||||
|
|
||||||
|
def apply_strength(value=None):
|
||||||
|
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||||
|
|
||||||
|
|
||||||
class Hypernetwork:
|
class Hypernetwork:
|
||||||
|
|
|
@ -238,7 +238,8 @@ options_templates.update(options_section(('training', "Training"), {
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||||
|
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||||
|
@ -348,6 +349,8 @@ class Options:
|
||||||
item = self.data_labels.get(key)
|
item = self.data_labels.get(key)
|
||||||
item.onchange = func
|
item.onchange = func
|
||||||
|
|
||||||
|
func()
|
||||||
|
|
||||||
def dumpjson(self):
|
def dumpjson(self):
|
||||||
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
|
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
|
||||||
return json.dumps(d)
|
return json.dumps(d)
|
||||||
|
|
|
@ -1244,7 +1244,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
def refresh():
|
def refresh():
|
||||||
info.refresh()
|
info.refresh()
|
||||||
refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
|
refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
|
||||||
res.choices = refreshed_args["choices"]
|
|
||||||
|
for k, v in refreshed_args.items():
|
||||||
|
setattr(res, k, v)
|
||||||
|
|
||||||
return gr.update(**(refreshed_args or {}))
|
return gr.update(**(refreshed_args or {}))
|
||||||
|
|
||||||
refresh_button.click(
|
refresh_button.click(
|
||||||
|
|
|
@ -107,6 +107,10 @@ def apply_hypernetwork(p, x, xs):
|
||||||
hypernetwork.load_hypernetwork(name)
|
hypernetwork.load_hypernetwork(name)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_hypernetwork_strength(p, x, xs):
|
||||||
|
hypernetwork.apply_strength(x)
|
||||||
|
|
||||||
|
|
||||||
def confirm_hypernetworks(p, xs):
|
def confirm_hypernetworks(p, xs):
|
||||||
for x in xs:
|
for x in xs:
|
||||||
if x.lower() in ["", "none"]:
|
if x.lower() in ["", "none"]:
|
||||||
|
@ -165,6 +169,7 @@ axis_options = [
|
||||||
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
|
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
|
||||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
|
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
|
||||||
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
|
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
|
||||||
|
AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
|
||||||
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
|
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
|
||||||
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
|
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
|
||||||
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
|
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
|
||||||
|
@ -250,7 +255,7 @@ class Script(scripts.Script):
|
||||||
y_values = gr.Textbox(label="Y values", visible=False, lines=1)
|
y_values = gr.Textbox(label="Y values", visible=False, lines=1)
|
||||||
|
|
||||||
draw_legend = gr.Checkbox(label='Draw legend', value=True)
|
draw_legend = gr.Checkbox(label='Draw legend', value=True)
|
||||||
include_lone_images = gr.Checkbox(label='Include Separate Images', value=True)
|
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
|
||||||
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False)
|
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False)
|
||||||
|
|
||||||
return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
|
return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
|
||||||
|
@ -377,6 +382,8 @@ class Script(scripts.Script):
|
||||||
modules.sd_models.reload_model_weights(shared.sd_model)
|
modules.sd_models.reload_model_weights(shared.sd_model)
|
||||||
|
|
||||||
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
|
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
|
||||||
|
hypernetwork.apply_strength()
|
||||||
|
|
||||||
|
|
||||||
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
|
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
|
||||||
|
|
||||||
|
|
|
@ -522,6 +522,9 @@ canvas[key="mask"] {
|
||||||
z-index: 200;
|
z-index: 200;
|
||||||
width: 8em;
|
width: 8em;
|
||||||
}
|
}
|
||||||
|
#quicksettings .gr-box > div > div > input.gr-text-input {
|
||||||
|
top: -1.12em;
|
||||||
|
}
|
||||||
|
|
||||||
.row.gr-compact{
|
.row.gr-compact{
|
||||||
overflow: visible;
|
overflow: visible;
|
||||||
|
|
2
webui.py
2
webui.py
|
@ -72,7 +72,6 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||||
|
|
||||||
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
||||||
|
|
||||||
|
|
||||||
def initialize():
|
def initialize():
|
||||||
modelloader.cleanup_models()
|
modelloader.cleanup_models()
|
||||||
modules.sd_models.setup_model()
|
modules.sd_models.setup_model()
|
||||||
|
@ -86,6 +85,7 @@ def initialize():
|
||||||
shared.sd_model = modules.sd_models.load_model()
|
shared.sd_model = modules.sd_models.load_model()
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
||||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
||||||
|
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
||||||
|
|
||||||
|
|
||||||
def webui():
|
def webui():
|
||||||
|
|
Loading…
Reference in a new issue