JosephZ commited on
Commit
c3445fa
·
verified ·
1 Parent(s): 4f4e1b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -18
app.py CHANGED
@@ -10,6 +10,7 @@ import torch
10
  from transformers import Qwen2VLForConditionalGeneration, GenerationConfig, AutoProcessor
11
  import spaces
12
 
 
13
 
14
  def extract_answer_content(text: str) -> str:
15
  """
@@ -62,6 +63,10 @@ SYSTEM_PROMPT = (
62
  processor = AutoProcessor.from_pretrained("JosephZ/qwen2vl-7b-sft-grpo-close-sgg", max_pixels=1024*28*28)
63
 
64
  device='cuda' if torch.cuda.is_available() else "cpu"
 
 
 
 
65
  model = Qwen2VLForConditionalGeneration.from_pretrained("JosephZ/qwen2vl-7b-sft-grpo-close-sgg",
66
  torch_dtype=torch.bfloat16,
67
  device_map=device)
@@ -75,9 +80,25 @@ generation_config=GenerationConfig(
75
  max_new_tokens=2048,
76
  use_cache=True
77
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def build_prompt(image, user_text):
80
- #base64_image = encode_image_to_base64(image)
81
  messages = [
82
  {
83
  "role": "system",
@@ -86,8 +107,8 @@ def build_prompt(image, user_text):
86
  {
87
  "role": "user",
88
  "content": [
89
- #{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
90
- {"type": "image"},
91
  {"type": "text", "text": user_text},
92
  ],
93
  },
@@ -161,24 +182,14 @@ def generate_sgg(image):
161
 
162
  iw, ih = image.size
163
  scale_factors = (iw / 1000.0, ih / 1000.0)
164
-
165
  conversation = build_prompt(image, PROMPT_CLOSE)
166
- text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
167
 
168
- inputs = processor(
169
- text=[text_prompt], images=[image], padding=True, return_tensors="pt"
170
- )
171
- inputs = inputs.to(model.device)
172
-
173
- output_ids = model.generate(**inputs, generation_config=generation_config)
174
- generated_ids = [
175
- output_ids[len(input_ids) :]
176
- for input_ids, output_ids in zip(inputs.input_ids, output_ids)
177
- ]
178
- output_text = processor.batch_decode(
179
- generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
180
- )[0]
181
 
 
182
  resp = extract_answer_content(output_text)
183
 
184
  try:
 
10
  from transformers import Qwen2VLForConditionalGeneration, GenerationConfig, AutoProcessor
11
  import spaces
12
 
13
+ from vllm import LLM, SamplingParams
14
 
15
  def extract_answer_content(text: str) -> str:
16
  """
 
63
  processor = AutoProcessor.from_pretrained("JosephZ/qwen2vl-7b-sft-grpo-close-sgg", max_pixels=1024*28*28)
64
 
65
  device='cuda' if torch.cuda.is_available() else "cpu"
66
+ model_name = "JosephZ/qwen2vl-7b-sft-grpo-close-sgg"
67
+
68
+
69
+ """
70
  model = Qwen2VLForConditionalGeneration.from_pretrained("JosephZ/qwen2vl-7b-sft-grpo-close-sgg",
71
  torch_dtype=torch.bfloat16,
72
  device_map=device)
 
80
  max_new_tokens=2048,
81
  use_cache=True
82
  )
83
+ """
84
+ model = LLM(
85
+ model=model_name,
86
+ limit_mm_per_prompt={"image": 1},
87
+ dtype='bfloat16',
88
+ device=device,
89
+ max_model_len=4096,
90
+ mm_processor_kwargs= { "max_pixels": 1024*28*28, "min_pixels": 4*28*28},
91
+ )
92
+ sampling_params = SamplingParams(
93
+ temperature=0.01,
94
+ top_k=1,
95
+ top_p=0.001,
96
+ repetition_penalty=1.0,
97
+ max_tokens=2048,
98
+ )
99
 
100
  def build_prompt(image, user_text):
101
+ base64_image = encode_image_to_base64(image)
102
  messages = [
103
  {
104
  "role": "system",
 
107
  {
108
  "role": "user",
109
  "content": [
110
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
111
+ # {"type": "image"},
112
  {"type": "text", "text": user_text},
113
  ],
114
  },
 
182
 
183
  iw, ih = image.size
184
  scale_factors = (iw / 1000.0, ih / 1000.0)
185
+
186
  conversation = build_prompt(image, PROMPT_CLOSE)
 
187
 
188
+ with torch.no_grad():
189
+ outputs = model.chat([conversation], sampling_params=sampling_params)
190
+ output_texts = [output.outputs[0].text for output in outputs]
 
 
 
 
 
 
 
 
 
 
191
 
192
+ output_text = output_texts[0]
193
  resp = extract_answer_content(output_text)
194
 
195
  try: