Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import json | |
| import io | |
| import os | |
| import uuid | |
| from PIL import Image | |
| import base64 | |
| from prance import ResolvingParser | |
| SCHEMA_URL = "http://localhost:5000/openapi.json" | |
| FILENAME = "openapi.json" | |
| schema_response = requests.get(SCHEMA_URL) | |
| openapi_spec = schema_response | |
| r = requests.get(SCHEMA_URL) | |
| print(r.content) | |
| with open(FILENAME, "wb") as f: | |
| f.write(r.content) | |
| parser = ResolvingParser(FILENAME) | |
| api_spec = parser.specification | |
| print(parser.specification) | |
| def extract_property_info(prop): | |
| combined_prop = {} | |
| merge_keywords = ["allOf", "anyOf", "oneOf"] | |
| for keyword in merge_keywords: | |
| if keyword in prop: | |
| for subprop in prop[keyword]: | |
| combined_prop.update(subprop) | |
| del prop[keyword] | |
| if not combined_prop: | |
| combined_prop = prop.copy() | |
| for key in ['description', 'default']: | |
| if key in prop: | |
| combined_prop[key] = prop[key] | |
| return combined_prop | |
| def sort_properties_by_order(properties): | |
| ordered_properties = sorted(properties.items(), key=lambda x: x[1].get('x-order', float('inf'))) | |
| return ordered_properties | |
| def parse_outputs(data): | |
| values = [] | |
| if isinstance(data, dict): | |
| # Handle case where data is an object | |
| dict_values = [] | |
| for value in data.values(): | |
| extracted_values = parse_outputs(value) | |
| # For dict, we append instead of extend to maintain list structure within objects | |
| if isinstance(value, list): | |
| dict_values += [extracted_values] | |
| else: | |
| dict_values += extracted_values | |
| return dict_values | |
| elif isinstance(data, list): | |
| # Handle case where data is an array | |
| list_values = [] | |
| for item in data: | |
| # Here we extend to flatten the list since we're already in an array context | |
| list_values += parse_outputs(item) | |
| return list_values | |
| else: | |
| # Handle primitive data types directly | |
| return [data] | |
| def create_gradio_app(api_spec, api_url): | |
| inputs = [] | |
| outputs = [] | |
| input_schema = api_spec["components"]["schemas"]["Input"]["properties"] | |
| output_schema = api_spec["components"]["schemas"]["Output"] | |
| ordered_input_schema = sort_properties_by_order(input_schema) | |
| names = [] | |
| for name, prop in ordered_input_schema: | |
| prop = extract_property_info(prop) | |
| if "enum" in prop: | |
| input_field = gr.Dropdown( | |
| choices=prop["enum"], label=prop.get("title"), info=prop.get("description"), value=prop.get("default") | |
| ) | |
| elif prop["type"] == "integer": | |
| if prop.get("minimum") and prop.get("maximum"): | |
| input_field = gr.Slider( | |
| label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), | |
| minimum=prop.get("minimum"), maximum=prop.get("maximum"), step=1, | |
| ) | |
| else: | |
| input_field = gr.Number(label=prop.get("title"), info=prop.get("description"), value=prop.get("default")) | |
| elif prop["type"] == "number": | |
| if prop.get("minimum") and prop.get("maximum"): | |
| input_field = gr.Slider( | |
| label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), | |
| minimum=prop.get("minimum"), maximum=prop.get("maximum"), | |
| ) | |
| else: | |
| input_field = gr.Number(label=prop.get("title"), info=prop.get("description"), value=prop.get("default")) | |
| elif prop["type"] == "boolean": | |
| input_field = gr.Checkbox(label=prop.get("title"), info=prop.get("description"), value=prop.get("default")) | |
| elif prop["type"] == "string" and prop.get("format") == "uri": | |
| input_field = gr.File(label=prop.get("title")) | |
| else: | |
| input_field = gr.Textbox(label=prop.get("title"), info=prop.get("description")) | |
| inputs.append(input_field) | |
| names.append(name) | |
| print(names) | |
| data_field = gr.State(value=names) | |
| inputs.append(data_field) | |
| print(output_schema) | |
| outputs.append(gr.Image(label=output_schema["title"], visible=True)) | |
| outputs.append(gr.Audio(label=output_schema["title"], visible=False)) | |
| outputs.append(gr.Textbox(label=output_schema["title"], visible=False)) | |
| outputs.append(data_field) | |
| #else if there's multiple outputs | |
| def predict(*args): | |
| print(args) | |
| keys = args[-1] | |
| payload = {"input": {}} | |
| for i, key in enumerate(keys): | |
| value = args[i] | |
| if value and (os.path.exists(str(value))): | |
| value = "http://localhost:7860/file=" + value | |
| payload["input"][key] = value | |
| print(payload) | |
| response = requests.post(api_url, headers={"Content-Type": "application/json"}, json=payload) | |
| print(response) | |
| if response.status_code == 200: | |
| json_response = response.json() | |
| print(json_response) | |
| if "status" in json_response and json_response["status"] == "failed": | |
| raise gr.Error("Failed to generate output") | |
| outputs = parse_outputs(json_response["output"]) | |
| for output in outputs: | |
| if output.startswith("data:image"): | |
| # Process as image | |
| base64_data = output_uri.split(",", 1)[1] | |
| image_data = base64.b64decode(base64_data) | |
| image_stream = io.BytesIO(image_data) | |
| image = Image.open(image_stream) | |
| return gr.update(visible=True, value=image), gr.update(visible=False), gr.update(visible=False), keys | |
| elif output.startswith("data:audio"): | |
| base64_data = output_uri.split(",", 1)[1] | |
| audio_data = base64.b64decode(base64_data) | |
| audio_stream = io.BytesIO(audio_data) | |
| # Here you can save the audio or return the stream for further processing | |
| filename = f"{uuid.uuid4()}.wav" # Change format as needed | |
| with open(filename, "wb") as audio_file: | |
| audio_file.write(audio_stream.getbuffer()) | |
| return gr.update(visible=False), gr.update(visible=True, value=filename), gr.update(visible=False), keys | |
| else: | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, value=output) | |
| else: | |
| raise gr.Error("The submission failed!") | |
| return gr.Interface(fn=predict, inputs=inputs, outputs=outputs) | |
| API_URL = "http://localhost:5000/predictions" | |
| app = create_gradio_app(api_spec, API_URL) | |
| app.launch(share=True) |