Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| import torch | |
| # Load pre-trained BERT model and tokenizer | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6) | |
| model.eval() | |
| # Function to classify text using the pre-trained BERT model | |
| def classify_text(text): | |
| # Tokenize input text | |
| input_ids = tokenizer.encode(text, add_special_tokens=True) | |
| # Convert tokenized input to tensor | |
| input_tensor = torch.tensor([input_ids]) | |
| # Get model predictions | |
| with torch.no_grad(): | |
| logits = model(input_tensor)[0] | |
| # Get predicted labels | |
| predicted_labels = torch.sigmoid(logits).numpy() | |
| return predicted_labels | |
| # Create a persistent DataFrame to store classification results | |
| results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate']) | |
| # Streamlit app | |
| def app(): | |
| st.title("Toxicity Classification App") | |
| st.write("Enter text below to classify its toxicity.") | |
| # User input | |
| user_input = st.text_area("Enter text here:", "", key='user_input') | |
| # Classification | |
| if st.button("Classify"): | |
| # Perform classification | |
| labels = classify_text(user_input) | |
| # Print classification results | |
| st.write("Classification Results:") | |
| st.write("Toxic: {:.2%}".format(labels[0][0])) | |
| st.write("Severe Toxic: {:.2%}".format(labels[0][1])) | |
| st.write("Obscene: {:.2%}".format(labels[0][2])) | |
| st.write("Threat: {:.2%}".format(labels[0][3])) | |
| st.write("Insult: {:.2%}".format(labels[0][4])) | |
| st.write("Identity Hate: {:.2%}".format(labels[0][5])) | |
| # Add results to persistent DataFrame | |
| results_df.loc[len(results_df)] = [user_input, labels[0][0], labels[0][1], labels[0][2], labels[0][3], labels[0][4], labels[0][5]] | |
| # Show results DataFrame | |
| st.write("Classification Results DataFrame:") | |
| st.write(results_df) | |
| # Run the app | |
| if __name__ == "__main__": | |
| app() | |