youj2005 commited on
Commit
19dbb2e
·
1 Parent(s): d2825f4

fix input_ids

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -10,11 +10,11 @@ qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", dev
10
 
11
  def predict(context, intent):
12
  input_text = "In one word, what is the opposite of: " + intent + "?"
13
- input_ids = qa_tokenizer(input_text, return_tensors="pt")
14
- opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, return_tensors='pt')[0])
15
  input_text = "In one word, what is the following describing: " + context
16
- input_ids = qa_tokenizer(input_text, return_tensors="pt")
17
- object_output = qa_tokenizer.decode(qa_model.generate(input_ids, return_tensors='pt')[0])
18
  batch = ['I think the ' + object_output + ' are long.', 'I think the ' + object_output + ' are ' + opposite_output, 'I think the ' + object_output + ' are the perfect']
19
  outputs = []
20
  for i, hypothesis in enumerate(batch):
@@ -51,5 +51,7 @@ gradio_app = gr.Interface(
51
  title="Intent Analysis",
52
  )
53
 
 
 
54
  if __name__ == "__main__":
55
  gradio_app.launch()
 
10
 
11
  def predict(context, intent):
12
  input_text = "In one word, what is the opposite of: " + intent + "?"
13
+ input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids
14
+ opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length = 1)[0])
15
  input_text = "In one word, what is the following describing: " + context
16
+ input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids
17
+ object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length = 1)[0])
18
  batch = ['I think the ' + object_output + ' are long.', 'I think the ' + object_output + ' are ' + opposite_output, 'I think the ' + object_output + ' are the perfect']
19
  outputs = []
20
  for i, hypothesis in enumerate(batch):
 
51
  title="Intent Analysis",
52
  )
53
 
54
+ print(predict("The cat is short.", "long"))
55
+
56
  if __name__ == "__main__":
57
  gradio_app.launch()