Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import tensorflow as tf | |
from flask import Flask, request, render_template, jsonify | |
from tensorflow.keras.utils import load_img, img_to_array | |
from werkzeug.utils import secure_filename | |
from datetime import datetime | |
from huggingface_hub import hf_hub_download # Crucial for downloading model from HF Hub | |
import time # Used for potential retry logic, though not explicitly in hf_hub_download here | |
app = Flask(__name__) | |
# --- Model Loading Configuration --- | |
MODEL_FILE_NAME = "model.keras" | |
# IMPORTANT: Replace "YOUR_USERNAME/garbage-detection-model" with the actual REPO ID of YOUR MODEL on Hugging Face Hub. | |
# This is the repository where your 'model.keras' file is stored. | |
# Example: "your_huggingface_username/your_model_repo_name" | |
MODEL_REPO_ID = "nonamelife/garbage-detection-model" # <--- MAKE SURE THIS IS YOUR CORRECT MODEL REPO ID! | |
model = None # Initialize model as None | |
# --- Model Loading Logic --- | |
try: | |
print(f"Attempting to download '{MODEL_FILE_NAME}' from Hugging Face Hub ({MODEL_REPO_ID})...") | |
# hf_hub_download returns the FULL PATH to the downloaded file in the cache. | |
# We specify local_dir within /app to ensure write permissions and consistency. | |
# local_dir_use_symlinks=False is important for Docker environments to avoid symlink issues. | |
downloaded_model_path = hf_hub_download( | |
repo_id=MODEL_REPO_ID, | |
filename=MODEL_FILE_NAME, | |
local_dir="/app/.cache/huggingface/models", # This directory will be created if it doesn't exist | |
local_dir_use_symlinks=False | |
) | |
print(f"'{MODEL_FILE_NAME}' downloaded successfully to: {downloaded_model_path}") | |
# Now, load the model directly from the downloaded path. | |
# This check is a safeguard, as hf_hub_download should ensure existence. | |
if os.path.exists(downloaded_model_path): | |
model = tf.keras.models.load_model(downloaded_model_path) | |
print("Model loaded successfully!") | |
else: | |
# This message indicates a very unusual state if download was reported successful. | |
print(f"ERROR: Download reported success, but file not found at expected path: {downloaded_model_path}") | |
except Exception as e: | |
# Catch any exceptions during download or loading and log them. | |
print(f"FATAL: Could not download or load model from Hugging Face Hub: {e}") | |
model = None # Ensure model remains None if there's an error | |
# --- End Model Loading Logic --- | |
# Configurations for Flask app | |
UPLOAD_FOLDER = os.path.join('static', 'uploads') | |
ALLOWED_EXTENSIONS = {'jpg', 'jpeg', 'png'} | |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
# Ensure the uploads directory exists within the container | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
def allowed_file(filename): | |
"""Checks if the uploaded file has an allowed extension.""" | |
return '.' in filename and \ | |
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
def preprocess_image(image_path): | |
"""Loads and preprocesses an image for model prediction.""" | |
img = load_img(image_path, target_size=(224, 224)) # Load image, resize to 224x224 | |
img_array = img_to_array(img) / 255.0 # Convert to array and normalize pixel values | |
return np.expand_dims(img_array, axis=0) # Add batch dimension for model input | |
# --- Flask Routes --- | |
def index(): | |
"""Renders the home page.""" | |
return render_template('home.html') | |
def tool(): | |
"""Renders the image upload tool page.""" | |
return render_template('tool.html') | |
def about(): | |
"""Renders the about page.""" | |
return render_template('about.html') | |
def contact(): | |
"""Renders the contact page.""" | |
return render_template('contact.html') | |
def predict(): | |
"""Handles image uploads and returns predictions.""" | |
# Check if the model was loaded successfully at startup | |
if model is None: | |
return jsonify({'error': 'Model not loaded. Please check server logs.'}), 500 | |
# Check if a file was part of the request | |
if 'file' not in request.files: | |
return jsonify({'error': 'No files uploaded'}), 400 | |
files = request.files.getlist('file') | |
if not files or all(f.filename == '' for f in files): | |
return jsonify({'error': 'No files selected'}), 400 | |
results = [] | |
for file in files: | |
file_path = None | |
if file and allowed_file(file.filename): | |
# Secure filename and create a unique name to prevent collisions | |
filename = secure_filename(file.filename) | |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") | |
unique_filename = f"{timestamp}_{filename}" | |
file_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename) | |
file.save(file_path) # Save the uploaded file temporarily | |
try: | |
img_array = preprocess_image(file_path) # Preprocess the image | |
prediction = model.predict(img_array)[0][0] # Get prediction from the model | |
# Determine label and confidence based on sigmoid output | |
label = "Dirty" if prediction > 0.5 else "Clean" | |
confidence = prediction if label == "Dirty" else 1 - prediction | |
results.append({ | |
'label': label, | |
'confidence': f"{confidence:.2%}", # Format confidence as percentage | |
'image_url': f"/static/uploads/{unique_filename}" # URL for displaying the image | |
}) | |
except Exception as e: | |
# Catch any errors during prediction or processing | |
results.append({ | |
'label': 'Error', | |
'confidence': 'N/A', | |
'image_url': None, | |
'error': str(e) | |
}) | |
else: | |
# Handle invalid file types | |
results.append({ | |
'label': 'Error', | |
'confidence': 'N/A', | |
'image_url': None, | |
'error': f"Invalid file type: {file.filename}" | |
}) | |
# Render the results page with predictions | |
return render_template('results.html', results=results) | |
# --- Main execution block --- | |
if __name__ == '__main__': | |
# Hugging Face Spaces sets the PORT environment variable for the app to listen on. | |
# We default to 7860 as it's common for HF Spaces apps. | |
# Debug mode should be OFF for production deployments (like Hugging Face Spaces) for security. | |
port = int(os.environ.get('PORT', 7860)) | |
app.run(host='0.0.0.0', port=port, debug=False) |