Acnes / app.py
Rihem02's picture
kk
79182ed
import gradio as gr
import torch
import cv2
import numpy as np
import segmentation_models_pytorch as smp
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch import ToTensorV2
from transformers import pipeline
# Constants
MODEL_PATH = "Acnes_model.pth"
MASK_OPACITY = 0.9
DEVICE = torch.device("cpu")
# ----------------- Rebuild Model & Load Weights -----------------
# Recreate your model (same as in training)
model = smp.Unet(
encoder_name="resnet34",
encoder_weights=None,
in_channels=3,
classes=1
)
# Load weights only (safe)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
# ----------------- Classification Model -----------------
classification_pipe = pipeline("image-classification", model="imfarzanansari/skintelligent-acne")
# ----------------- Preprocessing -----------------
def preprocess_image(image, img_size=(256, 256)):
transform = Compose([
Resize(*img_size),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
augmented = transform(image=image)
tensor_image = augmented["image"].unsqueeze(0)
return tensor_image, image
# ----------------- Inference -----------------
def predict_mask(model, image_tensor):
with torch.no_grad():
image_tensor = image_tensor.to(DEVICE)
output = model(image_tensor)
mask = torch.sigmoid(output)
return mask.squeeze().cpu().numpy()
# ----------------- Overlay -----------------
def overlay_mask(image, mask, color=(255, 0, 0), alpha=MASK_OPACITY):
mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
overlay = image.copy()
mask_colored = np.zeros_like(image, dtype=np.uint8)
mask_colored[mask_resized > 0.2] = color
blended = cv2.addWeighted(overlay, 1, mask_colored, alpha, 0)
return blended
# ----------------- Severity Mapping -----------------
def map_classification_label_to_level(label):
levels = {
'level -1': "Level -1: Clear Skin",
'level 0': "Level 0: Occasional Spots",
'level 1': "Level 1: Mild Acne",
'level 2': "Level 2: Moderate Acne",
'level 3': "Level 3: Severe Acne",
'level 4': "Level 4: Very Severe Acne"
}
return levels.get(label, "Unknown")
# ----------------- Combined Prediction -----------------
def predict(image):
input_tensor, original_image = preprocess_image(image)
predicted_mask = predict_mask(model, input_tensor)
overlayed_image = overlay_mask(original_image, predicted_mask, color=(255, 0, 0), alpha=MASK_OPACITY)
temp_path = "/tmp/temp_image.jpg"
cv2.imwrite(temp_path, cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR))
classification_result = classification_pipe(temp_path)
predicted_label = max(classification_result, key=lambda x: x['score'])['label']
confidence = max(classification_result, key=lambda x: x['score'])['score']
severity = map_classification_label_to_level(predicted_label)
return overlayed_image, f"{severity}\nConfidence: {confidence:.2f}"
# ----------------- Gradio UI -----------------
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Face Image"),
outputs=[
gr.Image(label="Segmentation Overlay"),
gr.Text(label="Acne Severity Prediction")
],
title="🧼 Acne Segmentation & Severity Classification",
description="Upload a facial image to detect acne regions and predict severity level using UNet and a pretrained classifier."
)
demo.launch()