zhiqiulin commited on
Commit
2b389ac
·
verified ·
1 Parent(s): 99fb211

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -11
app.py CHANGED
@@ -9,22 +9,16 @@ from t2v_metrics import VQAScore, list_all_vqascore_models
9
  print(list_all_vqascore_models())
10
 
11
  # Initialize the model only once
12
- model_pipe = None
13
-
14
- @spaces.GPU
15
- def initialize_model(model_name):
16
- global model_pipe
17
- if model_pipe is None:
18
- model_pipe = VQAScore(model=model_name) # our recommended scoring model
19
- print("Model initialized!")
20
- return model_pipe
21
 
22
  @spaces.GPU
23
  def generate(model_name, image, text):
24
- print("Model_name:", model_name)
25
  print("Image:", image)
26
  print("Text:", text)
27
- model_pipe = initialize_model(model_name)
28
  return model_pipe(images=[image], texts=[text])
29
 
30
  iface = gr.Interface(
 
9
  print(list_all_vqascore_models())
10
 
11
  # Initialize the model only once
12
+ if torch.cuda.is_available():
13
+ model_pipe = VQAScore(model="clip-flant5-x") # our recommended scoring model
14
+ model_pipe.to("cuda")
15
+ print("Model initialized!")
 
 
 
 
 
16
 
17
  @spaces.GPU
18
  def generate(model_name, image, text):
19
+ # print("Model_name:", model_name)
20
  print("Image:", image)
21
  print("Text:", text)
 
22
  return model_pipe(images=[image], texts=[text])
23
 
24
  iface = gr.Interface(