Szeyu commited on
Commit
6af8332
·
verified ·
1 Parent(s): 90ab0d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -121
app.py CHANGED
@@ -1,135 +1,161 @@
1
- import re
2
  import streamlit as st
3
- from transformers import pipeline
4
- import textwrap
 
 
5
  import numpy as np
6
- import soundfile as sf
7
- import tempfile
8
- import os
9
- from PIL import Image
10
- import string
11
-
12
- # Initialize pipelines with caching
13
- @st.cache_resource
14
- def load_pipelines():
15
- captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
16
- storyer = pipeline("text-generation", model="aspis/gpt2-genre-story-generation")
17
- tts = pipeline("text-to-speech", model="facebook/mms-tts-eng")
18
- return captioner, storyer, tts
19
 
20
- captioner, storyer, tts = load_pipelines()
 
 
 
 
 
 
 
 
 
21
 
22
- def clean_generated_story(raw_story: str) -> str:
23
- """
24
- Cleans the generated story by:
25
- 1. Removing URLs.
26
- 2. Removing digits.
27
- 3. Removing words likely to be random letter combinations based on having no vowels.
28
- 4. Removing single-letter words unless allowed (such as 'a' or 'I').
29
- """
30
- # Remove URLs starting with http://, https://, or www.
31
- no_urls = re.sub(r'\b(?:https?://|www\.)\S+\b', '', raw_story)
32
- # Remove domain names without protocol (e.g., erskybooks.com)
33
- no_urls = re.sub(r'\b\w+\.(com|net|org|co\.uk|ca\.us|me)\b', '', no_urls)
34
-
35
- # Remove all digits
36
- story_without_numbers = re.sub(r'\d+', '', no_urls)
37
-
38
- vowels = set('aeiouAEIOU')
39
-
40
- def is_valid_word(word: str) -> bool:
41
- # Allow "a" and "I" for single-letter words
42
- if len(word) == 1 and word.lower() not in ['a', 'i']:
43
- return False
44
- # For words longer than one letter, filter out those that do not contain any vowels
45
- if len(word) > 1 and not any(char in vowels for char in word):
46
- return False
47
- return True
48
 
49
- # Split the cleaned text into words, filter them, and reassemble
50
- words = story_without_numbers.split()
51
- filtered_words = [word for word in words if is_valid_word(word)]
52
-
53
- # Trim the cleaned story to the first 100 words (optional)
54
- clean_story = " ".join(filtered_words[:100])
55
- return clean_story
 
 
56
 
57
- def get_caption(image) -> str:
58
- """
59
- Takes an image and returns a generated caption.
60
- """
61
- pil_image = Image.open(image)
62
- caption = captioner(pil_image)[0]["generated_text"]
63
- st.write("**🌟 What's in the picture: 🌟**")
64
- st.write(caption)
65
- return caption
 
 
66
 
67
- def get_story(caption: str) -> str:
68
- """
69
- Takes a caption and returns a funny, bright, and playful story targeted toward young children.
70
- """
71
- prompt = (
72
- f"Write a funny and playful story for young children precisely centered on this scene {caption}\nStory: "
73
- f"mention the exact place and venue within {caption}. "
74
- f"Make the story magical and exciting."
75
  )
76
-
77
- raw = storyer(
78
- prompt,
79
- max_new_tokens=150,
80
- temperature=0.7,
81
- top_p=0.9,
82
- no_repeat_ngram_size=2,
83
- return_full_text=False
84
- )[0]["generated_text"].strip()
85
-
86
- story = clean_generated_story(raw)
87
- st.write("**📖 Your funny story: 📖**")
88
- st.write(story)
89
- return story
90
 
91
- def generate_audio(story: str) -> str:
92
- """
93
- Converts the text story into speech audio and returns the file path for the audio.
94
- """
95
- chunks = textwrap.wrap(story, width=200)
96
- audio = np.concatenate([tts(chunk)["audio"].squeeze() for chunk in chunks])
97
-
98
- # Save the audio to a temporary file and return its path.
99
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
100
- sf.write(temp_file.name, audio, tts.model.config.sampling_rate)
101
- temp_file_path = temp_file.name
102
- return temp_file_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- def generate_content(image):
105
- """
106
- Pipeline function that:
107
- - Generates a caption from the uploaded image.
108
- - Uses the caption to generate a story.
109
- - Converts the story to speech audio.
110
- """
111
- caption = get_caption(image)
112
- story = get_story(caption)
113
- audio_path = generate_audio(story)
114
- return caption, story, audio_path
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- # Streamlit UI section
117
- st.title(" Magic Story Maker ")
118
- st.markdown("Upload a picture to make a funny story and hear it too! 📸")
119
 
120
- uploaded_image = st.file_uploader("Choose your picture", type=["jpg", "jpeg", "png"])
 
 
121
 
122
- if uploaded_image is None:
123
- st.image("https://example.com/placeholder_image.jpg", caption="Upload your picture here! 📷", use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  else:
125
- st.image(uploaded_image, caption="Your Picture 🌟", use_container_width=True)
126
-
127
- if st.button("✨ Make My Story! ✨"):
128
- if uploaded_image is not None:
129
- with st.spinner("🔮 Creating your magical story..."):
130
- caption, story, audio_path = generate_content(uploaded_image)
131
- st.success("🎉 Your story is ready! 🎉")
132
- st.audio(audio_path, format="audio/wav")
133
- os.remove(audio_path)
134
- else:
135
- st.warning("Please upload a picture first! 📸")
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ from sklearn.model_selection import train_test_split
4
+ from datasets import Dataset
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, get_linear_schedule_with_warmup
6
  import numpy as np
7
+ import torch
8
+ from transformers import pipeline
9
+ from collections import Counter
10
+ import time
11
+ from tqdm import tqdm
12
+ import evaluate
 
 
 
 
 
 
 
13
 
14
+ # Function to load and process data
15
+ def load_and_process_data(news_file, trend_file):
16
+ news_df = pd.read_csv(news_file)
17
+ trend_df = pd.read_csv(trend_file)
18
+ trend_df = trend_df.rename(columns={'Symbol': 'Stock'})
19
+ news_labeled_df = news_df.merge(trend_df[['Stock', 'Trend']], on='Stock', how='left')
20
+ news_labeled_df = news_labeled_df[news_labeled_df['Trend'].isin(['Positive', 'Negative'])]
21
+ label_map = {'Negative': 0, 'Positive': 1}
22
+ news_labeled_df['label'] = news_labeled_df['Trend'].map(label_map)
23
+ return news_labeled_df
24
 
25
+ # Function to check class imbalance
26
+ def check_class_imbalance(df):
27
+ class_counts = df['label'].value_counts()
28
+ st.write("**Class Distribution:**", class_counts.to_dict())
29
+ if class_counts.min() / class_counts.max() < 0.5:
30
+ st.warning("Warning: Class imbalance detected. Consider balancing techniques.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Function to split data
33
+ def split_data(df):
34
+ stocks = df['Stock'].unique()
35
+ train_val_stocks, test_stocks = train_test_split(stocks, test_size=0.2, random_state=42)
36
+ train_stocks, val_stocks = train_test_split(train_val_stocks, test_size=0.25, random_state=42)
37
+ train_df = df[df['Stock'].isin(train_stocks)]
38
+ val_df = df[df['Stock'].isin(val_stocks)]
39
+ test_df = df[df['Stock'].isin(test_stocks)]
40
+ return train_df, val_df, test_df
41
 
42
+ # Function to tokenize datasets
43
+ def tokenize_datasets(train_df, val_df, test_df, tokenizer):
44
+ train_dataset = Dataset.from_pandas(train_df[['Headline', 'label']])
45
+ val_dataset = Dataset.from_pandas(val_df[['Headline', 'label']])
46
+ test_dataset = Dataset.from_pandas(test_df[['Headline', 'label']])
47
+ def tokenize_function(examples):
48
+ return tokenizer(examples['Headline'], padding='max_length', truncation=True, max_length=128)
49
+ tokenized_train = train_dataset.map(tokenize_function, batched=True)
50
+ tokenized_val = val_dataset.map(tokenize_function, batched=True)
51
+ tokenized_test = test_dataset.map(tokenize_function, batched=True)
52
+ return tokenized_train, tokenized_val, tokenized_test
53
 
54
+ # Function to load model with caching
55
+ @st.cache_resource
56
+ def load_model():
57
+ model = AutoModelForSequenceClassification.from_pretrained(
58
+ "yiyanghkust/finbert-tone",
59
+ num_labels=2,
60
+ ignore_mismatched_sizes=True
 
61
  )
62
+ for param in model.bert.encoder.layer[:6].parameters():
63
+ param.requires_grad = False
64
+ return model
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # Function to train model
67
+ def train_model(tokenized_train, tokenized_val, model):
68
+ training_args = TrainingArguments(
69
+ output_dir="./results",
70
+ num_train_epochs=5,
71
+ per_device_train_batch_size=32,
72
+ per_device_eval_batch_size=32,
73
+ eval_strategy="epoch",
74
+ save_strategy="epoch",
75
+ load_best_model_at_end=True,
76
+ metric_for_best_model="accuracy",
77
+ learning_rate=5e-5,
78
+ weight_decay=0.1,
79
+ report_to="none",
80
+ )
81
+ total_steps = len(tokenized_train) // training_args.per_device_train_batch_size * training_args.num_train_epochs
82
+ optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate)
83
+ trainer = Trainer(
84
+ model=model,
85
+ args=training_args,
86
+ train_dataset=tokenized_train,
87
+ eval_dataset=tokenized_val,
88
+ compute_metrics=lambda eval_pred: {"accuracy": evaluate.load("accuracy").compute(predictions=np.argmax(eval_pred.predictions, axis=1), references=eval_pred.label_ids)},
89
+ optimizers=(optimizer, get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)),
90
+ )
91
+ trainer.train()
92
+ trainer.save_model("./fine_tuned_model")
93
+ return trainer
94
 
95
+ # Function to evaluate model
96
+ def evaluate_model(pipe, df, model_name=""):
97
+ results = []
98
+ total_start = time.perf_counter()
99
+ for stock, group in tqdm(df.groupby("Stock")):
100
+ headlines = group["Headline"].tolist()
101
+ true_trend = group["Trend"].iloc[0]
102
+ try:
103
+ preds = pipe(headlines, truncation=True)
104
+ except Exception as e:
105
+ st.error(f"Error for {stock}: {e}")
106
+ continue
107
+ labels = [p['label'] for p in preds]
108
+ count = Counter(labels)
109
+ num_pos, num_neg = count.get("Positive", 0), count.get("Negative", 0)
110
+ predicted_trend = "Positive" if num_pos > num_neg else "Negative"
111
+ match = predicted_trend == true_trend
112
+ results.append(match)
113
+ total_runtime = time.perf_counter() - total_start
114
+ accuracy = sum(results) / len(results) if results else 0
115
+ st.write(f"**🔍 Evaluation Summary for {model_name}**")
116
+ st.write(f"✅ Accuracy: {accuracy:.2%}")
117
+ st.write(f"⏱ Total Runtime: {total_runtime:.2f} seconds")
118
+ return accuracy
119
 
120
+ # Streamlit UI
121
+ st.title("Financial Sentiment Analysis with FinBERT")
122
+ st.markdown("Upload your CSV files to train and evaluate a sentiment analysis model on financial news headlines.")
123
 
124
+ st.header("Upload CSV Files")
125
+ news_file = st.file_uploader("Upload Train_stock_news.csv", type="csv")
126
+ trend_file = st.file_uploader("Upload Training_price_comparison.csv", type="csv")
127
 
128
+ if news_file and trend_file:
129
+ with st.spinner("Processing data..."):
130
+ df = load_and_process_data(news_file, trend_file)
131
+ check_class_imbalance(df)
132
+ train_df, val_df, test_df = split_data(df)
133
+ st.write(f"**Training stocks:** {len(train_df['Stock'].unique())}")
134
+ st.write(f"**Validation stocks:** {len(val_df['Stock'].unique())}")
135
+ st.write(f"**Test stocks:** {len(test_df['Stock'].unique())}")
136
+
137
+ tokenizer = AutoTokenizer.from_pretrained("yiyanghkust/finbert-tone")
138
+ tokenized_train, tokenized_val, tokenized_test = tokenize_datasets(train_df, val_df, test_df, tokenizer)
139
+
140
+ model = load_model()
141
+
142
+ with st.spinner("Training model..."):
143
+ trainer = train_model(tokenized_train, tokenized_val, model)
144
+
145
+ st.success("Model training completed!")
146
+
147
+ # Evaluate original model
148
+ original_pipe = pipeline("text-classification", model="yiyanghkust/finbert-tone")
149
+ st.write("Evaluating original model...")
150
+ original_accuracy = evaluate_model(original_pipe, test_df, model_name="Original Model")
151
+
152
+ # Evaluate fine-tuned model
153
+ fine_tuned_pipe = pipeline("text-classification", model="./fine_tuned_model")
154
+ st.write("Evaluating fine-tuned model...")
155
+ fine_tuned_accuracy = evaluate_model(fine_tuned_pipe, test_df, model_name="Fine-tuned Model")
156
+
157
+ st.write(f"**Comparison:**")
158
+ st.write(f"Original Model Accuracy: {original_accuracy:.2%}")
159
+ st.write(f"Fine-tuned Model Accuracy: {fine_tuned_accuracy:.2%}")
160
  else:
161
+ st.warning("Please upload both CSV files to proceed.")