stock / database.py
gitdeem's picture
Upload 32 files
f5d52f6 verified
import os
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, Text, JSON
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from datetime import datetime
# 读取配置
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///data/stock_analyzer.db')
USE_DATABASE = os.getenv('USE_DATABASE', 'False').lower() == 'true'
# 创建引擎
engine = create_engine(DATABASE_URL)
Base = declarative_base()
# 定义模型
class StockInfo(Base):
__tablename__ = 'stock_info'
id = Column(Integer, primary_key=True)
stock_code = Column(String(10), nullable=False, index=True)
stock_name = Column(String(50))
market_type = Column(String(5))
industry = Column(String(50))
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
def to_dict(self):
return {
'stock_code': self.stock_code,
'stock_name': self.stock_name,
'market_type': self.market_type,
'industry': self.industry,
'updated_at': self.updated_at.strftime('%Y-%m-%d %H:%M:%S') if self.updated_at else None
}
class AnalysisResult(Base):
__tablename__ = 'analysis_results'
id = Column(Integer, primary_key=True)
stock_code = Column(String(10), nullable=False, index=True)
market_type = Column(String(5))
analysis_date = Column(DateTime, default=datetime.now)
score = Column(Float)
recommendation = Column(String(100))
technical_data = Column(JSON)
fundamental_data = Column(JSON)
capital_flow_data = Column(JSON)
ai_analysis = Column(Text)
def to_dict(self):
return {
'stock_code': self.stock_code,
'market_type': self.market_type,
'analysis_date': self.analysis_date.strftime('%Y-%m-%d %H:%M:%S') if self.analysis_date else None,
'score': self.score,
'recommendation': self.recommendation,
'technical_data': self.technical_data,
'fundamental_data': self.fundamental_data,
'capital_flow_data': self.capital_flow_data,
'ai_analysis': self.ai_analysis
}
class Portfolio(Base):
__tablename__ = 'portfolios'
id = Column(Integer, primary_key=True)
user_id = Column(String(50), nullable=False, index=True)
name = Column(String(100))
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
stocks = Column(JSON) # 存储股票列表的JSON
def to_dict(self):
return {
'id': self.id,
'user_id': self.user_id,
'name': self.name,
'created_at': self.created_at.strftime('%Y-%m-%d %H:%M:%S') if self.created_at else None,
'updated_at': self.updated_at.strftime('%Y-%m-%d %H:%M:%S') if self.updated_at else None,
'stocks': self.stocks
}
# 创建会话工厂
Session = sessionmaker(bind=engine)
# 初始化数据库
def init_db():
Base.metadata.create_all(engine)
# 获取数据库会话
def get_session():
return Session()
# 如果启用数据库,则初始化
if USE_DATABASE:
init_db()