hysts
commited on
Commit
·
ef06562
1
Parent(s):
f47c26b
Fix
Browse files- app_inference.py +23 -6
app_inference.py
CHANGED
|
@@ -61,6 +61,13 @@ class InferenceUtil:
|
|
| 61 |
instance_prompt = getattr(card.data, 'instance_prompt', '')
|
| 62 |
return base_model, instance_prompt
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def create_inference_demo(pipe: InferencePipeline,
|
| 66 |
hf_token: str | None = None) -> gr.Blocks:
|
|
@@ -117,12 +124,22 @@ def create_inference_demo(pipe: InferencePipeline,
|
|
| 117 |
with gr.Column():
|
| 118 |
result = gr.Image(label='Result')
|
| 119 |
|
| 120 |
-
model_source.change(
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
lora_model_id.change(fn=app.load_model_info,
|
| 127 |
inputs=lora_model_id,
|
| 128 |
outputs=[
|
|
|
|
| 61 |
instance_prompt = getattr(card.data, 'instance_prompt', '')
|
| 62 |
return base_model, instance_prompt
|
| 63 |
|
| 64 |
+
def reload_lora_model_list_and_update_model_info(
|
| 65 |
+
self, model_source: str) -> tuple[dict, str, str]:
|
| 66 |
+
model_list_update = self.reload_lora_model_list(model_source)
|
| 67 |
+
model_list = model_list_update['choices']
|
| 68 |
+
model_info = self.load_model_info(model_list[0] if model_list else '')
|
| 69 |
+
return model_list_update, *model_info
|
| 70 |
+
|
| 71 |
|
| 72 |
def create_inference_demo(pipe: InferencePipeline,
|
| 73 |
hf_token: str | None = None) -> gr.Blocks:
|
|
|
|
| 124 |
with gr.Column():
|
| 125 |
result = gr.Image(label='Result')
|
| 126 |
|
| 127 |
+
model_source.change(
|
| 128 |
+
fn=app.reload_lora_model_list_and_update_model_info,
|
| 129 |
+
inputs=model_source,
|
| 130 |
+
outputs=[
|
| 131 |
+
lora_model_id,
|
| 132 |
+
base_model_used_for_training,
|
| 133 |
+
instance_prompt_used_for_training,
|
| 134 |
+
])
|
| 135 |
+
reload_button.click(
|
| 136 |
+
fn=app.reload_lora_model_list_and_update_model_info,
|
| 137 |
+
inputs=model_source,
|
| 138 |
+
outputs=[
|
| 139 |
+
lora_model_id,
|
| 140 |
+
base_model_used_for_training,
|
| 141 |
+
instance_prompt_used_for_training,
|
| 142 |
+
])
|
| 143 |
lora_model_id.change(fn=app.load_model_info,
|
| 144 |
inputs=lora_model_id,
|
| 145 |
outputs=[
|