ginipick commited on
Commit
a9938b5
·
verified ·
1 Parent(s): 3293514

Create app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +270 -0
app-backup.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import gradio as gr
4
+ import numpy as np
5
+ import spaces
6
+ from diffusers import DiffusionPipeline
7
+ from PIL import Image
8
+
9
+ # --- [Optional Patch] ---------------------------------------------------------
10
+ # This patch fixes potential JSON schema parsing issues in Gradio/Gradio-Client.
11
+ import gradio_client.utils
12
+ original_json_schema = gradio_client.utils._json_schema_to_python_type
13
+
14
+ def patched_json_schema(schema, defs=None):
15
+ # Handle boolean schema directly
16
+ if isinstance(schema, bool):
17
+ return "bool"
18
+
19
+ # If 'additionalProperties' is a boolean, replace it with a generic type
20
+ try:
21
+ if "additionalProperties" in schema and isinstance(schema["additionalProperties"], bool):
22
+ schema["additionalProperties"] = {"type": "any"}
23
+ except (TypeError, KeyError):
24
+ pass
25
+
26
+ # Attempt to parse normally; fallback to "any" on error
27
+ try:
28
+ return original_json_schema(schema, defs)
29
+ except Exception:
30
+ return "any"
31
+
32
+ gradio_client.utils._json_schema_to_python_type = patched_json_schema
33
+ # -----------------------------------------------------------------------------
34
+
35
+ # ----------------------------- Model Loading ----------------------------------
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ repo_id = "black-forest-labs/FLUX.1-dev"
38
+ adapter_id = "openfree/flux-chatgpt-ghibli-lora"
39
+
40
+ def load_model_with_retry(max_retries=5):
41
+ for attempt in range(max_retries):
42
+ try:
43
+ print(f"Loading model attempt {attempt+1}/{max_retries}...")
44
+ pipeline = DiffusionPipeline.from_pretrained(
45
+ repo_id,
46
+ torch_dtype=torch.bfloat16,
47
+ use_safetensors=True,
48
+ resume_download=True
49
+ )
50
+ print("Base model loaded successfully, now loading LoRA weights...")
51
+ pipeline.load_lora_weights(adapter_id)
52
+ pipeline = pipeline.to(device)
53
+ print("Pipeline is ready!")
54
+ return pipeline
55
+ except Exception as e:
56
+ if attempt < max_retries - 1:
57
+ wait_time = 10 * (attempt + 1)
58
+ print(f"Error loading model: {e}. Retrying in {wait_time} seconds...")
59
+ import time
60
+ time.sleep(wait_time)
61
+ else:
62
+ raise Exception(f"Failed to load model after {max_retries} attempts: {e}")
63
+
64
+ pipeline = load_model_with_retry()
65
+
66
+ # ----------------------------- Inference Function -----------------------------
67
+ MAX_SEED = np.iinfo(np.int32).max
68
+ MAX_IMAGE_SIZE = 1024
69
+
70
+ @spaces.GPU(duration=120)
71
+ def inference(
72
+ prompt: str,
73
+ seed: int,
74
+ randomize_seed: bool,
75
+ width: int,
76
+ height: int,
77
+ guidance_scale: float,
78
+ num_inference_steps: int,
79
+ lora_scale: float,
80
+ ):
81
+ # If "randomize_seed" is selected, choose a random seed
82
+ if randomize_seed:
83
+ seed = random.randint(0, MAX_SEED)
84
+ generator = torch.Generator(device=device).manual_seed(seed)
85
+
86
+ try:
87
+ image = pipeline(
88
+ prompt=prompt,
89
+ guidance_scale=guidance_scale,
90
+ num_inference_steps=num_inference_steps,
91
+ width=width,
92
+ height=height,
93
+ generator=generator,
94
+ joint_attention_kwargs={"scale": lora_scale},
95
+ ).images[0]
96
+ return image, seed
97
+ except Exception as e:
98
+ print(f"Error during inference: {e}")
99
+ # Return a red error image of the specified size and the used seed
100
+ error_img = Image.new('RGB', (width, height), color='red')
101
+ return error_img, seed
102
+
103
+ # ----------------------------- Florence-2 Captioner ---------------------------
104
+ import subprocess
105
+ subprocess.run(
106
+ 'pip install flash-attn --no-build-isolation',
107
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
108
+ shell=True
109
+ )
110
+
111
+ from transformers import AutoProcessor, AutoModelForCausalLM
112
+
113
+ # Pre-load models and processors
114
+ models = {
115
+ 'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained(
116
+ 'gokaygokay/Florence-2-Flux-Large', trust_remote_code=True
117
+ ).eval(),
118
+ 'gokaygokay/Florence-2-Flux': AutoModelForCausalLM.from_pretrained(
119
+ 'gokaygokay/Florence-2-Flux', trust_remote_code=True
120
+ ).eval(),
121
+ }
122
+
123
+ processors = {
124
+ 'gokaygokay/Florence-2-Flux-Large': AutoProcessor.from_pretrained(
125
+ 'gokaygokay/Florence-2-Flux-Large', trust_remote_code=True
126
+ ),
127
+ 'gokaygokay/Florence-2-Flux': AutoProcessor.from_pretrained(
128
+ 'gokaygokay/Florence-2-Flux', trust_remote_code=True
129
+ ),
130
+ }
131
+
132
+ @spaces.GPU
133
+ def caption_image(image, model_name='gokaygokay/Florence-2-Flux-Large'):
134
+ """
135
+ Runs the selected Florence-2 model to generate a detailed caption.
136
+ """
137
+ from PIL import Image as PILImage
138
+
139
+ task_prompt = "<DESCRIPTION>"
140
+ user_prompt = task_prompt + "Describe this image in great detail."
141
+
142
+ # Convert input to RGB if needed
143
+ image = PILImage.fromarray(image)
144
+ if image.mode != "RGB":
145
+ image = image.convert("RGB")
146
+
147
+ model = models[model_name]
148
+ processor = processors[model_name]
149
+
150
+ inputs = processor(text=user_prompt, images=image, return_tensors="pt")
151
+ generated_ids = model.generate(
152
+ input_ids=inputs["input_ids"],
153
+ pixel_values=inputs["pixel_values"],
154
+ max_new_tokens=1024,
155
+ num_beams=3,
156
+ repetition_penalty=1.10,
157
+ )
158
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
159
+ parsed_answer = processor.post_process_generation(
160
+ generated_text, task=task_prompt, image_size=(image.width, image.height)
161
+ )
162
+ return parsed_answer["<DESCRIPTION>"]
163
+
164
+ # ----------------------------- Gradio UI --------------------------------------
165
+ with gr.Blocks(analytics_enabled=False) as demo:
166
+ with gr.Tabs():
167
+ # ------------------ TAB 1: Image Generation ----------------------------
168
+ with gr.TabItem("FLUX Ghibli LoRA Generator"):
169
+ gr.Markdown("## Generate an image with the FLUX Ghibli LoRA")
170
+
171
+ with gr.Row():
172
+ with gr.Column():
173
+ prompt = gr.Textbox(
174
+ label="Prompt",
175
+ placeholder="Describe your Ghibli-style image...",
176
+ lines=3
177
+ )
178
+ with gr.Row():
179
+ seed = gr.Slider(
180
+ label="Seed",
181
+ minimum=0,
182
+ maximum=MAX_SEED,
183
+ step=1,
184
+ value=42
185
+ )
186
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
187
+ with gr.Row():
188
+ width = gr.Slider(
189
+ label="Width",
190
+ minimum=256,
191
+ maximum=MAX_IMAGE_SIZE,
192
+ step=32,
193
+ value=512
194
+ )
195
+ height = gr.Slider(
196
+ label="Height",
197
+ minimum=256,
198
+ maximum=MAX_IMAGE_SIZE,
199
+ step=32,
200
+ value=512
201
+ )
202
+ with gr.Row():
203
+ guidance_scale = gr.Slider(
204
+ label="Guidance scale",
205
+ minimum=0.0,
206
+ maximum=10.0,
207
+ step=0.1,
208
+ value=3.5
209
+ )
210
+ num_inference_steps = gr.Slider(
211
+ label="Steps",
212
+ minimum=1,
213
+ maximum=50,
214
+ step=1,
215
+ value=30
216
+ )
217
+ lora_scale = gr.Slider(
218
+ label="LoRA scale",
219
+ minimum=0.0,
220
+ maximum=1.0,
221
+ step=0.1,
222
+ value=1.0
223
+ )
224
+ generate_button = gr.Button("Generate Image")
225
+
226
+ with gr.Column():
227
+ output_image = gr.Image(label="Generated Image")
228
+ output_seed = gr.Number(label="Seed Used")
229
+
230
+ # Link the button to the inference function
231
+ generate_button.click(
232
+ inference,
233
+ inputs=[
234
+ prompt,
235
+ seed,
236
+ randomize_seed,
237
+ width,
238
+ height,
239
+ guidance_scale,
240
+ num_inference_steps,
241
+ lora_scale,
242
+ ],
243
+ outputs=[output_image, output_seed]
244
+ )
245
+
246
+ # ------------------ TAB 2: Image Captioning ---------------------------
247
+ with gr.TabItem("Florence-2 Captioner"):
248
+ gr.Markdown("## Generate a caption for an uploaded image using Florence-2")
249
+
250
+ with gr.Row():
251
+ with gr.Column():
252
+ input_img = gr.Image(label="Upload an Image")
253
+ model_selector = gr.Dropdown(
254
+ choices=list(models.keys()),
255
+ value='gokaygokay/Florence-2-Flux-Large',
256
+ label="Select Model"
257
+ )
258
+ caption_button = gr.Button("Generate Caption")
259
+ with gr.Column():
260
+ caption_output = gr.Textbox(label="Caption")
261
+
262
+ caption_button.click(
263
+ caption_image,
264
+ inputs=[input_img, model_selector],
265
+ outputs=[caption_output]
266
+ )
267
+
268
+ # Just remove or simplify the queue call if needed:
269
+
270
+ demo.launch(debug=True)