File size: 4,332 Bytes
0bc1cb9
a1dcb35
8f8a72d
a1dcb35
8866c4b
a1dcb35
 
 
 
 
 
 
 
 
 
 
 
 
 
8866c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1dcb35
 
 
8866c4b
 
a1dcb35
 
 
 
 
8866c4b
 
 
 
 
 
 
a1dcb35
 
 
8866c4b
 
 
 
 
 
 
 
 
a1dcb35
 
 
 
8866c4b
 
 
 
 
a1dcb35
 
8866c4b
a1dcb35
 
 
8866c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1dcb35
 
 
 
 
 
8f8a72d
8866c4b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import torch
import numpy as np
import PIL.Image
from flask import Flask, request, render_template, send_file, url_for
from werkzeug.utils import secure_filename

# Set up Flask app
app = Flask(__name__)

# Define upload folder
UPLOAD_FOLDER = "uploads"
RESULT_FOLDER = "results"
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
app.config["RESULT_FOLDER"] = RESULT_FOLDER
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULT_FOLDER, exist_ok=True)

# Load StyleGAN3 Model
try:
    import legacy
    
    network_pkl = "models/stylegan3-r-ffhq-1024x1024.pkl"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Loading networks from "{network_pkl}"...')
    
    with open(network_pkl, "rb") as f:
        G = legacy.load_network_pkl(f)["G_ema"].to(device)
    
    model_loaded = True
    print("StyleGAN3 model loaded successfully!")
except Exception as e:
    print(f"Error loading StyleGAN3 model: {e}")
    model_loaded = False

# Function to encode an image into latent space
def image_to_latent(image_path):
    # Note: This is a simplified version. Actual image encoding to latent space
    # requires techniques like optimization or encoder networks
    latent_vector = torch.randn(1, G.z_dim, device=device)  # Generate random latent vector
    return latent_vector

# Function to modify latent code to make the face look younger
def modify_age(latent_vector, age_factor=-2.0):
    try:
        age_direction = torch.load("models/age_direction.pt").to(device)  # Load precomputed age direction
        new_latent_vector = latent_vector + age_factor * age_direction
        return new_latent_vector
    except Exception as e:
        print(f"Error modifying age: {e}")
        return latent_vector  # Return original if error

# Function to generate an image from a latent code
def generate_image(latent_vector):
    try:
        img = G.synthesis(latent_vector, noise_mode="const")
        img = (img + 1) * (255 / 2)
        img = img.permute(0, 2, 3, 1).cpu().numpy()[0].astype(np.uint8)
        return PIL.Image.fromarray(img)
    except Exception as e:
        print(f"Error generating image: {e}")
        # Return a blank image if there's an error
        return PIL.Image.new('RGB', (1024, 1024), color = (255, 255, 255))

# Flask Routes
@app.route("/", methods=["GET", "POST"])
def upload_file():
    error_message = None
    
    if not model_loaded:
        error_message = "StyleGAN3 model could not be loaded. Please check the server logs."
    
    if request.method == "POST":
        if "file" not in request.files:
            return render_template("index.html", error="No file uploaded")
        
        file = request.files["file"]
        if file.filename == "":
            return render_template("index.html", error="No selected file")
        
        if not model_loaded:
            return render_template("index.html", error=error_message)
        
        try:
            filename = secure_filename(file.filename)
            input_path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
            file.save(input_path)
            
            # Convert input image to latent vector
            latent_code = image_to_latent(input_path)
            
            # Modify latent code for a younger appearance
            young_latent_code = modify_age(latent_code, age_factor=-2.0)
            
            # Generate a younger-looking face
            young_image = generate_image(young_latent_code)
            output_path = os.path.join(app.config["RESULT_FOLDER"], "young_" + filename)
            young_image.save(output_path)
            
            return render_template("result.html", filename="young_" + filename)
        except Exception as e:
            error_message = f"Error processing image: {str(e)}"
            return render_template("index.html", error=error_message)
    
    return render_template("index.html", error=error_message)

@app.route("/results/<filename>")
def display_image(filename):
    return send_file(os.path.join(app.config["RESULT_FOLDER"], filename))

@app.route("/download/<filename>")
def download_file(filename):
    return send_file(os.path.join(app.config["RESULT_FOLDER"], filename), as_attachment=True)

# Run the Flask app
if __name__ == "__main__":
    app.run(debug=True, host="0.0.0.0")