ImageCompress / app.py
SAVAI123's picture
Update app.py
0612339 verified
import os
import torch
import numpy as np
import PIL.Image
import subprocess
import time
from flask import Flask, request, render_template, send_file, url_for
from werkzeug.utils import secure_filename
# Set up Flask app
app = Flask(__name__)
# Define folders
UPLOAD_FOLDER = "uploads"
RESULT_FOLDER = "results"
MODELS_FOLDER = "models"
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
app.config["RESULT_FOLDER"] = RESULT_FOLDER
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULT_FOLDER, exist_ok=True)
os.makedirs(MODELS_FOLDER, exist_ok=True)
# Function to download the StyleGAN3 model if it doesn't exist
def download_stylegan3_model():
network_pkl = os.path.join(MODELS_FOLDER, "stylegan3-r-ffhq-1024x1024.pkl")
# Check if model exists
if not os.path.exists(network_pkl):
print(f"StyleGAN3 model not found. Downloading to {network_pkl}...")
try:
# Download using subprocess for better feedback
result = subprocess.run([
"wget",
"https://nvlabs-fi-cdn.nvidia.com/stylegan3/pretrained/stylegan3-r-ffhq-1024x1024.pkl",
"-P", MODELS_FOLDER
], capture_output=True, text=True)
if result.returncode != 0:
print(f"wget failed: {result.stderr}")
print("Trying alternative download method with requests...")
# Fallback to requests if wget fails
import requests
response = requests.get("https://nvlabs-fi-cdn.nvidia.com/stylegan3/pretrained/stylegan3-r-ffhq-1024x1024.pkl", stream=True)
response.raise_for_status()
with open(network_pkl, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("Download completed successfully using requests.")
else:
print("Download completed successfully using wget.")
return network_pkl
except Exception as e:
print(f"Error downloading StyleGAN3 model: {e}")
return None
else:
print(f"StyleGAN3 model already exists at {network_pkl}")
return network_pkl
# Create a simple dummy age direction vector if it doesn't exist
def create_dummy_age_direction():
age_direction_path = os.path.join(MODELS_FOLDER, "age_direction.pt")
if not os.path.exists(age_direction_path):
print("Creating a dummy age direction vector...")
try:
# Create a simple random vector as a placeholder
# In a real application, this would be a properly trained vector
dummy_vector = torch.randn(1, 512) # Assuming 512-dimensional latent space
torch.save(dummy_vector, age_direction_path)
print(f"Created dummy age direction vector at {age_direction_path}")
except Exception as e:
print(f"Error creating dummy age direction vector: {e}")
return age_direction_path
# Load StyleGAN3 Model
def load_stylegan3_model():
try:
# First ensure the model file exists
network_pkl = download_stylegan3_model()
if not network_pkl:
return None
# Make sure the age direction vector exists
create_dummy_age_direction()
# Import the legacy module for StyleGAN3
import sys
if not os.path.exists("legacy"):
print("Warning: 'legacy' module not found in current directory.")
print("StyleGAN3 requires this module to load models.")
print("You may need to clone the StyleGAN3 repository to get this module.")
return None
# Add current directory to path to find the legacy module
if "" not in sys.path:
sys.path.append("")
import legacy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Loading networks from "{network_pkl}" using device {device}...')
with open(network_pkl, "rb") as f:
G = legacy.load_network_pkl(f)["G_ema"].to(device)
print("StyleGAN3 model loaded successfully!")
return G
except ImportError as e:
print(f"Import error: {e}")
print("Make sure you have the required modules for StyleGAN3.")
return None
except Exception as e:
print(f"Error loading StyleGAN3 model: {e}")
return None
# Attempt to load the model
G = load_stylegan3_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_loaded = G is not None
# Function to encode an image into latent space
def image_to_latent(image_path):
# Note: This is a simplified version. Actual image encoding to latent space
# requires techniques like optimization or encoder networks
if G is not None:
latent_vector = torch.randn(1, G.z_dim, device=device) # Generate random latent vector
else:
latent_vector = torch.randn(1, 512, device=device) # Assuming 512-dimensional latent space
return latent_vector
# Function to modify latent code to make the face look younger
def modify_age(latent_vector, age_factor=-2.0):
try:
age_direction_path = os.path.join(MODELS_FOLDER, "age_direction.pt")
age_direction = torch.load(age_direction_path).to(device) # Load precomputed age direction
new_latent_vector = latent_vector + age_factor * age_direction
return new_latent_vector
except Exception as e:
print(f"Error modifying age: {e}")
return latent_vector # Return original if error
# Function to generate an image from a latent code
def generate_image(latent_vector):
try:
if G is None:
# If model isn't loaded, return a placeholder image
return PIL.Image.new('RGB', (1024, 1024), color=(255, 255, 255))
img = G.synthesis(latent_vector, noise_mode="const")
img = (img + 1) * (255 / 2)
img = img.permute(0, 2, 3, 1).cpu().numpy()[0].astype(np.uint8)
return PIL.Image.fromarray(img)
except Exception as e:
print(f"Error generating image: {e}")
# Return a blank image if there's an error
return PIL.Image.new('RGB', (1024, 1024), color=(255, 255, 255))
# Flask Routes
@app.route("/", methods=["GET", "POST"])
def upload_file():
error_message = None
model_status = "Model loaded successfully" if model_loaded else "Model could not be loaded. See server logs for details."
if request.method == "POST":
if "file" not in request.files:
return render_template("index.html", error="No file uploaded", model_status=model_status)
file = request.files["file"]
if file.filename == "":
return render_template("index.html", error="No selected file", model_status=model_status)
if not model_loaded:
return render_template("index.html", error="StyleGAN3 model is not loaded. Cannot process images.", model_status=model_status)
try:
filename = secure_filename(file.filename)
input_path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
file.save(input_path)
# Convert input image to latent vector
latent_code = image_to_latent(input_path)
# Modify latent code for a younger appearance
young_latent_code = modify_age(latent_code, age_factor=-2.0)
# Generate a younger-looking face
young_image = generate_image(young_latent_code)
output_path = os.path.join(app.config["RESULT_FOLDER"], "young_" + filename)
young_image.save(output_path)
return render_template("result.html", filename="young_" + filename)
except Exception as e:
error_message = f"Error processing image: {str(e)}"
return render_template("index.html", error=error_message, model_status=model_status)
return render_template("index.html", error=error_message, model_status=model_status)
@app.route("/results/<filename>")
def display_image(filename):
return send_file(os.path.join(app.config["RESULT_FOLDER"], filename))
@app.route("/download/<filename>")
def download_file(filename):
return send_file(os.path.join(app.config["RESULT_FOLDER"], filename), as_attachment=True)
# Create templates directory and files if they don't exist
def create_templates():
os.makedirs("templates", exist_ok=True)
# Create index.html
index_html = """<!DOCTYPE html>
<html>
<head>
<title>Image Age Reduction</title>
<style>
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
.container { border: 1px solid #ddd; padding: 20px; border-radius: 5px; }
.error { color: red; margin-bottom: 15px; }
.status { color: blue; margin-bottom: 15px; }
.form-group { margin-bottom: 15px; }
.btn { background-color: #4CAF50; color: white; padding: 10px 15px; border: none; cursor: pointer; }
</style>
</head>
<body>
<div class="container">
<h1>Image Age Reduction</h1>
{% if error %}
<div class="error">{{ error }}</div>
{% endif %}
{% if model_status %}
<div class="status">Model Status: {{ model_status }}</div>
{% endif %}
<form method="post" enctype="multipart/form-data">
<div class="form-group">
<label for="file">Select an image:</label>
<input type="file" name="file" id="file" accept="image/*" required>
</div>
<button type="submit" class="btn">Process Image</button>
</form>
</div>
</body>
</html>"""
# Create result.html
result_html = """<!DOCTYPE html>
<html>
<head>
<title>Processing Result</title>
<style>
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
.container { border: 1px solid #ddd; padding: 20px; border-radius: 5px; }
.result-img { max-width: 100%; margin: 20px 0; }
.btn { background-color: #4CAF50; color: white; padding: 10px 15px; border: none; cursor: pointer; text-decoration: none; display: inline-block; margin-right: 10px; }
.btn-back { background-color: #f44336; }
</style>
</head>
<body>
<div class="container">
<h1>Processing Result</h1>
<img src="{{ url_for('display_image', filename=filename) }}" class="result-img">
<div>
<a href="{{ url_for('download_file', filename=filename) }}" class="btn">Download</a>
<a href="{{ url_for('upload_file') }}" class="btn btn-back">Back</a>
</div>
</div>
</body>
</html>"""
# Write the template files
with open(os.path.join("templates", "index.html"), "w") as f:
f.write(index_html)
with open(os.path.join("templates", "result.html"), "w") as f:
f.write(result_html)
# Create the template files before running the app
create_templates()
# Run the Flask app
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0")