Desalegnn commited on
Commit
04af1b9
Β·
verified Β·
1 Parent(s): 40c91da

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import pandas as pd
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ from transformers import XLMRobertaModel, XLMRobertaTokenizer
7
+ import torch.nn as nn
8
+ import gradio as gr
9
+ from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.metrics import classification_report
12
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
13
+ # Load BERT model and tokenizer via HuggingFace Transformers
14
+ bert = XLMRobertaModel.from_pretrained('castorini/afriberta_large')
15
+ tokenizer = XLMRobertaTokenizer.from_pretrained('castorini/afriberta_large')
16
+ # Define the model architecture
17
+ class BERT_Arch(nn.Module):
18
+ def __init__(self, bert):
19
+ super(BERT_Arch, self).__init__()
20
+ self.bert = bert
21
+ self.dropout = nn.Dropout(0.1) # Dropout layer
22
+ self.relu = nn.ReLU() # ReLU activation function
23
+ self.fc1 = nn.Linear(768, 512) # Dense layer 1
24
+ self.fc2 = nn.Linear(512, 2) # Dense layer 2 (Output layer)
25
+ self.softmax = nn.LogSoftmax(dim=1) # Softmax activation function
26
+
27
+ def forward(self, sent_id, mask): # Define the forward pass
28
+ cls_hs = self.bert(sent_id, attention_mask=mask)['pooler_output']
29
+ x = self.fc1(cls_hs)
30
+ x = self.relu(x)
31
+ x = self.dropout(x)
32
+ x = self.fc2(x) # Output layer
33
+ x = self.softmax(x) # Apply softmax activation
34
+ return x
35
+
36
+ # Load the model and set it to evaluation mode
37
+ model = BERT_Arch(bert)
38
+ fake_news_model_path = "Fake_model.pt"
39
+ fake_news_model = torch.load(fake_news_model_path, map_location=torch.device('cpu'))
40
+ fake_news_model.eval()
41
+
42
+ # Function to detect fake news
43
+ def detect_fake_news(text):
44
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
45
+ with torch.no_grad():
46
+ outputs = fake_news_model(inputs['input_ids'], inputs['attention_mask'])
47
+ label = torch.argmax(outputs, dim=1).item()
48
+ fake_news_result = "Fake" if label == 1 else "Not Fake"
49
+ return fake_news_result
50
+
51
+ # Function to handle post logic
52
+ def post_text(text, fake_news_result):
53
+ if fake_news_result == "Fake":
54
+ return "Your message contains Fake News and cannot be posted.", ""
55
+ else:
56
+ return "The text is safe to post.", text
57
+
58
+ # Gradio Interface
59
+ interface = gr.Blocks()
60
+ with interface:
61
+ gr.Markdown("## Fake News Detection")
62
+ with gr.Row():
63
+ text_input = gr.Textbox(label="Enter Text", lines=5)
64
+ with gr.Row():
65
+ detect_fake_button = gr.Button("Detect Fake News")
66
+ with gr.Row():
67
+ fake_news_result_box = gr.Textbox(label="Fake News Detection Result", interactive=False)
68
+ with gr.Row():
69
+ post_button = gr.Button("Post Text")
70
+ with gr.Row():
71
+ post_result_box = gr.Textbox(label="Posting Status", interactive=False)
72
+ posted_text_box = gr.Textbox(label="Posted Text", interactive=False)
73
+
74
+ detect_fake_button.click(
75
+ fn=detect_fake_news,
76
+ inputs=text_input,
77
+ outputs=fake_news_result_box,
78
+ )
79
+
80
+ post_button.click(
81
+ fn=post_text,
82
+ inputs=[text_input, fake_news_result_box],
83
+ outputs=[post_result_box, posted_text_box],
84
+ )
85
+
86
+ # Launch the app
87
+ if __name__ == "__main__":
88
+ interface.launch()