gradio-on-cog / app.py
multimodalart's picture
Update app.py
2ca7450 verified
raw
history blame
2.98 kB
import requests
import base64
import io
from PIL import Image
import gradio as gr
import json
# API and Schema URLs
API_URL = "http://localhost:5000/predictions"
SCHEMA_URL = "http://localhost:5000/openapi.json"
def fetch_api_spec(url):
response = requests.get(url)
return response.json()
def create_gradio_app_from_api_spec(api_spec):
input_properties = api_spec['components']['schemas']['Input']['properties']
inputs = []
for prop, details in input_properties.items():
if 'enum' in details:
choices = details['enum']
inputs.append(gr.Dropdown(choices=choices, label=prop, value=details.get('default')))
elif details['type'] == 'integer':
inputs.append(gr.Number(label=prop, value=details.get('default'), minimum=details.get('minimum'), maximum=details.get('maximum')))
elif details['type'] == 'number':
inputs.append(gr.Slider(minimum=details.get('minimum'), maximum=details.get('maximum'), value=details.get('default'), label=prop))
elif details['type'] == 'string' and 'format' in details and details['format'] == 'uri':
inputs.append(gr.Image(label=prop))
elif details['type'] == 'string':
inputs.append(gr.Textbox(label=prop, value=details.get('default')))
elif details['type'] == 'boolean':
inputs.append(gr.Checkbox(label=prop, value=details.get('default')))
def predict_function(**kwargs):
# Adjust the input kwargs for image inputs to convert them to the expected format by the API if needed
payload = {
"input": kwargs
}
print(payload)
response = requests.post(API_URL, headers={"Content-Type": "application/json"}, json=payload)
json_response = response.json()
if 'status' in json_response and json_response["status"] == "failed":
raise gr.Error("Failed to generate image")
output_spec = api_spec['components']['schemas']['Output']
if output_spec['items']['type'] == 'string' and output_spec['items']['format'] == 'uri':
outputs = []
for uri in json_response["output"]:
if uri.startswith("data:image"):
base64_image = uri.split(",")[1] # Strip the prefix part
image_data = base64.b64decode(base64_image)
image_stream = io.BytesIO(image_data)
image = Image.open(image_stream)
outputs.append(image)
else:
outputs.append(uri)
return outputs
else:
return json_response["output"]
iface = gr.Interface(fn=predict_function, inputs=inputs, outputs=gr.outputs.Image(type="pil"), title=api_spec['info']['title'])
return iface
# Fetch API Specification
api_spec = fetch_api_spec(SCHEMA_URL)
# Create and Launch Gradio App
gradio_app = create_gradio_app_from_api_spec(api_spec)
gradio_app.launch()