Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import json | |
import io | |
import os | |
from PIL import Image | |
import base64 | |
from prance import ResolvingParser | |
SCHEMA_URL = "http://localhost:5000/openapi.json" | |
FILENAME = "openapi.json" | |
schema_response = requests.get(SCHEMA_URL) | |
openapi_spec = schema_response | |
r = requests.get(SCHEMA_URL) | |
print(r.content) | |
with open(FILENAME, "wb") as f: | |
f.write(r.content) | |
parser = ResolvingParser(FILENAME) | |
api_spec = parser.specification | |
print(parser.specification) | |
def extract_property_info(prop): | |
combined_prop = {} | |
merge_keywords = ["allOf", "anyOf", "oneOf"] | |
for keyword in merge_keywords: | |
if keyword in prop: | |
for subprop in prop[keyword]: | |
combined_prop.update(subprop) | |
del prop[keyword] | |
if not combined_prop: | |
combined_prop = prop.copy() | |
for key in ['description', 'default']: | |
if key in prop: | |
combined_prop[key] = prop[key] | |
return combined_prop | |
def sort_properties_by_order(properties): | |
ordered_properties = sorted(properties.items(), key=lambda x: x[1].get('x-order', float('inf'))) | |
return ordered_properties | |
def create_gradio_app(api_spec, api_url): | |
inputs = [] | |
outputs = [] | |
input_schema = api_spec["components"]["schemas"]["Input"]["properties"] | |
output_schema = api_spec["components"]["schemas"]["Output"] | |
ordered_input_schema = sort_properties_by_order(input_schema) | |
names = [] | |
for name, prop in ordered_input_schema: | |
prop = extract_property_info(prop) | |
if "enum" in prop: | |
input_field = gr.Dropdown( | |
choices=prop["enum"], label=prop.get("title"), info=prop.get("description"), value=prop.get("default") | |
) | |
elif prop["type"] == "integer": | |
input_field = gr.Slider( | |
label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), | |
minimum=prop.get("minimum"), maximum=prop.get("maximum"), step=1, | |
) | |
elif prop["type"] == "number": | |
input_field = gr.Slider( | |
label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), | |
minimum=prop.get("minimum"), maximum=prop.get("maximum"), | |
) | |
elif prop["type"] == "boolean": | |
input_field = gr.Checkbox(label=prop.get("title"), info=prop.get("description"), value=prop.get("default")) | |
elif prop["type"] == "string" and prop.get("format") == "uri": | |
input_field = gr.File(label=prop.get("title")) | |
else: | |
input_field = gr.Textbox(label=prop.get("title"), info=prop.get("description")) | |
inputs.append(input_field) | |
names.append(name) | |
print(names) | |
data_field = gr.State(value=names) | |
inputs.append(data_field) | |
print(output_schema) | |
if output_schema["type"] == "string": | |
if "format" in output_schema: | |
if(output_schema["format"] == "uri"): | |
output_component = gr.Image(label=output_schema["title"]) | |
else: | |
output_component = gr.Textbox(label="Output") | |
else: | |
output_component = gr.Textbox(label="Output") | |
outputs.append(output_component) | |
elif output_schema["type"] == "array": | |
if "format" in output_schema["items"]: | |
if(output_schema["items"]["format"] == "uri"): | |
output_component = gr.Image(label=output_schema["title"]) | |
else: | |
output_component = gr.Textbox(label=output_schema["title"]) | |
else: | |
output_component = gr.Textbox(label=output_schema["title"]) | |
outputs.append(output_component) | |
outputs.append(data_field) | |
#else if there's multiple outputs | |
def predict(*args): | |
print(args) | |
keys = args[-1] | |
payload = {"input": {}} | |
for i, key in enumerate(keys): | |
value = args[i] | |
if value and (os.path.exists(str(value))): | |
value = "http://localhost:7860/file=" + value | |
payload["input"][key] = value | |
print(payload) | |
response = requests.post(api_url, headers={"Content-Type": "application/json"}, json=payload) | |
print(response) | |
if response.status_code == 200: | |
json_response = response.json() | |
print(json_response) | |
if "status" in json_response and json_response["status"] == "failed": | |
raise gr.Error("Failed to generate image") | |
output_images = [] | |
for output_uri in json_response["output"]: | |
base64_image = output_uri.replace("data:image/png;base64,", "") | |
image_data = base64.b64decode(base64_image) | |
image_stream = io.BytesIO(image_data) | |
output_images.append(Image.open(image_stream)) | |
return output_images[0], keys | |
else: | |
raise gr.Error("The submission failed!") | |
return gr.Interface(fn=predict, inputs=inputs, outputs=outputs if outputs else "textbox") | |
API_URL = "http://localhost:5000/predictions" | |
app = create_gradio_app(api_spec, API_URL) | |
app.launch(share=True) |