image-upscaler / app.py
bluenevus's picture
Update app.py
a240e21 verified
raw
history blame
4.57 kB
import base64
import io
import os
import requests
import threading
import time
from dash import Dash, dcc, html, Input, Output, State, ctx
import dash_bootstrap_components as dbc
# Initialize the Dash app
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
# Stability AI API key (to be set as a Hugging Face secret)
STABILITY_API_KEY = os.getenv('STABILITY_API_KEY')
# Global variable to store the generated file
generated_file = None
# Function to upscale image
def upscale_image(contents):
global generated_file
generated_file = None
# Decode the base64 image
content_type, content_string = contents.split(',')
decoded = base64.b64decode(content_string)
# Prepare the API request
url = "https://api.stability.ai/v2beta/stable-image/upscale/fast"
headers = {
"Authorization": f"Bearer {STABILITY_API_KEY}",
"Accept": "image/*"
}
files = {
"image": ("image.png", io.BytesIO(decoded), "image/png")
}
data = {
"output_format": "png"
}
# Make the API request
response = requests.post(url, headers=headers, files=files, data=data)
if response.status_code == 200:
generated_file = response.content
return True
else:
print(f"Error: {response.status_code}, {response.text}")
return False
# App layout
app.layout = dbc.Container([
html.H1("Image Upscaler", className="text-center my-4"),
dbc.Row([
dbc.Col([
dcc.Upload(
id='upload-image',
children=html.Div([
'Drag and Drop or ',
html.A('Select an Image')
]),
style={
'width': '100%',
'height': '60px',
'lineHeight': '60px',
'borderWidth': '1px',
'borderStyle': 'dashed',
'borderRadius': '5px',
'textAlign': 'center',
'margin': '10px'
},
multiple=False
),
], width=12),
]),
dbc.Row([
dbc.Col([
dbc.Button("Upscale Image", id="upscale-button", color="primary", className="mb-3"),
dbc.Spinner(html.Div(id="loading-output"), color="primary", type="border", size="sm"),
], width=12, className="text-center"),
]),
dbc.Row([
dbc.Col([
html.H4("Original Image", className="mt-4"),
dbc.Card(id='output-image-upload', body=True)
], md=12, lg=6),
dbc.Col([
html.H4("Upscaled Image", className="mt-4"),
dbc.Button("Download Upscaled Image", id="download-button", color="success", className="mb-3", disabled=True),
dbc.Card(id='output-upscaled-image', body=True),
dcc.Download(id="download-image")
], md=12, lg=6),
]),
], fluid=True)
@app.callback(
Output('output-image-upload', 'children'),
Input('upload-image', 'contents'),
State('upload-image', 'filename')
)
def update_output(contents, filename):
if contents is not None:
return html.Div([
html.Img(src=contents, style={'width': '100%'}),
html.P(filename)
])
@app.callback(
[Output('output-upscaled-image', 'children'),
Output('download-button', 'disabled'),
Output('loading-output', 'children')],
Input('upscale-button', 'n_clicks'),
State('upload-image', 'contents'),
prevent_initial_call=True
)
def upscale_image_callback(n_clicks, contents):
if contents is None:
return None, True, ""
def upscale_thread():
upscale_image(contents)
threading.Thread(target=upscale_thread).start()
while generated_file is None:
time.sleep(0.1) # Wait for the upscaling to complete
if generated_file:
upscaled_image = html.Div([
html.Img(src=f"data:image/png;base64,{base64.b64encode(generated_file).decode()}", style={'width': '100%'})
])
return upscaled_image, False, ""
else:
return None, True, ""
@app.callback(
Output("download-image", "data"),
Input("download-button", "n_clicks"),
prevent_initial_call=True
)
def download_image(n_clicks):
if generated_file:
return dcc.send_bytes(generated_file, "upscaled_image.png")
if __name__ == '__main__':
print("Starting the Dash application...")
app.run(debug=True, host='0.0.0.0', port=7860)
print("Dash application has finished running.")