ashwml commited on
Commit
d2864bb
·
1 Parent(s): 000e1d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -47,28 +47,28 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
47
  vitgpt_model.to(device)
48
 
49
  def generate_caption(processor, model, image, tokenizer=None):
50
- max_length = 16
51
- num_beams = 4
52
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
53
 
54
- pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
55
- pixel_values = pixel_values.to(device)
56
 
57
- output_ids = model.generate(pixel_values, **gen_kwargs)
58
 
59
- preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
60
- preds = [pred.strip() for pred in preds]
61
- return preds
62
- # inputs = processor(images=image, return_tensors="pt").to(device)
63
 
64
- # generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
65
 
66
- # if tokenizer is not None:
67
- # generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
68
- # else:
69
- # generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
70
 
71
- # return generated_caption
72
 
73
  def predict_event(image):
74
 
 
47
  vitgpt_model.to(device)
48
 
49
  def generate_caption(processor, model, image, tokenizer=None):
50
+ # max_length = 16
51
+ # num_beams = 4
52
+ # gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
53
 
54
+ # pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
55
+ # pixel_values = pixel_values.to(device)
56
 
57
+ # output_ids = model.generate(pixel_values, **gen_kwargs)
58
 
59
+ # preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
60
+ # preds = [pred.strip() for pred in preds]
61
+ # return preds
62
+ inputs = processor(images=image, return_tensors="pt").to(device)
63
 
64
+ generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
65
 
66
+ if tokenizer is not None:
67
+ generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
68
+ else:
69
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
70
 
71
+ return generated_caption
72
 
73
  def predict_event(image):
74