stzhao commited on
Commit
2092a85
·
verified ·
1 Parent(s): c6d11ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -5,9 +5,6 @@ import spaces
5
  from diffusers import Lumina2Pipeline
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
- # # Set up environment
9
- # os.environ['CUDA_VISIBLE_DEVICES'] = "0"
10
-
11
  if torch.cuda.is_available():
12
  torch_dtype = torch.bfloat16
13
  else:
@@ -29,12 +26,20 @@ def load_models():
29
  torch_dtype=torch.bfloat16
30
  )
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
- # pipe.to(device, torch_dtype)
33
 
34
  return model, tokenizer, pipe
35
 
36
  model, tokenizer, pipe = load_models()
37
 
 
 
 
 
 
 
 
 
 
38
  @spaces.GPU(duration=200)
39
  def generate_enhanced_caption(image_caption, text_caption, progress=gr.Progress(track_tqdm=True)):
40
  """Generate enhanced caption using the LeX-Enhancer model"""
@@ -71,6 +76,9 @@ Below is the simple caption of an image with text. Please deduce the detailed de
71
  @spaces.GPU(duration=200)
72
  def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
73
  """Generate image using LeX-Lumina"""
 
 
 
74
  generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None
75
 
76
  image = pipe(
 
5
  from diffusers import Lumina2Pipeline
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
 
 
 
8
  if torch.cuda.is_available():
9
  torch_dtype = torch.bfloat16
10
  else:
 
26
  torch_dtype=torch.bfloat16
27
  )
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
29
 
30
  return model, tokenizer, pipe
31
 
32
  model, tokenizer, pipe = load_models()
33
 
34
+ def truncate_caption_by_tokens(caption, max_tokens=256):
35
+ """Truncate the caption to fit within the max token limit"""
36
+ tokens = tokenizer.encode(caption)
37
+ if len(tokens) > max_tokens:
38
+ truncated_tokens = tokens[:max_tokens]
39
+ caption = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
40
+ print(f"Caption was truncated from {len(tokens)} tokens to {max_tokens} tokens")
41
+ return caption
42
+
43
  @spaces.GPU(duration=200)
44
  def generate_enhanced_caption(image_caption, text_caption, progress=gr.Progress(track_tqdm=True)):
45
  """Generate enhanced caption using the LeX-Enhancer model"""
 
76
  @spaces.GPU(duration=200)
77
  def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
78
  """Generate image using LeX-Lumina"""
79
+ # Truncate the caption if it's too long
80
+ enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256)
81
+
82
  generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None
83
 
84
  image = pipe(