gradio-on-cog / app.py
multimodalart's picture
Update app.py
3d0b330 verified
import gradio as gr
import requests
import json
import io
import os
import uuid
from PIL import Image
import base64
from prance import ResolvingParser
SCHEMA_URL = "http://localhost:5000/openapi.json"
FILENAME = "openapi.json"
schema_response = requests.get(SCHEMA_URL)
openapi_spec = schema_response
r = requests.get(SCHEMA_URL)
print(r.content)
with open(FILENAME, "wb") as f:
f.write(r.content)
parser = ResolvingParser(FILENAME)
api_spec = parser.specification
print(parser.specification)
def extract_property_info(prop):
combined_prop = {}
merge_keywords = ["allOf", "anyOf", "oneOf"]
for keyword in merge_keywords:
if keyword in prop:
for subprop in prop[keyword]:
combined_prop.update(subprop)
del prop[keyword]
if not combined_prop:
combined_prop = prop.copy()
for key in ['description', 'default']:
if key in prop:
combined_prop[key] = prop[key]
return combined_prop
def sort_properties_by_order(properties):
ordered_properties = sorted(properties.items(), key=lambda x: x[1].get('x-order', float('inf')))
return ordered_properties
def parse_outputs(data):
values = []
if isinstance(data, dict):
# Handle case where data is an object
dict_values = []
for value in data.values():
extracted_values = parse_outputs(value)
# For dict, we append instead of extend to maintain list structure within objects
if isinstance(value, list):
dict_values += [extracted_values]
else:
dict_values += extracted_values
return dict_values
elif isinstance(data, list):
# Handle case where data is an array
list_values = []
for item in data:
# Here we extend to flatten the list since we're already in an array context
list_values += parse_outputs(item)
return list_values
else:
# Handle primitive data types directly
return [data]
def create_gradio_app(api_spec, api_url):
inputs = []
outputs = []
input_schema = api_spec["components"]["schemas"]["Input"]["properties"]
output_schema = api_spec["components"]["schemas"]["Output"]
ordered_input_schema = sort_properties_by_order(input_schema)
names = []
for name, prop in ordered_input_schema:
prop = extract_property_info(prop)
if "enum" in prop:
input_field = gr.Dropdown(
choices=prop["enum"], label=prop.get("title"), info=prop.get("description"), value=prop.get("default")
)
elif prop["type"] == "integer":
if prop.get("minimum") and prop.get("maximum"):
input_field = gr.Slider(
label=prop.get("title"), info=prop.get("description"), value=prop.get("default"),
minimum=prop.get("minimum"), maximum=prop.get("maximum"), step=1,
)
else:
input_field = gr.Number(label=prop.get("title"), info=prop.get("description"), value=prop.get("default"))
elif prop["type"] == "number":
if prop.get("minimum") and prop.get("maximum"):
input_field = gr.Slider(
label=prop.get("title"), info=prop.get("description"), value=prop.get("default"),
minimum=prop.get("minimum"), maximum=prop.get("maximum"),
)
else:
input_field = gr.Number(label=prop.get("title"), info=prop.get("description"), value=prop.get("default"))
elif prop["type"] == "boolean":
input_field = gr.Checkbox(label=prop.get("title"), info=prop.get("description"), value=prop.get("default"))
elif prop["type"] == "string" and prop.get("format") == "uri":
input_field = gr.File(label=prop.get("title"))
else:
input_field = gr.Textbox(label=prop.get("title"), info=prop.get("description"))
inputs.append(input_field)
names.append(name)
print(names)
data_field = gr.State(value=names)
inputs.append(data_field)
print(output_schema)
outputs.append(gr.Image(label=output_schema["title"], visible=True))
outputs.append(gr.Audio(label=output_schema["title"], visible=False))
outputs.append(gr.Textbox(label=output_schema["title"], visible=False))
outputs.append(data_field)
#else if there's multiple outputs
def predict(*args):
print(args)
keys = args[-1]
payload = {"input": {}}
for i, key in enumerate(keys):
value = args[i]
if value and (os.path.exists(str(value))):
value = "http://localhost:7860/file=" + value
payload["input"][key] = value
print(payload)
response = requests.post(api_url, headers={"Content-Type": "application/json"}, json=payload)
print(response)
if response.status_code == 200:
json_response = response.json()
print(json_response)
if "status" in json_response and json_response["status"] == "failed":
raise gr.Error("Failed to generate output")
print(json_response["output"])
outputs = parse_outputs(json_response["output"])
print(outputs)
for output in outputs:
if not output:
continue
if output.startswith("data:image"):
# Process as image
base64_data = output.split(",", 1)[1]
image_data = base64.b64decode(base64_data)
image_stream = io.BytesIO(image_data)
image = Image.open(image_stream)
return gr.update(visible=True, value=image), gr.update(visible=False), gr.update(visible=False), keys
elif output.startswith("data:audio"):
base64_data = output.split(",", 1)[1]
audio_data = base64.b64decode(base64_data)
audio_stream = io.BytesIO(audio_data)
# Here you can save the audio or return the stream for further processing
filename = f"{uuid.uuid4()}.wav" # Change format as needed
with open(filename, "wb") as audio_file:
audio_file.write(audio_stream.getbuffer())
return gr.update(visible=False), gr.update(visible=True, value=filename), gr.update(visible=False), keys
else:
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, value=output), keys
else:
raise gr.Error("The submission failed!")
return gr.Interface(fn=predict, inputs=inputs, outputs=outputs)
API_URL = "http://localhost:5000/predictions"
app = create_gradio_app(api_spec, API_URL)
app.launch(share=True)