multimodalart HF Staff commited on
Commit
38a61be
·
verified ·
1 Parent(s): 48c3a03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -26
app.py CHANGED
@@ -3,6 +3,7 @@ import requests
3
  import json
4
  import io
5
  import os
 
6
  from PIL import Image
7
  import base64
8
  from prance import ResolvingParser
@@ -80,24 +81,9 @@ def create_gradio_app(api_spec, api_url):
80
  data_field = gr.State(value=names)
81
  inputs.append(data_field)
82
  print(output_schema)
83
- if output_schema["type"] == "string":
84
- if "format" in output_schema:
85
- if(output_schema["format"] == "uri"):
86
- output_component = gr.Image(label=output_schema["title"])
87
- else:
88
- output_component = gr.Textbox(label="Output")
89
- else:
90
- output_component = gr.Textbox(label="Output")
91
- outputs.append(output_component)
92
- elif output_schema["type"] == "array":
93
- if "format" in output_schema["items"]:
94
- if(output_schema["items"]["format"] == "uri"):
95
- output_component = gr.Image(label=output_schema["title"])
96
- else:
97
- output_component = gr.Textbox(label=output_schema["title"])
98
- else:
99
- output_component = gr.Textbox(label=output_schema["title"])
100
- outputs.append(output_component)
101
  outputs.append(data_field)
102
  #else if there's multiple outputs
103
 
@@ -117,19 +103,31 @@ def create_gradio_app(api_spec, api_url):
117
  json_response = response.json()
118
  print(json_response)
119
  if "status" in json_response and json_response["status"] == "failed":
120
- raise gr.Error("Failed to generate image")
121
 
122
  output_images = []
123
  for output_uri in json_response["output"]:
124
- base64_image = output_uri.replace("data:image/png;base64,", "")
125
- image_data = base64.b64decode(base64_image)
126
- image_stream = io.BytesIO(image_data)
127
- output_images.append(Image.open(image_stream))
128
-
129
- return output_images[0], keys
 
 
 
 
 
 
 
 
 
 
 
 
130
  else:
131
  raise gr.Error("The submission failed!")
132
- return gr.Interface(fn=predict, inputs=inputs, outputs=outputs if outputs else "textbox")
133
 
134
  API_URL = "http://localhost:5000/predictions"
135
  app = create_gradio_app(api_spec, API_URL)
 
3
  import json
4
  import io
5
  import os
6
+ import uuid
7
  from PIL import Image
8
  import base64
9
  from prance import ResolvingParser
 
81
  data_field = gr.State(value=names)
82
  inputs.append(data_field)
83
  print(output_schema)
84
+ outputs.append(gr.Image(label=output_schema["title"], visible=True))
85
+ outputs.append(gr.Audio(label=output_schema["title"], visible=False))
86
+ outputs.append(gr.Textbox(label=output_schema["title"], visible=False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  outputs.append(data_field)
88
  #else if there's multiple outputs
89
 
 
103
  json_response = response.json()
104
  print(json_response)
105
  if "status" in json_response and json_response["status"] == "failed":
106
+ raise gr.Error("Failed to generate output")
107
 
108
  output_images = []
109
  for output_uri in json_response["output"]:
110
+ if output_uri.startswith("data:image"):
111
+ # Process as image
112
+ base64_data = output_uri.split(",", 1)[1]
113
+ image_data = base64.b64decode(base64_data)
114
+ image_stream = io.BytesIO(image_data)
115
+ image = Image.open(image_stream)
116
+ return gr.update(visible=True, value=image), gr.update(visible=False), gr.update(visible=False), keys
117
+ elif output_uri.startswith("data:audio"):
118
+ # Process as audio
119
+ base64_data = output_uri.split(",", 1)[1]
120
+ audio_data = base64.b64decode(base64_data)
121
+ audio_stream = io.BytesIO(audio_data)
122
+ # Here you can save the audio or return the stream for further processing
123
+ filename = f"{uuid.uuid4()}.wav" # Change format as needed
124
+ with open(filename, "wb") as audio_file:
125
+ audio_file.write(audio_stream.getbuffer())
126
+ return gr.update(visible=False), gr.update(visible=True, value=filename), gr.update(visible=False), keys
127
+
128
  else:
129
  raise gr.Error("The submission failed!")
130
+ return gr.Interface(fn=predict, inputs=inputs, outputs=outputs)
131
 
132
  API_URL = "http://localhost:5000/predictions"
133
  app = create_gradio_app(api_spec, API_URL)