BuckLakeAI / us_stock_yfinance.py
parkerjj's picture
feat: Add data source configuration and implement yfinance integration for US stock indices
ac60dc3
import logging
import re
import pandas as pd
from datetime import datetime, timedelta
import time # 导入标准库的 time 模块
import os
import requests
import threading
import asyncio
import yfinance as yf
logging.basicConfig(level=logging.INFO)
# 获取当前文件的目录
base_dir = os.path.dirname(os.path.abspath(__file__))
# 构建CSV文件的绝对路径
nasdaq_100_path = os.path.join(base_dir, './model/nasdaq100.csv')
dow_jones_path = os.path.join(base_dir, './model/dji.csv')
sp500_path = os.path.join(base_dir, './model/sp500.csv')
nasdaq_composite_path = os.path.join(base_dir, './model/nasdaq_all.csv')
# 从CSV文件加载成分股数据
nasdaq_100_stocks = pd.read_csv(nasdaq_100_path)
dow_jones_stocks = pd.read_csv(dow_jones_path)
sp500_stocks = pd.read_csv(sp500_path)
nasdaq_composite_stocks = pd.read_csv(nasdaq_composite_path)
def fetch_stock_us_spot_data_with_retries():
"""使用 yfinance 和本地 CSV 数据创建股票代码表"""
try:
# 从本地CSV文件收集所有股票代码
all_symbols = set()
# 从各个指数CSV文件中提取股票代码
for df, name in [
(nasdaq_100_stocks, "NASDAQ-100"),
(dow_jones_stocks, "Dow Jones"),
(sp500_stocks, "S&P 500"),
(nasdaq_composite_stocks, "NASDAQ Composite")
]:
if 'Symbol' in df.columns:
symbols_from_csv = df['Symbol'].dropna().astype(str).tolist()
all_symbols.update(symbols_from_csv)
elif 'Code' in df.columns:
symbols_from_csv = df['Code'].dropna().astype(str).tolist()
all_symbols.update(symbols_from_csv)
# 添加一些常见的ETF和热门股票
additional_symbols = [
# 主要ETF
'SPY', 'QQQ', 'IWM', 'VTI', 'ARKK', 'TQQQ', 'SQQQ', 'SPXL',
# 热门科技股
'AAPL', 'MSFT', 'GOOGL', 'GOOG', 'AMZN', 'TSLA', 'META', 'NVDA', 'NFLX',
'AMD', 'INTC', 'ORCL', 'CRM', 'ADBE', 'PYPL', 'UBER', 'LYFT',
# 中概股
'BABA', 'JD', 'PDD', 'NIO', 'XPEV', 'LI', 'DIDI', 'TME',
# 其他热门股票
'COST', 'WMT', 'JPM', 'BAC', 'XOM', 'CVX', 'PFE', 'JNJ', 'KO', 'PEP'
]
all_symbols.update(additional_symbols)
# 创建DataFrame
symbols_list = sorted(list(all_symbols))
symbols_df = pd.DataFrame({
'代码': symbols_list,
'名称': [f'{symbol} Inc.' for symbol in symbols_list] # 简单的名称映射
})
print(f"Created symbols dataframe with {len(symbols_df)} symbols")
return symbols_df
except Exception as e:
print(f"Error creating symbols dataframe: {e}")
# 返回基本的fallback数据
fallback_symbols = [
'AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA', 'META', 'NVDA', 'NFLX',
'SPY', 'QQQ', 'IWM', 'VTI'
]
return pd.DataFrame({
'代码': fallback_symbols,
'名称': [f'{symbol} Inc.' for symbol in fallback_symbols]
})
async def fetch_stock_us_spot_data_with_retries_async():
"""异步版本的股票代码获取"""
try:
return await asyncio.to_thread(fetch_stock_us_spot_data_with_retries)
except Exception as e:
print(f"Error in async fetch: {e}")
return pd.DataFrame()
symbols = None
async def fetch_symbols():
global symbols
try:
print("Starting symbols initialization...")
# 异步获取数据
symbols = await fetch_stock_us_spot_data_with_retries_async()
if symbols is not None and not symbols.empty:
print(f"Symbols initialized successfully: {len(symbols)} symbols loaded")
else:
print("Symbols initialization failed, using empty dataset")
symbols = pd.DataFrame()
except Exception as e:
print(f"Error in fetch_symbols: {e}")
symbols = pd.DataFrame()
finally:
print("Symbols initialization completed")
# 全局变量
index_us_stock_index_INX = None
index_us_stock_index_DJI = None
index_us_stock_index_IXIC = None
index_us_stock_index_NDX = None
def update_stock_indices():
"""使用 yfinance 获取美股指数数据"""
global index_us_stock_index_INX, index_us_stock_index_DJI, index_us_stock_index_IXIC, index_us_stock_index_NDX
try:
print("Starting stock indices update using yfinance...")
# 获取过去8周的数据
end_date = datetime.now()
start_date = end_date - timedelta(weeks=8)
# 指数映射
indices = {
'^GSPC': 'INX', # S&P 500
'^DJI': 'DJI', # Dow Jones
'^IXIC': 'IXIC', # NASDAQ Composite
'^NDX': 'NDX' # NASDAQ 100
}
results = {}
for yf_symbol, var_name in indices.items():
try:
ticker = yf.Ticker(yf_symbol)
hist_data = ticker.history(start=start_date, end=end_date)
if not hist_data.empty:
# 转换为与akshare相同的格式
formatted_data = pd.DataFrame({
'date': hist_data.index.strftime('%Y-%m-%d'),
'开盘': hist_data['Open'].values,
'收盘': hist_data['Close'].values,
'最高': hist_data['High'].values,
'最低': hist_data['Low'].values,
'成交量': hist_data['Volume'].values,
'成交额': (hist_data['Close'] * hist_data['Volume']).values
})
results[var_name] = formatted_data
print(f"Successfully fetched {var_name} data: {len(formatted_data)} records")
else:
print(f"No data received for {yf_symbol}")
results[var_name] = pd.DataFrame()
except Exception as e:
print(f"Error fetching {yf_symbol}: {e}")
results[var_name] = pd.DataFrame()
# 设置全局变量
index_us_stock_index_INX = results.get('INX', pd.DataFrame())
index_us_stock_index_DJI = results.get('DJI', pd.DataFrame())
index_us_stock_index_IXIC = results.get('IXIC', pd.DataFrame())
index_us_stock_index_NDX = results.get('NDX', pd.DataFrame())
print("Stock indices updated successfully using yfinance")
except Exception as e:
print(f"Error updating stock indices: {e}")
# 设置定时器,每隔12小时更新一次
threading.Timer(12 * 60 * 60, update_stock_indices).start()
# 程序开始时不立即更新,而是延迟启动
def start_indices_update():
"""延迟启动股票指数更新,避免阻塞应用启动"""
threading.Timer(5, update_stock_indices).start() # 5秒后开始第一次更新
# 延迟启动股票指数更新
start_indices_update()
# 创建列名转换的字典
column_mapping = {
'日期': 'date',
'开盘': 'open',
'收盘': 'close',
'最高': 'high',
'最低': 'low',
'成交量': 'volume',
'成交额': 'amount',
'振幅': 'amplitude',
'涨跌幅': 'price_change_percentage',
'涨跌额': 'price_change_amount',
'换手率': 'turnover_rate'
}
# 定义一个标准的列顺序
standard_columns = ['date', 'open', 'close', 'high', 'low', 'volume', 'amount']
# 定义查找函数
def find_stock_entry(stock_code):
# 使用 str.endswith 来匹配股票代码
if symbols is None or symbols.empty:
print("Warning: symbols data is empty")
return ""
try:
matching_row = symbols[symbols['代码'].str.endswith(stock_code, na=False)]
if not matching_row.empty:
return matching_row['代码'].values[0]
else:
# 如果没有找到,直接返回输入的代码(假设它是有效的)
return stock_code.upper()
except Exception as e:
print(f"Error in find_stock_entry: {e}")
return stock_code.upper()
def reduce_columns(df, columns_to_keep):
return df[columns_to_keep]
# 创建缓存字典
_price_cache = {}
def get_last_minute_stock_price(symbol: str, max_retries=3) -> float:
"""获取股票最新价格,使用30分钟缓存,并包含重试机制"""
if not symbol:
return -1.0
if symbol == "NONE_SYMBOL_FOUND":
return -1.0
current_time = datetime.now()
# 检查缓存
if symbol in _price_cache:
cached_price, cached_time = _price_cache[symbol]
# 如果缓存时间在30分钟内,直接返回缓存的价格
if current_time - cached_time < timedelta(minutes=30):
return cached_price
# 重试机制
for attempt in range(max_retries):
try:
# 使用yfinance获取实时数据
ticker = yf.Ticker(symbol)
info = ticker.info
current_price = info.get('regularMarketPrice') or info.get('currentPrice')
if current_price is None:
# 尝试获取历史数据的最新价格
hist = ticker.history(period='1d', interval='1m')
if not hist.empty:
current_price = float(hist['Close'].iloc[-1])
if current_price is not None:
current_price = float(current_price)
# 更新缓存
_price_cache[symbol] = (current_price, current_time)
return current_price
else:
print(f"Warning: No price data for {symbol}, attempt {attempt + 1}/{max_retries}")
if attempt == max_retries - 1:
return -1.0
time.sleep(1)
except Exception as e:
print(f"Error fetching price for {symbol}, attempt {attempt + 1}/{max_retries}: {str(e)}")
if attempt == max_retries - 1:
return -1.0
time.sleep(1)
return -1.0
# 返回个股历史数据
def get_stock_history(symbol, news_date, retries=10):
"""使用 yfinance 获取股票历史数据"""
# 如果传入的symbol不包含数字前缀,则通过 find_stock_entry 获取完整的symbol
if not any(char.isdigit() for char in symbol):
full_symbol = find_stock_entry(symbol)
if len(symbol) != 0 and full_symbol:
symbol = full_symbol
else:
symbol = ""
# 将news_date转换为datetime对象
current_date = datetime.now()
# 计算start_date和end_date
start_date = current_date - timedelta(days=60)
end_date = current_date
stock_hist_df = None
retry_count = 0
while retry_count <= retries and len(symbol) != 0:
try:
# 使用yfinance获取数据
ticker = yf.Ticker(symbol)
stock_hist_df = ticker.history(start=start_date, end=end_date)
if stock_hist_df.empty:
print(f"No data for {symbol} on {news_date}.")
stock_hist_df = None
else:
# 转换为与akshare相同的格式
stock_hist_df = stock_hist_df.reset_index()
stock_hist_df = pd.DataFrame({
'date': stock_hist_df['Date'].dt.strftime('%Y-%m-%d'),
'开盘': stock_hist_df['Open'],
'收盘': stock_hist_df['Close'],
'最高': stock_hist_df['High'],
'最低': stock_hist_df['Low'],
'成交量': stock_hist_df['Volume'],
'成交额': stock_hist_df['Close'] * stock_hist_df['Volume'],
'振幅': 0, # yfinance没有直接提供,设为0
'涨跌幅': 0, # 可以计算,但这里简化为0
'涨跌额': 0, # 可以计算,但这里简化为0
'换手率': 0 # yfinance没有直接提供,设为0
})
break
except Exception as e:
print(f"Error {e} scraping data for {symbol} on {news_date}. Retrying...")
retry_count += 1
if retry_count <= retries:
time.sleep(2) # 等待2秒后重试
continue
# 如果获取失败或数据为空,返回填充为0的 DataFrame
if stock_hist_df is None or stock_hist_df.empty:
# 构建一个空的 DataFrame,包含指定日期范围的空数据
date_range = pd.date_range(start=start_date, end=end_date)
stock_hist_df = pd.DataFrame({
'date': date_range.strftime('%Y-%m-%d'),
'开盘': 0,
'收盘': 0,
'最高': 0,
'最低': 0,
'成交量': 0,
'成交额': 0,
'振幅': 0,
'涨跌幅': 0,
'涨跌额': 0,
'换手率': 0
})
# 使用rename方法转换列名
stock_hist_df = stock_hist_df.rename(columns=column_mapping)
stock_hist_df = stock_hist_df.reindex(columns=standard_columns)
# 处理个股数据,保留所需列
stock_hist_df = reduce_columns(stock_hist_df, standard_columns)
return stock_hist_df
# 返回个股所属指数历史数据
def get_stock_index_history(symbol, news_date, force_index=0):
# 检查股票所属的指数
if symbol in nasdaq_100_stocks['Symbol'].values or force_index == 1:
index_code = ".NDX"
index_data = index_us_stock_index_NDX
elif symbol in dow_jones_stocks['Symbol'].values or force_index == 2:
index_code = ".DJI"
index_data = index_us_stock_index_DJI
elif symbol in sp500_stocks['Symbol'].values or force_index == 3:
index_code = ".INX"
index_data = index_us_stock_index_INX
elif symbol in nasdaq_composite_stocks["Symbol"].values or symbol is None or symbol == "" or force_index == 4:
index_code = ".IXIC"
index_data = index_us_stock_index_IXIC
else:
index_code = ".IXIC"
index_data = index_us_stock_index_IXIC
# 获取当前日期
current_date = datetime.now()
# 计算 start_date 和 end_date
start_date = (current_date - timedelta(weeks=8)).strftime("%Y-%m-%d")
end_date = current_date.strftime("%Y-%m-%d")
if index_data is None or index_data.empty:
# 如果全局数据为空,尝试实时获取
print(f"Index data for {index_code} is empty, fetching real-time data...")
try:
# 映射到yfinance符号
yf_symbol_map = {
'.INX': '^GSPC',
'.DJI': '^DJI',
'.IXIC': '^IXIC',
'.NDX': '^NDX'
}
yf_symbol = yf_symbol_map.get(index_code, '^IXIC')
ticker = yf.Ticker(yf_symbol)
hist_data = ticker.history(start=start_date, end=end_date)
if not hist_data.empty:
index_data = pd.DataFrame({
'date': hist_data.index.strftime('%Y-%m-%d'),
'开盘': hist_data['Open'].values,
'收盘': hist_data['Close'].values,
'最高': hist_data['High'].values,
'最低': hist_data['Low'].values,
'成交量': hist_data['Volume'].values,
'成交额': (hist_data['Close'] * hist_data['Volume']).values
})
else:
# 返回空数据
date_range = pd.date_range(start=start_date, end=end_date)
index_data = pd.DataFrame({
'date': date_range.strftime('%Y-%m-%d'),
'开盘': 0, '收盘': 0, '最高': 0, '最低': 0, '成交量': 0, '成交额': 0
})
except Exception as e:
print(f"Error fetching real-time index data: {e}")
# 返回空数据
date_range = pd.date_range(start=start_date, end=end_date)
index_data = pd.DataFrame({
'date': date_range.strftime('%Y-%m-%d'),
'开盘': 0, '收盘': 0, '最高': 0, '最低': 0, '成交量': 0, '成交额': 0
})
# 确保 index_data['date'] 是 datetime 类型
index_data['date'] = pd.to_datetime(index_data['date'])
# 从指数历史数据中提取指定日期范围的数据
index_hist_df = index_data[(index_data['date'] >= start_date) & (index_data['date'] <= end_date)]
# 统一列名
index_hist_df = index_hist_df.rename(columns=column_mapping)
index_hist_df = index_hist_df.reindex(columns=standard_columns)
# 处理个股数据,保留所需列
index_hist_df = reduce_columns(index_hist_df, standard_columns)
return index_hist_df
def find_stock_codes_or_names(entities):
"""
从给定的实体列表中检索股票代码或公司名称。
"""
stock_codes = set()
# 合并所有股票字典并清理数据,确保都是字符串
all_symbols = pd.concat([nasdaq_100_stocks['Symbol'],
dow_jones_stocks['Symbol'],
sp500_stocks['Symbol'],
nasdaq_composite_stocks['Symbol']]).dropna().astype(str).unique().tolist()
all_names = pd.concat([nasdaq_100_stocks['Name'],
nasdaq_composite_stocks['Name'],
sp500_stocks['Security'],
dow_jones_stocks['Company']]).dropna().astype(str).unique().tolist()
# 创建一个 Name 到 Symbol 的映射
name_to_symbol = {}
for idx, name in enumerate(all_names):
if idx < len(all_symbols):
symbol = all_symbols[idx]
name_to_symbol[name.lower()] = symbol
# 查找实体映射到的股票代码
for entity, entity_type in entities:
entity_lower = entity.lower()
entity_upper = entity.upper()
# 检查 Symbol 列
if entity_upper in all_symbols:
stock_codes.add(entity_upper)
# 检查 Name 列,确保完整匹配而不是部分匹配
for name, symbol in name_to_symbol.items():
# 使用正则表达式进行严格匹配
pattern = rf'\b{re.escape(entity_lower)}\b'
if re.search(pattern, name):
stock_codes.add(symbol.upper())
if not stock_codes:
return ['NONE_SYMBOL_FOUND']
return list(stock_codes)
def process_history(stock_history, target_date, history_days=30, following_days=3):
# 检查数据是否为空
if stock_history.empty:
return create_empty_data(history_days), create_empty_data(following_days)
# 确保日期列存在并转换为datetime格式
if 'date' not in stock_history.columns:
return create_empty_data(history_days), create_empty_data(following_days)
stock_history['date'] = pd.to_datetime(stock_history['date'])
target_date = pd.to_datetime(target_date)
# 按日期升序排序
stock_history = stock_history.sort_values('date')
# 找到目标日期对应的索引
target_row = stock_history[stock_history['date'] <= target_date]
if target_row.empty:
return create_empty_data(history_days), create_empty_data(following_days)
# 获取目标日期最近的行
target_index = target_row.index[-1]
target_pos = stock_history.index.get_loc(target_index)
# 获取历史数据(包括目标日期)
start_pos = max(0, target_pos - history_days + 1)
previous_rows = stock_history.iloc[start_pos:target_pos + 1]
# 获取后续数据
following_rows = stock_history.iloc[target_pos + 1:target_pos + following_days + 1]
# 删除日期列并确保数据完整性
previous_rows = previous_rows.drop(columns=['date'])
following_rows = following_rows.drop(columns=['date'])
# 处理数据不足的情况
previous_rows = handle_insufficient_data(previous_rows, history_days)
following_rows = handle_insufficient_data(following_rows, following_days)
return previous_rows.iloc[:, :6], following_rows.iloc[:, :6]
def create_empty_data(days):
return pd.DataFrame({
'开盘': [-1] * days,
'收盘': [-1] * days,
'最高': [-1] * days,
'最低': [-1] * days,
'成交量': [-1] * days,
'成交额': [-1] * days
})
def handle_insufficient_data(data, required_days):
current_rows = len(data)
if current_rows < required_days:
missing_rows = required_days - current_rows
empty_data = create_empty_data(missing_rows)
return pd.concat([empty_data, data]).reset_index(drop=True)
return data
if __name__ == "__main__":
# 测试函数
result = find_stock_entry('AAPL')
print(f"find_stock_entry: {result}")
result = get_stock_history('AAPL', '20240214')
print(f"get_stock_history: {result}")
result = get_stock_index_history('AAPL', '20240214')
print(f"get_stock_index_history: {result}")
result = find_stock_codes_or_names([('苹果', 'ORG'), ('苹果公司', 'ORG')])
print(f"find_stock_codes_or_names: {result}")
result = process_history(get_stock_history('AAPL', '20240214'), '20240214')
print(f"process_history: {result}")
result = process_history(get_stock_index_history('AAPL', '20240214'), '20240214')
print(f"process_history: {result}")
pass