PSNbst commited on
Commit
19860f0
·
verified ·
1 Parent(s): 52cb57a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -9,6 +9,9 @@ import matplotlib.pyplot as plt
9
  import numpy as np
10
  from openai import OpenAI
11
  from huggingface_hub import hf_hub_download
 
 
 
12
 
13
  # 初始化模型
14
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
@@ -16,13 +19,29 @@ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
17
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
18
  sam_checkpoint = hf_hub_download(
19
- repo_id="facebook/sam", filename="sam_vit_h_4b8939.pth", use_auth_token=False
 
 
20
  )
21
  sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
22
  sam_predictor = SamPredictor(sam)
23
- yolo_model = YOLO("yolov8x.pt") # 替换为实际 YOLO 模型路径
24
- wd_processor = AutoProcessor.from_pretrained("SmilingWolf/wd-v1-4-vit-large-tagger")
25
- wd_model = AutoModelForImageClassification.from_pretrained("SmilingWolf/wd-v1-4-vit-large-tagger")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # 自动识别图片类型
28
  def classify_image_type(image):
 
9
  import numpy as np
10
  from openai import OpenAI
11
  from huggingface_hub import hf_hub_download
12
+ from segment_anything import SamPredictor, sam_model_registry
13
+ from yolo_world.models.detectors import build_detector
14
+ from mmcv import Config
15
 
16
  # 初始化模型
17
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
 
19
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
20
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
21
  sam_checkpoint = hf_hub_download(
22
+ repo_id="facebook/sam-vit-large", # 仓库 ID
23
+ filename="model.safetensors", # 模型文件名
24
+ use_auth_token=False # 公共仓库无需身份验证
25
  )
26
  sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
27
  sam_predictor = SamPredictor(sam)
28
+ # Hugging Face 下载 YOLO-World 权重
29
+ yolo_checkpoint = hf_hub_download(
30
+ repo_id="stevengrove/YOLO-World", # Hugging Face 仓库 ID
31
+ filename="yolo_world_v2_xl_obj365v1_goldg_cc3mlite_pretrain.pth", # 模型权重文件名
32
+ use_auth_token=False # 公共仓库无需身份验证
33
+ )
34
+ # 加载 YOLO-World 配置文件
35
+ yolo_config = Config.fromfile('path/to/yolo_world_config.py') # 替换为实际配置文件路径
36
+ # 构建 YOLO-World 模型
37
+ yolo_model = build_detector(yolo_config.model)
38
+ # 加载权重到模型
39
+ checkpoint = torch.load(yolo_checkpoint, map_location="cpu") # 使用 CPU 加载权重,后续可以转移到 GPU
40
+ yolo_model.load_state_dict(checkpoint["state_dict"])
41
+ yolo_model.eval() # 设置为评估模式
42
+
43
+ wd_processor = AutoProcessor.from_pretrained("SmilingWolf/wd-vit-tagger-v3")
44
+ wd_model = AutoModelForImageClassification.from_pretrained("SmilingWolf/wd-vit-tagger-v3")
45
 
46
  # 自动识别图片类型
47
  def classify_image_type(image):