Jainesh212 commited on
Commit
20fdf2e
·
1 Parent(s): 8a84d79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -78
app.py CHANGED
@@ -1,84 +1,48 @@
1
  import streamlit as st
 
2
  import pandas as pd
3
- import numpy as np
4
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, BertForSequenceClassification, DistilBertModel
5
- import torch
6
- from torch import cuda
7
- from torch.utils.data import Dataset, DataLoader
8
- import finetuning
9
- from finetuning import CustomDistilBertClass
10
 
11
-
12
- model_map = {
13
- 'BERT': 'bert-base-uncased',
14
- 'RoBERTa': 'roberta-base',
15
- 'DistilBERT': 'distilbert-base-uncased'
16
- }
17
-
18
- model_options = list(model_map.keys())
19
- label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
20
-
21
-
22
- @st.cache_resource
23
- def load_model(model_name):
24
- """Load pretrained BERT model."""
25
- path = "finetuned_model.pt"
26
- model = torch.load(path)
27
- tokenizer = AutoTokenizer.from_pretrained(model_map[model_name])
28
- return model, tokenizer
29
-
30
- def classify_text(model, tokenizer, text):
31
- """Classify text using pretrained BERT model."""
32
- inputs = tokenizer.encode_plus(
33
- text,
34
- add_special_tokens=True,
35
- max_length=512,
36
- padding='max_length',
37
- return_tensors='pt',
38
- truncation=True
39
- )
40
- with torch.no_grad():
41
- logits = model(inputs['input_ids'],inputs['attention_mask'])[0]
42
- probabilities = torch.softmax(logits, dim=1)[0]
43
- pred_class = torch.argmax(probabilities, dim=0)
44
- return label_cols[pred_class], round(probabilities[0].tolist(),2)
45
 
 
 
46
 
 
 
 
47
 
48
- st.title('Toxicity Classification App')
49
- model_name = st.sidebar.selectbox('Select model', model_options)
50
- st.sidebar.write('You selected:', model_name)
51
- model, tokenizer = load_model(model_name)
52
-
53
-
54
- st.subheader('Enter your text below:')
55
- text_input = st.text_area(label='', height=100, max_chars=500)
56
-
57
- if st.button('Classify'):
58
- if not text_input:
59
- st.write('Please enter some text')
60
- else:
61
- class_label, class_prob = classify_text(model, tokenizer, text_input)
62
- st.subheader('Result')
63
- st.write('Input Text:', text_input)
64
- st.write('Highest Toxicity Class:', class_label)
65
- st.write('Probability:', class_prob)
66
-
67
- st.subheader('Classification Results')
68
- if 'classification_results' not in st.session_state:
69
- st.session_state.classification_results = pd.DataFrame(columns=['text', 'toxicity_class', 'probability'])
70
- if st.button('Add to Results'):
71
- if not text_input:
72
- st.write('Please enter some text')
73
- else:
74
- class_label, class_prob = classify_text(model, tokenizer, text_input)
75
- st.subheader('Result')
76
- st.write('Input Text:', text_input)
77
- st.write('Highest Toxicity Class:', class_label)
78
- st.write('Probability:', class_prob)
79
- st.session_state.classification_results = st.session_state.classification_results.append({
80
- 'text': text_input,
81
- 'toxicity_class': class_label,
82
- 'probability': class_prob
83
- }, ignore_index=True)
84
- st.write(st.session_state.classification_results)
 
1
  import streamlit as st
2
+ import transformers
3
  import pandas as pd
 
 
 
 
 
 
 
4
 
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
6
+
7
+ # Load the pre-trained BERT model
8
+ model_name = 'nlptown/bert-base-multilingual-uncased-sentiment'
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
+ pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer, framework='pt', task='text-classification')
12
+
13
+ # Define the toxicity classification function
14
+ def classify_toxicity(text):
15
+ result = pipeline(text)[0]
16
+ label = result['label']
17
+ score = result['score']
18
+ return label, score
19
+
20
+
21
+ # Define the Streamlit app
22
+ def app():
23
+ # Create a persistent DataFrame
24
+ if 'results' not in st.session_state:
25
+ st.session_state.results = pd.DataFrame(columns=['text', 'toxicity', 'score'])
26
+
27
+ # Create a form for users to enter their text
28
+ with st.form(key='text_form'):
29
+ text_input = st.text_input(label='Enter your text:')
30
+ submit_button = st.form_submit_button(label='Classify')
31
+
32
+ # Classify the text and display the results
33
+ if submit_button and text_input != '':
34
+ label, score = classify_toxicity(text_input)
35
+ st.write('Classification Result:')
36
+ st.write(f'Text: {text_input}')
37
+ st.write(f'Toxicity: {label}')
38
+ st.write(f'Score: {score}')
39
 
40
+ # Add the classification result to the persistent DataFrame
41
+ st.session_state.results = st.session_state.results.append({'text': text_input, 'toxicity': label, 'score': score}, ignore_index=True)
42
 
43
+ # Display the persistent DataFrame
44
+ st.write('Classification Results:')
45
+ st.write(st.session_state.results)
46
 
47
+ if __name__ == '__main__':
48
+ app()