Kvikontent commited on
Commit
76a650f
·
1 Parent(s): e69d0e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -1,19 +1,28 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
- # Load the BART tokenizer and model
6
  tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
7
- model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base")
8
 
9
- def generate_response(user_input):
10
- inputs = tokenizer(user_input, return_tensors="pt", max_length=512, truncation=True)
11
- outputs = model.generate(**inputs)
12
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
13
- return response
14
 
15
- input_textbox = gr.Textbox(lines=10, label="Enter your text here")
16
- output_textbox = gr.Textbox(label="Chatbot Response")
17
 
18
- chatbot_interface = gr.Interface(fn=generate_response, inputs=input_textbox, outputs=output_textbox, title="Hugging Face BART Chatbot", description="This chatbot uses the Hugging Face BART model to generate responses based on user input.")
19
- chatbot_interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, BartModel
3
  import torch
4
 
5
+ # Load the tokenizer and model
6
  tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
7
+ model = BartModel.from_pretrained("facebook/bart-base")
8
 
9
+ def generate_response(text):
10
+ # Tokenize the input text
11
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
 
12
 
13
+ # Generate model output
14
+ outputs = model(**inputs)
15
 
16
+ # Get the last hidden states from the model output
17
+ last_hidden_states = outputs.last_hidden_state
18
+
19
+ # Return last hidden states as output
20
+ return last_hidden_states
21
+
22
+ iface = gr.Interface(
23
+ fn=generate_response,
24
+ inputs=gr.Textbox(lines=7, label="Input Text"),
25
+ outputs=gr.Textbox(label="Output Last Hidden States")
26
+ )
27
+
28
+ iface.launch()