khang119966 commited on
Commit
3b9285d
·
verified ·
1 Parent(s): 0ea2c0e

Delete demo.py

Browse files
Files changed (1) hide show
  1. demo.py +0 -98
demo.py DELETED
@@ -1,98 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- from PIL import Image
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
-
7
- import cv2
8
- try:
9
- from mmengine.visualization import Visualizer
10
- except ImportError:
11
- Visualizer = None
12
- print("Warning: mmengine is not installed, visualization is disabled.")
13
-
14
-
15
- def parse_args():
16
- parser = argparse.ArgumentParser(description='Video Reasoning Segmentation')
17
- parser.add_argument('image_folder', help='Path to image file')
18
- parser.add_argument('--model_path', default="ByteDance/Sa2VA-8B")
19
- parser.add_argument('--work-dir', default=None, help='The dir to save results.')
20
- parser.add_argument('--text', type=str, default="<image>Please describe the video content.")
21
- parser.add_argument('--select', type=int, default=-1)
22
- args = parser.parse_args()
23
- return args
24
-
25
-
26
- def visualize(pred_mask, image_path, work_dir):
27
- visualizer = Visualizer()
28
- img = cv2.imread(image_path)
29
- visualizer.set_image(img)
30
- visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
31
- visual_result = visualizer.get_image()
32
-
33
- output_path = os.path.join(work_dir, os.path.basename(image_path))
34
- cv2.imwrite(output_path, visual_result)
35
-
36
- if __name__ == "__main__":
37
- cfg = parse_args()
38
- model_path = cfg.model_path
39
- model = AutoModelForCausalLM.from_pretrained(
40
- model_path,
41
- torch_dtype="auto",
42
- device_map="auto",
43
- trust_remote_code=True
44
- )
45
-
46
- tokenizer = AutoTokenizer.from_pretrained(
47
- model_path,
48
- trust_remote_code=True
49
- )
50
-
51
- image_files = []
52
- image_paths = []
53
- image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}
54
- for filename in sorted(list(os.listdir(cfg.image_folder))):
55
- if os.path.splitext(filename)[1].lower() in image_extensions:
56
- image_files.append(filename)
57
- image_paths.append(os.path.join(cfg.image_folder, filename))
58
-
59
- vid_frames = []
60
- for img_path in image_paths:
61
- img = Image.open(img_path).convert('RGB')
62
- vid_frames.append(img)
63
-
64
-
65
- if cfg.select > 0:
66
- img_frame = vid_frames[cfg.select - 1]
67
-
68
- print(f"Selected frame {cfg.select}")
69
- print(f"The input is:\n{cfg.text}")
70
- result = model.predict_forward(
71
- image=img_frame,
72
- text=cfg.text,
73
- tokenizer=tokenizer,
74
- )
75
- else:
76
- print(f"The input is:\n{cfg.text}")
77
- result = model.predict_forward(
78
- video=vid_frames,
79
- text=cfg.text,
80
- tokenizer=tokenizer,
81
- )
82
-
83
- prediction = result['prediction']
84
- print(f"The output is:\n{prediction}")
85
-
86
- if '[SEG]' in prediction and Visualizer is not None:
87
- _seg_idx = 0
88
- pred_masks = result['prediction_masks'][_seg_idx]
89
- for frame_idx in range(len(vid_frames)):
90
- pred_mask = pred_masks[frame_idx]
91
- if cfg.work_dir:
92
- os.makedirs(cfg.work_dir, exist_ok=True)
93
- visualize(pred_mask, image_paths[frame_idx], cfg.work_dir)
94
- else:
95
- os.makedirs('./temp_visualize_results', exist_ok=True)
96
- visualize(pred_mask, image_paths[frame_idx], './temp_visualize_results')
97
- else:
98
- pass