File size: 3,608 Bytes
477d55b
 
 
 
 
 
 
 
 
 
 
 
 
 
79182ed
 
 
 
 
 
 
 
 
 
 
477d55b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79182ed
477d55b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import torch
import cv2
import numpy as np
import segmentation_models_pytorch as smp
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch import ToTensorV2
from transformers import pipeline

# Constants
MODEL_PATH = "Acnes_model.pth"
MASK_OPACITY = 0.9
DEVICE = torch.device("cpu")

# ----------------- Rebuild Model & Load Weights -----------------
# Recreate your model (same as in training)
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=3,
    classes=1
)

# Load weights only (safe)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()

# ----------------- Classification Model -----------------
classification_pipe = pipeline("image-classification", model="imfarzanansari/skintelligent-acne")

# ----------------- Preprocessing -----------------
def preprocess_image(image, img_size=(256, 256)):
    transform = Compose([
        Resize(*img_size),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    augmented = transform(image=image)
    tensor_image = augmented["image"].unsqueeze(0)
    return tensor_image, image

# ----------------- Inference -----------------
def predict_mask(model, image_tensor):
    with torch.no_grad():
        image_tensor = image_tensor.to(DEVICE)
        output = model(image_tensor)
        mask = torch.sigmoid(output)
        return mask.squeeze().cpu().numpy()

# ----------------- Overlay -----------------
def overlay_mask(image, mask, color=(255, 0, 0), alpha=MASK_OPACITY):
    mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
    overlay = image.copy()
    mask_colored = np.zeros_like(image, dtype=np.uint8)
    mask_colored[mask_resized > 0.2] = color
    blended = cv2.addWeighted(overlay, 1, mask_colored, alpha, 0)
    return blended

# ----------------- Severity Mapping -----------------
def map_classification_label_to_level(label):
    levels = {
        'level -1': "Level -1: Clear Skin",
        'level 0': "Level 0: Occasional Spots",
        'level 1': "Level 1: Mild Acne",
        'level 2': "Level 2: Moderate Acne",
        'level 3': "Level 3: Severe Acne",
        'level 4': "Level 4: Very Severe Acne"
    }
    return levels.get(label, "Unknown")

# ----------------- Combined Prediction -----------------
def predict(image):
    input_tensor, original_image = preprocess_image(image)
    predicted_mask = predict_mask(model, input_tensor)
    overlayed_image = overlay_mask(original_image, predicted_mask, color=(255, 0, 0), alpha=MASK_OPACITY)

    temp_path = "/tmp/temp_image.jpg"
    cv2.imwrite(temp_path, cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR))

    classification_result = classification_pipe(temp_path)
    predicted_label = max(classification_result, key=lambda x: x['score'])['label']
    confidence = max(classification_result, key=lambda x: x['score'])['score']
    severity = map_classification_label_to_level(predicted_label)

    return overlayed_image, f"{severity}\nConfidence: {confidence:.2f}"

# ----------------- Gradio UI -----------------
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="numpy", label="Upload Face Image"),
    outputs=[
        gr.Image(label="Segmentation Overlay"),
        gr.Text(label="Acne Severity Prediction")
    ],
    title="🧼 Acne Segmentation & Severity Classification",
    description="Upload a facial image to detect acne regions and predict severity level using UNet and a pretrained classifier."
)

demo.launch()