ashwml commited on
Commit
000e1d3
·
1 Parent(s): c20bcc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -38,31 +38,43 @@ import torch
38
 
39
  # f1_metric.set(f1)
40
 
41
- vitgpt_processor = ViTImageProcessor.from_pretrained("model")
42
- vitgpt_model = VisionEncoderDecoderModel.from_pretrained("model")
43
- vitgpt_tokenizer = AutoTokenizer.from_pretrained("model", return_tensors="pt")
44
 
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
  vitgpt_model.to(device)
48
 
49
  def generate_caption(processor, model, image, tokenizer=None):
50
- inputs = processor(images=image, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
53
 
54
- if tokenizer is not None:
55
- generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
56
- else:
57
- generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
58
 
59
- return generated_caption
60
 
61
- def predict_event(input):
62
 
63
 
64
 
65
- caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)
66
 
67
  return caption_vitgpt
68
 
 
38
 
39
  # f1_metric.set(f1)
40
 
41
+ feature_extractor = ViTImageProcessor.from_pretrained("model")
42
+ cap_model = VisionEncoderDecoderModel.from_pretrained("model")
43
+ tokenizer = AutoTokenizer.from_pretrained("model")
44
 
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
  vitgpt_model.to(device)
48
 
49
  def generate_caption(processor, model, image, tokenizer=None):
50
+ max_length = 16
51
+ num_beams = 4
52
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
53
+
54
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
55
+ pixel_values = pixel_values.to(device)
56
+
57
+ output_ids = model.generate(pixel_values, **gen_kwargs)
58
+
59
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
60
+ preds = [pred.strip() for pred in preds]
61
+ return preds
62
+ # inputs = processor(images=image, return_tensors="pt").to(device)
63
 
64
+ # generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
65
 
66
+ # if tokenizer is not None:
67
+ # generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
68
+ # else:
69
+ # generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
70
 
71
+ # return generated_caption
72
 
73
+ def predict_event(image):
74
 
75
 
76
 
77
+ caption_vitgpt = generate_caption(feature_extractor, cap_model, image, tokenizer)
78
 
79
  return caption_vitgpt
80