import os from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, Text, JSON from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from datetime import datetime # 读取配置 DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///data/stock_analyzer.db') USE_DATABASE = os.getenv('USE_DATABASE', 'False').lower() == 'true' # 创建引擎 engine = create_engine(DATABASE_URL) Base = declarative_base() # 定义模型 class StockInfo(Base): __tablename__ = 'stock_info' id = Column(Integer, primary_key=True) stock_code = Column(String(10), nullable=False, index=True) stock_name = Column(String(50)) market_type = Column(String(5)) industry = Column(String(50)) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) def to_dict(self): return { 'stock_code': self.stock_code, 'stock_name': self.stock_name, 'market_type': self.market_type, 'industry': self.industry, 'updated_at': self.updated_at.strftime('%Y-%m-%d %H:%M:%S') if self.updated_at else None } class AnalysisResult(Base): __tablename__ = 'analysis_results' id = Column(Integer, primary_key=True) stock_code = Column(String(10), nullable=False, index=True) market_type = Column(String(5)) analysis_date = Column(DateTime, default=datetime.now) score = Column(Float) recommendation = Column(String(100)) technical_data = Column(JSON) fundamental_data = Column(JSON) capital_flow_data = Column(JSON) ai_analysis = Column(Text) def to_dict(self): return { 'stock_code': self.stock_code, 'market_type': self.market_type, 'analysis_date': self.analysis_date.strftime('%Y-%m-%d %H:%M:%S') if self.analysis_date else None, 'score': self.score, 'recommendation': self.recommendation, 'technical_data': self.technical_data, 'fundamental_data': self.fundamental_data, 'capital_flow_data': self.capital_flow_data, 'ai_analysis': self.ai_analysis } class Portfolio(Base): __tablename__ = 'portfolios' id = Column(Integer, primary_key=True) user_id = Column(String(50), nullable=False, index=True) name = Column(String(100)) created_at = Column(DateTime, default=datetime.now) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) stocks = Column(JSON) # 存储股票列表的JSON def to_dict(self): return { 'id': self.id, 'user_id': self.user_id, 'name': self.name, 'created_at': self.created_at.strftime('%Y-%m-%d %H:%M:%S') if self.created_at else None, 'updated_at': self.updated_at.strftime('%Y-%m-%d %H:%M:%S') if self.updated_at else None, 'stocks': self.stocks } # 创建会话工厂 Session = sessionmaker(bind=engine) # 初始化数据库 def init_db(): Base.metadata.create_all(engine) # 获取数据库会话 def get_session(): return Session() # 如果启用数据库,则初始化 if USE_DATABASE: init_db()