Jainesh212 commited on
Commit
a83ff17
·
1 Parent(s): aab37dc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, BertForSequenceClassification, DistilBertModel
5
+ import torch
6
+ from torch import cuda
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import finetuning
9
+ from finetuning import CustomDistilBertClass
10
+
11
+
12
+ model_map = {
13
+ 'BERT': 'bert-base-uncased',
14
+ 'RoBERTa': 'roberta-base',
15
+ 'DistilBERT': 'distilbert-base-uncased'
16
+ }
17
+
18
+ model_options = list(model_map.keys())
19
+ label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
20
+
21
+
22
+ @st.cache_resource
23
+ def load_model(model_name):
24
+ """Load pretrained BERT model."""
25
+ path = "finetuned_model.pt"
26
+ model = torch.load(path)
27
+ tokenizer = AutoTokenizer.from_pretrained(model_map[model_name])
28
+ return model, tokenizer
29
+
30
+ def classify_text(model, tokenizer, text):
31
+ """Classify text using pretrained BERT model."""
32
+ inputs = tokenizer.encode_plus(
33
+ text,
34
+ add_special_tokens=True,
35
+ max_length=512,
36
+ padding='max_length',
37
+ return_tensors='pt',
38
+ truncation=True
39
+ )
40
+ with torch.no_grad():
41
+ logits = model(inputs['input_ids'],inputs['attention_mask'])[0]
42
+ probabilities = torch.softmax(logits, dim=1)[0]
43
+ pred_class = torch.argmax(probabilities, dim=0)
44
+ return label_cols[pred_class], round(probabilities[0].tolist(),2)
45
+
46
+
47
+
48
+ st.title('Toxicity Classification App')
49
+ model_name = st.sidebar.selectbox('Select model', model_options)
50
+ st.sidebar.write('You selected:', model_name)
51
+ model, tokenizer = load_model(model_name)
52
+
53
+
54
+ st.subheader('Enter your text below:')
55
+ text_input = st.text_area(label='', height=100, max_chars=500)
56
+
57
+ if st.button('Classify'):
58
+ if not text_input:
59
+ st.write('Please enter some text')
60
+ else:
61
+ class_label, class_prob = classify_text(model, tokenizer, text_input)
62
+ st.subheader('Result')
63
+ st.write('Input Text:', text_input)
64
+ st.write('Highest Toxicity Class:', class_label)
65
+ st.write('Probability:', class_prob)
66
+
67
+ st.subheader('Classification Results')
68
+ if 'classification_results' not in st.session_state:
69
+ st.session_state.classification_results = pd.DataFrame(columns=['text', 'toxicity_class', 'probability'])
70
+ if st.button('Add to Results'):
71
+ if not text_input:
72
+ st.write('Please enter some text')
73
+ else:
74
+ class_label, class_prob = classify_text(model, tokenizer, text_input)
75
+ st.subheader('Result')
76
+ st.write('Input Text:', text_input)
77
+ st.write('Highest Toxicity Class:', class_label)
78
+ st.write('Probability:', class_prob)
79
+ st.session_state.classification_results = st.session_state.classification_results.append({
80
+ 'text': text_input,
81
+ 'toxicity_class': class_label,
82
+ 'probability': class_prob
83
+ }, ignore_index=True)
84
+ st.write(st.session_state.classification_results)