hysts
commited on
Commit
·
4bd7dce
1
Parent(s):
f689e96
Allow changing LoRA scaling alpha
Browse files- app_inference.py +6 -0
- inference.py +8 -4
app_inference.py
CHANGED
|
@@ -99,6 +99,11 @@ def create_inference_demo(pipe: InferencePipeline,
|
|
| 99 |
max_lines=1,
|
| 100 |
placeholder='Example: "A picture of a sks dog in a bucket"'
|
| 101 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
seed = gr.Slider(label='Seed',
|
| 103 |
minimum=0,
|
| 104 |
maximum=100000,
|
|
@@ -149,6 +154,7 @@ def create_inference_demo(pipe: InferencePipeline,
|
|
| 149 |
inputs = [
|
| 150 |
lora_model_id,
|
| 151 |
prompt,
|
|
|
|
| 152 |
seed,
|
| 153 |
num_steps,
|
| 154 |
guidance_scale,
|
|
|
|
| 99 |
max_lines=1,
|
| 100 |
placeholder='Example: "A picture of a sks dog in a bucket"'
|
| 101 |
)
|
| 102 |
+
alpha = gr.Slider(label='LoRA alpha',
|
| 103 |
+
minimum=0,
|
| 104 |
+
maximum=2,
|
| 105 |
+
step=0.05,
|
| 106 |
+
value=1)
|
| 107 |
seed = gr.Slider(label='Seed',
|
| 108 |
minimum=0,
|
| 109 |
maximum=100000,
|
|
|
|
| 154 |
inputs = [
|
| 155 |
lora_model_id,
|
| 156 |
prompt,
|
| 157 |
+
alpha,
|
| 158 |
seed,
|
| 159 |
num_steps,
|
| 160 |
guidance_scale,
|
inference.py
CHANGED
|
@@ -73,6 +73,7 @@ class InferencePipeline:
|
|
| 73 |
self,
|
| 74 |
lora_model_id: str,
|
| 75 |
prompt: str,
|
|
|
|
| 76 |
seed: int,
|
| 77 |
n_steps: int,
|
| 78 |
guidance_scale: float,
|
|
@@ -83,8 +84,11 @@ class InferencePipeline:
|
|
| 83 |
self.load_pipe(lora_model_id)
|
| 84 |
|
| 85 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 86 |
-
out = self.pipe(
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
| 90 |
return out.images[0]
|
|
|
|
| 73 |
self,
|
| 74 |
lora_model_id: str,
|
| 75 |
prompt: str,
|
| 76 |
+
lora_scale: float,
|
| 77 |
seed: int,
|
| 78 |
n_steps: int,
|
| 79 |
guidance_scale: float,
|
|
|
|
| 84 |
self.load_pipe(lora_model_id)
|
| 85 |
|
| 86 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 87 |
+
out = self.pipe(
|
| 88 |
+
prompt,
|
| 89 |
+
num_inference_steps=n_steps,
|
| 90 |
+
guidance_scale=guidance_scale,
|
| 91 |
+
generator=generator,
|
| 92 |
+
cross_attention_kwargs={'scale': lora_scale},
|
| 93 |
+
) # type: ignore
|
| 94 |
return out.images[0]
|