khurrameycon commited on
Commit
e59f632
·
verified ·
1 Parent(s): 1dcef5f

predict_text

Browse files
Files changed (1) hide show
  1. app.py +20 -2
app.py CHANGED
@@ -55,11 +55,29 @@ def predict(image, text):
55
  response = processor.decode(outputs[0], skip_special_tokens=True)
56
  return response
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Define the Gradio interface
59
  interface = gr.Interface(
60
- fn=predict,
61
  inputs=[
62
- gr.Image(type="pil", label="Image Input"), # Image input with label
63
  gr.Textbox(label="Text Input") # Textbox input with label
64
  ],
65
  outputs=gr.Textbox(label="Generated Response"), # Output with a more descriptive label
 
55
  response = processor.decode(outputs[0], skip_special_tokens=True)
56
  return response
57
 
58
+ def predict_text(text):
59
+ # Prepare the input messages
60
+ messages = [{"role": "user", "content": [{"type": "text", "text": txt}]}]
61
+
62
+ # Create the input text using the processor's chat template
63
+ input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
64
+
65
+ # Process the inputs and move to the appropriate device
66
+ # inputs = processor(image, input_text, return_tensors="pt").to(device)
67
+ inputs = processor(text=text, return_tensors="pt").to("cuda")
68
+ # Generate a response from the model
69
+ outputs = model.generate(**inputs, max_new_tokens=250)
70
+
71
+ # Decode the output to return the final response
72
+ response = processor.decode(outputs[0], skip_special_tokens=True)
73
+ return response
74
+
75
+
76
  # Define the Gradio interface
77
  interface = gr.Interface(
78
+ fn=predict_text,
79
  inputs=[
80
+ # gr.Image(type="pil", label="Image Input"), # Image input with label
81
  gr.Textbox(label="Text Input") # Textbox input with label
82
  ],
83
  outputs=gr.Textbox(label="Generated Response"), # Output with a more descriptive label