Spaces:
Paused
Paused
import base64 | |
import io | |
import os | |
import requests | |
import threading | |
import time | |
from dash import Dash, dcc, html, Input, Output, State, ctx | |
import dash_bootstrap_components as dbc | |
# Initialize the Dash app | |
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) | |
# Stability AI API key (to be set as a Hugging Face secret) | |
STABILITY_API_KEY = os.getenv('STABILITY_API_KEY') | |
# Global variable to store the generated file | |
generated_file = None | |
# Function to upscale image | |
def upscale_image(contents): | |
global generated_file | |
generated_file = None | |
# Decode the base64 image | |
content_type, content_string = contents.split(',') | |
decoded = base64.b64decode(content_string) | |
# Prepare the API request | |
url = "https://api.stability.ai/v2beta/stable-image/upscale/fast" | |
headers = { | |
"Authorization": f"Bearer {STABILITY_API_KEY}", | |
"Accept": "image/*" | |
} | |
files = { | |
"image": ("image.png", io.BytesIO(decoded), "image/png") | |
} | |
data = { | |
"output_format": "png" | |
} | |
# Make the API request | |
response = requests.post(url, headers=headers, files=files, data=data) | |
if response.status_code == 200: | |
generated_file = response.content | |
return True | |
else: | |
print(f"Error: {response.status_code}, {response.text}") | |
return False | |
# App layout | |
app.layout = dbc.Container([ | |
html.H1("Image Upscaler", className="text-center my-4"), | |
dbc.Row([ | |
dbc.Col([ | |
dcc.Upload( | |
id='upload-image', | |
children=html.Div([ | |
'Drag and Drop or ', | |
html.A('Select an Image') | |
]), | |
style={ | |
'width': '100%', | |
'height': '60px', | |
'lineHeight': '60px', | |
'borderWidth': '1px', | |
'borderStyle': 'dashed', | |
'borderRadius': '5px', | |
'textAlign': 'center', | |
'margin': '10px' | |
}, | |
multiple=False | |
), | |
], width=12), | |
]), | |
dbc.Row([ | |
dbc.Col([ | |
html.H4("Original Image", className="mt-4"), | |
dbc.Button("Upscale Image", id="upscale-button", color="primary", className="mb-3"), | |
dbc.Spinner(html.Div(id="loading-output"), color="primary", type="border", size="sm"), | |
dbc.Card(id='output-image-upload', body=True) | |
], md=12, lg=6), | |
dbc.Col([ | |
html.H4("Upscaled Image", className="mt-4"), | |
dbc.Button("Download Upscaled Image", id="download-button", color="success", className="mb-3", disabled=True), | |
dbc.Card(id='output-upscaled-image', body=True), | |
dcc.Download(id="download-image") | |
], md=12, lg=6), | |
]), | |
], fluid=True) | |
def update_output(contents, filename): | |
if contents is not None: | |
return html.Div([ | |
html.Img(src=contents, style={'width': '100%'}), | |
html.P(filename) | |
]) | |
def upscale_image_callback(n_clicks, contents): | |
if contents is None: | |
return None, True, "" | |
def upscale_thread(): | |
upscale_image(contents) | |
threading.Thread(target=upscale_thread).start() | |
while generated_file is None: | |
time.sleep(0.1) # Wait for the upscaling to complete | |
if generated_file: | |
upscaled_image = html.Div([ | |
html.Img(src=f"data:image/png;base64,{base64.b64encode(generated_file).decode()}", style={'width': '100%'}) | |
]) | |
return upscaled_image, False, "" | |
else: | |
return None, True, "" | |
def download_image(n_clicks): | |
if generated_file: | |
return dcc.send_bytes(generated_file, "upscaled_image.png") | |
if __name__ == '__main__': | |
print("Starting the Dash application...") | |
app.run(debug=True, host='0.0.0.0', port=7860) | |
print("Dash application has finished running.") |