image-upscaler / app.py
bluenevus's picture
Update app.py
255e0b3 verified
import base64
import io
import os
import requests
import threading
import time
from dash import Dash, dcc, html, Input, Output, State, ctx
import dash_bootstrap_components as dbc
# Initialize the Dash app
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
# Stability AI API key (to be set as a Hugging Face secret)
STABILITY_API_KEY = os.getenv('STABILITY_API_KEY')
# Global variable to store the generated file
generated_file = None
# Function to upscale image
def upscale_image(contents):
global generated_file
generated_file = None
# Decode the base64 image
content_type, content_string = contents.split(',')
decoded = base64.b64decode(content_string)
# Prepare the API request
url = "https://api.stability.ai/v2beta/stable-image/upscale/fast"
headers = {
"Authorization": f"Bearer {STABILITY_API_KEY}",
"Accept": "image/*"
}
files = {
"image": ("image.png", io.BytesIO(decoded), "image/png")
}
data = {
"output_format": "png"
}
# Make the API request
response = requests.post(url, headers=headers, files=files, data=data)
if response.status_code == 200:
generated_file = response.content
return True
else:
print(f"Error: {response.status_code}, {response.text}")
return False
# App layout
app.layout = dbc.Container([
html.H1("Image Upscaler", className="text-center my-4"),
dbc.Row([
dbc.Col([
dcc.Upload(
id='upload-image',
children=html.Div([
'Drag and Drop or ',
html.A('Select an Image')
]),
style={
'width': '100%',
'height': '60px',
'lineHeight': '60px',
'borderWidth': '1px',
'borderStyle': 'dashed',
'borderRadius': '5px',
'textAlign': 'center',
'margin': '10px'
},
multiple=False
),
], width=12),
]),
dbc.Row([
dbc.Col([
html.H4("Original Image", className="mt-4"),
dbc.Button("Upscale Image", id="upscale-button", color="primary", className="mb-3"),
dbc.Spinner(html.Div(id="loading-output"), color="primary", type="border", size="sm"),
dbc.Card(id='output-image-upload', body=True)
], md=12, lg=6),
dbc.Col([
html.H4("Upscaled Image", className="mt-4"),
dbc.Button("Download Upscaled Image", id="download-button", color="success", className="mb-3", disabled=True),
dbc.Card(id='output-upscaled-image', body=True),
dcc.Download(id="download-image")
], md=12, lg=6),
]),
], fluid=True)
@app.callback(
Output('output-image-upload', 'children'),
Input('upload-image', 'contents'),
State('upload-image', 'filename')
)
def update_output(contents, filename):
if contents is not None:
return html.Div([
html.Img(src=contents, style={'width': '100%'}),
html.P(filename)
])
@app.callback(
[Output('output-upscaled-image', 'children'),
Output('download-button', 'disabled'),
Output('loading-output', 'children')],
Input('upscale-button', 'n_clicks'),
State('upload-image', 'contents'),
prevent_initial_call=True
)
def upscale_image_callback(n_clicks, contents):
if contents is None:
return None, True, ""
def upscale_thread():
upscale_image(contents)
threading.Thread(target=upscale_thread).start()
while generated_file is None:
time.sleep(0.1) # Wait for the upscaling to complete
if generated_file:
upscaled_image = html.Div([
html.Img(src=f"data:image/png;base64,{base64.b64encode(generated_file).decode()}", style={'width': '100%'})
])
return upscaled_image, False, ""
else:
return None, True, ""
@app.callback(
Output("download-image", "data"),
Input("download-button", "n_clicks"),
prevent_initial_call=True
)
def download_image(n_clicks):
if generated_file:
return dcc.send_bytes(generated_file, "upscaled_image.png")
if __name__ == '__main__':
print("Starting the Dash application...")
app.run(debug=True, host='0.0.0.0', port=7860)
print("Dash application has finished running.")