multimodalart HF Staff commited on
Commit
a8cf54a
·
verified ·
1 Parent(s): 2ca7450

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -58
app.py CHANGED
@@ -1,70 +1,110 @@
 
1
  import requests
2
- import base64
3
  import io
4
  from PIL import Image
5
- import gradio as gr
6
- import json
 
7
 
8
- # API and Schema URLs
9
- API_URL = "http://localhost:5000/predictions"
10
  SCHEMA_URL = "http://localhost:5000/openapi.json"
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def fetch_api_spec(url):
13
- response = requests.get(url)
14
- return response.json()
 
 
 
 
 
15
 
16
- def create_gradio_app_from_api_spec(api_spec):
17
- input_properties = api_spec['components']['schemas']['Input']['properties']
18
  inputs = []
19
- for prop, details in input_properties.items():
20
- if 'enum' in details:
21
- choices = details['enum']
22
- inputs.append(gr.Dropdown(choices=choices, label=prop, value=details.get('default')))
23
- elif details['type'] == 'integer':
24
- inputs.append(gr.Number(label=prop, value=details.get('default'), minimum=details.get('minimum'), maximum=details.get('maximum')))
25
- elif details['type'] == 'number':
26
- inputs.append(gr.Slider(minimum=details.get('minimum'), maximum=details.get('maximum'), value=details.get('default'), label=prop))
27
- elif details['type'] == 'string' and 'format' in details and details['format'] == 'uri':
28
- inputs.append(gr.Image(label=prop))
29
- elif details['type'] == 'string':
30
- inputs.append(gr.Textbox(label=prop, value=details.get('default')))
31
- elif details['type'] == 'boolean':
32
- inputs.append(gr.Checkbox(label=prop, value=details.get('default')))
33
-
34
- def predict_function(**kwargs):
35
- # Adjust the input kwargs for image inputs to convert them to the expected format by the API if needed
36
- payload = {
37
- "input": kwargs
38
- }
39
- print(payload)
40
- response = requests.post(API_URL, headers={"Content-Type": "application/json"}, json=payload)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  json_response = response.json()
42
 
43
- if 'status' in json_response and json_response["status"] == "failed":
44
  raise gr.Error("Failed to generate image")
45
 
46
- output_spec = api_spec['components']['schemas']['Output']
47
- if output_spec['items']['type'] == 'string' and output_spec['items']['format'] == 'uri':
48
- outputs = []
49
- for uri in json_response["output"]:
50
- if uri.startswith("data:image"):
51
- base64_image = uri.split(",")[1] # Strip the prefix part
52
- image_data = base64.b64decode(base64_image)
53
- image_stream = io.BytesIO(image_data)
54
- image = Image.open(image_stream)
55
- outputs.append(image)
56
- else:
57
- outputs.append(uri)
58
- return outputs
59
- else:
60
- return json_response["output"]
61
-
62
- iface = gr.Interface(fn=predict_function, inputs=inputs, outputs=gr.outputs.Image(type="pil"), title=api_spec['info']['title'])
63
- return iface
64
-
65
- # Fetch API Specification
66
- api_spec = fetch_api_spec(SCHEMA_URL)
67
-
68
- # Create and Launch Gradio App
69
- gradio_app = create_gradio_app_from_api_spec(api_spec)
70
- gradio_app.launch()
 
1
+ import gradio as gr
2
  import requests
3
+ import json
4
  import io
5
  from PIL import Image
6
+ import base64
7
+ from prance import ResolvingParser
8
+
9
 
 
 
10
  SCHEMA_URL = "http://localhost:5000/openapi.json"
11
+ FILENAME = "openapi.json"
12
+ schema_response = requests.get(SCHEMA_URL)
13
+ openapi_spec = schema_response
14
+
15
+ r = requests.get(SCHEMA_URL)
16
+ print(r.content)
17
+ with open(FILENAME, "wb") as f:
18
+ f.write(r.content)
19
+
20
+ parser = ResolvingParser(FILENAME)
21
+ print(parser.specification)
22
+
23
 
24
+ def extract_property_info(prop):
25
+ # Handle 'allOf' by merging all contained properties (assuming simple case of enum merging)
26
+ if "allOf" in prop:
27
+ combined_prop = {}
28
+ for subprop in prop["allOf"]:
29
+ combined_prop.update(subprop)
30
+ prop = combined_prop
31
+ return prop
32
 
33
+
34
+ def create_gradio_app(api_spec, api_url):
35
  inputs = []
36
+ input_schema = api_spec["components"]["schemas"]["PredictionRequest"]["properties"][
37
+ "input"
38
+ ]["properties"]
39
+
40
+ for name, prop in input_schema.items():
41
+ prop = extract_property_info(
42
+ prop
43
+ ) # Extract property info correctly for 'allOf'
44
+ print(prop)
45
+ if "enum" in prop:
46
+ input_field = gr.Dropdown(
47
+ choices=prop["enum"], label=name, value=prop.get("default")
48
+ )
49
+ elif prop["type"] == "integer":
50
+ input_field = gr.Number(
51
+ label=name,
52
+ value=prop.get("default"),
53
+ minimum=prop.get("minimum"),
54
+ maximum=prop.get("maximum"),
55
+ step=1,
56
+ )
57
+ elif prop["type"] == "number":
58
+ input_field = gr.Number(
59
+ label=name,
60
+ value=prop.get("default"),
61
+ minimum=prop.get("minimum"),
62
+ maximum=prop.get("maximum"),
63
+ )
64
+ elif prop["type"] == "boolean":
65
+ input_field = gr.Checkbox(label=name, value=prop.get("default"))
66
+ elif prop["type"] == "string" and prop.get("format") == "uri":
67
+ input_field = gr.File(label=name)
68
+ else: # Assuming string type for simplicity, can add more types as needed
69
+ input_field = gr.Textbox(label=name, value=prop.get("default"))
70
+ inputs.append(input_field)
71
+
72
+ def predict(**kwargs):
73
+ payload = {"input": {}}
74
+ for key, value in kwargs.items():
75
+ if isinstance(
76
+ value, io.BytesIO
77
+ ): # For image inputs, convert to the desired format
78
+ value.seek(0)
79
+ value = (
80
+ "data:image/jpeg;base64," + base64.b64encode(value.read()).decode()
81
+ )
82
+ payload["input"][key] = value
83
+
84
+ response = requests.post(
85
+ api_url, headers={"Content-Type": "application/json"}, json=payload
86
+ )
87
  json_response = response.json()
88
 
89
+ if "status" in json_response and json_response["status"] == "failed":
90
  raise gr.Error("Failed to generate image")
91
 
92
+ output_images = []
93
+ for output_uri in json_response["output"]:
94
+ base64_image = output_uri.replace("data:image/png;base64,", "")
95
+ image_data = base64.b64decode(base64_image)
96
+ image_stream = io.BytesIO(image_data)
97
+ output_images.append(Image.open(image_stream))
98
+
99
+ return output_images
100
+
101
+ output_component = gr.Gallery(label="Output Images")
102
+ return gr.Interface(fn=predict, inputs=inputs, outputs=output_component)
103
+
104
+
105
+ # Use the modified function with the API URL
106
+ api_spec = parser.specification
107
+
108
+ API_URL = "http://localhost:5000/predictions"
109
+ app = create_gradio_app(api_spec, API_URL)
110
+ app.launch()