stzhao commited on
Commit
4920a3b
·
verified ·
1 Parent(s): 1de771c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -16,8 +16,8 @@ def load_models():
16
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
- torch_dtype="auto",
20
- device_map="auto"
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
 
@@ -26,7 +26,7 @@ def load_models():
26
  torch_dtype=torch.bfloat16
27
  )
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
- pipe.to("cuda")
30
 
31
  return model, tokenizer, pipe
32
 
@@ -43,6 +43,7 @@ def truncate_caption_by_tokens(caption, max_tokens=256):
43
 
44
  @spaces.GPU(duration=50)
45
  def generate_enhanced_caption(image_caption, text_caption):
 
46
  """Generate enhanced caption using the LeX-Enhancer model"""
47
  combined_caption = f"{image_caption}, with the text on it: {text_caption}."
48
  instruction = """
@@ -76,6 +77,7 @@ Below is the simple caption of an image with text. Please deduce the detailed de
76
 
77
  @spaces.GPU(duration=60)
78
  def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
 
79
  """Generate image using LeX-Lumina"""
80
  # Truncate the caption if it's too long
81
  enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256)
 
16
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
+ torch_dtype=torch.bfloat16,
20
+ # device_map="auto"
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
 
 
26
  torch_dtype=torch.bfloat16
27
  )
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ # pipe.to("cuda")
30
 
31
  return model, tokenizer, pipe
32
 
 
43
 
44
  @spaces.GPU(duration=50)
45
  def generate_enhanced_caption(image_caption, text_caption):
46
+ model.to("cuda")
47
  """Generate enhanced caption using the LeX-Enhancer model"""
48
  combined_caption = f"{image_caption}, with the text on it: {text_caption}."
49
  instruction = """
 
77
 
78
  @spaces.GPU(duration=60)
79
  def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
80
+ pipe.to("cuda")
81
  """Generate image using LeX-Lumina"""
82
  # Truncate the caption if it's too long
83
  enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256)