Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import json | |
import io | |
from PIL import Image | |
import base64 | |
from prance import ResolvingParser | |
SCHEMA_URL = "http://localhost:5000/openapi.json" | |
FILENAME = "openapi.json" | |
schema_response = requests.get(SCHEMA_URL) | |
openapi_spec = schema_response | |
r = requests.get(SCHEMA_URL) | |
print(r.content) | |
with open(FILENAME, "wb") as f: | |
f.write(r.content) | |
parser = ResolvingParser(FILENAME) | |
print(parser.specification) | |
def extract_property_info(prop): | |
# Handle 'allOf' by merging all contained properties (assuming simple case of enum merging) | |
if "allOf" in prop: | |
combined_prop = {} | |
for subprop in prop["allOf"]: | |
combined_prop.update(subprop) | |
prop = combined_prop | |
return prop | |
def create_gradio_app(api_spec, api_url): | |
inputs = [] | |
input_schema = api_spec["components"]["schemas"]["PredictionRequest"]["properties"][ | |
"input" | |
]["properties"] | |
for name, prop in input_schema.items(): | |
prop = extract_property_info( | |
prop | |
) # Extract property info correctly for 'allOf' | |
print(prop) | |
if "enum" in prop: | |
input_field = gr.Dropdown( | |
choices=prop["enum"], label=name, value=prop.get("default") | |
) | |
elif prop["type"] == "integer": | |
input_field = gr.Number( | |
label=name, | |
value=prop.get("default"), | |
minimum=prop.get("minimum"), | |
maximum=prop.get("maximum"), | |
step=1, | |
) | |
elif prop["type"] == "number": | |
input_field = gr.Number( | |
label=name, | |
value=prop.get("default"), | |
minimum=prop.get("minimum"), | |
maximum=prop.get("maximum"), | |
) | |
elif prop["type"] == "boolean": | |
input_field = gr.Checkbox(label=name, value=prop.get("default")) | |
elif prop["type"] == "string" and prop.get("format") == "uri": | |
input_field = gr.File(label=name) | |
else: # Assuming string type for simplicity, can add more types as needed | |
input_field = gr.Textbox(label=name, value=prop.get("default")) | |
inputs.append(input_field) | |
def predict(**kwargs): | |
payload = {"input": {}} | |
for key, value in kwargs.items(): | |
if isinstance( | |
value, io.BytesIO | |
): # For image inputs, convert to the desired format | |
value.seek(0) | |
value = ( | |
"data:image/jpeg;base64," + base64.b64encode(value.read()).decode() | |
) | |
payload["input"][key] = value | |
response = requests.post( | |
api_url, headers={"Content-Type": "application/json"}, json=payload | |
) | |
json_response = response.json() | |
if "status" in json_response and json_response["status"] == "failed": | |
raise gr.Error("Failed to generate image") | |
output_images = [] | |
for output_uri in json_response["output"]: | |
base64_image = output_uri.replace("data:image/png;base64,", "") | |
image_data = base64.b64decode(base64_image) | |
image_stream = io.BytesIO(image_data) | |
output_images.append(Image.open(image_stream)) | |
return output_images | |
output_component = gr.Gallery(label="Output Images") | |
return gr.Interface(fn=predict, inputs=inputs, outputs=output_component) | |
# Use the modified function with the API URL | |
api_spec = parser.specification | |
API_URL = "http://localhost:5000/predictions" | |
app = create_gradio_app(api_spec, API_URL) | |
app.launch() |