Kvikontent commited on
Commit
c2ba8e6
·
1 Parent(s): 4d54af4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -11
app.py CHANGED
@@ -1,15 +1,71 @@
1
  import gradio as gr
2
- from transformers import BertTokenizer, BertForSequenceClassification
 
3
 
4
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
5
- model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=8)
 
6
 
7
- def classify(text):
8
- inputs = tokenizer(text, return_tensors='pt')
9
- outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'])
10
- logits = outputs[0].logits
11
- probabilities = torch.softmax(logits, dim=-1).tolist()
12
- predicted_label = np.argmax(probabilities)
13
- return {'Label': predicted_label}
 
 
 
 
14
 
15
- gr.Interface(fn=classify, inputs="text", outputs="json").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, BartModel
4
 
5
+ # Load the BART model and tokenizer
6
+ tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
7
+ model = BartModel.from_pretrained("facebook/bart-base")
8
 
9
+ # Define the user input widget
10
+ user_input = gr.TextInput(
11
+ value="",
12
+ placeholder="Type something...",
13
+ rows=10,
14
+ cols=50,
15
+ fontsize=16,
16
+ padding=10,
17
+ border=True,
18
+ background="#f2f2f2"
19
+ )
20
 
21
+ # Define the chatbot output widget
22
+ chatbot_output = gr.TextOutput(
23
+ value="",
24
+ fontsize=16,
25
+ padding=10,
26
+ border=True,
27
+ background="#ffffff"
28
+ )
29
+
30
+ # Define the button to trigger the chatbot response
31
+ button = gr.Button(
32
+ label="Send",
33
+ fontsize=16,
34
+ padding=10,
35
+ border=True,
36
+ background="#4CAF50"
37
+ )
38
+
39
+ # Define the callback function to call when the button is clicked
40
+ def send_message():
41
+ # Get the user input text
42
+ input_text = user_input.value
43
+
44
+ # Tokenize the input text
45
+ inputs = tokenizer(input_text, return_tensors="pt")
46
+
47
+ # Run the input through the BART model to get the generated text
48
+ outputs = model(**inputs)
49
+ last_hidden_states = outputs.last_hidden_state
50
+
51
+ # Extract the generated text from the last hidden state
52
+ generated_text = tokenizer.generate(last_hidden_states, max_length=50, padding="max_length", truncation=True).strip()
53
+
54
+ # Update the chatbot output text
55
+ chatbot_output.value = generated_text
56
+
57
+ # Add the widgets to the layout
58
+ gr.layout(
59
+ gr.Row(
60
+ gr.Column(user_input),
61
+ gr.Column(button),
62
+ gr.Column(chatbot_output)
63
+ )
64
+ )
65
+
66
+ # Set up the button click event handler
67
+ button.on_click(send_message)
68
+
69
+ # Start the Gradio app
70
+ if __name__ == "__main__":
71
+ gr.run()