LinkLinkWu's picture
Update func.py
628c80f verified
raw
history blame
3.16 kB
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
from bs4 import BeautifulSoup
import requests
# ----------- Eager Initialization of Pipelines -----------
# Sentiment pipeline
model_id = "LinkLinkWu/ISOM5240HKUSTBASE"
sentiment_tokenizer = AutoTokenizer.from_pretrained(model_id)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(model_id)
sentiment_pipeline = pipeline(
"sentiment-analysis",
model=sentiment_model,
tokenizer=sentiment_tokenizer
)
# NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
ner_pipeline = pipeline(
"ner",
model=ner_model,
tokenizer=ner_tokenizer,
grouped_entities=True
)
# ----------- Core Functions -----------
def fetch_news(ticker):
try:
url = f"https://finviz.com/quote.ashx?t={ticker}"
headers = {
'User-Agent': 'Mozilla/5.0',
'Accept': 'text/html',
'Accept-Language': 'en-US,en;q=0.5',
'Referer': 'https://finviz.com/',
'Connection': 'keep-alive',
}
response = requests.get(url, headers=headers)
if response.status_code != 200:
return []
soup = BeautifulSoup(response.text, 'html.parser')
title = soup.title.text if soup.title else ""
if ticker not in title:
return []
news_table = soup.find(id='news-table')
if news_table is None:
return []
news = []
for row in news_table.findAll('tr')[:30]:
a_tag = row.find('a')
if a_tag:
title_text = a_tag.get_text()
link = a_tag['href']
news.append({'title': title_text, 'link': link})
return news
except Exception:
return []
def analyze_sentiment(text, pipe=None):
"""
兼容两种调用:
- analyze_sentiment(text) -> 使用全局 sentiment_pipeline
- analyze_sentiment(text, some_pipeline) -> 使用传入的 some_pipeline
"""
try:
sentiment_pipe = pipe or sentiment_pipeline
result = sentiment_pipe(text)[0]
return "Positive" if result['label'] == 'POSITIVE' else "Negative"
except Exception:
return "Unknown"
def extract_org_entities(text, pipe=None):
"""
- extract_org_entities(text)
- extract_org_entities(text, some_pipeline)
"""
try:
ner_pipe = pipe or ner_pipeline
entities = ner_pipe(text)
orgs = []
for ent in entities:
if ent["entity_group"] == "ORG":
w = ent["word"].replace("##", "").strip().upper()
if w not in orgs:
orgs.append(w)
if len(orgs) >= 5:
break
return orgs
except Exception:
return []
# ----------- Helper Functions for Imports -----------
def get_sentiment_pipeline():
return sentiment_pipeline
def get_ner_pipeline():
return ner_pipeline