File size: 4,485 Bytes
4883955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4441bea
 
4883955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12070f2
 
4441bea
 
255e0b3
ffffc22
a240e21
6904ba7
4441bea
 
 
12070f2
6904ba7
 
4441bea
4883955
 
 
 
 
 
 
 
 
 
6904ba7
 
 
4883955
 
 
a240e21
 
 
4883955
 
 
 
 
 
a240e21
4883955
 
a240e21
4883955
 
 
 
 
 
 
6904ba7
 
4883955
a240e21
4883955
a240e21
4883955
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import base64
import io
import os
import requests
import threading
import time
from dash import Dash, dcc, html, Input, Output, State, ctx
import dash_bootstrap_components as dbc

# Initialize the Dash app
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Stability AI API key (to be set as a Hugging Face secret)
STABILITY_API_KEY = os.getenv('STABILITY_API_KEY')

# Global variable to store the generated file
generated_file = None

# Function to upscale image
def upscale_image(contents):
    global generated_file
    generated_file = None
    
    # Decode the base64 image
    content_type, content_string = contents.split(',')
    decoded = base64.b64decode(content_string)
    
    # Prepare the API request
    url = "https://api.stability.ai/v2beta/stable-image/upscale/fast"
    headers = {
        "Authorization": f"Bearer {STABILITY_API_KEY}",
        "Accept": "image/*"
    }
    files = {
        "image": ("image.png", io.BytesIO(decoded), "image/png")
    }
    data = {
        "output_format": "png"
    }
    
    # Make the API request
    response = requests.post(url, headers=headers, files=files, data=data)
    
    if response.status_code == 200:
        generated_file = response.content
        return True
    else:
        print(f"Error: {response.status_code}, {response.text}")
        return False

# App layout
app.layout = dbc.Container([
    html.H1("Image Upscaler", className="text-center my-4"),
    dbc.Row([
        dbc.Col([
            dcc.Upload(
                id='upload-image',
                children=html.Div([
                    'Drag and Drop or ',
                    html.A('Select an Image')
                ]),
                style={
                    'width': '100%',
                    'height': '60px',
                    'lineHeight': '60px',
                    'borderWidth': '1px',
                    'borderStyle': 'dashed',
                    'borderRadius': '5px',
                    'textAlign': 'center',
                    'margin': '10px'
                },
                multiple=False
            ),
        ], width=12),
    ]),
    dbc.Row([
        dbc.Col([
            html.H4("Original Image", className="mt-4"),
            dbc.Button("Upscale Image", id="upscale-button", color="primary", className="mb-3"),
            dbc.Spinner(html.Div(id="loading-output"), color="primary", type="border", size="sm"),
            dbc.Card(id='output-image-upload', body=True)
        ], md=12, lg=6),
        dbc.Col([
            html.H4("Upscaled Image", className="mt-4"),
            dbc.Button("Download Upscaled Image", id="download-button", color="success", className="mb-3", disabled=True),
            dbc.Card(id='output-upscaled-image', body=True),
            dcc.Download(id="download-image")
        ], md=12, lg=6),
    ]),
], fluid=True)

@app.callback(
    Output('output-image-upload', 'children'),
    Input('upload-image', 'contents'),
    State('upload-image', 'filename')
)
def update_output(contents, filename):
    if contents is not None:
        return html.Div([
            html.Img(src=contents, style={'width': '100%'}),
            html.P(filename)
        ])

@app.callback(
    [Output('output-upscaled-image', 'children'),
     Output('download-button', 'disabled'),
     Output('loading-output', 'children')],
    Input('upscale-button', 'n_clicks'),
    State('upload-image', 'contents'),
    prevent_initial_call=True
)
def upscale_image_callback(n_clicks, contents):
    if contents is None:
        return None, True, ""
    
    def upscale_thread():
        upscale_image(contents)
    
    threading.Thread(target=upscale_thread).start()
    
    while generated_file is None:
        time.sleep(0.1)  # Wait for the upscaling to complete
    
    if generated_file:
        upscaled_image = html.Div([
            html.Img(src=f"data:image/png;base64,{base64.b64encode(generated_file).decode()}", style={'width': '100%'})
        ])
        return upscaled_image, False, ""
    else:
        return None, True, ""

@app.callback(
    Output("download-image", "data"),
    Input("download-button", "n_clicks"),
    prevent_initial_call=True
)
def download_image(n_clicks):
    if generated_file:
        return dcc.send_bytes(generated_file, "upscaled_image.png")

if __name__ == '__main__':
    print("Starting the Dash application...")
    app.run(debug=True, host='0.0.0.0', port=7860)
    print("Dash application has finished running.")