Update app.py
Browse files
app.py
CHANGED
@@ -12,28 +12,28 @@ ner_tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-large-cased-finetuned-
|
|
12 |
ner_model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
|
13 |
ner_model.eval()
|
14 |
|
15 |
-
|
16 |
-
#
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
|
28 |
|
29 |
def predict_sentiment_and_stock_info(headline):
|
30 |
# Sentiment Analysis
|
31 |
sentiment_inputs = sentiment_tokenizer(headline, padding=True, truncation=True, return_tensors='pt')
|
32 |
with torch.no_grad():
|
33 |
sentiment_outputs = sentiment_model(**sentiment_inputs)
|
34 |
-
|
35 |
|
36 |
-
|
37 |
sentiment_label = "Positive" if pos > neg and pos > neutr else "Negative" if neg > pos and neg > neutr else "Neutral"
|
38 |
|
39 |
# Named Entity Recognition (NER)
|
@@ -41,17 +41,16 @@ def predict_sentiment_and_stock_info(headline):
|
|
41 |
with torch.no_grad():
|
42 |
ner_outputs = ner_model(**ner_inputs)
|
43 |
|
44 |
-
#
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
|
50 |
-
#
|
51 |
-
|
52 |
|
53 |
-
|
54 |
-
return sentiment_label
|
55 |
|
56 |
# Gradio Interface
|
57 |
'''iface = gr.Interface(
|
|
|
12 |
ner_model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
|
13 |
ner_model.eval()
|
14 |
|
15 |
+
def get_advice(sentiment_label, stocks_mentioned):
|
16 |
+
# Add your own logic for providing advice based on sentiment and stocks mentioned
|
17 |
+
if sentiment_label == "Positive":
|
18 |
+
advice = "Positive sentiment. Consider taking advantage of positive market trends."
|
19 |
+
elif sentiment_label == "Negative":
|
20 |
+
if stocks_mentioned:
|
21 |
+
advice = f"Negative sentiment. Consider re-evaluating your position on stocks: {', '.join(stocks_mentioned)}."
|
22 |
+
else:
|
23 |
+
advice = "Negative sentiment. Consider monitoring the market for potential impacts."
|
24 |
+
else:
|
25 |
+
advice = "Neutral sentiment. The market may not be strongly influenced. Monitor for changes."
|
26 |
|
27 |
+
return advice
|
28 |
|
29 |
def predict_sentiment_and_stock_info(headline):
|
30 |
# Sentiment Analysis
|
31 |
sentiment_inputs = sentiment_tokenizer(headline, padding=True, truncation=True, return_tensors='pt')
|
32 |
with torch.no_grad():
|
33 |
sentiment_outputs = sentiment_model(**sentiment_inputs)
|
34 |
+
sentiment_prediction = torch.nn.functional.softmax(sentiment_outputs.logits, dim=-1)
|
35 |
|
36 |
+
pos, neg, neutr = sentiment_prediction[:, 0].item(), sentiment_prediction[:, 1].item(), sentiment_prediction[:, 2].item()
|
37 |
sentiment_label = "Positive" if pos > neg and pos > neutr else "Negative" if neg > pos and neg > neutr else "Neutral"
|
38 |
|
39 |
# Named Entity Recognition (NER)
|
|
|
41 |
with torch.no_grad():
|
42 |
ner_outputs = ner_model(**ner_inputs)
|
43 |
|
44 |
+
# Identify stocks mentioned in the headline
|
45 |
+
ner_predictions = torch.nn.functional.softmax(ner_outputs.logits, dim=-1).argmax(2)
|
46 |
+
tokens = ner_tokenizer.convert_ids_to_tokens(ner_inputs['input_ids'][0].tolist()) # Use ner_inputs here
|
47 |
+
entities = ner_tokenizer.convert_ids_to_tokens(ner_predictions[0].tolist())
|
48 |
+
stocks_mentioned = [tokens[i] for i, entity in enumerate(entities) if entity.startswith("B")]
|
49 |
|
50 |
+
# Advice based on sentiment and identified stocks
|
51 |
+
advice = get_advice(sentiment_label, stocks_mentioned)
|
52 |
|
53 |
+
return sentiment_label, advice
|
|
|
54 |
|
55 |
# Gradio Interface
|
56 |
'''iface = gr.Interface(
|