ashwml commited on
Commit
c845a40
·
1 Parent(s): 3e2a45d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -39,6 +39,7 @@ import torch
39
  # f1_metric.set(f1)
40
 
41
  feature_extractor = ViTImageProcessor.from_pretrained("model")
 
42
  cap_model = VisionEncoderDecoderModel.from_pretrained("model")
43
  tokenizer = AutoTokenizer.from_pretrained("model")
44
 
@@ -61,7 +62,7 @@ def generate_caption(processor, model, image, tokenizer=None):
61
  # return preds
62
  inputs = processor(images=image, return_tensors="pt").to(device)
63
 
64
- generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
65
 
66
  if tokenizer is not None:
67
  generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
39
  # f1_metric.set(f1)
40
 
41
  feature_extractor = ViTImageProcessor.from_pretrained("model")
42
+ print(feature_extractor)
43
  cap_model = VisionEncoderDecoderModel.from_pretrained("model")
44
  tokenizer = AutoTokenizer.from_pretrained("model")
45
 
 
62
  # return preds
63
  inputs = processor(images=image, return_tensors="pt").to(device)
64
 
65
+ generated_ids = model.generate(pixel_values=inputs.pixel_values)
66
 
67
  if tokenizer is not None:
68
  generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]