stzhao commited on
Commit
b1cdcda
·
verified ·
1 Parent(s): a07a29c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -10,16 +10,16 @@ def load_models():
10
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_name,
13
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
14
- device_map="auto"
15
- )
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
 
18
  return model, tokenizer
19
 
20
  model, tokenizer = load_models()
21
 
22
- @spaces.GPU()
23
  def generate_enhanced_caption(image_caption, text_caption):
24
  """Generate enhanced caption using the LeX-Enhancer model"""
25
  combined_caption = f"{image_caption}, with the text on it: {text_caption}."
 
10
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_name,
13
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
14
+ # device_map="auto"
15
+ ).to("cuda")
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
 
18
  return model, tokenizer
19
 
20
  model, tokenizer = load_models()
21
 
22
+ # @spaces.GPU()
23
  def generate_enhanced_caption(image_caption, text_caption):
24
  """Generate enhanced caption using the LeX-Enhancer model"""
25
  combined_caption = f"{image_caption}, with the text on it: {text_caption}."