Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,9 @@ import matplotlib.pyplot as plt
|
|
9 |
import numpy as np
|
10 |
from openai import OpenAI
|
11 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
12 |
|
13 |
# 初始化模型
|
14 |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
@@ -16,13 +19,29 @@ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
16 |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
17 |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
18 |
sam_checkpoint = hf_hub_download(
|
19 |
-
repo_id="facebook/sam",
|
|
|
|
|
20 |
)
|
21 |
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
|
22 |
sam_predictor = SamPredictor(sam)
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# 自动识别图片类型
|
28 |
def classify_image_type(image):
|
|
|
9 |
import numpy as np
|
10 |
from openai import OpenAI
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
+
from segment_anything import SamPredictor, sam_model_registry
|
13 |
+
from yolo_world.models.detectors import build_detector
|
14 |
+
from mmcv import Config
|
15 |
|
16 |
# 初始化模型
|
17 |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
19 |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
20 |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
21 |
sam_checkpoint = hf_hub_download(
|
22 |
+
repo_id="facebook/sam-vit-large", # 仓库 ID
|
23 |
+
filename="model.safetensors", # 模型文件名
|
24 |
+
use_auth_token=False # 公共仓库无需身份验证
|
25 |
)
|
26 |
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
|
27 |
sam_predictor = SamPredictor(sam)
|
28 |
+
# 从 Hugging Face 下载 YOLO-World 权重
|
29 |
+
yolo_checkpoint = hf_hub_download(
|
30 |
+
repo_id="stevengrove/YOLO-World", # Hugging Face 仓库 ID
|
31 |
+
filename="yolo_world_v2_xl_obj365v1_goldg_cc3mlite_pretrain.pth", # 模型权重文件名
|
32 |
+
use_auth_token=False # 公共仓库无需身份验证
|
33 |
+
)
|
34 |
+
# 加载 YOLO-World 配置文件
|
35 |
+
yolo_config = Config.fromfile('path/to/yolo_world_config.py') # 替换为实际配置文件路径
|
36 |
+
# 构建 YOLO-World 模型
|
37 |
+
yolo_model = build_detector(yolo_config.model)
|
38 |
+
# 加载权重到模型
|
39 |
+
checkpoint = torch.load(yolo_checkpoint, map_location="cpu") # 使用 CPU 加载权重,后续可以转移到 GPU
|
40 |
+
yolo_model.load_state_dict(checkpoint["state_dict"])
|
41 |
+
yolo_model.eval() # 设置为评估模式
|
42 |
+
|
43 |
+
wd_processor = AutoProcessor.from_pretrained("SmilingWolf/wd-vit-tagger-v3")
|
44 |
+
wd_model = AutoModelForImageClassification.from_pretrained("SmilingWolf/wd-vit-tagger-v3")
|
45 |
|
46 |
# 自动识别图片类型
|
47 |
def classify_image_type(image):
|