|
import gradio as gr |
|
import torch |
|
import threading |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from torchvision.models import resnet18, ResNet18_Weights |
|
from PIL import Image |
|
|
|
|
|
labels = ["drawings", "hentai", "neutral", "porn", "sexy"] |
|
description = f"""This is a demo of classifing nsfw pictures. Label division is based on the following: |
|
[*https://github.com/alex000kim/nsfw_data_scraper*](https://github.com/alex000kim/nsfw_data_scraper). |
|
|
|
You can continue to train this model with the same preprocess-to-images. |
|
Finally, welcome to star my [*github repository*](https://github.com/csuer411/nsfw_classify)""" |
|
|
|
class Classifier(nn.Module): |
|
def __init__(self): |
|
super(Classifier, self).__init__() |
|
self.cnn_layers = resnet18(weights=ResNet18_Weights) |
|
self.fc_layers = nn.Sequential( |
|
nn.Linear(1000, 512), |
|
nn.Dropout(0.3), |
|
nn.Linear(512, 128), |
|
nn.ReLU(), |
|
nn.Linear(128, 5), |
|
) |
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.cnn_layers(x) |
|
x = self.fc_layers(x) |
|
return x |
|
|
|
|
|
|
|
preprocess = transforms.Compose( |
|
[ |
|
transforms.Resize(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
model = Classifier() |
|
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu")) |
|
model.eval() |
|
|
|
|
|
def predict(inp): |
|
inp = preprocess(inp).unsqueeze(0) |
|
with torch.no_grad(): |
|
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0) |
|
result = {labels[i]: float(prediction[i]) for i in range(5)} |
|
return result |
|
|
|
|
|
inputs = gr.Image(type='pil') |
|
outputs = gr.Label(num_top_classes=2) |
|
gr.Interface( |
|
fn=predict, inputs=inputs, outputs=outputs, examples=["./example/anime.jpg", "./example/real.jpg"], description=description, |
|
).launch() |
|
|