kidwaiaun commited on
Commit
6ac282e
·
verified ·
1 Parent(s): 7b30e45

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import gradio as gr
3
+ from datasets import load_dataset
4
+ import pandas as pd
5
+ import numpy as np
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
7
+ import torch
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
10
+ import os
11
+ import shutil
12
+
13
+ # Load dataset once at startup
14
+ ds = load_dataset("ashraq/financial-news-articles")
15
+ df = pd.DataFrame(ds['train'])
16
+
17
+ # Simulate labels (replace with real labels in practice)
18
+ np.random.seed(42)
19
+ df['label'] = np.random.randint(0, 3, size=len(df)) # 0=neg, 1=neu, 2=pos
20
+ df['input_text'] = df['title'] + " " + df['text']
21
+
22
+ # Global variables for model and tokenizer
23
+ model = None
24
+ tokenizer = None
25
+ sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"}
26
+
27
+ # Function to tokenize dataset
28
+ def tokenize_function(examples, tokenizer):
29
+ return tokenizer(examples['input_text'], padding="max_length", truncation=True, max_length=512)
30
+
31
+ # Function to compute metrics
32
+ def compute_metrics(pred):
33
+ labels = pred.label_ids
34
+ preds = np.argmax(pred.predictions, axis=1)
35
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
36
+ acc = accuracy_score(labels, preds)
37
+ return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall}
38
+
39
+ # Train the model with user-defined parameters
40
+ def train_model(learning_rate, epochs, batch_size, save_path):
41
+ global model, tokenizer
42
+
43
+ # Split dataset
44
+ train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
45
+
46
+ # Load tokenizer and model
47
+ model_name = "bert-base-uncased"
48
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
49
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
50
+
51
+ # Prepare datasets
52
+ train_dataset = Dataset.from_pandas(train_df[['input_text', 'label']])
53
+ test_dataset = Dataset.from_pandas(test_df[['input_text', 'label']])
54
+ train_dataset = train_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)
55
+ test_dataset = test_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)
56
+ train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
57
+ test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
58
+
59
+ # Training arguments
60
+ training_args = TrainingArguments(
61
+ output_dir="./temp_model",
62
+ evaluation_strategy="epoch",
63
+ learning_rate=learning_rate,
64
+ per_device_train_batch_size=batch_size,
65
+ per_device_eval_batch_size=batch_size,
66
+ num_train_epochs=epochs,
67
+ weight_decay=0.01,
68
+ logging_dir='./logs',
69
+ logging_steps=10,
70
+ save_strategy="epoch",
71
+ load_best_model_at_end=True,
72
+ )
73
+
74
+ # Initialize trainer
75
+ trainer = Trainer(
76
+ model=model,
77
+ args=training_args,
78
+ train_dataset=train_dataset,
79
+ eval_dataset=test_dataset,
80
+ compute_metrics=compute_metrics,
81
+ )
82
+
83
+ # Train and evaluate
84
+ trainer.train()
85
+ eval_results = trainer.evaluate()
86
+
87
+ # Save the model if path provided
88
+ if save_path:
89
+ trainer.save_model(save_path)
90
+ tokenizer.save_pretrained(save_path)
91
+ output = f"Model saved to {save_path}\nEvaluation results: {eval_results}"
92
+ else:
93
+ output = f"Model trained but not saved.\nEvaluation results: {eval_results}"
94
+
95
+ # Clean up temp directory
96
+ if os.path.exists("./temp_model"):
97
+ shutil.rmtree("./temp_model")
98
+ if os.path.exists("./logs"):
99
+ shutil.rmtree("./logs")
100
+
101
+ return output
102
+
103
+ # Load a pre-trained model for inference
104
+ def load_model(model_path):
105
+ global model, tokenizer
106
+ if not os.path.exists(model_path):
107
+ return "Error: Model path does not exist."
108
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
109
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
110
+ return "Model loaded successfully from " + model_path
111
+
112
+ # Predict sentiment for new input
113
+ def predict_sentiment(title, text):
114
+ global model, tokenizer
115
+ if model is None or tokenizer is None:
116
+ return "Error: Please train or load a model first."
117
+
118
+ input_text = title + " " + text
119
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
120
+ with torch.no_grad():
121
+ outputs = model(**inputs)
122
+ pred_label = np.argmax(outputs.logits.numpy(), axis=1)[0]
123
+ return f"Predicted Sentiment: {sentiment_map[pred_label]}"
124
+
125
+ # Gradio interface
126
+ with gr.Blocks(title="Financial News Sentiment Analyzer") as demo:
127
+ gr.Markdown("# Financial News Sentiment Analyzer")
128
+ gr.Markdown("Train a sentiment model on financial news articles, save it, and predict sentiments.")
129
+
130
+ with gr.Tab("Train Model"):
131
+ gr.Markdown("### Train a New Sentiment Model")
132
+ learning_rate = gr.Slider(1e-5, 5e-5, value=2e-5, label="Learning Rate")
133
+ epochs = gr.Slider(1, 5, value=3, step=1, label="Number of Epochs")
134
+ batch_size = gr.Slider(4, 16, value=8, step=4, label="Batch Size")
135
+ save_path = gr.Textbox(label="Save Model Path (optional)", placeholder="e.g., ./my_sentiment_model")
136
+ train_button = gr.Button("Train Model")
137
+ output = gr.Textbox(label="Training Output")
138
+ train_button.click(
139
+ fn=train_model,
140
+ inputs=[learning_rate, epochs, batch_size, save_path],
141
+ outputs=output
142
+ )
143
+
144
+ with gr.Tab("Load Model"):
145
+ gr.Markdown("### Load an Existing Model")
146
+ model_path = gr.Textbox(label="Model Path", placeholder="e.g., ./my_sentiment_model")
147
+ load_button = gr.Button("Load Model")
148
+ load_output = gr.Textbox(label="Load Status")
149
+ load_button.click(
150
+ fn=load_model,
151
+ inputs=model_path,
152
+ outputs=load_output
153
+ )
154
+
155
+ with gr.Tab("Predict Sentiment"):
156
+ gr.Markdown("### Predict Sentiment for New Input")
157
+ title_input = gr.Textbox(label="Article Title")
158
+ text_input = gr.Textbox(label="Article Text", lines=5)
159
+ predict_button = gr.Button("Predict")
160
+ pred_output = gr.Textbox(label="Prediction")
161
+ predict_button.click(
162
+ fn=predict_sentiment,
163
+ inputs=[title_input, text_input],
164
+ outputs=pred_output
165
+ )
166
+
167
+ # Launch the app
168
+ demo.launch()