matjarm commited on
Commit
9d46b1d
·
1 Parent(s): b90a4c8
Files changed (1) hide show
  1. app.py +95 -195
app.py CHANGED
@@ -1,207 +1,107 @@
 
1
  import os
2
- import random
3
- import uuid
4
  import gradio as gr
5
- import numpy as np
 
 
6
  from PIL import Image
7
  import torch
8
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
9
- from typing import Tuple
10
-
11
- # CSS for Gradio Interface
12
- css = '''
13
- .gradio-container{max-width: 575px !important}
14
- h1{text-align:center}
15
- footer {
16
- visibility: hidden
17
- }
18
- '''
19
-
20
- DESCRIPTION = """
21
- ## Text-to-Image Generator 🚀
22
- Create stunning images from text prompts using Stable Diffusion XL. Explore high-quality styles and customizable options.
23
- """
24
-
25
- # Example Prompts
26
- examples = [
27
- "A beautiful sunset over the ocean, ultra-realistic, high resolution",
28
- "A futuristic cityscape with flying cars, cyberpunk theme, vibrant colors",
29
- "A cozy cabin in the woods during winter, detailed and realistic",
30
- "A magical forest with glowing plants and creatures, fantasy art",
31
- ]
32
-
33
- # Model Configurations
34
- MODEL_OPTIONS = {
35
- "LIGHTNING V5.0": "SG161222/RealVisXL_V5.0_Lightning",
36
- "LIGHTNING V4.0": "SG161222/RealVisXL_V4.0_Lightning",
37
- }
38
-
39
- # Define Styles
40
- style_list = [
41
- {
42
- "name": "Ultra HD",
43
- "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
44
- "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
45
- },
46
- {
47
- "name": "4K Realistic",
48
- "prompt": "realistic 4K image of {prompt}. sharp, detailed, vibrant colors, photorealistic",
49
- "negative_prompt": "cartoonish, blurry, low resolution",
50
- },
51
- {
52
- "name": "Minimal Style",
53
- "prompt": "{prompt}, clean, minimalistic",
54
- "negative_prompt": "",
55
- },
56
- ]
57
-
58
- styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
59
- DEFAULT_STYLE_NAME = "Ultra HD"
60
-
61
- # Define Global Variables
62
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
63
- MAX_IMAGE_SIZE = 4096
64
- MAX_SEED = np.iinfo(np.int32).max
65
-
66
- # Load Model Function
67
- def load_and_prepare_model(model_id):
68
- pipe = StableDiffusionXLPipeline.from_pretrained(
69
- model_id,
70
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
71
- ).to(device)
72
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
73
- return pipe
74
 
75
  # Load Models
76
- models = {key: load_and_prepare_model(value) for key, value in MODEL_OPTIONS.items()}
77
-
78
- # Generate Function
79
- def generate_image(
80
- model_choice: str,
81
- prompt: str,
82
- negative_prompt: str,
83
- style_name: str,
84
- width: int,
85
- height: int,
86
- guidance_scale: float,
87
- num_steps: int,
88
- num_images: int,
89
- randomize_seed: bool,
90
- seed: int,
91
- ):
92
- # Apply Style
93
- positive_style, negative_style = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
94
- styled_prompt = positive_style.replace("{prompt}", prompt)
95
- styled_negative_prompt = negative_style + (negative_prompt if negative_prompt else "")
96
-
97
- # Randomize Seed if Enabled
98
- if randomize_seed:
99
- seed = random.randint(0, MAX_SEED)
100
- generator = torch.Generator(device=device).manual_seed(seed)
101
-
102
- # Generate Images
103
- pipe = models[model_choice]
104
- images = pipe(
105
- prompt=[styled_prompt] * num_images,
106
- negative_prompt=[styled_negative_prompt] * num_images,
107
- width=width,
108
- height=height,
109
- guidance_scale=guidance_scale,
110
- num_inference_steps=num_steps,
111
- generator=generator,
112
- output_type="pil",
113
- ).images
114
-
115
- # Save and Return Images
116
- image_paths = []
117
- for img in images:
118
- unique_name = f"{uuid.uuid4()}.png"
119
- img.save(unique_name)
120
- image_paths.append(unique_name)
121
-
122
- return image_paths, seed
123
 
124
- # Gradio Interface
125
- with gr.Blocks(css=css) as demo:
126
- gr.Markdown(DESCRIPTION)
127
-
128
- with gr.Row():
129
- model_choice = gr.Dropdown(
130
- label="Select Model",
131
- choices=list(MODEL_OPTIONS.keys()),
132
- value="LIGHTNING V5.0",
133
- )
134
-
135
- prompt = gr.Textbox(
136
- label="Prompt",
137
- placeholder="Enter your creative prompt here...",
138
- )
139
-
140
- negative_prompt = gr.Textbox(
141
- label="Negative Prompt",
142
- placeholder="Optional: Add details you want to avoid...",
143
- value="blurry, deformed, low-quality, cartoonish",
144
- )
145
-
146
- style_name = gr.Radio(
147
- label="Style",
148
- choices=list(styles.keys()),
149
- value=DEFAULT_STYLE_NAME,
150
- )
151
 
152
- with gr.Accordion("Advanced Options", open=False):
153
- width = gr.Slider(label="Width", minimum=512, maximum=2048, step=8, value=1024)
154
- height = gr.Slider(label="Height", minimum=512, maximum=2048, step=8, value=1024)
155
- guidance_scale = gr.Slider(
156
- label="Guidance Scale",
157
- minimum=1,
158
- maximum=20,
159
- step=0.5,
160
- value=7.5,
161
- )
162
- num_steps = gr.Slider(
163
- label="Steps",
164
- minimum=1,
165
- maximum=50,
166
- step=1,
167
- value=25,
168
- )
169
- num_images = gr.Slider(
170
- label="Number of Images",
171
- minimum=1,
172
- maximum=5,
173
- step=1,
174
- value=1,
175
- )
176
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
177
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
178
-
179
- with gr.Row():
180
- run_button = gr.Button("Generate Images")
181
- result_gallery = gr.Gallery(label="Generated Images", show_label=False)
182
-
183
- run_button.click(
184
- generate_image,
185
- inputs=[
186
- model_choice,
187
- prompt,
188
- negative_prompt,
189
- style_name,
190
- width,
191
- height,
192
- guidance_scale,
193
- num_steps,
194
- num_images,
195
- randomize_seed,
196
- seed,
197
- ],
198
- outputs=[result_gallery, seed],
199
- )
200
 
201
- gr.Examples(
202
- examples=examples,
203
- inputs=prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  )
205
 
206
  if __name__ == "__main__":
207
- demo.queue(max_size=50).launch()
 
1
+ import cv2
2
  import os
 
 
3
  import gradio as gr
4
+ import requests
5
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
6
+ from transformers import BlipProcessor, BlipForConditionalGeneration
7
  from PIL import Image
8
  import torch
9
+ import uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Load Models
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Model 1: ViT-GPT2
15
+ model1 = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
16
+ feature_extractor1 = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
17
+ tokenizer1 = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Model 2: FuseCap
20
+ processor2 = BlipProcessor.from_pretrained("noamrot/FuseCap")
21
+ model2 = BlipForConditionalGeneration.from_pretrained("noamrot/FuseCap").to(device)
22
+
23
+ # Model 3: BLIP Large
24
+ processor3 = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
25
+ model3 = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
26
+
27
+
28
+ # Frame Extraction and Captioning Logic
29
+ def process_video(video_path):
30
+ vidObj = cv2.VideoCapture(video_path)
31
+ count = 0
32
+ success = True
33
+ frame_captions = {"Model 1": [], "Model 2": [], "Model 3": []}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ while success:
36
+ success, frame = vidObj.read()
37
+
38
+ if not success:
39
+ break
40
+
41
+ # Process every 20th frame
42
+ if count % 20 == 0:
43
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
44
+
45
+ # Model 1: ViT-GPT2
46
+ pixel_values = feature_extractor1(images=[image], return_tensors="pt").pixel_values.to(device)
47
+ output_ids = model1.generate(pixel_values, max_length=16, num_beams=4)
48
+ caption1 = tokenizer1.decode(output_ids[0], skip_special_tokens=True)
49
+ frame_captions["Model 1"].append(caption1)
50
+
51
+ # Model 2: FuseCap
52
+ inputs = processor2(image, "a picture of ", return_tensors="pt").to(device)
53
+ out2 = model2.generate(**inputs, num_beams=3)
54
+ caption2 = processor2.decode(out2[0], skip_special_tokens=True)
55
+ frame_captions["Model 2"].append(caption2)
56
+
57
+ # Model 3: BLIP Large
58
+ inputs3 = processor3(image, return_tensors="pt").to(device)
59
+ out3 = model3.generate(**inputs3)
60
+ caption3 = processor3.decode(out3[0], skip_special_tokens=True)
61
+ frame_captions["Model 3"].append(caption3)
62
+
63
+ count += 1
64
+
65
+ vidObj.release()
66
+ return frame_captions
67
+
68
+
69
+ # Gradio Interface
70
+ def generate_captions(video):
71
+ # Save uploaded video
72
+ video_path = f"temp_{uuid.uuid4()}.mp4"
73
+ with open(video_path, "wb") as f:
74
+ f.write(video.read())
75
+
76
+ # Process video and get captions
77
+ captions = process_video(video_path)
78
+
79
+ # Clean up temporary file
80
+ os.remove(video_path)
81
+
82
+ # Format output for display
83
+ result = ""
84
+ for model_name, model_captions in captions.items():
85
+ result += f"### {model_name}\n"
86
+ result += "\n".join(f"- {caption}" for caption in model_captions)
87
+ result += "\n\n"
88
+
89
+ return result
90
+
91
+
92
+ # Gradio UI
93
+ with gr.Blocks() as demo:
94
+ gr.Markdown("# Video Captioning with Multiple Models 🎥")
95
+ gr.Markdown("Upload a video to generate captions for its frames using three different models.")
96
+ video_input = gr.Video(label="Upload Video")
97
+ output = gr.Textbox(label="Generated Captions", lines=20)
98
+ submit_button = gr.Button("Generate Captions")
99
+
100
+ submit_button.click(
101
+ fn=generate_captions,
102
+ inputs=video_input,
103
+ outputs=output,
104
  )
105
 
106
  if __name__ == "__main__":
107
+ demo.launch()