File size: 11,295 Bytes
0bc1cb9
a1dcb35
8f8a72d
a1dcb35
0612339
 
8866c4b
a1dcb35
 
 
 
 
0612339
a1dcb35
 
0612339
a1dcb35
 
 
 
0612339
a1dcb35
0612339
 
 
8866c4b
0612339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1dcb35
 
 
8866c4b
 
0612339
 
 
 
a1dcb35
 
 
 
8866c4b
0612339
 
8866c4b
 
 
 
 
a1dcb35
 
 
8866c4b
0612339
 
 
 
8866c4b
 
 
 
 
 
 
0612339
a1dcb35
 
 
 
8866c4b
0612339
8866c4b
a1dcb35
 
0612339
a1dcb35
 
 
0612339
8866c4b
 
0612339
8866c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0612339
8866c4b
0612339
8866c4b
 
 
 
a1dcb35
 
 
 
 
0612339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1dcb35
8f8a72d
8866c4b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import os
import torch
import numpy as np
import PIL.Image
import subprocess
import time
from flask import Flask, request, render_template, send_file, url_for
from werkzeug.utils import secure_filename

# Set up Flask app
app = Flask(__name__)

# Define folders
UPLOAD_FOLDER = "uploads"
RESULT_FOLDER = "results"
MODELS_FOLDER = "models"
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
app.config["RESULT_FOLDER"] = RESULT_FOLDER
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULT_FOLDER, exist_ok=True)
os.makedirs(MODELS_FOLDER, exist_ok=True)

# Function to download the StyleGAN3 model if it doesn't exist
def download_stylegan3_model():
    network_pkl = os.path.join(MODELS_FOLDER, "stylegan3-r-ffhq-1024x1024.pkl")
    
    # Check if model exists
    if not os.path.exists(network_pkl):
        print(f"StyleGAN3 model not found. Downloading to {network_pkl}...")
        try:
            # Download using subprocess for better feedback
            result = subprocess.run([
                "wget", 
                "https://nvlabs-fi-cdn.nvidia.com/stylegan3/pretrained/stylegan3-r-ffhq-1024x1024.pkl", 
                "-P", MODELS_FOLDER
            ], capture_output=True, text=True)
            
            if result.returncode != 0:
                print(f"wget failed: {result.stderr}")
                print("Trying alternative download method with requests...")
                
                # Fallback to requests if wget fails
                import requests
                response = requests.get("https://nvlabs-fi-cdn.nvidia.com/stylegan3/pretrained/stylegan3-r-ffhq-1024x1024.pkl", stream=True)
                response.raise_for_status()
                
                with open(network_pkl, 'wb') as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                
                print("Download completed successfully using requests.")
            else:
                print("Download completed successfully using wget.")
                
            return network_pkl
        except Exception as e:
            print(f"Error downloading StyleGAN3 model: {e}")
            return None
    else:
        print(f"StyleGAN3 model already exists at {network_pkl}")
        return network_pkl

# Create a simple dummy age direction vector if it doesn't exist
def create_dummy_age_direction():
    age_direction_path = os.path.join(MODELS_FOLDER, "age_direction.pt")
    if not os.path.exists(age_direction_path):
        print("Creating a dummy age direction vector...")
        try:
            # Create a simple random vector as a placeholder
            # In a real application, this would be a properly trained vector
            dummy_vector = torch.randn(1, 512)  # Assuming 512-dimensional latent space
            torch.save(dummy_vector, age_direction_path)
            print(f"Created dummy age direction vector at {age_direction_path}")
        except Exception as e:
            print(f"Error creating dummy age direction vector: {e}")
    return age_direction_path

# Load StyleGAN3 Model
def load_stylegan3_model():
    try:
        # First ensure the model file exists
        network_pkl = download_stylegan3_model()
        if not network_pkl:
            return None
        
        # Make sure the age direction vector exists
        create_dummy_age_direction()
        
        # Import the legacy module for StyleGAN3
        import sys
        if not os.path.exists("legacy"):
            print("Warning: 'legacy' module not found in current directory.")
            print("StyleGAN3 requires this module to load models.")
            print("You may need to clone the StyleGAN3 repository to get this module.")
            return None
        
        # Add current directory to path to find the legacy module
        if "" not in sys.path:
            sys.path.append("")
        
        import legacy
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f'Loading networks from "{network_pkl}" using device {device}...')
        
        with open(network_pkl, "rb") as f:
            G = legacy.load_network_pkl(f)["G_ema"].to(device)
        
        print("StyleGAN3 model loaded successfully!")
        return G
    except ImportError as e:
        print(f"Import error: {e}")
        print("Make sure you have the required modules for StyleGAN3.")
        return None
    except Exception as e:
        print(f"Error loading StyleGAN3 model: {e}")
        return None

# Attempt to load the model
G = load_stylegan3_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_loaded = G is not None

# Function to encode an image into latent space
def image_to_latent(image_path):
    # Note: This is a simplified version. Actual image encoding to latent space
    # requires techniques like optimization or encoder networks
    if G is not None:
        latent_vector = torch.randn(1, G.z_dim, device=device)  # Generate random latent vector
    else:
        latent_vector = torch.randn(1, 512, device=device)  # Assuming 512-dimensional latent space
    return latent_vector

# Function to modify latent code to make the face look younger
def modify_age(latent_vector, age_factor=-2.0):
    try:
        age_direction_path = os.path.join(MODELS_FOLDER, "age_direction.pt")
        age_direction = torch.load(age_direction_path).to(device)  # Load precomputed age direction
        new_latent_vector = latent_vector + age_factor * age_direction
        return new_latent_vector
    except Exception as e:
        print(f"Error modifying age: {e}")
        return latent_vector  # Return original if error

# Function to generate an image from a latent code
def generate_image(latent_vector):
    try:
        if G is None:
            # If model isn't loaded, return a placeholder image
            return PIL.Image.new('RGB', (1024, 1024), color=(255, 255, 255))
        
        img = G.synthesis(latent_vector, noise_mode="const")
        img = (img + 1) * (255 / 2)
        img = img.permute(0, 2, 3, 1).cpu().numpy()[0].astype(np.uint8)
        return PIL.Image.fromarray(img)
    except Exception as e:
        print(f"Error generating image: {e}")
        # Return a blank image if there's an error
        return PIL.Image.new('RGB', (1024, 1024), color=(255, 255, 255))

# Flask Routes
@app.route("/", methods=["GET", "POST"])
def upload_file():
    error_message = None
    model_status = "Model loaded successfully" if model_loaded else "Model could not be loaded. See server logs for details."
    
    if request.method == "POST":
        if "file" not in request.files:
            return render_template("index.html", error="No file uploaded", model_status=model_status)
        
        file = request.files["file"]
        if file.filename == "":
            return render_template("index.html", error="No selected file", model_status=model_status)
        
        if not model_loaded:
            return render_template("index.html", error="StyleGAN3 model is not loaded. Cannot process images.", model_status=model_status)
        
        try:
            filename = secure_filename(file.filename)
            input_path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
            file.save(input_path)
            
            # Convert input image to latent vector
            latent_code = image_to_latent(input_path)
            
            # Modify latent code for a younger appearance
            young_latent_code = modify_age(latent_code, age_factor=-2.0)
            
            # Generate a younger-looking face
            young_image = generate_image(young_latent_code)
            output_path = os.path.join(app.config["RESULT_FOLDER"], "young_" + filename)
            young_image.save(output_path)
            
            return render_template("result.html", filename="young_" + filename)
        except Exception as e:
            error_message = f"Error processing image: {str(e)}"
            return render_template("index.html", error=error_message, model_status=model_status)
    
    return render_template("index.html", error=error_message, model_status=model_status)

@app.route("/results/<filename>")
def display_image(filename):
    return send_file(os.path.join(app.config["RESULT_FOLDER"], filename))

@app.route("/download/<filename>")
def download_file(filename):
    return send_file(os.path.join(app.config["RESULT_FOLDER"], filename), as_attachment=True)

# Create templates directory and files if they don't exist
def create_templates():
    os.makedirs("templates", exist_ok=True)
    
    # Create index.html
    index_html = """<!DOCTYPE html>
<html>
<head>
    <title>Image Age Reduction</title>
    <style>
        body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
        .container { border: 1px solid #ddd; padding: 20px; border-radius: 5px; }
        .error { color: red; margin-bottom: 15px; }
        .status { color: blue; margin-bottom: 15px; }
        .form-group { margin-bottom: 15px; }
        .btn { background-color: #4CAF50; color: white; padding: 10px 15px; border: none; cursor: pointer; }
    </style>
</head>
<body>
    <div class="container">
        <h1>Image Age Reduction</h1>
        {% if error %}
        <div class="error">{{ error }}</div>
        {% endif %}
        {% if model_status %}
        <div class="status">Model Status: {{ model_status }}</div>
        {% endif %}
        <form method="post" enctype="multipart/form-data">
            <div class="form-group">
                <label for="file">Select an image:</label>
                <input type="file" name="file" id="file" accept="image/*" required>
            </div>
            <button type="submit" class="btn">Process Image</button>
        </form>
    </div>
</body>
</html>"""
    
    # Create result.html
    result_html = """<!DOCTYPE html>
<html>
<head>
    <title>Processing Result</title>
    <style>
        body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
        .container { border: 1px solid #ddd; padding: 20px; border-radius: 5px; }
        .result-img { max-width: 100%; margin: 20px 0; }
        .btn { background-color: #4CAF50; color: white; padding: 10px 15px; border: none; cursor: pointer; text-decoration: none; display: inline-block; margin-right: 10px; }
        .btn-back { background-color: #f44336; }
    </style>
</head>
<body>
    <div class="container">
        <h1>Processing Result</h1>
        <img src="{{ url_for('display_image', filename=filename) }}" class="result-img">
        <div>
            <a href="{{ url_for('download_file', filename=filename) }}" class="btn">Download</a>
            <a href="{{ url_for('upload_file') }}" class="btn btn-back">Back</a>
        </div>
    </div>
</body>
</html>"""
    
    # Write the template files
    with open(os.path.join("templates", "index.html"), "w") as f:
        f.write(index_html)
    
    with open(os.path.join("templates", "result.html"), "w") as f:
        f.write(result_html)

# Create the template files before running the app
create_templates()

# Run the Flask app
if __name__ == "__main__":
    app.run(debug=True, host="0.0.0.0")