Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import numpy as np | |
import PIL.Image | |
import subprocess | |
import time | |
from flask import Flask, request, render_template, send_file, url_for | |
from werkzeug.utils import secure_filename | |
# Set up Flask app | |
app = Flask(__name__) | |
# Define folders | |
UPLOAD_FOLDER = "uploads" | |
RESULT_FOLDER = "results" | |
MODELS_FOLDER = "models" | |
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER | |
app.config["RESULT_FOLDER"] = RESULT_FOLDER | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
os.makedirs(RESULT_FOLDER, exist_ok=True) | |
os.makedirs(MODELS_FOLDER, exist_ok=True) | |
# Function to download the StyleGAN3 model if it doesn't exist | |
def download_stylegan3_model(): | |
network_pkl = os.path.join(MODELS_FOLDER, "stylegan3-r-ffhq-1024x1024.pkl") | |
# Check if model exists | |
if not os.path.exists(network_pkl): | |
print(f"StyleGAN3 model not found. Downloading to {network_pkl}...") | |
try: | |
# Download using subprocess for better feedback | |
result = subprocess.run([ | |
"wget", | |
"https://nvlabs-fi-cdn.nvidia.com/stylegan3/pretrained/stylegan3-r-ffhq-1024x1024.pkl", | |
"-P", MODELS_FOLDER | |
], capture_output=True, text=True) | |
if result.returncode != 0: | |
print(f"wget failed: {result.stderr}") | |
print("Trying alternative download method with requests...") | |
# Fallback to requests if wget fails | |
import requests | |
response = requests.get("https://nvlabs-fi-cdn.nvidia.com/stylegan3/pretrained/stylegan3-r-ffhq-1024x1024.pkl", stream=True) | |
response.raise_for_status() | |
with open(network_pkl, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
print("Download completed successfully using requests.") | |
else: | |
print("Download completed successfully using wget.") | |
return network_pkl | |
except Exception as e: | |
print(f"Error downloading StyleGAN3 model: {e}") | |
return None | |
else: | |
print(f"StyleGAN3 model already exists at {network_pkl}") | |
return network_pkl | |
# Create a simple dummy age direction vector if it doesn't exist | |
def create_dummy_age_direction(): | |
age_direction_path = os.path.join(MODELS_FOLDER, "age_direction.pt") | |
if not os.path.exists(age_direction_path): | |
print("Creating a dummy age direction vector...") | |
try: | |
# Create a simple random vector as a placeholder | |
# In a real application, this would be a properly trained vector | |
dummy_vector = torch.randn(1, 512) # Assuming 512-dimensional latent space | |
torch.save(dummy_vector, age_direction_path) | |
print(f"Created dummy age direction vector at {age_direction_path}") | |
except Exception as e: | |
print(f"Error creating dummy age direction vector: {e}") | |
return age_direction_path | |
# Load StyleGAN3 Model | |
def load_stylegan3_model(): | |
try: | |
# First ensure the model file exists | |
network_pkl = download_stylegan3_model() | |
if not network_pkl: | |
return None | |
# Make sure the age direction vector exists | |
create_dummy_age_direction() | |
# Import the legacy module for StyleGAN3 | |
import sys | |
if not os.path.exists("legacy"): | |
print("Warning: 'legacy' module not found in current directory.") | |
print("StyleGAN3 requires this module to load models.") | |
print("You may need to clone the StyleGAN3 repository to get this module.") | |
return None | |
# Add current directory to path to find the legacy module | |
if "" not in sys.path: | |
sys.path.append("") | |
import legacy | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f'Loading networks from "{network_pkl}" using device {device}...') | |
with open(network_pkl, "rb") as f: | |
G = legacy.load_network_pkl(f)["G_ema"].to(device) | |
print("StyleGAN3 model loaded successfully!") | |
return G | |
except ImportError as e: | |
print(f"Import error: {e}") | |
print("Make sure you have the required modules for StyleGAN3.") | |
return None | |
except Exception as e: | |
print(f"Error loading StyleGAN3 model: {e}") | |
return None | |
# Attempt to load the model | |
G = load_stylegan3_model() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_loaded = G is not None | |
# Function to encode an image into latent space | |
def image_to_latent(image_path): | |
# Note: This is a simplified version. Actual image encoding to latent space | |
# requires techniques like optimization or encoder networks | |
if G is not None: | |
latent_vector = torch.randn(1, G.z_dim, device=device) # Generate random latent vector | |
else: | |
latent_vector = torch.randn(1, 512, device=device) # Assuming 512-dimensional latent space | |
return latent_vector | |
# Function to modify latent code to make the face look younger | |
def modify_age(latent_vector, age_factor=-2.0): | |
try: | |
age_direction_path = os.path.join(MODELS_FOLDER, "age_direction.pt") | |
age_direction = torch.load(age_direction_path).to(device) # Load precomputed age direction | |
new_latent_vector = latent_vector + age_factor * age_direction | |
return new_latent_vector | |
except Exception as e: | |
print(f"Error modifying age: {e}") | |
return latent_vector # Return original if error | |
# Function to generate an image from a latent code | |
def generate_image(latent_vector): | |
try: | |
if G is None: | |
# If model isn't loaded, return a placeholder image | |
return PIL.Image.new('RGB', (1024, 1024), color=(255, 255, 255)) | |
img = G.synthesis(latent_vector, noise_mode="const") | |
img = (img + 1) * (255 / 2) | |
img = img.permute(0, 2, 3, 1).cpu().numpy()[0].astype(np.uint8) | |
return PIL.Image.fromarray(img) | |
except Exception as e: | |
print(f"Error generating image: {e}") | |
# Return a blank image if there's an error | |
return PIL.Image.new('RGB', (1024, 1024), color=(255, 255, 255)) | |
# Flask Routes | |
def upload_file(): | |
error_message = None | |
model_status = "Model loaded successfully" if model_loaded else "Model could not be loaded. See server logs for details." | |
if request.method == "POST": | |
if "file" not in request.files: | |
return render_template("index.html", error="No file uploaded", model_status=model_status) | |
file = request.files["file"] | |
if file.filename == "": | |
return render_template("index.html", error="No selected file", model_status=model_status) | |
if not model_loaded: | |
return render_template("index.html", error="StyleGAN3 model is not loaded. Cannot process images.", model_status=model_status) | |
try: | |
filename = secure_filename(file.filename) | |
input_path = os.path.join(app.config["UPLOAD_FOLDER"], filename) | |
file.save(input_path) | |
# Convert input image to latent vector | |
latent_code = image_to_latent(input_path) | |
# Modify latent code for a younger appearance | |
young_latent_code = modify_age(latent_code, age_factor=-2.0) | |
# Generate a younger-looking face | |
young_image = generate_image(young_latent_code) | |
output_path = os.path.join(app.config["RESULT_FOLDER"], "young_" + filename) | |
young_image.save(output_path) | |
return render_template("result.html", filename="young_" + filename) | |
except Exception as e: | |
error_message = f"Error processing image: {str(e)}" | |
return render_template("index.html", error=error_message, model_status=model_status) | |
return render_template("index.html", error=error_message, model_status=model_status) | |
def display_image(filename): | |
return send_file(os.path.join(app.config["RESULT_FOLDER"], filename)) | |
def download_file(filename): | |
return send_file(os.path.join(app.config["RESULT_FOLDER"], filename), as_attachment=True) | |
# Create templates directory and files if they don't exist | |
def create_templates(): | |
os.makedirs("templates", exist_ok=True) | |
# Create index.html | |
index_html = """<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Image Age Reduction</title> | |
<style> | |
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; } | |
.container { border: 1px solid #ddd; padding: 20px; border-radius: 5px; } | |
.error { color: red; margin-bottom: 15px; } | |
.status { color: blue; margin-bottom: 15px; } | |
.form-group { margin-bottom: 15px; } | |
.btn { background-color: #4CAF50; color: white; padding: 10px 15px; border: none; cursor: pointer; } | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>Image Age Reduction</h1> | |
{% if error %} | |
<div class="error">{{ error }}</div> | |
{% endif %} | |
{% if model_status %} | |
<div class="status">Model Status: {{ model_status }}</div> | |
{% endif %} | |
<form method="post" enctype="multipart/form-data"> | |
<div class="form-group"> | |
<label for="file">Select an image:</label> | |
<input type="file" name="file" id="file" accept="image/*" required> | |
</div> | |
<button type="submit" class="btn">Process Image</button> | |
</form> | |
</div> | |
</body> | |
</html>""" | |
# Create result.html | |
result_html = """<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Processing Result</title> | |
<style> | |
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; } | |
.container { border: 1px solid #ddd; padding: 20px; border-radius: 5px; } | |
.result-img { max-width: 100%; margin: 20px 0; } | |
.btn { background-color: #4CAF50; color: white; padding: 10px 15px; border: none; cursor: pointer; text-decoration: none; display: inline-block; margin-right: 10px; } | |
.btn-back { background-color: #f44336; } | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>Processing Result</h1> | |
<img src="{{ url_for('display_image', filename=filename) }}" class="result-img"> | |
<div> | |
<a href="{{ url_for('download_file', filename=filename) }}" class="btn">Download</a> | |
<a href="{{ url_for('upload_file') }}" class="btn btn-back">Back</a> | |
</div> | |
</div> | |
</body> | |
</html>""" | |
# Write the template files | |
with open(os.path.join("templates", "index.html"), "w") as f: | |
f.write(index_html) | |
with open(os.path.join("templates", "result.html"), "w") as f: | |
f.write(result_html) | |
# Create the template files before running the app | |
create_templates() | |
# Run the Flask app | |
if __name__ == "__main__": | |
app.run(debug=True, host="0.0.0.0") |