import gradio as gr import torch import cv2 import numpy as np import segmentation_models_pytorch as smp from albumentations import Compose, Normalize, Resize from albumentations.pytorch import ToTensorV2 from transformers import pipeline # Constants MODEL_PATH = "Acnes_model.pth" MASK_OPACITY = 0.9 DEVICE = torch.device("cpu") # ----------------- Rebuild Model & Load Weights ----------------- # Recreate your model (same as in training) model = smp.Unet( encoder_name="resnet34", encoder_weights=None, in_channels=3, classes=1 ) # Load weights only (safe) model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model.to(DEVICE) model.eval() # ----------------- Classification Model ----------------- classification_pipe = pipeline("image-classification", model="imfarzanansari/skintelligent-acne") # ----------------- Preprocessing ----------------- def preprocess_image(image, img_size=(256, 256)): transform = Compose([ Resize(*img_size), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]) augmented = transform(image=image) tensor_image = augmented["image"].unsqueeze(0) return tensor_image, image # ----------------- Inference ----------------- def predict_mask(model, image_tensor): with torch.no_grad(): image_tensor = image_tensor.to(DEVICE) output = model(image_tensor) mask = torch.sigmoid(output) return mask.squeeze().cpu().numpy() # ----------------- Overlay ----------------- def overlay_mask(image, mask, color=(255, 0, 0), alpha=MASK_OPACITY): mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST) overlay = image.copy() mask_colored = np.zeros_like(image, dtype=np.uint8) mask_colored[mask_resized > 0.2] = color blended = cv2.addWeighted(overlay, 1, mask_colored, alpha, 0) return blended # ----------------- Severity Mapping ----------------- def map_classification_label_to_level(label): levels = { 'level -1': "Level -1: Clear Skin", 'level 0': "Level 0: Occasional Spots", 'level 1': "Level 1: Mild Acne", 'level 2': "Level 2: Moderate Acne", 'level 3': "Level 3: Severe Acne", 'level 4': "Level 4: Very Severe Acne" } return levels.get(label, "Unknown") # ----------------- Combined Prediction ----------------- def predict(image): input_tensor, original_image = preprocess_image(image) predicted_mask = predict_mask(model, input_tensor) overlayed_image = overlay_mask(original_image, predicted_mask, color=(255, 0, 0), alpha=MASK_OPACITY) temp_path = "/tmp/temp_image.jpg" cv2.imwrite(temp_path, cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)) classification_result = classification_pipe(temp_path) predicted_label = max(classification_result, key=lambda x: x['score'])['label'] confidence = max(classification_result, key=lambda x: x['score'])['score'] severity = map_classification_label_to_level(predicted_label) return overlayed_image, f"{severity}\nConfidence: {confidence:.2f}" # ----------------- Gradio UI ----------------- demo = gr.Interface( fn=predict, inputs=gr.Image(type="numpy", label="Upload Face Image"), outputs=[ gr.Image(label="Segmentation Overlay"), gr.Text(label="Acne Severity Prediction") ], title="🧼 Acne Segmentation & Severity Classification", description="Upload a facial image to detect acne regions and predict severity level using UNet and a pretrained classifier." ) demo.launch()