gradio-on-cog / app.py
multimodalart's picture
Update app.py
a8cf54a verified
raw
history blame
3.63 kB
import gradio as gr
import requests
import json
import io
from PIL import Image
import base64
from prance import ResolvingParser
SCHEMA_URL = "http://localhost:5000/openapi.json"
FILENAME = "openapi.json"
schema_response = requests.get(SCHEMA_URL)
openapi_spec = schema_response
r = requests.get(SCHEMA_URL)
print(r.content)
with open(FILENAME, "wb") as f:
f.write(r.content)
parser = ResolvingParser(FILENAME)
print(parser.specification)
def extract_property_info(prop):
# Handle 'allOf' by merging all contained properties (assuming simple case of enum merging)
if "allOf" in prop:
combined_prop = {}
for subprop in prop["allOf"]:
combined_prop.update(subprop)
prop = combined_prop
return prop
def create_gradio_app(api_spec, api_url):
inputs = []
input_schema = api_spec["components"]["schemas"]["PredictionRequest"]["properties"][
"input"
]["properties"]
for name, prop in input_schema.items():
prop = extract_property_info(
prop
) # Extract property info correctly for 'allOf'
print(prop)
if "enum" in prop:
input_field = gr.Dropdown(
choices=prop["enum"], label=name, value=prop.get("default")
)
elif prop["type"] == "integer":
input_field = gr.Number(
label=name,
value=prop.get("default"),
minimum=prop.get("minimum"),
maximum=prop.get("maximum"),
step=1,
)
elif prop["type"] == "number":
input_field = gr.Number(
label=name,
value=prop.get("default"),
minimum=prop.get("minimum"),
maximum=prop.get("maximum"),
)
elif prop["type"] == "boolean":
input_field = gr.Checkbox(label=name, value=prop.get("default"))
elif prop["type"] == "string" and prop.get("format") == "uri":
input_field = gr.File(label=name)
else: # Assuming string type for simplicity, can add more types as needed
input_field = gr.Textbox(label=name, value=prop.get("default"))
inputs.append(input_field)
def predict(**kwargs):
payload = {"input": {}}
for key, value in kwargs.items():
if isinstance(
value, io.BytesIO
): # For image inputs, convert to the desired format
value.seek(0)
value = (
"data:image/jpeg;base64," + base64.b64encode(value.read()).decode()
)
payload["input"][key] = value
response = requests.post(
api_url, headers={"Content-Type": "application/json"}, json=payload
)
json_response = response.json()
if "status" in json_response and json_response["status"] == "failed":
raise gr.Error("Failed to generate image")
output_images = []
for output_uri in json_response["output"]:
base64_image = output_uri.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_image)
image_stream = io.BytesIO(image_data)
output_images.append(Image.open(image_stream))
return output_images
output_component = gr.Gallery(label="Output Images")
return gr.Interface(fn=predict, inputs=inputs, outputs=output_component)
# Use the modified function with the API URL
api_spec = parser.specification
API_URL = "http://localhost:5000/predictions"
app = create_gradio_app(api_spec, API_URL)
app.launch()