csuer commited on
Commit
48f08ee
·
verified ·
1 Parent(s): 38e331e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import torch.nn as nn
4
  from torchvision import transforms
5
  from torchvision.models import resnet18, ResNet18_Weights
6
  from PIL import Image
 
7
 
8
  # number convert to label
9
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
@@ -58,6 +59,17 @@ def predict(image_path):
58
  result = {labels[i]: float(prediction[i]) for i in range(5)}
59
  return result
60
 
 
 
 
 
 
 
 
 
 
 
 
61
  # 美化 Gradio 界面
62
  css = """
63
  .gradio-container {
@@ -92,7 +104,7 @@ css = """
92
  footer {visibility: hidden}
93
  """
94
 
95
- with gr.Blocks(theme="glass", css=css) as demo:
96
  gr.Markdown("# 🌟 NSFW 图片分类器")
97
  gr.Markdown(description)
98
 
@@ -112,5 +124,8 @@ with gr.Blocks(theme="glass", css=css) as demo:
112
 
113
  submit_btn.click(predict, inputs=image_input, outputs=output_label)
114
 
 
 
 
115
  # 启动 Web 界面
116
  demo.launch()
 
4
  from torchvision import transforms
5
  from torchvision.models import resnet18, ResNet18_Weights
6
  from PIL import Image
7
+ from io import BytesIO
8
 
9
  # number convert to label
10
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
 
59
  result = {labels[i]: float(prediction[i]) for i in range(5)}
60
  return result
61
 
62
+ # API 预测接口
63
+ def api_predict(image: gr.Image) -> dict:
64
+ img = Image.open(BytesIO(image)).convert("RGB")
65
+ img = preprocess(img).unsqueeze(0)
66
+
67
+ with torch.no_grad():
68
+ prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
69
+
70
+ result = {labels[i]: float(prediction[i]) for i in range(5)}
71
+ return result
72
+
73
  # 美化 Gradio 界面
74
  css = """
75
  .gradio-container {
 
104
  footer {visibility: hidden}
105
  """
106
 
107
+ with gr.Blocks(theme="huggingface", css=css) as demo:
108
  gr.Markdown("# 🌟 NSFW 图片分类器")
109
  gr.Markdown(description)
110
 
 
124
 
125
  submit_btn.click(predict, inputs=image_input, outputs=output_label)
126
 
127
+ # 启动 API 接口
128
+ gr.Interface(fn=api_predict, inputs=gr.Image(), outputs=gr.JSON(), api=True)
129
+
130
  # 启动 Web 界面
131
  demo.launch()