File size: 3,625 Bytes
a8cf54a
fa2d678
a8cf54a
fa2d678
4791078
a8cf54a
 
 
fa2d678
aa69d1f
a8cf54a
 
 
 
 
 
 
 
 
 
 
 
aa69d1f
a8cf54a
 
 
 
 
 
 
 
4791078
a8cf54a
 
4791078
a8cf54a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4791078
fa2d678
a8cf54a
4791078
fa2d678
a8cf54a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import requests
import json
import io
from PIL import Image
import base64
from prance import ResolvingParser


SCHEMA_URL = "http://localhost:5000/openapi.json"
FILENAME = "openapi.json"
schema_response = requests.get(SCHEMA_URL)
openapi_spec = schema_response

r = requests.get(SCHEMA_URL)
print(r.content)
with open(FILENAME, "wb") as f:
    f.write(r.content)

parser = ResolvingParser(FILENAME)
print(parser.specification)


def extract_property_info(prop):
    # Handle 'allOf' by merging all contained properties (assuming simple case of enum merging)
    if "allOf" in prop:
        combined_prop = {}
        for subprop in prop["allOf"]:
            combined_prop.update(subprop)
        prop = combined_prop
    return prop


def create_gradio_app(api_spec, api_url):
    inputs = []
    input_schema = api_spec["components"]["schemas"]["PredictionRequest"]["properties"][
        "input"
    ]["properties"]

    for name, prop in input_schema.items():
        prop = extract_property_info(
            prop
        )  # Extract property info correctly for 'allOf'
        print(prop)
        if "enum" in prop:
            input_field = gr.Dropdown(
                choices=prop["enum"], label=name, value=prop.get("default")
            )
        elif prop["type"] == "integer":
            input_field = gr.Number(
                label=name,
                value=prop.get("default"),
                minimum=prop.get("minimum"),
                maximum=prop.get("maximum"),
                step=1,
            )
        elif prop["type"] == "number":
            input_field = gr.Number(
                label=name,
                value=prop.get("default"),
                minimum=prop.get("minimum"),
                maximum=prop.get("maximum"),
            )
        elif prop["type"] == "boolean":
            input_field = gr.Checkbox(label=name, value=prop.get("default"))
        elif prop["type"] == "string" and prop.get("format") == "uri":
            input_field = gr.File(label=name)
        else:  # Assuming string type for simplicity, can add more types as needed
            input_field = gr.Textbox(label=name, value=prop.get("default"))
        inputs.append(input_field)

    def predict(**kwargs):
        payload = {"input": {}}
        for key, value in kwargs.items():
            if isinstance(
                value, io.BytesIO
            ):  # For image inputs, convert to the desired format
                value.seek(0)
                value = (
                    "data:image/jpeg;base64," + base64.b64encode(value.read()).decode()
                )
            payload["input"][key] = value

        response = requests.post(
            api_url, headers={"Content-Type": "application/json"}, json=payload
        )
        json_response = response.json()

        if "status" in json_response and json_response["status"] == "failed":
            raise gr.Error("Failed to generate image")

        output_images = []
        for output_uri in json_response["output"]:
            base64_image = output_uri.replace("data:image/png;base64,", "")
            image_data = base64.b64decode(base64_image)
            image_stream = io.BytesIO(image_data)
            output_images.append(Image.open(image_stream))

        return output_images

    output_component = gr.Gallery(label="Output Images")
    return gr.Interface(fn=predict, inputs=inputs, outputs=output_component)


# Use the modified function with the API URL
api_spec = parser.specification

API_URL = "http://localhost:5000/predictions"
app = create_gradio_app(api_spec, API_URL)
app.launch()