csuer commited on
Commit
88589c5
·
verified ·
1 Parent(s): 60d0c94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -18
app.py CHANGED
@@ -4,7 +4,6 @@ import torch.nn as nn
4
  from torchvision import transforms
5
  from torchvision.models import resnet18, ResNet18_Weights
6
  from PIL import Image
7
- from io import BytesIO
8
 
9
  # number convert to label
10
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
@@ -59,17 +58,6 @@ def predict(image_path):
59
  result = {labels[i]: float(prediction[i]) for i in range(5)}
60
  return result
61
 
62
- # API 预测接口
63
- def api_predict(image: gr.Image) -> dict:
64
- img = Image.open(BytesIO(image)).convert("RGB")
65
- img = preprocess(img).unsqueeze(0)
66
-
67
- with torch.no_grad():
68
- prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
69
-
70
- result = {labels[i]: float(prediction[i]) for i in range(5)}
71
- return result
72
-
73
  # 美化 Gradio 界面
74
  css = """
75
  .gradio-container {
@@ -104,8 +92,7 @@ css = """
104
  footer {visibility: hidden}
105
  """
106
 
107
- # 使用 Blocks API 创建界面
108
- with gr.Blocks(css=css) as demo:
109
  gr.Markdown("# 🌟 NSFW 图片分类器")
110
  gr.Markdown(description)
111
 
@@ -123,10 +110,7 @@ with gr.Blocks(css=css) as demo:
123
  ]
124
  gr.Examples(examples=examples, inputs=image_input, label="点击上传示例图片")
125
 
126
- submit_btn.click(predict, inputs=image_input, outputs=output_label)
127
-
128
- # 手动为 API 设置路由
129
- demo.add_api_route("/api/predict", api_predict, methods=["POST"])
130
 
131
  # 启动 Web 界面
132
  demo.launch()
 
4
  from torchvision import transforms
5
  from torchvision.models import resnet18, ResNet18_Weights
6
  from PIL import Image
 
7
 
8
  # number convert to label
9
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
 
58
  result = {labels[i]: float(prediction[i]) for i in range(5)}
59
  return result
60
 
 
 
 
 
 
 
 
 
 
 
 
61
  # 美化 Gradio 界面
62
  css = """
63
  .gradio-container {
 
92
  footer {visibility: hidden}
93
  """
94
 
95
+ with gr.Blocks(theme="huggingface", css=css) as demo:
 
96
  gr.Markdown("# 🌟 NSFW 图片分类器")
97
  gr.Markdown(description)
98
 
 
110
  ]
111
  gr.Examples(examples=examples, inputs=image_input, label="点击上传示例图片")
112
 
113
+ submit_btn.click(predict, inputs=image_input, outputs=output_label, api_name="predict")
 
 
 
114
 
115
  # 启动 Web 界面
116
  demo.launch()