andrejrad commited on
Commit
b989be2
·
verified ·
1 Parent(s): 8d3d460

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -43
app.py CHANGED
@@ -5,14 +5,17 @@ from PIL import Image
5
  import torch
6
  from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
7
 
8
- # -------- Env / params --------
9
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
- TEMP = 0.1
12
- MAX_NEW_TOKENS = 768 # safe for demo; raise if needed
 
 
 
13
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
14
 
15
- # -------- Prompts --------
16
  SYSTEM_PROMPT = (
17
  "You are an image annotation API trained to analyze YouTube video keyframes. "
18
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
@@ -27,39 +30,40 @@ Your job is to extract detailed **factual elements directly visible** in the ima
27
  Return JSON in this structure:
28
 
29
  {
30
- "description": "...",
31
- "objects": ["..."],
32
- "actions": ["..."],
33
- "environment": "...",
34
- "content_type": "...",
35
- "specific_style": "...",
36
- "production_quality": "...",
37
- "summary": "...",
38
- "logos": ["..."]
39
  }
40
 
41
  Rules:
42
  - Be specific and literal. Focus on what is explicitly visible.
43
  - Do NOT include interpretations of emotion, mood, or narrative unless it's visually explicit.
44
  - No artistic or cinematic analysis.
45
- - Always include the language of any text in the image if present as an object.
46
  - Maximum 10 objects and 5 actions.
47
- - Return [] for 'logos' if none are present.
48
- - Strictly valid JSON only.
 
 
49
  """
50
 
51
- # -------- Utils --------
52
  def extract_top_level_json(s: str):
 
53
  try:
54
  return json.loads(s)
55
  except Exception:
56
  pass
57
- start = None
58
- depth = 0
59
  for i, ch in enumerate(s):
60
  if ch == '{':
61
- if depth == 0:
62
- start = i
63
  depth += 1
64
  elif ch == '}':
65
  if depth > 0:
@@ -79,16 +83,16 @@ def build_messages(image):
79
  {"type": "text", "text": USER_PROMPT}]}
80
  ]
81
 
82
- def downscale_if_huge(pil: Image.Image, max_side: int = 1792) -> Image.Image:
83
  if pil is None: return pil
84
  w, h = pil.size
85
  m = max(w, h)
86
- if m <= max_side:
87
  return pil.convert("RGB")
88
- s = max_side / m
89
  return pil.convert("RGB").resize((int(w*s), int(h*s)), Image.BICUBIC)
90
 
91
- # -------- Load model --------
92
  processor = tokenizer = model = None
93
  LOAD_ERROR = None
94
  try:
@@ -100,10 +104,10 @@ try:
100
  model = AutoModelForCausalLM.from_pretrained(
101
  MODEL_ID,
102
  token=HF_TOKEN,
103
- device_map="auto",
104
  torch_dtype=DTYPE,
105
  trust_remote_code=True,
106
- # quantization_config=None, # uncomment if you want to force full precision
107
  )
108
  tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
109
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
@@ -111,58 +115,57 @@ try:
111
  except Exception as e:
112
  LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
113
 
114
- # -------- Inference --------
115
  def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
116
  if image is None:
117
  return "Please upload an image.", None, False
118
  if model is None or processor is None:
119
  return f"❌ Load error:\n{LOAD_ERROR}", None, False
120
 
121
- image = downscale_if_huge(image)
122
 
 
123
  if hasattr(processor, "apply_chat_template"):
124
  prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
125
  else:
126
  prompt = USER_PROMPT
127
 
128
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
 
129
 
130
  tried = []
131
- # (1) Greedy
 
132
  try:
133
  g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
134
- eos = getattr(model.config, "eos_token_id", None)
135
- if eos is not None:
136
- g["eos_token_id"] = eos
137
  with torch.inference_mode():
138
  out = model.generate(**inputs, **g)
139
  text = processor.decode(out[0], skip_special_tokens=True)
140
  parsed = extract_top_level_json(text)
141
- if isinstance(parsed, dict):
142
  return json.dumps(parsed, indent=2), parsed, True
143
- tried.append(("greedy", "parse-failed"))
144
  except Exception as e:
145
  tried.append(("greedy", f"err={e}"))
146
 
147
- # (2) Sampling
148
  try:
149
  g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
150
- eos = getattr(model.config, "eos_token_id", None)
151
- if eos is not None:
152
- g["eos_token_id"] = eos
153
  with torch.inference_mode():
154
  out = model.generate(**inputs, **g)
155
  text = processor.decode(out[0], skip_special_tokens=True)
156
  parsed = extract_top_level_json(text)
157
- if isinstance(parsed, dict):
158
  return json.dumps(parsed, indent=2), parsed, True
159
- tried.append(("sample", "parse-failed"))
160
  except Exception as e:
161
- tried.append(("sample", f"err={e}"))
162
 
163
  return "Generation failed.\n" + str(tried), None, False
164
 
165
- # -------- UI --------
166
  with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="ClipTagger (VLM)") as demo:
167
  gr.Markdown("# ClipTagger\nUpload an image to get **strict JSON** annotations.")
168
  if LOAD_ERROR:
 
5
  import torch
6
  from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
7
 
8
+ # ===== Env / params =====
9
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
+
12
+ # Latency/quality knobs (tuned for A100)
13
+ TEMP = 0.1 # per model docs
14
+ MAX_NEW_TOKENS = 384 # fast + sufficient for schema (raise to 512/768 later if needed)
15
+ VISION_LONG_SIDE = 896 # matches your vision_config.image_size
16
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
17
 
18
+ # ===== Prompts (exact, no example output) =====
19
  SYSTEM_PROMPT = (
20
  "You are an image annotation API trained to analyze YouTube video keyframes. "
21
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
 
30
  Return JSON in this structure:
31
 
32
  {
33
+ "description": "A detailed, factual account of what is visibly happening (4 sentences max). Only mention concrete elements or actions that are clearly shown. Do not include anything about how the image is styled, shot, or composed. Do not lead the description with something like 'This image shows' or 'this keyframe is...', just get right into the details.",
34
+ "objects": ["object1 with relevant visual details", "object2 with relevant visual details", ...],
35
+ "actions": ["action1 with participants and context", "action2 with participants and context", ...],
36
+ "environment": "Detailed factual description of the setting and atmosphere based on visible cues (e.g., interior of a classroom with fluorescent lighting, or outdoor forest path with snow-covered trees).",
37
+ "content_type": "The type of content it is, e.g., 'real-world footage', 'video game', 'animation', 'cartoon', 'CGI', 'VTuber', etc.",
38
+ "specific_style": "Specific genre, aesthetic, or platform style (e.g., anime, 3D animation, mobile gameplay, vlog, tutorial, news broadcast, etc.)",
39
+ "production_quality": "Visible production level: e.g., 'professional studio', 'amateur handheld', 'webcam recording', 'TV broadcast', etc.",
40
+ "summary": "One clear, comprehensive sentence summarizing the visual content of the frame. Like the description, get right to the point.",
41
+ "logos": ["logo1 with visual description", "logo2 with visual description", ...]
42
  }
43
 
44
  Rules:
45
  - Be specific and literal. Focus on what is explicitly visible.
46
  - Do NOT include interpretations of emotion, mood, or narrative unless it's visually explicit.
47
  - No artistic or cinematic analysis.
48
+ - Always include the language of any text in the image if present as an object, e.g., "English text", "Japanese text", "Russian text", etc.
49
  - Maximum 10 objects and 5 actions.
50
+ - Return an empty array for 'logos' if none are present.
51
+ - Always output strictly valid JSON with proper escaping.
52
+ - Output **only the JSON**, no extra text or explanation.
53
+ - Do not use placeholder strings or ellipses ('...'). Replace with concrete values directly observed in the image only.
54
  """
55
 
56
+ # ===== Utils =====
57
  def extract_top_level_json(s: str):
58
+ """Parse JSON; if there’s surrounding text, extract the first balanced {...} block."""
59
  try:
60
  return json.loads(s)
61
  except Exception:
62
  pass
63
+ start, depth = None, 0
 
64
  for i, ch in enumerate(s):
65
  if ch == '{':
66
+ if depth == 0: start = i
 
67
  depth += 1
68
  elif ch == '}':
69
  if depth > 0:
 
83
  {"type": "text", "text": USER_PROMPT}]}
84
  ]
85
 
86
+ def resize_to_vision(pil: Image.Image, long_side: int = VISION_LONG_SIDE) -> Image.Image:
87
  if pil is None: return pil
88
  w, h = pil.size
89
  m = max(w, h)
90
+ if m <= long_side:
91
  return pil.convert("RGB")
92
+ s = long_side / m
93
  return pil.convert("RGB").resize((int(w*s), int(h*s)), Image.BICUBIC)
94
 
95
+ # ===== Load model (A100) =====
96
  processor = tokenizer = model = None
97
  LOAD_ERROR = None
98
  try:
 
104
  model = AutoModelForCausalLM.from_pretrained(
105
  MODEL_ID,
106
  token=HF_TOKEN,
107
+ device_map="cuda", # keep on A100
108
  torch_dtype=DTYPE,
109
  trust_remote_code=True,
110
+ # quantization_config=None, # uncomment to force full precision if you removed quant
111
  )
112
  tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
113
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
 
115
  except Exception as e:
116
  LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
117
 
118
+ # ===== Inference =====
119
  def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
120
  if image is None:
121
  return "Please upload an image.", None, False
122
  if model is None or processor is None:
123
  return f"❌ Load error:\n{LOAD_ERROR}", None, False
124
 
125
+ image = resize_to_vision(image, VISION_LONG_SIDE)
126
 
127
+ # Chat prompt
128
  if hasattr(processor, "apply_chat_template"):
129
  prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
130
  else:
131
  prompt = USER_PROMPT
132
 
133
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
134
+ eos = getattr(model.config, "eos_token_id", None)
135
 
136
  tried = []
137
+
138
+ # (1) Greedy (fast, stable)
139
  try:
140
  g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
141
+ if eos is not None: g["eos_token_id"] = eos
 
 
142
  with torch.inference_mode():
143
  out = model.generate(**inputs, **g)
144
  text = processor.decode(out[0], skip_special_tokens=True)
145
  parsed = extract_top_level_json(text)
146
+ if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
147
  return json.dumps(parsed, indent=2), parsed, True
148
+ tried.append(("greedy", "parse-failed-or-ellipses"))
149
  except Exception as e:
150
  tried.append(("greedy", f"err={e}"))
151
 
152
+ # (2) Short sampled retry
153
  try:
154
  g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
155
+ if eos is not None: g["eos_token_id"] = eos
 
 
156
  with torch.inference_mode():
157
  out = model.generate(**inputs, **g)
158
  text = processor.decode(out[0], skip_special_tokens=True)
159
  parsed = extract_top_level_json(text)
160
+ if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
161
  return json.dumps(parsed, indent=2), parsed, True
162
+ tried.append(("sample_t0.1", "parse-failed-or-ellipses"))
163
  except Exception as e:
164
+ tried.append(("sample_t0.1", f"err={e}"))
165
 
166
  return "Generation failed.\n" + str(tried), None, False
167
 
168
+ # ===== UI =====
169
  with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="ClipTagger (VLM)") as demo:
170
  gr.Markdown("# ClipTagger\nUpload an image to get **strict JSON** annotations.")
171
  if LOAD_ERROR: