multimodalart HF Staff commited on
Commit
4791078
·
verified ·
1 Parent(s): 5d63c5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -205
app.py CHANGED
@@ -1,214 +1,69 @@
1
-
2
- import gradio as gr
3
- from PIL import Image
4
- from gradio_imageslider import ImageSlider
5
-
6
  import requests
7
  import base64
8
- import numpy as np
9
- import random
10
  import io
 
 
 
11
 
12
- URL = "http://localhost:5000/predictions"
 
13
  SCHEMA_URL = "http://localhost:5000/openapi.json"
14
 
15
- def get_schema():
16
- response = requests.get(SCHEMA_URL)
17
- json_response = response.json()
18
- print("The schema")
19
- print(json_response)
20
-
21
- get_schema()
22
-
23
- HEADERS = {
24
- "Content-Type": "application/json",
25
- }
26
-
27
- MAX_SEED = np.iinfo(np.int32).max
28
-
29
-
30
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
31
- if randomize_seed:
32
- seed = random.randint(0, MAX_SEED)
33
- return seed
34
-
35
-
36
- def generate(
37
- prompt: str,
38
- input_image: Image,
39
- negative_prompt: str = "",
40
- seed: int = 0,
41
- width: int = 1024,
42
- height: int = 1024,
43
- prior_num_inference_steps: int = 30,
44
- # prior_timesteps: List[float] = None,
45
- prior_guidance_scale: float = 4.0,
46
- decoder_num_inference_steps: int = 12,
47
- # decoder_timesteps: List[float] = None,
48
- decoder_guidance_scale: float = 0.0,
49
- num_images_per_prompt: int = 2,
50
-
51
- ) -> Image:
52
-
53
- payload = {
54
- "input": {
55
- "hdr": 0,
56
- "image": "http://localhost:7860/file=" + input_image,
57
- "steps": 20,
58
- "prompt": prompt,
59
- "scheduler": "DDIM",
60
- "creativity": 0.25,
61
- "guess_mode": False,
62
- "resolution": "original",
63
- "resemblance": 0.75,
64
- "guidance_scale": 7,
65
- "negative_prompt": negative_prompt,
66
  }
67
- }
68
- print("\n\n PAYLOAD", payload)
69
- response = requests.post(URL, headers=HEADERS, json=payload)
70
- json_response = response.json()
71
- if 'status' in json_response:
72
- if json_response["status"] == "failed":
73
- raise gr.Error("Failed to generate image")
74
- base64_image = json_response["output"][0]
75
- image_data = base64.b64decode(
76
- base64_image.replace("data:image/png;base64,", ""))
77
- image_stream = io.BytesIO(image_data)
78
- return [Image.open(input_image), Image.open(image_stream)]
79
- else:
80
- raise gr.Error("Failed to generate image")
81
-
82
-
83
- examples = [
84
- ["An astronaut riding a green horse", "examples/image2.png"],
85
- ["A mecha robot in a favela by Tarsila do Amaral", "examples/image2.png"],
86
- ["The sprirt of a Tamagotchi wandering in the city of Los Angeles",
87
- "examples/image1.png"],
88
- ["A delicious feijoada ramen dish", "examples/image0.png"],
89
- ]
90
-
91
- with gr.Blocks() as demo:
92
- with gr.Row():
93
- with gr.Column():
94
- input_image = gr.Image(type="filepath")
95
- with gr.Group():
96
- with gr.Row():
97
- prompt = gr.Text(
98
- label="Prompt",
99
- show_label=False,
100
- max_lines=1,
101
- placeholder="Enter your prompt",
102
- container=False,
103
- )
104
- run_button = gr.Button("Run", scale=0)
105
- with gr.Column():
106
- result = ImageSlider(label="Result", type="pil")
107
- with gr.Accordion("Advanced options", open=False):
108
- negative_prompt = gr.Text(
109
- label="Negative prompt",
110
- max_lines=1,
111
- placeholder="Enter a Negative Prompt",
112
- )
113
-
114
- seed = gr.Slider(
115
- label="Seed",
116
- minimum=0,
117
- maximum=MAX_SEED,
118
- step=1,
119
- value=0,
120
- )
121
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
122
- with gr.Row():
123
- width = gr.Slider(
124
- label="Width",
125
- minimum=1024,
126
- maximum=1024,
127
- step=512,
128
- value=1024,
129
- )
130
- height = gr.Slider(
131
- label="Height",
132
- minimum=1024,
133
- maximum=1024,
134
- step=512,
135
- value=1024,
136
- )
137
- num_images_per_prompt = gr.Slider(
138
- label="Number of Images",
139
- minimum=1,
140
- maximum=2,
141
- step=1,
142
- value=1,
143
- )
144
- with gr.Row():
145
- prior_guidance_scale = gr.Slider(
146
- label="Prior Guidance Scale",
147
- minimum=0,
148
- maximum=20,
149
- step=0.1,
150
- value=4.0,
151
- )
152
- prior_num_inference_steps = gr.Slider(
153
- label="Prior Inference Steps",
154
- minimum=10,
155
- maximum=30,
156
- step=1,
157
- value=20,
158
- )
159
-
160
- decoder_guidance_scale = gr.Slider(
161
- label="Decoder Guidance Scale",
162
- minimum=0,
163
- maximum=0,
164
- step=0.1,
165
- value=0.0,
166
- )
167
- decoder_num_inference_steps = gr.Slider(
168
- label="Decoder Inference Steps",
169
- minimum=4,
170
- maximum=12,
171
- step=1,
172
- value=10,
173
- )
174
-
175
- gr.Examples(
176
- examples=examples,
177
- inputs=[prompt, input_image],
178
- outputs=result,
179
- fn=generate,
180
- cache_examples=False,
181
- )
182
-
183
- inputs = [
184
- prompt,
185
- input_image,
186
- negative_prompt,
187
- seed,
188
- width,
189
- height,
190
- prior_num_inference_steps,
191
- # prior_timesteps,
192
- prior_guidance_scale,
193
- decoder_num_inference_steps,
194
- # decoder_timesteps,
195
- decoder_guidance_scale,
196
- num_images_per_prompt,
197
- ]
198
- gr.on(
199
- triggers=[prompt.submit, negative_prompt.submit, run_button.click],
200
- fn=randomize_seed_fn,
201
- inputs=[seed, randomize_seed],
202
- outputs=seed,
203
- queue=False,
204
- api_name=False,
205
- ).then(
206
- fn=generate,
207
- inputs=inputs,
208
- outputs=result,
209
- api_name="run",
210
- )
211
 
 
 
212
 
213
- if __name__ == "__main__":
214
- demo.queue(max_size=20).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
  import base64
 
 
3
  import io
4
+ from PIL import Image
5
+ import gradio as gr
6
+ import json
7
 
8
+ # API and Schema URLs
9
+ API_URL = "http://localhost:5000/predictions"
10
  SCHEMA_URL = "http://localhost:5000/openapi.json"
11
 
12
+ def fetch_api_spec(url):
13
+ response = requests.get(url)
14
+ return response.json()
15
+
16
+ def create_gradio_app_from_api_spec(api_spec):
17
+ input_properties = api_spec['components']['schemas']['Input']['properties']
18
+ inputs = []
19
+ for prop, details in input_properties.items():
20
+ if 'enum' in details:
21
+ choices = details['enum']
22
+ inputs.append(gr.inputs.Dropdown(choices=choices, label=prop))
23
+ elif details['type'] == 'integer':
24
+ inputs.append(gr.inputs.Number(label=prop, default=details.get('default'), minimum=details.get('minimum'), maximum=details.get('maximum')))
25
+ elif details['type'] == 'number':
26
+ inputs.append(gr.inputs.Slider(minimum=details.get('minimum'), maximum=details.get('maximum'), default=details.get('default'), label=prop))
27
+ elif details['type'] == 'string' and 'format' in details and details['format'] == 'uri':
28
+ inputs.append(gr.inputs.Image(label=prop))
29
+ elif details['type'] == 'string':
30
+ inputs.append(gr.inputs.Textbox(label=prop, default=details.get('default')))
31
+ elif details['type'] == 'boolean':
32
+ inputs.append(gr.inputs.Checkbox(label=prop, default=details.get('default')))
33
+
34
+ def predict_function(**kwargs):
35
+ # Adjust the input kwargs for image inputs to convert them to the expected format by the API if needed
36
+ payload = {
37
+ "input": kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  }
39
+ response = requests.post(API_URL, headers={"Content-Type": "application/json"}, json=payload)
40
+ json_response = response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ if 'status' in json_response and json_response["status"] == "failed":
43
+ raise gr.Error("Failed to generate image")
44
 
45
+ output_spec = api_spec['components']['schemas']['Output']
46
+ if output_spec['items']['type'] == 'string' and output_spec['items']['format'] == 'uri':
47
+ outputs = []
48
+ for uri in json_response["output"]:
49
+ if uri.startswith("data:image"):
50
+ base64_image = uri.split(",")[1] # Strip the prefix part
51
+ image_data = base64.b64decode(base64_image)
52
+ image_stream = io.BytesIO(image_data)
53
+ image = Image.open(image_stream)
54
+ outputs.append(image)
55
+ else:
56
+ outputs.append(uri)
57
+ return outputs
58
+ else:
59
+ return json_response["output"]
60
+
61
+ iface = gr.Interface(fn=predict_function, inputs=inputs, outputs=gr.outputs.Image(type="pil"), title=api_spec['info']['title'])
62
+ return iface
63
+
64
+ # Fetch API Specification
65
+ api_spec = fetch_api_spec(SCHEMA_URL)
66
+
67
+ # Create and Launch Gradio App
68
+ gradio_app = create_gradio_app_from_api_spec(api_spec)
69
+ gradio_app.launch()