# service.py from typing import List, Optional, Dict, Any from decimal import Decimal from datetime import date, datetime, timezone # Added timezone from tortoise.exceptions import DoesNotExist from tortoise.transactions import in_transaction from App.schemas import AppException # Assuming AppException is in App.schemas from .models import ( Portfolio, PortfolioStock, PortfolioUTT, PortfolioBond, PortfolioTransaction, PortfolioCalendar, PortfolioSnapshot, ) # Assuming models for stocks, utts, bonds are in these paths from ..stocks.models import Stock, StockPriceData from ..utt.models import UTTFund, UTTFundData from ..bonds.models import ( Bond, ) # Assuming Bond model might have price_per_100 or similar # Import Pydantic schemas from .schemas import ( PortfolioSummary, StockHoldingResponse, UTTHoldingResponse, BondHoldingResponse, AssetAllocation, PortfolioBase, TransactionResponse, CalendarEventResponse, # Added PortfolioBase and other response schemas ) from App.routers.tasks.models import ImportTask from datetime import date, timedelta from tortoise.expressions import Q from typing import List, Generator def _calculate_bond_coupon_dates( bond: Bond, start_date: date, end_date: date ) -> Generator[date, None, None]: """ Calculates the semi-annual coupon payment dates for a bond within a given date range. This makes a common assumption that coupon payments occur semi-annually, with one payment on the maturity month/day and the other 6 months apart. """ if bond.maturity_date and bond.coupon_rate > 0: # First coupon payment month and day month1 = bond.maturity_date.month day1 = bond.maturity_date.day # Second coupon payment is 6 months from the first month2 = ( month1 + 5 ) % 12 + 1 # +5 then %12 handles the 6-month offset correctly # Iterate through years from the bond's issue to maturity for year in range(bond.effective_date.year, bond.maturity_date.year + 1): try: # Construct the two potential coupon dates for the year coupon_date1 = date(year, month1, day1) coupon_date2 = date(year, month2, day1) # Day is assumed the same # Yield the date if it falls within the user's requested filter range if start_date <= coupon_date1 <= end_date: yield coupon_date1 if start_date <= coupon_date2 <= end_date: yield coupon_date2 except ValueError: # Handles cases like Feb 29 on a non-leap year, just skip that invalid date. continue class PortfolioService: @staticmethod async def get_user_portfolios( user_id: int, include_inactive: bool = False ) -> List[Portfolio]: """Get all portfolios for a user""" query = Portfolio.filter(user_id=user_id) if not include_inactive: query = query.filter(is_active=True) return await query.order_by("-created_at").all() @staticmethod async def create_portfolio( user_id: int, name: str, description: Optional[str] = None ) -> Portfolio: """Create a new portfolio for user""" return await Portfolio.create( user_id=user_id, name=name, description=description ) @staticmethod async def get_portfolio_summary(portfolio_id: int) -> PortfolioSummary: """Get comprehensive portfolio summary with all holdings and calculations""" portfolio_orm = await Portfolio.get_or_none(id=portfolio_id) if not portfolio_orm: raise DoesNotExist("Portfolio not found") # Get all holdings with calculated values stock_holdings_resp = await PortfolioService._get_stock_holdings_with_values( portfolio_id ) utt_holdings_resp = await PortfolioService._get_utt_holdings_with_values( portfolio_id ) bond_holdings_resp = await PortfolioService._get_bond_holdings_with_values( portfolio_id ) # Calculate total market values total_stock_value = sum( h.market_value or Decimal("0") for h in stock_holdings_resp ) total_utt_value = sum( Decimal(h.market_value) or Decimal("0") for h in utt_holdings_resp ) total_bond_value = sum( Decimal(h.market_value) or Decimal("0") for h in bond_holdings_resp ) total_market_value = total_stock_value + total_utt_value + total_bond_value # Calculate total cost basis # For stocks/UTTs, purchase_price is average unit price on the aggregated holding. total_stock_cost = sum( h.purchase_price * h.quantity for h in stock_holdings_resp ) total_utt_cost = sum(h.purchase_price * h.units_held for h in utt_holdings_resp) # For bonds, BondHoldingResponse.purchase_price is the *total* purchase cost for that aggregated holding. total_bond_cost = sum(h.purchase_price for h in bond_holdings_resp) total_cost_basis = total_stock_cost + total_utt_cost + total_bond_cost # Calculate overall gains/losses overall_unrealized_gain_loss = total_market_value - total_cost_basis overall_unrealized_gain_loss_percentage = ( (overall_unrealized_gain_loss / total_cost_basis * Decimal("100")) if total_cost_basis > 0 else Decimal("0") ) # Get recent transactions recent_transactions_orm = ( await PortfolioTransaction.filter(portfolio_id=portfolio_id) .order_by("-transaction_date", "-created_at") .limit(10) .all() ) recent_transactions_resp = [ TransactionResponse.from_orm(t) for t in recent_transactions_orm ] # Get upcoming events upcoming_events_orm = ( await PortfolioCalendar.filter( portfolio_id=portfolio_id, event_date__gte=date.today(), is_completed=False, ) .order_by("event_date") .limit(10) .all() ) upcoming_events_resp = [ CalendarEventResponse.from_orm(e) for e in upcoming_events_orm ] # Asset allocation asset_alloc = AssetAllocation( stocks_percentage=( (total_stock_value / total_market_value * Decimal("100")) if total_market_value > 0 else Decimal("0") ), bonds_percentage=( (total_bond_value / total_market_value * Decimal("100")) if total_market_value > 0 else Decimal("0") ), utts_percentage=( (total_utt_value / total_market_value * Decimal("100")) if total_market_value > 0 else Decimal("0") ), cash_percentage=Decimal( "0" ), # Assuming cash is not directly tracked here yet total_value=total_market_value, ) portfolio_base = PortfolioBase.from_orm(portfolio_orm) return PortfolioSummary( portfolio=portfolio_base, total_market_value=total_market_value, total_cost_basis=total_cost_basis, overall_unrealized_gain_loss=overall_unrealized_gain_loss, overall_unrealized_gain_loss_percentage=overall_unrealized_gain_loss_percentage, stock_holdings=stock_holdings_resp, utt_holdings=utt_holdings_resp, bond_holdings=bond_holdings_resp, asset_allocation=asset_alloc, recent_transactions=recent_transactions_resp, upcoming_events=upcoming_events_resp, ) @staticmethod async def _get_stock_holdings_with_values( portfolio_id: int, ) -> List[StockHoldingResponse]: holdings_orm = ( await PortfolioStock.filter(portfolio_id=portfolio_id) .prefetch_related("stock") .all() ) results = [] for holding in holdings_orm: # holding is now an aggregated record latest_price_data = ( await StockPriceData.filter(stock_id=holding.stock_id) .order_by("-date") .first() ) current_price = ( latest_price_data.closing_price if latest_price_data else None ) market_value = ( (current_price * holding.quantity) if current_price is not None else None ) # holding.purchase_price is average unit price cost_basis = holding.purchase_price * holding.quantity gain_loss = ( (market_value - cost_basis) if market_value is not None else None ) gain_loss_percentage = ( (gain_loss / cost_basis * Decimal("100")) if gain_loss is not None and cost_basis > 0 else None ) results.append( StockHoldingResponse( id=holding.id, # This ID is of the PortfolioStock record itself stock_id=holding.stock.id, stock_symbol=holding.stock.symbol, stock_name=holding.stock.name, quantity=holding.quantity, purchase_price=holding.purchase_price, # Average unit purchase price purchase_date=holding.purchase_date, # Date of first/last buy or as defined current_price=current_price, market_value=market_value, gain_loss=gain_loss, gain_loss_percentage=gain_loss_percentage, notes=holding.notes, created_at=holding.created_at, ) ) return results @staticmethod async def _get_utt_holdings_with_values( portfolio_id: int, ) -> List[UTTHoldingResponse]: holdings_orm = ( await PortfolioUTT.filter(portfolio_id=portfolio_id) .prefetch_related("utt_fund") .all() ) results = [] for holding in holdings_orm: # holding is now an aggregated record latest_nav_data = ( await UTTFundData.filter(fund_id=holding.utt_fund_id) .order_by("-date") .first() ) current_nav = latest_nav_data.nav_per_unit if latest_nav_data else None market_value = ( (Decimal(current_nav) * holding.units_held) if current_nav is not None else None ) # holding.purchase_price is average unit price cost_basis = holding.purchase_price * holding.units_held gain_loss = ( (market_value - cost_basis) if market_value is not None else None ) gain_loss_percentage = ( (gain_loss / cost_basis * Decimal("100")) if gain_loss is not None and cost_basis > 0 else None ) results.append( UTTHoldingResponse( id=holding.id, # This ID is of the PortfolioUTT record itself utt_fund_id=holding.utt_fund.id, fund_symbol=holding.utt_fund.symbol, fund_name=holding.utt_fund.name, units_held=holding.units_held, purchase_price=holding.purchase_price, # Average unit purchase price purchase_date=holding.purchase_date, # Date of first/last buy or as defined current_nav=current_nav, market_value=market_value, gain_loss=gain_loss, gain_loss_percentage=gain_loss_percentage, notes=holding.notes, created_at=holding.created_at, ) ) return results @staticmethod async def _get_bond_holdings_with_values( portfolio_id: int, ) -> List[BondHoldingResponse]: holdings_orm = ( await PortfolioBond.filter(portfolio_id=portfolio_id) .prefetch_related("bond") .all() ) results = [] for holding in holdings_orm: # holding is now an aggregated record current_price_percentage = ( holding.bond.price_per_100 if hasattr(holding.bond, "price_per_100") and holding.bond.price_per_100 else Decimal("100") ) market_value = Decimal( holding.face_value_held * current_price_percentage ) / Decimal("100") # print(f"cu") # holding.purchase_price on PortfolioBond model is the TOTAL cost of this aggregated holding cost_basis = holding.purchase_price gain_loss = ( (market_value - cost_basis) if market_value is not None else None ) results.append( BondHoldingResponse( id=holding.id, # This ID is of the PortfolioBond record itself bond_id=holding.bond.id, instrument_type=holding.bond.instrument_type, auction_number=( holding.bond.auction_number if hasattr(holding.bond, "auction_number") else None ), maturity_date=holding.bond.maturity_date, face_value_held=holding.face_value_held, purchase_price=cost_basis, # Reporting total purchase price of this holding purchase_date=holding.purchase_date, # Date of first/last buy or as defined current_price=current_price_percentage, market_value=market_value, accrued_interest=None, yield_to_maturity=None, gain_loss=gain_loss, notes=holding.notes, created_at=holding.created_at, ) ) return results @staticmethod async def add_stock_to_portfolio( portfolio_id: int, stock_id: int, quantity_to_add: Decimal, # Quantity for this specific purchase purchase_price_of_lot: Decimal, # Unit price for this specific purchase purchase_date: date, notes: Optional[str] = None, ) -> PortfolioStock: stock_obj = await Stock.get_or_none(id=stock_id) if not stock_obj: raise DoesNotExist("Stock not found") if quantity_to_add <= 0: raise AppException( status_code=400, detail="Quantity to add must be positive." ) async with in_transaction(): holding = await PortfolioStock.get_or_none( portfolio_id=portfolio_id, stock_id=stock_id ) if holding: # Update existing aggregated holding new_total_cost = (holding.quantity * holding.purchase_price) + ( quantity_to_add * purchase_price_of_lot ) holding.quantity += quantity_to_add if holding.quantity > 0: holding.purchase_price = ( new_total_cost / holding.quantity ) # New average price else: # Should not happen if quantity_to_add is positive holding.purchase_price = purchase_price_of_lot holding.purchase_date = purchase_date # Update to latest purchase_date if notes: holding.notes = ( f"{holding.notes}\n{notes}".strip() if holding.notes else notes ) await holding.save() else: # Create new holding holding = await PortfolioStock.create( portfolio_id=portfolio_id, stock=stock_obj, quantity=quantity_to_add, purchase_price=purchase_price_of_lot, # Initial average price is this lot's price purchase_date=purchase_date, notes=notes, ) await PortfolioTransaction.create( portfolio_id=portfolio_id, transaction_type="BUY", asset_type="STOCK", asset_id=stock_obj.id, asset_name=stock_obj.symbol, quantity=quantity_to_add, price=purchase_price_of_lot, total_amount=quantity_to_add * purchase_price_of_lot, transaction_date=purchase_date, notes=notes or f"Bought {quantity_to_add} shares of {stock_obj.symbol}", ) return holding @staticmethod async def sell_stock_holding( portfolio_id: int, stock_id: int, # This is the asset_id quantity_to_sell: Decimal, sell_price: Decimal, sell_date: date, notes: Optional[str] = None, ) -> PortfolioTransaction: # Fetch the stock object to ensure it exists (optional, but good practice) # stock_obj = await Stock.get_or_none(id=stock_id) # if not stock_obj: # raise DoesNotExist("Stock definition not found.") # Fetch the aggregated holding by portfolio_id and stock_id holding = await PortfolioStock.get_or_none( portfolio_id=portfolio_id, stock_id=stock_id ).prefetch_related( "stock" ) # prefetch_related is good if you need stock.symbol etc. if not holding: raise DoesNotExist("Stock holding not found in this portfolio.") if quantity_to_sell <= 0: raise AppException( status_code=400, detail="Quantity to sell must be positive." ) if holding.quantity < quantity_to_sell: raise AppException( status_code=400, detail=f"Not enough shares to sell. Currently hold {holding.quantity}, trying to sell {quantity_to_sell}.", ) async with in_transaction(): transaction = await PortfolioTransaction.create( portfolio_id=portfolio_id, transaction_type="SELL", asset_type="STOCK", asset_id=holding.stock.id, # stock_id asset_name=holding.stock.symbol, quantity=quantity_to_sell, price=sell_price, total_amount=quantity_to_sell * sell_price, transaction_date=sell_date, notes=notes or f"Sold {quantity_to_sell} shares of {holding.stock.symbol}", ) holding.quantity -= quantity_to_sell # The average purchase_price of the holding does not change upon selling. if holding.quantity == 0: await holding.delete() else: await holding.save() return transaction @staticmethod async def add_utt_to_portfolio( portfolio_id: int, utt_fund_id: int, units_to_add: Decimal, # Units for this specific purchase purchase_price_of_lot: Decimal, # Unit price for this specific purchase purchase_date: date, notes: Optional[str] = None, ) -> PortfolioUTT: utt_fund_obj = await UTTFund.get_or_none(id=utt_fund_id) if not utt_fund_obj: raise DoesNotExist("UTT Fund not found") if units_to_add <= 0: raise AppException(status_code=400, detail="Units to add must be positive.") async with in_transaction(): holding = await PortfolioUTT.get_or_none( portfolio_id=portfolio_id, utt_fund_id=utt_fund_id ) if holding: # Update existing aggregated holding new_total_cost = (holding.units_held * holding.purchase_price) + ( units_to_add * purchase_price_of_lot ) holding.units_held += units_to_add if holding.units_held > 0: holding.purchase_price = ( new_total_cost / holding.units_held ) # New average price else: holding.purchase_price = purchase_price_of_lot holding.purchase_date = purchase_date # Update to latest purchase_date if notes: holding.notes = ( f"{holding.notes}\n{notes}".strip() if holding.notes else notes ) await holding.save() else: # Create new holding holding = await PortfolioUTT.create( portfolio_id=portfolio_id, utt_fund=utt_fund_obj, units_held=units_to_add, purchase_price=purchase_price_of_lot, # Initial average price purchase_date=purchase_date, notes=notes, ) await PortfolioTransaction.create( portfolio_id=portfolio_id, transaction_type="BUY", asset_type="UTT", asset_id=utt_fund_obj.id, asset_name=utt_fund_obj.symbol, quantity=units_to_add, price=purchase_price_of_lot, total_amount=units_to_add * purchase_price_of_lot, transaction_date=purchase_date, notes=notes or f"Bought {units_to_add} units of {utt_fund_obj.symbol}", ) return holding @staticmethod async def sell_utt_holding( portfolio_id: int, utt_fund_id: int, # Changed from holding_id to asset_id units_to_sell: Decimal, sell_price: Decimal, sell_date: date, notes: Optional[str] = None, ) -> PortfolioTransaction: holding = await PortfolioUTT.get_or_none( portfolio_id=portfolio_id, utt_fund_id=utt_fund_id ).prefetch_related("utt_fund") if not holding: raise DoesNotExist("UTT holding not found for this fund in the portfolio.") if units_to_sell <= 0: raise AppException( status_code=400, detail="Units to sell must be positive." ) if holding.units_held < units_to_sell: raise AppException( status_code=400, detail=f"Not enough units to sell. Currently hold {holding.units_held}, trying to sell {units_to_sell}.", ) async with in_transaction(): transaction = await PortfolioTransaction.create( portfolio_id=portfolio_id, transaction_type="SELL", asset_type="UTT", asset_id=holding.utt_fund.id, # This is utt_fund_id asset_name=holding.utt_fund.symbol, quantity=units_to_sell, price=sell_price, total_amount=units_to_sell * sell_price, transaction_date=sell_date, notes=notes or f"Sold {units_to_sell} units of {holding.utt_fund.symbol}", ) holding.units_held -= units_to_sell # Average purchase_price of the holding remains unchanged. if holding.units_held == 0: await holding.delete() else: await holding.save() return transaction @staticmethod async def add_bond_to_portfolio( portfolio_id: int, bond_id: int, face_value_to_add: Decimal, # Face value for this specific purchase total_purchase_price_of_lot: Decimal, # TOTAL purchase price for this face_value_to_add purchase_date: date, notes: Optional[str] = None, ) -> PortfolioBond: bond_obj = await Bond.get_or_none(id=bond_id) if not bond_obj: raise DoesNotExist("Bond not found") if face_value_to_add <= 0: raise AppException( status_code=400, detail="Face value to add must be positive." ) async with in_transaction(): holding = await PortfolioBond.get_or_none( portfolio_id=portfolio_id, bond_id=bond_id ) if holding: # Update existing aggregated holding holding.face_value_held += face_value_to_add holding.purchase_price += ( total_purchase_price_of_lot # Add total cost to existing total cost ) holding.purchase_date = purchase_date # Update to latest purchase_date if notes: holding.notes = ( f"{holding.notes}\n{notes}".strip() if holding.notes else notes ) await holding.save() else: # Create new holding holding = await PortfolioBond.create( portfolio_id=portfolio_id, bond=bond_obj, face_value_held=face_value_to_add, purchase_price=total_purchase_price_of_lot, # Storing total cost for this initial lot purchase_date=purchase_date, notes=notes, ) unit_price_for_transaction = ( total_purchase_price_of_lot / face_value_to_add if face_value_to_add > 0 else Decimal("0") ) await PortfolioTransaction.create( portfolio_id=portfolio_id, transaction_type="BUY", asset_type="BOND", asset_id=bond_obj.id, asset_name=f"Bond {bond_obj.auction_number or bond_obj.id}", quantity=face_value_to_add, price=unit_price_for_transaction, total_amount=total_purchase_price_of_lot, transaction_date=purchase_date, notes=notes or f"Bought {face_value_to_add} face value of Bond {bond_obj.auction_number or bond_obj.id}", ) return holding @staticmethod async def sell_bond_holding( portfolio_id: int, bond_id: int, # Changed from holding_id to asset_id face_value_to_sell: Decimal, sell_price_total: Decimal, # This is TOTAL proceeds for the face_value_to_sell sell_date: date, notes: Optional[str] = None, ) -> PortfolioTransaction: holding = await PortfolioBond.get_or_none( portfolio_id=portfolio_id, bond_id=bond_id ).prefetch_related("bond") if not holding: raise DoesNotExist("Bond holding not found for this bond in the portfolio.") if face_value_to_sell <= 0: raise AppException( status_code=400, detail="Face value to sell must be positive." ) if holding.face_value_held < face_value_to_sell: raise AppException( status_code=400, detail=f"Not enough face value to sell. Currently hold {holding.face_value_held}, trying to sell {face_value_to_sell}.", ) async with in_transaction(): unit_sell_price = ( sell_price_total / face_value_to_sell if face_value_to_sell > 0 else Decimal("0") ) transaction = await PortfolioTransaction.create( portfolio_id=portfolio_id, transaction_type="SELL", asset_type="BOND", asset_id=holding.bond.id, # This is bond_id asset_name=f"Bond {holding.bond.auction_number or holding.bond.id}", quantity=face_value_to_sell, price=unit_sell_price, total_amount=sell_price_total, transaction_date=sell_date, notes=notes or f"Sold {face_value_to_sell} face value of Bond {holding.bond.auction_number or holding.bond.id}", ) original_face_value_held = holding.face_value_held original_total_purchase_price = holding.purchase_price holding.face_value_held -= face_value_to_sell if holding.face_value_held == Decimal( "0" ): # Ensure exact zero comparison for Decimal await holding.delete() else: # Update the total purchase_price proportionally for the remaining face_value_held if original_face_value_held > 0: holding.purchase_price = ( holding.face_value_held / original_face_value_held ) * original_total_purchase_price else: holding.purchase_price = Decimal( "0" ) # Should not be reached if logic is correct await holding.save() return transaction @staticmethod async def remove_holding( portfolio_id: int, asset_type_str: str, asset_id_value: int ) -> bool: """ Remove an aggregated holding from portfolio. This is a hard delete. asset_id_value corresponds to stock_id, utt_fund_id, or bond_id. """ model_to_delete = None asset_id_field_name = None if asset_type_str.upper() == "STOCK": model_to_delete = PortfolioStock asset_id_field_name = "stock_id" elif asset_type_str.upper() == "UTT": model_to_delete = PortfolioUTT asset_id_field_name = "utt_fund_id" elif asset_type_str.upper() == "BOND": model_to_delete = PortfolioBond asset_id_field_name = "bond_id" else: raise AppException( status_code=400, detail=f"Unknown asset type: {asset_type_str}" ) filter_kwargs = { "portfolio_id": portfolio_id, asset_id_field_name: asset_id_value, } deleted_count = await model_to_delete.filter(**filter_kwargs).delete() return deleted_count > 0 @staticmethod async def create_portfolio_snapshot( portfolio_id: int, snapshot_date_input: Optional[date] = None ) -> PortfolioSnapshot: """ Creates or updates a daily snapshot of portfolio performance for a specific date. This function correctly calculates historical values by: 1. Determining the holdings that existed in the portfolio on the target_date. 2. Fetching the last known market price for each of those holdings as of the target_date. 3. Aggregating the values to create a point-in-time snapshot. """ target_date: date = date.today() if snapshot_date_input: if isinstance(snapshot_date_input, datetime): target_date = snapshot_date_input.date() else: target_date = snapshot_date_input # --- Initialize accumulators --- total_market_value = Decimal("0.0") total_cost_basis = Decimal("0.0") stock_val = Decimal("0.0") bond_val = Decimal("0.0") utt_val = Decimal("0.0") # --- 1. Process Stock Holdings --- # Get all stock holdings purchased on or before the target date stock_holdings = await PortfolioStock.filter( portfolio_id=portfolio_id, purchase_date__lte=target_date ).select_related("stock") for holding in stock_holdings: # Find the most recent price for this stock on or before the target_date price_data = ( await StockPriceData.filter( stock_id=holding.stock_id, date__lte=target_date ) .order_by("-date") .first() ) if price_data and price_data.closing_price is not None: holding_market_value = ( Decimal(holding.quantity) * price_data.closing_price ) stock_val += holding_market_value # The cost basis is the sum of purchase prices for all holdings that existed at that time total_cost_basis += holding.purchase_price # --- 2. Process UTT Holdings --- utt_holdings = await PortfolioUTT.filter( portfolio_id=portfolio_id, purchase_date__lte=target_date ).select_related("utt_fund") for holding in utt_holdings: # Find the most recent NAV for this fund on or before the target_date price_data = ( await UTTFundData.filter( fund_id=holding.utt_fund_id, date__lte=target_date ) .order_by("-date") .first() ) if price_data and price_data.nav_per_unit is not None: # Safely convert float to Decimal holding_market_value = holding.units_held * Decimal( str(price_data.nav_per_unit) ) utt_val += holding_market_value total_cost_basis += holding.purchase_price # --- 3. Process Bond Holdings --- bond_holdings = await PortfolioBond.filter( portfolio_id=portfolio_id, purchase_date__lte=target_date ).select_related("bond") for holding in bond_holdings: # NOTE: Bond valuation is complex. The current `Bond` model does not store historical prices. # A simplified valuation is used here: market value is assumed to be the face value. # For a more advanced system, a separate `BondPriceData` table would be needed. holding_market_value = Decimal(holding.face_value_held) bond_val += holding_market_value total_cost_basis += holding.purchase_price # --- Aggregate all values --- total_market_value = stock_val + bond_val + utt_val unrealized_gain_loss = total_market_value - total_cost_basis # --- Create or Update the snapshot for the target_date --- # This prevents duplicate snapshots if the task runs multiple times. snapshot_datetime = datetime.combine(target_date, datetime.min.time()) snapshot, created = await PortfolioSnapshot.update_or_create( portfolio_id=portfolio_id, snapshot_date=snapshot_datetime, defaults={ "total_value": total_market_value, "stock_value": stock_val, "bond_value": bond_val, "utt_value": utt_val, "cash_value": Decimal("0.0"), # Assuming cash isn't tracked yet "total_cost": total_cost_basis, "unrealized_gain_loss": unrealized_gain_loss, }, ) if created: print(f"Created snapshot for portfolio {portfolio_id} on {target_date}") else: print(f"Updated snapshot for portfolio {portfolio_id} on {target_date}") return snapshot @staticmethod async def regenerate_snapshots_task( task_id: int, portfolio_id: int, start_date: date = None ): """ A robust background task that generates or regenerates historical portfolio snapshots. - If a 'start_date' is provided (e.g., from a back-dated transaction), it will start from there. - If 'start_date' is None, it will intelligently find the date of the very first transaction in the portfolio and start from that point, ensuring all possible data is generated. - It always deletes existing snapshots in the target date range before creating new ones to prevent duplicates and ensure data is fresh. """ await ImportTask.filter(id=task_id).update(status="running") try: # 1. DETERMINE THE START DATE # If no specific start date is given, find the earliest transaction for this portfolio. if not start_date: first_transaction = ( await PortfolioTransaction.filter(portfolio_id=portfolio_id) .order_by("transaction_date") .first() ) if first_transaction: start_date = first_transaction.transaction_date print( f"[Task {task_id}] No start date provided. Found earliest transaction on {start_date}." ) else: # If there are no transactions, there's nothing to snapshot. await ImportTask.filter(id=task_id).update( status="completed", details={ "message": "No transactions found in portfolio. Nothing to generate." }, ) print( f"[Task {task_id}] No transactions for portfolio {portfolio_id}. Task complete." ) return end_date = date.today() print( f"[Task {task_id}] Starting snapshot generation for portfolio {portfolio_id} from {start_date} to {end_date}" ) # 2. INVALIDATE: Delete all stale snapshots in the date range to ensure a clean slate. start_datetime = datetime.combine(start_date, datetime.min.time()) deleted_count = await PortfolioSnapshot.filter( portfolio_id=portfolio_id, snapshot_date__gte=start_datetime ).delete() print( f"[Task {task_id}] Invalidated and deleted {deleted_count} stale snapshots." ) # 3. REGENERATE: Loop from the start date to today and recreate each snapshot. def date_range(start, end): # Helper to iterate through a range of dates. for n in range(int((end - start).days) + 1): yield start + timedelta(n) generated_count = 0 failed_days = [] for single_date in date_range(start_date, end_date): try: # This calls the other service method responsible for calculating and saving # a single day's snapshot. await PortfolioService.create_portfolio_snapshot( portfolio_id=portfolio_id, snapshot_date_input=single_date ) print( f"[Task {task_id}] Successfully generated snapshot for {single_date.isoformat()}" ) generated_count += 1 except Exception as e: # If one day fails (e.g., missing price data), log it and continue. failed_days.append(single_date.isoformat()) print( f"[Task {task_id}] WARNING: Could not generate snapshot for {single_date}: {e}" ) # 4. FINALIZE: Update the task with a summary of the operation. summary = { "message": "Snapshot generation complete.", "deleted_stale_snapshots": deleted_count, "new_snapshots_generated": generated_count, "failed_days_count": len(failed_days), "failed_days": failed_days, "date_range": f"{start_date.isoformat()} to {end_date.isoformat()}", } await ImportTask.filter(id=task_id).update( status="completed", details=summary ) print(f"[Task {task_id}] Completed successfully. Summary: {summary}") except Exception as e: # Catch any fatal error during the task and mark it as failed. await ImportTask.filter(id=task_id).update( status="failed", details={ "error": f"A fatal error occurred during snapshot regeneration: {str(e)}" }, ) print(f"[Task {task_id}] FAILED with a fatal error: {e}")