import gradio as gr import torch import threading import torch.nn as nn from torchvision import transforms from torchvision.models import resnet18, ResNet18_Weights from PIL import Image # number convert to label labels = ["drawings", "hentai", "neutral", "porn", "sexy"] description = f"""This is a demo of classifing nsfw pictures. Label division is based on the following: [*https://github.com/alex000kim/nsfw_data_scraper*](https://github.com/alex000kim/nsfw_data_scraper). You can continue to train this model with the same preprocess-to-images. Finally, welcome to star my [*github repository*](https://github.com/csuer411/nsfw_classify)""" # define CNN model class Classifier(nn.Module): def __init__(self): super(Classifier, self).__init__() self.cnn_layers = resnet18(weights=ResNet18_Weights) self.fc_layers = nn.Sequential( nn.Linear(1000, 512), nn.Dropout(0.3), nn.Linear(512, 128), nn.ReLU(), nn.Linear(128, 5), ) def forward(self, x): # Extract features by convolutional layers. x = self.cnn_layers(x) x = self.fc_layers(x) return x # pre-process preprocess = transforms.Compose( [ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) # load model model = Classifier() model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu")) model.eval() def predict(inp): inp = preprocess(inp).unsqueeze(0) with torch.no_grad(): prediction = torch.nn.functional.softmax(model(inp)[0], dim=0) result = {labels[i]: float(prediction[i]) for i in range(5)} return result inputs = gr.Image(type='pil') outputs = gr.Label(num_top_classes=2) gr.Interface( fn=predict, inputs=inputs, outputs=outputs, examples=["./example/anime.jpg", "./example/real.jpg"], description=description, ).launch()