Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import json | |
import io | |
import os | |
import uuid | |
from PIL import Image | |
import base64 | |
from prance import ResolvingParser | |
SCHEMA_URL = "http://localhost:5000/openapi.json" | |
FILENAME = "openapi.json" | |
schema_response = requests.get(SCHEMA_URL) | |
openapi_spec = schema_response | |
r = requests.get(SCHEMA_URL) | |
print(r.content) | |
with open(FILENAME, "wb") as f: | |
f.write(r.content) | |
parser = ResolvingParser(FILENAME) | |
api_spec = parser.specification | |
print(parser.specification) | |
def extract_property_info(prop): | |
combined_prop = {} | |
merge_keywords = ["allOf", "anyOf", "oneOf"] | |
for keyword in merge_keywords: | |
if keyword in prop: | |
for subprop in prop[keyword]: | |
combined_prop.update(subprop) | |
del prop[keyword] | |
if not combined_prop: | |
combined_prop = prop.copy() | |
for key in ['description', 'default']: | |
if key in prop: | |
combined_prop[key] = prop[key] | |
return combined_prop | |
def sort_properties_by_order(properties): | |
ordered_properties = sorted(properties.items(), key=lambda x: x[1].get('x-order', float('inf'))) | |
return ordered_properties | |
def parse_outputs(data): | |
values = [] | |
if isinstance(data, dict): | |
# Handle case where data is an object | |
dict_values = [] | |
for value in data.values(): | |
extracted_values = parse_outputs(value) | |
# For dict, we append instead of extend to maintain list structure within objects | |
if isinstance(value, list): | |
dict_values += [extracted_values] | |
else: | |
dict_values += extracted_values | |
return dict_values | |
elif isinstance(data, list): | |
# Handle case where data is an array | |
list_values = [] | |
for item in data: | |
# Here we extend to flatten the list since we're already in an array context | |
list_values += parse_outputs(item) | |
return list_values | |
else: | |
# Handle primitive data types directly | |
return [data] | |
def create_gradio_app(api_spec, api_url): | |
inputs = [] | |
outputs = [] | |
input_schema = api_spec["components"]["schemas"]["Input"]["properties"] | |
output_schema = api_spec["components"]["schemas"]["Output"] | |
ordered_input_schema = sort_properties_by_order(input_schema) | |
names = [] | |
for name, prop in ordered_input_schema: | |
prop = extract_property_info(prop) | |
if "enum" in prop: | |
input_field = gr.Dropdown( | |
choices=prop["enum"], label=prop.get("title"), info=prop.get("description"), value=prop.get("default") | |
) | |
elif prop["type"] == "integer": | |
if prop.get("minimum") and prop.get("maximum"): | |
input_field = gr.Slider( | |
label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), | |
minimum=prop.get("minimum"), maximum=prop.get("maximum"), step=1, | |
) | |
else: | |
input_field = gr.Number(label=prop.get("title"), info=prop.get("description"), value=prop.get("default")) | |
elif prop["type"] == "number": | |
if prop.get("minimum") and prop.get("maximum"): | |
input_field = gr.Slider( | |
label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), | |
minimum=prop.get("minimum"), maximum=prop.get("maximum"), | |
) | |
else: | |
input_field = gr.Number(label=prop.get("title"), info=prop.get("description"), value=prop.get("default")) | |
elif prop["type"] == "boolean": | |
input_field = gr.Checkbox(label=prop.get("title"), info=prop.get("description"), value=prop.get("default")) | |
elif prop["type"] == "string" and prop.get("format") == "uri": | |
input_field = gr.File(label=prop.get("title")) | |
else: | |
input_field = gr.Textbox(label=prop.get("title"), info=prop.get("description")) | |
inputs.append(input_field) | |
names.append(name) | |
print(names) | |
data_field = gr.State(value=names) | |
inputs.append(data_field) | |
print(output_schema) | |
outputs.append(gr.Image(label=output_schema["title"], visible=True)) | |
outputs.append(gr.Audio(label=output_schema["title"], visible=False)) | |
outputs.append(gr.Textbox(label=output_schema["title"], visible=False)) | |
outputs.append(data_field) | |
#else if there's multiple outputs | |
def predict(*args): | |
print(args) | |
keys = args[-1] | |
payload = {"input": {}} | |
for i, key in enumerate(keys): | |
value = args[i] | |
if value and (os.path.exists(str(value))): | |
value = "http://localhost:7860/file=" + value | |
payload["input"][key] = value | |
print(payload) | |
response = requests.post(api_url, headers={"Content-Type": "application/json"}, json=payload) | |
print(response) | |
if response.status_code == 200: | |
json_response = response.json() | |
print(json_response) | |
if "status" in json_response and json_response["status"] == "failed": | |
raise gr.Error("Failed to generate output") | |
print(json_response["output"]) | |
outputs = parse_outputs(json_response["output"]) | |
print(outputs) | |
for output in outputs: | |
if not output: | |
continue | |
if output.startswith("data:image"): | |
# Process as image | |
base64_data = output.split(",", 1)[1] | |
image_data = base64.b64decode(base64_data) | |
image_stream = io.BytesIO(image_data) | |
image = Image.open(image_stream) | |
return gr.update(visible=True, value=image), gr.update(visible=False), gr.update(visible=False), keys | |
elif output.startswith("data:audio"): | |
base64_data = output.split(",", 1)[1] | |
audio_data = base64.b64decode(base64_data) | |
audio_stream = io.BytesIO(audio_data) | |
# Here you can save the audio or return the stream for further processing | |
filename = f"{uuid.uuid4()}.wav" # Change format as needed | |
with open(filename, "wb") as audio_file: | |
audio_file.write(audio_stream.getbuffer()) | |
return gr.update(visible=False), gr.update(visible=True, value=filename), gr.update(visible=False), keys | |
else: | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, value=output), keys | |
else: | |
raise gr.Error("The submission failed!") | |
return gr.Interface(fn=predict, inputs=inputs, outputs=outputs) | |
API_URL = "http://localhost:5000/predictions" | |
app = create_gradio_app(api_spec, API_URL) | |
app.launch(share=True) |