Edwin Salguero
commited on
Commit
·
2c67d05
1
Parent(s):
859af74
Add FinRL integration with comprehensive RL trading agent
Browse files- Add FinRL agent with support for PPO, A2C, DDPG, and TD3 algorithms
- Create custom trading environment compatible with Gymnasium
- Implement technical indicators integration (RSI, Bollinger Bands, MACD)
- Add comprehensive configuration system for FinRL parameters
- Create demo script with training, evaluation, and visualization
- Add comprehensive test suite for FinRL functionality
- Update requirements.txt with FinRL dependencies
- Update README with detailed FinRL documentation
- Create necessary directories for models, logs, and plots
- README.md +133 -1
- agentic_ai_system/finrl_agent.py +447 -0
- config.yaml +27 -0
- finrl_demo.py +294 -0
- requirements.txt +5 -0
- tests/test_finrl_agent.py +373 -0
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# Algorithmic Trading System
|
| 2 |
|
| 3 |
-
A comprehensive algorithmic trading system with synthetic data generation, comprehensive logging,
|
| 4 |
|
| 5 |
## Features
|
| 6 |
|
|
@@ -10,6 +10,15 @@ A comprehensive algorithmic trading system with synthetic data generation, compr
|
|
| 10 |
- **Risk Management**: Position sizing and drawdown limits
|
| 11 |
- **Order Execution**: Simulated broker integration with realistic execution delays
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
### Synthetic Data Generation
|
| 14 |
- **Realistic Market Data**: Generate OHLCV data using geometric Brownian motion
|
| 15 |
- **Multiple Frequencies**: Support for 1min, 5min, 1H, and 1D data
|
|
@@ -226,6 +235,129 @@ logger.warning("High volatility detected")
|
|
| 226 |
logger.error("Order execution failed", exc_info=True)
|
| 227 |
```
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
## Testing
|
| 230 |
|
| 231 |
### Test Structure
|
|
|
|
| 1 |
# Algorithmic Trading System
|
| 2 |
|
| 3 |
+
A comprehensive algorithmic trading system with synthetic data generation, comprehensive logging, extensive testing capabilities, and FinRL reinforcement learning integration.
|
| 4 |
|
| 5 |
## Features
|
| 6 |
|
|
|
|
| 10 |
- **Risk Management**: Position sizing and drawdown limits
|
| 11 |
- **Order Execution**: Simulated broker integration with realistic execution delays
|
| 12 |
|
| 13 |
+
### FinRL Reinforcement Learning
|
| 14 |
+
- **Multiple RL Algorithms**: Support for PPO, A2C, DDPG, and TD3
|
| 15 |
+
- **Custom Trading Environment**: Gymnasium-compatible environment for RL training
|
| 16 |
+
- **Technical Indicators Integration**: Automatic calculation and inclusion of technical indicators
|
| 17 |
+
- **Portfolio Management**: Realistic portfolio simulation with transaction costs
|
| 18 |
+
- **Model Persistence**: Save and load trained models for inference
|
| 19 |
+
- **TensorBoard Integration**: Training progress visualization and monitoring
|
| 20 |
+
- **Comprehensive Evaluation**: Performance metrics including Sharpe ratio and total returns
|
| 21 |
+
|
| 22 |
### Synthetic Data Generation
|
| 23 |
- **Realistic Market Data**: Generate OHLCV data using geometric Brownian motion
|
| 24 |
- **Multiple Frequencies**: Support for 1min, 5min, 1H, and 1D data
|
|
|
|
| 235 |
logger.error("Order execution failed", exc_info=True)
|
| 236 |
```
|
| 237 |
|
| 238 |
+
## FinRL Integration
|
| 239 |
+
|
| 240 |
+
### Overview
|
| 241 |
+
The system now includes FinRL (Financial Reinforcement Learning) integration, providing state-of-the-art reinforcement learning capabilities for algorithmic trading. The FinRL agent can learn optimal trading strategies through interaction with a simulated market environment.
|
| 242 |
+
|
| 243 |
+
### Supported Algorithms
|
| 244 |
+
- **PPO (Proximal Policy Optimization)**: Stable policy gradient method
|
| 245 |
+
- **A2C (Advantage Actor-Critic)**: Actor-critic method with advantage estimation
|
| 246 |
+
- **DDPG (Deep Deterministic Policy Gradient)**: Continuous action space algorithm
|
| 247 |
+
- **TD3 (Twin Delayed DDPG)**: Improved version of DDPG with twin critics
|
| 248 |
+
|
| 249 |
+
### Trading Environment
|
| 250 |
+
The custom trading environment provides:
|
| 251 |
+
- **Action Space**: Discrete actions (0=Buy, 1=Hold, 2=Sell)
|
| 252 |
+
- **Observation Space**: OHLCV data + technical indicators + portfolio state
|
| 253 |
+
- **Reward Function**: Portfolio return-based rewards
|
| 254 |
+
- **Transaction Costs**: Realistic trading fees and slippage
|
| 255 |
+
- **Position Limits**: Maximum position constraints
|
| 256 |
+
|
| 257 |
+
### Usage Examples
|
| 258 |
+
|
| 259 |
+
#### Basic FinRL Training
|
| 260 |
+
```python
|
| 261 |
+
from agentic_ai_system.finrl_agent import FinRLAgent, FinRLConfig
|
| 262 |
+
import pandas as pd
|
| 263 |
+
|
| 264 |
+
# Create configuration
|
| 265 |
+
config = FinRLConfig(
|
| 266 |
+
algorithm="PPO",
|
| 267 |
+
learning_rate=0.0003,
|
| 268 |
+
batch_size=64,
|
| 269 |
+
total_timesteps=100000
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Initialize agent
|
| 273 |
+
agent = FinRLAgent(config)
|
| 274 |
+
|
| 275 |
+
# Train the agent
|
| 276 |
+
training_result = agent.train(
|
| 277 |
+
data=market_data,
|
| 278 |
+
total_timesteps=100000,
|
| 279 |
+
eval_freq=10000
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Generate predictions
|
| 283 |
+
predictions = agent.predict(test_data)
|
| 284 |
+
|
| 285 |
+
# Evaluate performance
|
| 286 |
+
evaluation = agent.evaluate(test_data)
|
| 287 |
+
print(f"Total Return: {evaluation['total_return']:.2%}")
|
| 288 |
+
```
|
| 289 |
+
|
| 290 |
+
#### Using Configuration File
|
| 291 |
+
```python
|
| 292 |
+
from agentic_ai_system.finrl_agent import create_finrl_agent_from_config
|
| 293 |
+
|
| 294 |
+
# Create agent from config file
|
| 295 |
+
agent = create_finrl_agent_from_config('config.yaml')
|
| 296 |
+
|
| 297 |
+
# Train and evaluate
|
| 298 |
+
agent.train(market_data)
|
| 299 |
+
results = agent.evaluate(test_data)
|
| 300 |
+
```
|
| 301 |
+
|
| 302 |
+
#### Running FinRL Demo
|
| 303 |
+
```bash
|
| 304 |
+
# Run the complete FinRL demo
|
| 305 |
+
python finrl_demo.py
|
| 306 |
+
|
| 307 |
+
# This will:
|
| 308 |
+
# 1. Generate synthetic training and test data
|
| 309 |
+
# 2. Train a FinRL agent
|
| 310 |
+
# 3. Evaluate performance
|
| 311 |
+
# 4. Generate trading predictions
|
| 312 |
+
# 5. Create visualization plots
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
### Configuration
|
| 316 |
+
FinRL settings can be configured in `config.yaml`:
|
| 317 |
+
|
| 318 |
+
```yaml
|
| 319 |
+
finrl:
|
| 320 |
+
algorithm: 'PPO' # PPO, A2C, DDPG, TD3
|
| 321 |
+
learning_rate: 0.0003
|
| 322 |
+
batch_size: 64
|
| 323 |
+
buffer_size: 1000000
|
| 324 |
+
gamma: 0.99
|
| 325 |
+
tensorboard_log: 'logs/finrl_tensorboard'
|
| 326 |
+
training:
|
| 327 |
+
total_timesteps: 100000
|
| 328 |
+
eval_freq: 10000
|
| 329 |
+
save_best_model: true
|
| 330 |
+
model_save_path: 'models/finrl_best/'
|
| 331 |
+
inference:
|
| 332 |
+
use_trained_model: false
|
| 333 |
+
model_path: 'models/finrl_best/best_model'
|
| 334 |
+
```
|
| 335 |
+
|
| 336 |
+
### Model Management
|
| 337 |
+
```python
|
| 338 |
+
# Save trained model
|
| 339 |
+
agent.save_model('models/my_finrl_model')
|
| 340 |
+
|
| 341 |
+
# Load pre-trained model
|
| 342 |
+
agent.load_model('models/my_finrl_model')
|
| 343 |
+
|
| 344 |
+
# Continue training
|
| 345 |
+
agent.train(more_data, total_timesteps=50000)
|
| 346 |
+
```
|
| 347 |
+
|
| 348 |
+
### Performance Monitoring
|
| 349 |
+
- **TensorBoard Integration**: Monitor training progress
|
| 350 |
+
- **Evaluation Metrics**: Total return, Sharpe ratio, portfolio value
|
| 351 |
+
- **Trading Statistics**: Buy/sell signal analysis
|
| 352 |
+
- **Visualization**: Price charts with trading signals
|
| 353 |
+
|
| 354 |
+
### Advanced Features
|
| 355 |
+
- **Multi-timeframe Support**: Train on different data frequencies
|
| 356 |
+
- **Feature Engineering**: Automatic technical indicator calculation
|
| 357 |
+
- **Risk Management**: Built-in position and drawdown limits
|
| 358 |
+
- **Backtesting**: Comprehensive backtesting capabilities
|
| 359 |
+
- **Hyperparameter Tuning**: Easy configuration for different algorithms
|
| 360 |
+
|
| 361 |
## Testing
|
| 362 |
|
| 363 |
### Test Structure
|
agentic_ai_system/finrl_agent.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FinRL Agent for Algorithmic Trading
|
| 3 |
+
|
| 4 |
+
This module provides a FinRL-based reinforcement learning agent that can be integrated
|
| 5 |
+
with the existing algorithmic trading system. It supports various RL algorithms
|
| 6 |
+
including PPO, A2C, DDPG, and TD3.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import gymnasium as gym
|
| 12 |
+
from gymnasium import spaces
|
| 13 |
+
from stable_baselines3 import PPO, A2C, DDPG, TD3
|
| 14 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
| 15 |
+
from stable_baselines3.common.callbacks import EvalCallback
|
| 16 |
+
import torch
|
| 17 |
+
import logging
|
| 18 |
+
from typing import Dict, List, Tuple, Optional, Any
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
import yaml
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class FinRLConfig:
|
| 27 |
+
"""Configuration for FinRL agent"""
|
| 28 |
+
algorithm: str = "PPO" # PPO, A2C, DDPG, TD3
|
| 29 |
+
learning_rate: float = 0.0003
|
| 30 |
+
batch_size: int = 64
|
| 31 |
+
buffer_size: int = 1000000
|
| 32 |
+
learning_starts: int = 100
|
| 33 |
+
gamma: float = 0.99
|
| 34 |
+
tau: float = 0.005
|
| 35 |
+
train_freq: int = 1
|
| 36 |
+
gradient_steps: int = 1
|
| 37 |
+
target_update_interval: int = 1
|
| 38 |
+
exploration_fraction: float = 0.1
|
| 39 |
+
exploration_initial_eps: float = 1.0
|
| 40 |
+
exploration_final_eps: float = 0.05
|
| 41 |
+
max_grad_norm: float = 10.0
|
| 42 |
+
verbose: int = 1
|
| 43 |
+
tensorboard_log: str = "logs/finrl_tensorboard"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TradingEnvironment(gym.Env):
|
| 47 |
+
"""
|
| 48 |
+
Custom trading environment for FinRL
|
| 49 |
+
|
| 50 |
+
This environment simulates a trading scenario where the agent can:
|
| 51 |
+
- Buy, sell, or hold positions
|
| 52 |
+
- Use technical indicators for decision making
|
| 53 |
+
- Manage portfolio value and risk
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, data: pd.DataFrame, initial_balance: float = 100000,
|
| 57 |
+
transaction_fee: float = 0.001, max_position: int = 100):
|
| 58 |
+
super().__init__()
|
| 59 |
+
|
| 60 |
+
self.data = data
|
| 61 |
+
self.initial_balance = initial_balance
|
| 62 |
+
self.transaction_fee = transaction_fee
|
| 63 |
+
self.max_position = max_position
|
| 64 |
+
|
| 65 |
+
# Reset state
|
| 66 |
+
self.reset()
|
| 67 |
+
|
| 68 |
+
# Define action space: [-1, 0, 1] for sell, hold, buy
|
| 69 |
+
self.action_space = spaces.Discrete(3)
|
| 70 |
+
|
| 71 |
+
# Define observation space
|
| 72 |
+
# Features: OHLCV + technical indicators + portfolio state
|
| 73 |
+
n_features = len(self._get_features(self.data.iloc[0]))
|
| 74 |
+
self.observation_space = spaces.Box(
|
| 75 |
+
low=-np.inf, high=np.inf, shape=(n_features,), dtype=np.float32
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def _get_features(self, row: pd.Series) -> np.ndarray:
|
| 79 |
+
"""Extract features from market data row"""
|
| 80 |
+
features = []
|
| 81 |
+
|
| 82 |
+
# Price features
|
| 83 |
+
features.extend([
|
| 84 |
+
row['open'], row['high'], row['low'], row['close'], row['volume']
|
| 85 |
+
])
|
| 86 |
+
|
| 87 |
+
# Technical indicators (if available)
|
| 88 |
+
for indicator in ['sma_20', 'sma_50', 'rsi', 'bb_upper', 'bb_lower', 'macd']:
|
| 89 |
+
if indicator in row.index:
|
| 90 |
+
features.append(row[indicator])
|
| 91 |
+
else:
|
| 92 |
+
features.append(0.0)
|
| 93 |
+
|
| 94 |
+
# Portfolio state
|
| 95 |
+
features.extend([
|
| 96 |
+
self.balance,
|
| 97 |
+
self.position,
|
| 98 |
+
self.portfolio_value,
|
| 99 |
+
self.total_return
|
| 100 |
+
])
|
| 101 |
+
|
| 102 |
+
return np.array(features, dtype=np.float32)
|
| 103 |
+
|
| 104 |
+
def _calculate_portfolio_value(self) -> float:
|
| 105 |
+
"""Calculate current portfolio value"""
|
| 106 |
+
current_price = self.data.iloc[self.current_step]['close']
|
| 107 |
+
return self.balance + (self.position * current_price)
|
| 108 |
+
|
| 109 |
+
def _calculate_reward(self) -> float:
|
| 110 |
+
"""Calculate reward based on portfolio performance"""
|
| 111 |
+
current_value = self._calculate_portfolio_value()
|
| 112 |
+
previous_value = self.previous_portfolio_value
|
| 113 |
+
|
| 114 |
+
# Calculate return
|
| 115 |
+
if previous_value > 0:
|
| 116 |
+
return (current_value - previous_value) / previous_value
|
| 117 |
+
else:
|
| 118 |
+
return 0.0
|
| 119 |
+
|
| 120 |
+
def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
|
| 121 |
+
"""Execute one step in the environment"""
|
| 122 |
+
|
| 123 |
+
# Get current market data
|
| 124 |
+
current_data = self.data.iloc[self.current_step]
|
| 125 |
+
current_price = current_data['close']
|
| 126 |
+
|
| 127 |
+
# Execute action
|
| 128 |
+
if action == 0: # Sell
|
| 129 |
+
if self.position > 0:
|
| 130 |
+
shares_to_sell = min(self.position, self.max_position)
|
| 131 |
+
sell_value = shares_to_sell * current_price * (1 - self.transaction_fee)
|
| 132 |
+
self.balance += sell_value
|
| 133 |
+
self.position -= shares_to_sell
|
| 134 |
+
elif action == 2: # Buy
|
| 135 |
+
if self.balance > 0:
|
| 136 |
+
max_shares = min(
|
| 137 |
+
int(self.balance / current_price),
|
| 138 |
+
self.max_position - self.position
|
| 139 |
+
)
|
| 140 |
+
if max_shares > 0:
|
| 141 |
+
buy_value = max_shares * current_price * (1 + self.transaction_fee)
|
| 142 |
+
self.balance -= buy_value
|
| 143 |
+
self.position += max_shares
|
| 144 |
+
|
| 145 |
+
# Update portfolio value
|
| 146 |
+
self.previous_portfolio_value = self.portfolio_value
|
| 147 |
+
self.portfolio_value = self._calculate_portfolio_value()
|
| 148 |
+
self.total_return = (self.portfolio_value - self.initial_balance) / self.initial_balance
|
| 149 |
+
|
| 150 |
+
# Calculate reward
|
| 151 |
+
reward = self._calculate_reward()
|
| 152 |
+
|
| 153 |
+
# Move to next step
|
| 154 |
+
self.current_step += 1
|
| 155 |
+
|
| 156 |
+
# Check if episode is done
|
| 157 |
+
done = self.current_step >= len(self.data) - 1
|
| 158 |
+
|
| 159 |
+
# Get observation
|
| 160 |
+
if not done:
|
| 161 |
+
observation = self._get_features(self.data.iloc[self.current_step])
|
| 162 |
+
else:
|
| 163 |
+
# Use last available data for final observation
|
| 164 |
+
observation = self._get_features(self.data.iloc[-1])
|
| 165 |
+
|
| 166 |
+
info = {
|
| 167 |
+
'balance': self.balance,
|
| 168 |
+
'position': self.position,
|
| 169 |
+
'portfolio_value': self.portfolio_value,
|
| 170 |
+
'total_return': self.total_return,
|
| 171 |
+
'current_price': current_price
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
return observation, reward, done, False, info
|
| 175 |
+
|
| 176 |
+
def reset(self, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
|
| 177 |
+
"""Reset the environment"""
|
| 178 |
+
super().reset(seed=seed)
|
| 179 |
+
|
| 180 |
+
self.current_step = 0
|
| 181 |
+
self.balance = self.initial_balance
|
| 182 |
+
self.position = 0
|
| 183 |
+
self.portfolio_value = self.initial_balance
|
| 184 |
+
self.previous_portfolio_value = self.initial_balance
|
| 185 |
+
self.total_return = 0.0
|
| 186 |
+
|
| 187 |
+
observation = self._get_features(self.data.iloc[self.current_step])
|
| 188 |
+
info = {
|
| 189 |
+
'balance': self.balance,
|
| 190 |
+
'position': self.position,
|
| 191 |
+
'portfolio_value': self.portfolio_value,
|
| 192 |
+
'total_return': self.total_return
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
return observation, info
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class FinRLAgent:
|
| 199 |
+
"""
|
| 200 |
+
FinRL-based reinforcement learning agent for algorithmic trading
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(self, config: FinRLConfig):
|
| 204 |
+
self.config = config
|
| 205 |
+
self.model = None
|
| 206 |
+
self.env = None
|
| 207 |
+
self.eval_env = None
|
| 208 |
+
self.callback = None
|
| 209 |
+
|
| 210 |
+
logger.info(f"Initializing FinRL agent with algorithm: {config.algorithm}")
|
| 211 |
+
|
| 212 |
+
def create_environment(self, data: pd.DataFrame, initial_balance: float = 100000) -> TradingEnvironment:
|
| 213 |
+
"""Create trading environment from market data"""
|
| 214 |
+
return TradingEnvironment(
|
| 215 |
+
data=data,
|
| 216 |
+
initial_balance=initial_balance,
|
| 217 |
+
transaction_fee=0.001,
|
| 218 |
+
max_position=100
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
def prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
|
| 222 |
+
"""Prepare data with technical indicators for FinRL"""
|
| 223 |
+
df = data.copy()
|
| 224 |
+
|
| 225 |
+
# Add technical indicators if not present
|
| 226 |
+
if 'sma_20' not in df.columns:
|
| 227 |
+
df['sma_20'] = df['close'].rolling(window=20).mean()
|
| 228 |
+
if 'sma_50' not in df.columns:
|
| 229 |
+
df['sma_50'] = df['close'].rolling(window=50).mean()
|
| 230 |
+
if 'rsi' not in df.columns:
|
| 231 |
+
df['rsi'] = self._calculate_rsi(df['close'])
|
| 232 |
+
if 'bb_upper' not in df.columns or 'bb_lower' not in df.columns:
|
| 233 |
+
bb_upper, bb_lower = self._calculate_bollinger_bands(df['close'])
|
| 234 |
+
df['bb_upper'] = bb_upper
|
| 235 |
+
df['bb_lower'] = bb_lower
|
| 236 |
+
if 'macd' not in df.columns:
|
| 237 |
+
df['macd'] = self._calculate_macd(df['close'])
|
| 238 |
+
|
| 239 |
+
# Fill NaN values
|
| 240 |
+
df = df.fillna(method='bfill').fillna(0)
|
| 241 |
+
|
| 242 |
+
return df
|
| 243 |
+
|
| 244 |
+
def _calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
|
| 245 |
+
"""Calculate RSI indicator"""
|
| 246 |
+
delta = prices.diff()
|
| 247 |
+
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
| 248 |
+
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
| 249 |
+
rs = gain / loss
|
| 250 |
+
rsi = 100 - (100 / (1 + rs))
|
| 251 |
+
return rsi
|
| 252 |
+
|
| 253 |
+
def _calculate_bollinger_bands(self, prices: pd.Series, period: int = 20, std_dev: int = 2) -> Tuple[pd.Series, pd.Series]:
|
| 254 |
+
"""Calculate Bollinger Bands"""
|
| 255 |
+
sma = prices.rolling(window=period).mean()
|
| 256 |
+
std = prices.rolling(window=period).std()
|
| 257 |
+
upper_band = sma + (std * std_dev)
|
| 258 |
+
lower_band = sma - (std * std_dev)
|
| 259 |
+
return upper_band, lower_band
|
| 260 |
+
|
| 261 |
+
def _calculate_macd(self, prices: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9) -> pd.Series:
|
| 262 |
+
"""Calculate MACD indicator"""
|
| 263 |
+
ema_fast = prices.ewm(span=fast).mean()
|
| 264 |
+
ema_slow = prices.ewm(span=slow).mean()
|
| 265 |
+
macd_line = ema_fast - ema_slow
|
| 266 |
+
return macd_line
|
| 267 |
+
|
| 268 |
+
def train(self, data: pd.DataFrame, total_timesteps: int = 100000,
|
| 269 |
+
eval_freq: int = 10000, eval_data: Optional[pd.DataFrame] = None) -> Dict[str, Any]:
|
| 270 |
+
"""Train the FinRL agent"""
|
| 271 |
+
|
| 272 |
+
logger.info("Starting FinRL agent training")
|
| 273 |
+
|
| 274 |
+
# Prepare data
|
| 275 |
+
train_data = self.prepare_data(data)
|
| 276 |
+
|
| 277 |
+
# Create training environment
|
| 278 |
+
self.env = DummyVecEnv([lambda: self.create_environment(train_data)])
|
| 279 |
+
|
| 280 |
+
# Create evaluation environment if provided
|
| 281 |
+
if eval_data is not None:
|
| 282 |
+
eval_data = self.prepare_data(eval_data)
|
| 283 |
+
self.eval_env = DummyVecEnv([lambda: self.create_environment(eval_data)])
|
| 284 |
+
self.callback = EvalCallback(
|
| 285 |
+
self.eval_env,
|
| 286 |
+
best_model_save_path="models/finrl_best/",
|
| 287 |
+
log_path="logs/finrl_eval/",
|
| 288 |
+
eval_freq=eval_freq,
|
| 289 |
+
deterministic=True,
|
| 290 |
+
render=False
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Initialize model based on algorithm
|
| 294 |
+
if self.config.algorithm == "PPO":
|
| 295 |
+
self.model = PPO(
|
| 296 |
+
"MlpPolicy",
|
| 297 |
+
self.env,
|
| 298 |
+
learning_rate=self.config.learning_rate,
|
| 299 |
+
batch_size=self.config.batch_size,
|
| 300 |
+
gamma=self.config.gamma,
|
| 301 |
+
verbose=self.config.verbose,
|
| 302 |
+
tensorboard_log=self.config.tensorboard_log
|
| 303 |
+
)
|
| 304 |
+
elif self.config.algorithm == "A2C":
|
| 305 |
+
self.model = A2C(
|
| 306 |
+
"MlpPolicy",
|
| 307 |
+
self.env,
|
| 308 |
+
learning_rate=self.config.learning_rate,
|
| 309 |
+
gamma=self.config.gamma,
|
| 310 |
+
verbose=self.config.verbose,
|
| 311 |
+
tensorboard_log=self.config.tensorboard_log
|
| 312 |
+
)
|
| 313 |
+
elif self.config.algorithm == "DDPG":
|
| 314 |
+
self.model = DDPG(
|
| 315 |
+
"MlpPolicy",
|
| 316 |
+
self.env,
|
| 317 |
+
learning_rate=self.config.learning_rate,
|
| 318 |
+
buffer_size=self.config.buffer_size,
|
| 319 |
+
learning_starts=self.config.learning_starts,
|
| 320 |
+
gamma=self.config.gamma,
|
| 321 |
+
tau=self.config.tau,
|
| 322 |
+
train_freq=self.config.train_freq,
|
| 323 |
+
gradient_steps=self.config.gradient_steps,
|
| 324 |
+
verbose=self.config.verbose,
|
| 325 |
+
tensorboard_log=self.config.tensorboard_log
|
| 326 |
+
)
|
| 327 |
+
elif self.config.algorithm == "TD3":
|
| 328 |
+
self.model = TD3(
|
| 329 |
+
"MlpPolicy",
|
| 330 |
+
self.env,
|
| 331 |
+
learning_rate=self.config.learning_rate,
|
| 332 |
+
buffer_size=self.config.buffer_size,
|
| 333 |
+
learning_starts=self.config.learning_starts,
|
| 334 |
+
gamma=self.config.gamma,
|
| 335 |
+
tau=self.config.tau,
|
| 336 |
+
train_freq=self.config.train_freq,
|
| 337 |
+
gradient_steps=self.config.gradient_steps,
|
| 338 |
+
target_update_interval=self.config.target_update_interval,
|
| 339 |
+
verbose=self.config.verbose,
|
| 340 |
+
tensorboard_log=self.config.tensorboard_log
|
| 341 |
+
)
|
| 342 |
+
else:
|
| 343 |
+
raise ValueError(f"Unsupported algorithm: {self.config.algorithm}")
|
| 344 |
+
|
| 345 |
+
# Train the model
|
| 346 |
+
callbacks = [self.callback] if self.callback else None
|
| 347 |
+
self.model.learn(
|
| 348 |
+
total_timesteps=total_timesteps,
|
| 349 |
+
callback=callbacks
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
logger.info("FinRL agent training completed")
|
| 353 |
+
|
| 354 |
+
return {
|
| 355 |
+
'algorithm': self.config.algorithm,
|
| 356 |
+
'total_timesteps': total_timesteps,
|
| 357 |
+
'model_path': f"models/finrl_{self.config.algorithm.lower()}"
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
def predict(self, data: pd.DataFrame) -> List[int]:
|
| 361 |
+
"""Generate trading predictions using the trained model"""
|
| 362 |
+
if self.model is None:
|
| 363 |
+
raise ValueError("Model not trained. Call train() first.")
|
| 364 |
+
|
| 365 |
+
# Prepare data
|
| 366 |
+
test_data = self.prepare_data(data)
|
| 367 |
+
|
| 368 |
+
# Create test environment
|
| 369 |
+
test_env = self.create_environment(test_data)
|
| 370 |
+
|
| 371 |
+
predictions = []
|
| 372 |
+
obs, _ = test_env.reset()
|
| 373 |
+
|
| 374 |
+
done = False
|
| 375 |
+
while not done:
|
| 376 |
+
action, _ = self.model.predict(obs, deterministic=True)
|
| 377 |
+
predictions.append(action)
|
| 378 |
+
obs, _, done, _, _ = test_env.step(action)
|
| 379 |
+
|
| 380 |
+
return predictions
|
| 381 |
+
|
| 382 |
+
def evaluate(self, data: pd.DataFrame) -> Dict[str, float]:
|
| 383 |
+
"""Evaluate the trained model on test data"""
|
| 384 |
+
if self.model is None:
|
| 385 |
+
raise ValueError("Model not trained. Call train() first.")
|
| 386 |
+
|
| 387 |
+
# Prepare data
|
| 388 |
+
test_data = self.prepare_data(data)
|
| 389 |
+
|
| 390 |
+
# Create test environment
|
| 391 |
+
test_env = self.create_environment(test_data)
|
| 392 |
+
|
| 393 |
+
obs, _ = test_env.reset()
|
| 394 |
+
done = False
|
| 395 |
+
total_reward = 0
|
| 396 |
+
steps = 0
|
| 397 |
+
|
| 398 |
+
while not done:
|
| 399 |
+
action, _ = self.model.predict(obs, deterministic=True)
|
| 400 |
+
obs, reward, done, _, info = test_env.step(action)
|
| 401 |
+
total_reward += reward
|
| 402 |
+
steps += 1
|
| 403 |
+
|
| 404 |
+
# Calculate metrics
|
| 405 |
+
final_portfolio_value = info['portfolio_value']
|
| 406 |
+
initial_balance = test_env.initial_balance
|
| 407 |
+
total_return = (final_portfolio_value - initial_balance) / initial_balance
|
| 408 |
+
|
| 409 |
+
return {
|
| 410 |
+
'total_reward': total_reward,
|
| 411 |
+
'total_return': total_return,
|
| 412 |
+
'final_portfolio_value': final_portfolio_value,
|
| 413 |
+
'steps': steps,
|
| 414 |
+
'sharpe_ratio': total_reward / steps if steps > 0 else 0
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
def save_model(self, path: str):
|
| 418 |
+
"""Save the trained model"""
|
| 419 |
+
if self.model is None:
|
| 420 |
+
raise ValueError("No model to save. Train the model first.")
|
| 421 |
+
|
| 422 |
+
self.model.save(path)
|
| 423 |
+
logger.info(f"Model saved to {path}")
|
| 424 |
+
|
| 425 |
+
def load_model(self, path: str):
|
| 426 |
+
"""Load a trained model"""
|
| 427 |
+
if self.config.algorithm == "PPO":
|
| 428 |
+
self.model = PPO.load(path)
|
| 429 |
+
elif self.config.algorithm == "A2C":
|
| 430 |
+
self.model = A2C.load(path)
|
| 431 |
+
elif self.config.algorithm == "DDPG":
|
| 432 |
+
self.model = DDPG.load(path)
|
| 433 |
+
elif self.config.algorithm == "TD3":
|
| 434 |
+
self.model = TD3.load(path)
|
| 435 |
+
else:
|
| 436 |
+
raise ValueError(f"Unsupported algorithm: {self.config.algorithm}")
|
| 437 |
+
|
| 438 |
+
logger.info(f"Model loaded from {path}")
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def create_finrl_agent_from_config(config_path: str) -> FinRLAgent:
|
| 442 |
+
"""Create FinRL agent from configuration file"""
|
| 443 |
+
with open(config_path, 'r') as file:
|
| 444 |
+
config_data = yaml.safe_load(file)
|
| 445 |
+
|
| 446 |
+
finrl_config = FinRLConfig(**config_data.get('finrl', {}))
|
| 447 |
+
return FinRLAgent(finrl_config)
|
config.yaml
CHANGED
|
@@ -33,3 +33,30 @@ logging:
|
|
| 33 |
enable_file: true
|
| 34 |
max_file_size_mb: 10
|
| 35 |
backup_count: 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
enable_file: true
|
| 34 |
max_file_size_mb: 10
|
| 35 |
backup_count: 5
|
| 36 |
+
|
| 37 |
+
# FinRL configuration
|
| 38 |
+
finrl:
|
| 39 |
+
algorithm: 'PPO' # PPO, A2C, DDPG, TD3
|
| 40 |
+
learning_rate: 0.0003
|
| 41 |
+
batch_size: 64
|
| 42 |
+
buffer_size: 1000000
|
| 43 |
+
learning_starts: 100
|
| 44 |
+
gamma: 0.99
|
| 45 |
+
tau: 0.005
|
| 46 |
+
train_freq: 1
|
| 47 |
+
gradient_steps: 1
|
| 48 |
+
target_update_interval: 1
|
| 49 |
+
exploration_fraction: 0.1
|
| 50 |
+
exploration_initial_eps: 1.0
|
| 51 |
+
exploration_final_eps: 0.05
|
| 52 |
+
max_grad_norm: 10.0
|
| 53 |
+
verbose: 1
|
| 54 |
+
tensorboard_log: 'logs/finrl_tensorboard'
|
| 55 |
+
training:
|
| 56 |
+
total_timesteps: 100000
|
| 57 |
+
eval_freq: 10000
|
| 58 |
+
save_best_model: true
|
| 59 |
+
model_save_path: 'models/finrl_best/'
|
| 60 |
+
inference:
|
| 61 |
+
use_trained_model: false
|
| 62 |
+
model_path: 'models/finrl_best/best_model'
|
finrl_demo.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
FinRL Demo Script
|
| 4 |
+
|
| 5 |
+
This script demonstrates the integration of FinRL with the algorithmic trading system.
|
| 6 |
+
It shows how to train a reinforcement learning agent and use it for trading decisions.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import yaml
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import numpy as np
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import seaborn as sns
|
| 16 |
+
from datetime import datetime, timedelta
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
# Add the project root to the path
|
| 20 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 21 |
+
|
| 22 |
+
from agentic_ai_system.finrl_agent import FinRLAgent, FinRLConfig, create_finrl_agent_from_config
|
| 23 |
+
from agentic_ai_system.synthetic_data_generator import SyntheticDataGenerator
|
| 24 |
+
from agentic_ai_system.logger_config import setup_logging
|
| 25 |
+
|
| 26 |
+
# Setup logging
|
| 27 |
+
setup_logging()
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_config(config_path: str = 'config.yaml') -> dict:
|
| 32 |
+
"""Load configuration from YAML file"""
|
| 33 |
+
with open(config_path, 'r') as file:
|
| 34 |
+
return yaml.safe_load(file)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def generate_training_data(config: dict) -> pd.DataFrame:
|
| 38 |
+
"""Generate synthetic data for training"""
|
| 39 |
+
logger.info("Generating synthetic training data")
|
| 40 |
+
|
| 41 |
+
generator = SyntheticDataGenerator(config)
|
| 42 |
+
|
| 43 |
+
# Generate training data (longer period)
|
| 44 |
+
train_data = generator.generate_ohlcv_data(
|
| 45 |
+
symbol='AAPL',
|
| 46 |
+
start_date='2023-01-01',
|
| 47 |
+
end_date='2023-12-31',
|
| 48 |
+
frequency='1H'
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Add technical indicators
|
| 52 |
+
train_data['sma_20'] = train_data['close'].rolling(window=20).mean()
|
| 53 |
+
train_data['sma_50'] = train_data['close'].rolling(window=50).mean()
|
| 54 |
+
train_data['rsi'] = calculate_rsi(train_data['close'])
|
| 55 |
+
bb_upper, bb_lower = calculate_bollinger_bands(train_data['close'])
|
| 56 |
+
train_data['bb_upper'] = bb_upper
|
| 57 |
+
train_data['bb_lower'] = bb_lower
|
| 58 |
+
train_data['macd'] = calculate_macd(train_data['close'])
|
| 59 |
+
|
| 60 |
+
# Fill NaN values
|
| 61 |
+
train_data = train_data.fillna(method='bfill').fillna(0)
|
| 62 |
+
|
| 63 |
+
logger.info(f"Generated {len(train_data)} training samples")
|
| 64 |
+
return train_data
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def generate_test_data(config: dict) -> pd.DataFrame:
|
| 68 |
+
"""Generate synthetic data for testing"""
|
| 69 |
+
logger.info("Generating synthetic test data")
|
| 70 |
+
|
| 71 |
+
generator = SyntheticDataGenerator(config)
|
| 72 |
+
|
| 73 |
+
# Generate test data (shorter period)
|
| 74 |
+
test_data = generator.generate_ohlcv_data(
|
| 75 |
+
symbol='AAPL',
|
| 76 |
+
start_date='2024-01-01',
|
| 77 |
+
end_date='2024-03-31',
|
| 78 |
+
frequency='1H'
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Add technical indicators
|
| 82 |
+
test_data['sma_20'] = test_data['close'].rolling(window=20).mean()
|
| 83 |
+
test_data['sma_50'] = test_data['close'].rolling(window=50).mean()
|
| 84 |
+
test_data['rsi'] = calculate_rsi(test_data['close'])
|
| 85 |
+
bb_upper, bb_lower = calculate_bollinger_bands(test_data['close'])
|
| 86 |
+
test_data['bb_upper'] = bb_upper
|
| 87 |
+
test_data['bb_lower'] = bb_lower
|
| 88 |
+
test_data['macd'] = calculate_macd(test_data['close'])
|
| 89 |
+
|
| 90 |
+
# Fill NaN values
|
| 91 |
+
test_data = test_data.fillna(method='bfill').fillna(0)
|
| 92 |
+
|
| 93 |
+
logger.info(f"Generated {len(test_data)} test samples")
|
| 94 |
+
return test_data
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def calculate_rsi(prices: pd.Series, period: int = 14) -> pd.Series:
|
| 98 |
+
"""Calculate RSI indicator"""
|
| 99 |
+
delta = prices.diff()
|
| 100 |
+
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
| 101 |
+
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
| 102 |
+
rs = gain / loss
|
| 103 |
+
rsi = 100 - (100 / (1 + rs))
|
| 104 |
+
return rsi
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def calculate_bollinger_bands(prices: pd.Series, period: int = 20, std_dev: int = 2):
|
| 108 |
+
"""Calculate Bollinger Bands"""
|
| 109 |
+
sma = prices.rolling(window=period).mean()
|
| 110 |
+
std = prices.rolling(window=period).std()
|
| 111 |
+
upper_band = sma + (std * std_dev)
|
| 112 |
+
lower_band = sma - (std * std_dev)
|
| 113 |
+
return upper_band, lower_band
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def calculate_macd(prices: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9) -> pd.Series:
|
| 117 |
+
"""Calculate MACD indicator"""
|
| 118 |
+
ema_fast = prices.ewm(span=fast).mean()
|
| 119 |
+
ema_slow = prices.ewm(span=slow).mean()
|
| 120 |
+
macd_line = ema_fast - ema_slow
|
| 121 |
+
return macd_line
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def train_finrl_agent(config: dict, train_data: pd.DataFrame, test_data: pd.DataFrame) -> FinRLAgent:
|
| 125 |
+
"""Train the FinRL agent"""
|
| 126 |
+
logger.info("Starting FinRL agent training")
|
| 127 |
+
|
| 128 |
+
# Create FinRL agent
|
| 129 |
+
finrl_config = FinRLConfig(**config['finrl'])
|
| 130 |
+
agent = FinRLAgent(finrl_config)
|
| 131 |
+
|
| 132 |
+
# Train the agent
|
| 133 |
+
training_result = agent.train(
|
| 134 |
+
data=train_data,
|
| 135 |
+
total_timesteps=config['finrl']['training']['total_timesteps'],
|
| 136 |
+
eval_freq=config['finrl']['training']['eval_freq'],
|
| 137 |
+
eval_data=test_data
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
logger.info(f"Training completed: {training_result}")
|
| 141 |
+
|
| 142 |
+
# Save the model
|
| 143 |
+
if config['finrl']['training']['save_best_model']:
|
| 144 |
+
model_path = config['finrl']['training']['model_save_path']
|
| 145 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
| 146 |
+
agent.save_model(model_path)
|
| 147 |
+
|
| 148 |
+
return agent
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def evaluate_agent(agent: FinRLAgent, test_data: pd.DataFrame) -> dict:
|
| 152 |
+
"""Evaluate the trained agent"""
|
| 153 |
+
logger.info("Evaluating FinRL agent")
|
| 154 |
+
|
| 155 |
+
# Evaluate on test data
|
| 156 |
+
evaluation_results = agent.evaluate(test_data)
|
| 157 |
+
|
| 158 |
+
logger.info(f"Evaluation results: {evaluation_results}")
|
| 159 |
+
|
| 160 |
+
return evaluation_results
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def generate_predictions(agent: FinRLAgent, test_data: pd.DataFrame) -> list:
|
| 164 |
+
"""Generate trading predictions"""
|
| 165 |
+
logger.info("Generating trading predictions")
|
| 166 |
+
|
| 167 |
+
predictions = agent.predict(test_data)
|
| 168 |
+
|
| 169 |
+
logger.info(f"Generated {len(predictions)} predictions")
|
| 170 |
+
|
| 171 |
+
return predictions
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def plot_results(test_data: pd.DataFrame, predictions: list, evaluation_results: dict):
|
| 175 |
+
"""Plot trading results"""
|
| 176 |
+
logger.info("Creating visualization plots")
|
| 177 |
+
|
| 178 |
+
# Create figure with subplots
|
| 179 |
+
fig, axes = plt.subplots(3, 1, figsize=(15, 12))
|
| 180 |
+
|
| 181 |
+
# Plot 1: Price and predictions
|
| 182 |
+
axes[0].plot(test_data.index, test_data['close'], label='Close Price', alpha=0.7)
|
| 183 |
+
|
| 184 |
+
# Mark buy/sell signals
|
| 185 |
+
buy_signals = [i for i, pred in enumerate(predictions) if pred == 2]
|
| 186 |
+
sell_signals = [i for i, pred in enumerate(predictions) if pred == 0]
|
| 187 |
+
|
| 188 |
+
if buy_signals:
|
| 189 |
+
axes[0].scatter(test_data.index[buy_signals], test_data['close'].iloc[buy_signals],
|
| 190 |
+
color='green', marker='^', s=100, label='Buy Signal', alpha=0.8)
|
| 191 |
+
if sell_signals:
|
| 192 |
+
axes[0].scatter(test_data.index[sell_signals], test_data['close'].iloc[sell_signals],
|
| 193 |
+
color='red', marker='v', s=100, label='Sell Signal', alpha=0.8)
|
| 194 |
+
|
| 195 |
+
axes[0].set_title('Price Action and Trading Signals')
|
| 196 |
+
axes[0].set_ylabel('Price')
|
| 197 |
+
axes[0].legend()
|
| 198 |
+
axes[0].grid(True, alpha=0.3)
|
| 199 |
+
|
| 200 |
+
# Plot 2: Technical indicators
|
| 201 |
+
axes[1].plot(test_data.index, test_data['close'], label='Close Price', alpha=0.7)
|
| 202 |
+
axes[1].plot(test_data.index, test_data['sma_20'], label='SMA 20', alpha=0.7)
|
| 203 |
+
axes[1].plot(test_data.index, test_data['sma_50'], label='SMA 50', alpha=0.7)
|
| 204 |
+
axes[1].plot(test_data.index, test_data['bb_upper'], label='BB Upper', alpha=0.5)
|
| 205 |
+
axes[1].plot(test_data.index, test_data['bb_lower'], label='BB Lower', alpha=0.5)
|
| 206 |
+
|
| 207 |
+
axes[1].set_title('Technical Indicators')
|
| 208 |
+
axes[1].set_ylabel('Price')
|
| 209 |
+
axes[1].legend()
|
| 210 |
+
axes[1].grid(True, alpha=0.3)
|
| 211 |
+
|
| 212 |
+
# Plot 3: RSI
|
| 213 |
+
axes[2].plot(test_data.index, test_data['rsi'], label='RSI', color='purple')
|
| 214 |
+
axes[2].axhline(y=70, color='r', linestyle='--', alpha=0.5, label='Overbought')
|
| 215 |
+
axes[2].axhline(y=30, color='g', linestyle='--', alpha=0.5, label='Oversold')
|
| 216 |
+
axes[2].set_title('RSI Indicator')
|
| 217 |
+
axes[2].set_ylabel('RSI')
|
| 218 |
+
axes[2].set_xlabel('Time')
|
| 219 |
+
axes[2].legend()
|
| 220 |
+
axes[2].grid(True, alpha=0.3)
|
| 221 |
+
|
| 222 |
+
plt.tight_layout()
|
| 223 |
+
|
| 224 |
+
# Save plot
|
| 225 |
+
os.makedirs('plots', exist_ok=True)
|
| 226 |
+
plt.savefig('plots/finrl_trading_results.png', dpi=300, bbox_inches='tight')
|
| 227 |
+
plt.show()
|
| 228 |
+
|
| 229 |
+
logger.info("Plots saved to plots/finrl_trading_results.png")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def print_summary(evaluation_results: dict, predictions: list):
|
| 233 |
+
"""Print trading summary"""
|
| 234 |
+
print("\n" + "="*60)
|
| 235 |
+
print("FINRL TRADING SYSTEM SUMMARY")
|
| 236 |
+
print("="*60)
|
| 237 |
+
|
| 238 |
+
print(f"Algorithm: {evaluation_results.get('algorithm', 'Unknown')}")
|
| 239 |
+
print(f"Total Return: {evaluation_results['total_return']:.2%}")
|
| 240 |
+
print(f"Final Portfolio Value: ${evaluation_results['final_portfolio_value']:,.2f}")
|
| 241 |
+
print(f"Total Reward: {evaluation_results['total_reward']:.4f}")
|
| 242 |
+
print(f"Sharpe Ratio: {evaluation_results['sharpe_ratio']:.4f}")
|
| 243 |
+
print(f"Number of Trading Steps: {evaluation_results['steps']}")
|
| 244 |
+
|
| 245 |
+
# Trading statistics
|
| 246 |
+
buy_signals = sum(1 for pred in predictions if pred == 2)
|
| 247 |
+
sell_signals = sum(1 for pred in predictions if pred == 0)
|
| 248 |
+
hold_signals = sum(1 for pred in predictions if pred == 1)
|
| 249 |
+
|
| 250 |
+
print(f"\nTrading Signals:")
|
| 251 |
+
print(f" Buy signals: {buy_signals}")
|
| 252 |
+
print(f" Sell signals: {sell_signals}")
|
| 253 |
+
print(f" Hold signals: {hold_signals}")
|
| 254 |
+
print(f" Total signals: {len(predictions)}")
|
| 255 |
+
|
| 256 |
+
print("\n" + "="*60)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def main():
|
| 260 |
+
"""Main function to run the FinRL demo"""
|
| 261 |
+
logger.info("Starting FinRL Demo")
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
# Load configuration
|
| 265 |
+
config = load_config()
|
| 266 |
+
|
| 267 |
+
# Generate data
|
| 268 |
+
train_data = generate_training_data(config)
|
| 269 |
+
test_data = generate_test_data(config)
|
| 270 |
+
|
| 271 |
+
# Train FinRL agent
|
| 272 |
+
agent = train_finrl_agent(config, train_data, test_data)
|
| 273 |
+
|
| 274 |
+
# Evaluate agent
|
| 275 |
+
evaluation_results = evaluate_agent(agent, test_data)
|
| 276 |
+
|
| 277 |
+
# Generate predictions
|
| 278 |
+
predictions = generate_predictions(agent, test_data)
|
| 279 |
+
|
| 280 |
+
# Create visualizations
|
| 281 |
+
plot_results(test_data, predictions, evaluation_results)
|
| 282 |
+
|
| 283 |
+
# Print summary
|
| 284 |
+
print_summary(evaluation_results, predictions)
|
| 285 |
+
|
| 286 |
+
logger.info("FinRL Demo completed successfully")
|
| 287 |
+
|
| 288 |
+
except Exception as e:
|
| 289 |
+
logger.error(f"Error in FinRL demo: {str(e)}")
|
| 290 |
+
raise
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
if __name__ == "__main__":
|
| 294 |
+
main()
|
requirements.txt
CHANGED
|
@@ -7,3 +7,8 @@ pytest
|
|
| 7 |
pytest-cov
|
| 8 |
python-dateutil
|
| 9 |
scipy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
pytest-cov
|
| 8 |
python-dateutil
|
| 9 |
scipy
|
| 10 |
+
finrl
|
| 11 |
+
stable-baselines3
|
| 12 |
+
gymnasium
|
| 13 |
+
tensorboard
|
| 14 |
+
torch
|
tests/test_finrl_agent.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for FinRL Agent
|
| 3 |
+
|
| 4 |
+
This module contains comprehensive tests for the FinRL agent functionality.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
import yaml
|
| 11 |
+
import tempfile
|
| 12 |
+
import os
|
| 13 |
+
from unittest.mock import Mock, patch
|
| 14 |
+
|
| 15 |
+
# Add the project root to the path
|
| 16 |
+
import sys
|
| 17 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 18 |
+
|
| 19 |
+
from agentic_ai_system.finrl_agent import (
|
| 20 |
+
FinRLAgent,
|
| 21 |
+
FinRLConfig,
|
| 22 |
+
TradingEnvironment,
|
| 23 |
+
create_finrl_agent_from_config
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TestFinRLConfig:
|
| 28 |
+
"""Test FinRL configuration"""
|
| 29 |
+
|
| 30 |
+
def test_default_config(self):
|
| 31 |
+
"""Test default configuration values"""
|
| 32 |
+
config = FinRLConfig()
|
| 33 |
+
|
| 34 |
+
assert config.algorithm == "PPO"
|
| 35 |
+
assert config.learning_rate == 0.0003
|
| 36 |
+
assert config.batch_size == 64
|
| 37 |
+
assert config.gamma == 0.99
|
| 38 |
+
|
| 39 |
+
def test_custom_config(self):
|
| 40 |
+
"""Test custom configuration values"""
|
| 41 |
+
config = FinRLConfig(
|
| 42 |
+
algorithm="A2C",
|
| 43 |
+
learning_rate=0.001,
|
| 44 |
+
batch_size=128
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
assert config.algorithm == "A2C"
|
| 48 |
+
assert config.learning_rate == 0.001
|
| 49 |
+
assert config.batch_size == 128
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class TestTradingEnvironment:
|
| 53 |
+
"""Test trading environment"""
|
| 54 |
+
|
| 55 |
+
@pytest.fixture
|
| 56 |
+
def sample_data(self):
|
| 57 |
+
"""Create sample market data"""
|
| 58 |
+
dates = pd.date_range('2024-01-01', periods=100, freq='1H')
|
| 59 |
+
data = pd.DataFrame({
|
| 60 |
+
'open': np.random.uniform(100, 200, 100),
|
| 61 |
+
'high': np.random.uniform(100, 200, 100),
|
| 62 |
+
'low': np.random.uniform(100, 200, 100),
|
| 63 |
+
'close': np.random.uniform(100, 200, 100),
|
| 64 |
+
'volume': np.random.uniform(1000, 10000, 100),
|
| 65 |
+
'sma_20': np.random.uniform(100, 200, 100),
|
| 66 |
+
'sma_50': np.random.uniform(100, 200, 100),
|
| 67 |
+
'rsi': np.random.uniform(0, 100, 100),
|
| 68 |
+
'bb_upper': np.random.uniform(100, 200, 100),
|
| 69 |
+
'bb_lower': np.random.uniform(100, 200, 100),
|
| 70 |
+
'macd': np.random.uniform(-10, 10, 100)
|
| 71 |
+
}, index=dates)
|
| 72 |
+
return data
|
| 73 |
+
|
| 74 |
+
def test_environment_initialization(self, sample_data):
|
| 75 |
+
"""Test environment initialization"""
|
| 76 |
+
env = TradingEnvironment(sample_data)
|
| 77 |
+
|
| 78 |
+
assert env.initial_balance == 100000
|
| 79 |
+
assert env.transaction_fee == 0.001
|
| 80 |
+
assert env.max_position == 100
|
| 81 |
+
assert env.action_space.n == 3
|
| 82 |
+
assert len(env.observation_space.shape) == 1
|
| 83 |
+
|
| 84 |
+
def test_environment_reset(self, sample_data):
|
| 85 |
+
"""Test environment reset"""
|
| 86 |
+
env = TradingEnvironment(sample_data)
|
| 87 |
+
obs, info = env.reset()
|
| 88 |
+
|
| 89 |
+
assert env.current_step == 0
|
| 90 |
+
assert env.balance == env.initial_balance
|
| 91 |
+
assert env.position == 0
|
| 92 |
+
assert env.portfolio_value == env.initial_balance
|
| 93 |
+
assert isinstance(obs, np.ndarray)
|
| 94 |
+
assert isinstance(info, dict)
|
| 95 |
+
|
| 96 |
+
def test_environment_step(self, sample_data):
|
| 97 |
+
"""Test environment step"""
|
| 98 |
+
env = TradingEnvironment(sample_data)
|
| 99 |
+
obs, info = env.reset()
|
| 100 |
+
|
| 101 |
+
# Test hold action
|
| 102 |
+
obs, reward, done, truncated, info = env.step(1)
|
| 103 |
+
|
| 104 |
+
assert isinstance(obs, np.ndarray)
|
| 105 |
+
assert isinstance(reward, float)
|
| 106 |
+
assert isinstance(done, bool)
|
| 107 |
+
assert isinstance(truncated, bool)
|
| 108 |
+
assert isinstance(info, dict)
|
| 109 |
+
assert env.current_step == 1
|
| 110 |
+
|
| 111 |
+
def test_buy_action(self, sample_data):
|
| 112 |
+
"""Test buy action"""
|
| 113 |
+
env = TradingEnvironment(sample_data, initial_balance=10000)
|
| 114 |
+
obs, info = env.reset()
|
| 115 |
+
|
| 116 |
+
initial_balance = env.balance
|
| 117 |
+
initial_position = env.position
|
| 118 |
+
|
| 119 |
+
# Buy action
|
| 120 |
+
obs, reward, done, truncated, info = env.step(2)
|
| 121 |
+
|
| 122 |
+
assert env.position > initial_position
|
| 123 |
+
assert env.balance < initial_balance
|
| 124 |
+
|
| 125 |
+
def test_sell_action(self, sample_data):
|
| 126 |
+
"""Test sell action"""
|
| 127 |
+
env = TradingEnvironment(sample_data, initial_balance=10000)
|
| 128 |
+
obs, info = env.reset()
|
| 129 |
+
|
| 130 |
+
# First buy some shares
|
| 131 |
+
obs, reward, done, truncated, info = env.step(2)
|
| 132 |
+
initial_position = env.position
|
| 133 |
+
initial_balance = env.balance
|
| 134 |
+
|
| 135 |
+
# Then sell
|
| 136 |
+
obs, reward, done, truncated, info = env.step(0)
|
| 137 |
+
|
| 138 |
+
assert env.position < initial_position
|
| 139 |
+
assert env.balance > initial_balance
|
| 140 |
+
|
| 141 |
+
def test_portfolio_value_calculation(self, sample_data):
|
| 142 |
+
"""Test portfolio value calculation"""
|
| 143 |
+
env = TradingEnvironment(sample_data)
|
| 144 |
+
obs, info = env.reset()
|
| 145 |
+
|
| 146 |
+
# Buy some shares
|
| 147 |
+
obs, reward, done, truncated, info = env.step(2)
|
| 148 |
+
|
| 149 |
+
expected_value = env.balance + (env.position * sample_data.iloc[env.current_step]['close'])
|
| 150 |
+
assert abs(env.portfolio_value - expected_value) < 1e-6
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class TestFinRLAgent:
|
| 154 |
+
"""Test FinRL agent"""
|
| 155 |
+
|
| 156 |
+
@pytest.fixture
|
| 157 |
+
def sample_data(self):
|
| 158 |
+
"""Create sample market data"""
|
| 159 |
+
dates = pd.date_range('2024-01-01', periods=100, freq='1H')
|
| 160 |
+
data = pd.DataFrame({
|
| 161 |
+
'open': np.random.uniform(100, 200, 100),
|
| 162 |
+
'high': np.random.uniform(100, 200, 100),
|
| 163 |
+
'low': np.random.uniform(100, 200, 100),
|
| 164 |
+
'close': np.random.uniform(100, 200, 100),
|
| 165 |
+
'volume': np.random.uniform(1000, 10000, 100)
|
| 166 |
+
}, index=dates)
|
| 167 |
+
return data
|
| 168 |
+
|
| 169 |
+
@pytest.fixture
|
| 170 |
+
def finrl_config(self):
|
| 171 |
+
"""Create FinRL configuration"""
|
| 172 |
+
return FinRLConfig(
|
| 173 |
+
algorithm="PPO",
|
| 174 |
+
learning_rate=0.0003,
|
| 175 |
+
batch_size=32,
|
| 176 |
+
total_timesteps=1000
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def test_agent_initialization(self, finrl_config):
|
| 180 |
+
"""Test agent initialization"""
|
| 181 |
+
agent = FinRLAgent(finrl_config)
|
| 182 |
+
|
| 183 |
+
assert agent.config == finrl_config
|
| 184 |
+
assert agent.model is None
|
| 185 |
+
assert agent.env is None
|
| 186 |
+
|
| 187 |
+
def test_prepare_data(self, finrl_config, sample_data):
|
| 188 |
+
"""Test data preparation"""
|
| 189 |
+
agent = FinRLAgent(finrl_config)
|
| 190 |
+
prepared_data = agent.prepare_data(sample_data)
|
| 191 |
+
|
| 192 |
+
# Check that technical indicators were added
|
| 193 |
+
assert 'sma_20' in prepared_data.columns
|
| 194 |
+
assert 'sma_50' in prepared_data.columns
|
| 195 |
+
assert 'rsi' in prepared_data.columns
|
| 196 |
+
assert 'bb_upper' in prepared_data.columns
|
| 197 |
+
assert 'bb_lower' in prepared_data.columns
|
| 198 |
+
assert 'macd' in prepared_data.columns
|
| 199 |
+
|
| 200 |
+
# Check that no NaN values remain
|
| 201 |
+
assert not prepared_data.isnull().any().any()
|
| 202 |
+
|
| 203 |
+
def test_create_environment(self, finrl_config, sample_data):
|
| 204 |
+
"""Test environment creation"""
|
| 205 |
+
agent = FinRLAgent(finrl_config)
|
| 206 |
+
env = agent.create_environment(sample_data)
|
| 207 |
+
|
| 208 |
+
assert isinstance(env, TradingEnvironment)
|
| 209 |
+
assert env.data.equals(sample_data)
|
| 210 |
+
|
| 211 |
+
def test_technical_indicators_calculation(self, finrl_config):
|
| 212 |
+
"""Test technical indicators calculation"""
|
| 213 |
+
agent = FinRLAgent(finrl_config)
|
| 214 |
+
|
| 215 |
+
# Test RSI calculation
|
| 216 |
+
prices = pd.Series([100, 101, 99, 102, 98, 103, 97, 104, 96, 105])
|
| 217 |
+
rsi = agent._calculate_rsi(prices, period=3)
|
| 218 |
+
assert len(rsi) == len(prices)
|
| 219 |
+
assert not rsi.isnull().all()
|
| 220 |
+
|
| 221 |
+
# Test Bollinger Bands calculation
|
| 222 |
+
bb_upper, bb_lower = agent._calculate_bollinger_bands(prices, period=3)
|
| 223 |
+
assert len(bb_upper) == len(prices)
|
| 224 |
+
assert len(bb_lower) == len(prices)
|
| 225 |
+
assert (bb_upper >= bb_lower).all()
|
| 226 |
+
|
| 227 |
+
# Test MACD calculation
|
| 228 |
+
macd = agent._calculate_macd(prices)
|
| 229 |
+
assert len(macd) == len(prices)
|
| 230 |
+
|
| 231 |
+
@patch('agentic_ai_system.finrl_agent.PPO')
|
| 232 |
+
def test_training_ppo(self, mock_ppo, finrl_config, sample_data):
|
| 233 |
+
"""Test PPO training"""
|
| 234 |
+
# Mock the PPO model
|
| 235 |
+
mock_model = Mock()
|
| 236 |
+
mock_ppo.return_value = mock_model
|
| 237 |
+
|
| 238 |
+
agent = FinRLAgent(finrl_config)
|
| 239 |
+
result = agent.train(sample_data, total_timesteps=100)
|
| 240 |
+
|
| 241 |
+
assert result['algorithm'] == 'PPO'
|
| 242 |
+
assert result['total_timesteps'] == 100
|
| 243 |
+
mock_model.learn.assert_called_once()
|
| 244 |
+
|
| 245 |
+
@patch('agentic_ai_system.finrl_agent.A2C')
|
| 246 |
+
def test_training_a2c(self, mock_a2c):
|
| 247 |
+
"""Test A2C training"""
|
| 248 |
+
config = FinRLConfig(algorithm="A2C")
|
| 249 |
+
mock_model = Mock()
|
| 250 |
+
mock_a2c.return_value = mock_model
|
| 251 |
+
|
| 252 |
+
agent = FinRLAgent(config)
|
| 253 |
+
sample_data = pd.DataFrame({
|
| 254 |
+
'open': [100, 101, 102],
|
| 255 |
+
'high': [101, 102, 103],
|
| 256 |
+
'low': [99, 100, 101],
|
| 257 |
+
'close': [100, 101, 102],
|
| 258 |
+
'volume': [1000, 1100, 1200]
|
| 259 |
+
})
|
| 260 |
+
|
| 261 |
+
result = agent.train(sample_data, total_timesteps=100)
|
| 262 |
+
|
| 263 |
+
assert result['algorithm'] == 'A2C'
|
| 264 |
+
mock_model.learn.assert_called_once()
|
| 265 |
+
|
| 266 |
+
def test_invalid_algorithm(self):
|
| 267 |
+
"""Test invalid algorithm handling"""
|
| 268 |
+
config = FinRLConfig(algorithm="INVALID")
|
| 269 |
+
agent = FinRLAgent(config)
|
| 270 |
+
sample_data = pd.DataFrame({
|
| 271 |
+
'open': [100, 101, 102],
|
| 272 |
+
'high': [101, 102, 103],
|
| 273 |
+
'low': [99, 100, 101],
|
| 274 |
+
'close': [100, 101, 102],
|
| 275 |
+
'volume': [1000, 1100, 1200]
|
| 276 |
+
})
|
| 277 |
+
|
| 278 |
+
with pytest.raises(ValueError, match="Unsupported algorithm"):
|
| 279 |
+
agent.train(sample_data, total_timesteps=100)
|
| 280 |
+
|
| 281 |
+
def test_predict_without_training(self, finrl_config, sample_data):
|
| 282 |
+
"""Test prediction without training"""
|
| 283 |
+
agent = FinRLAgent(finrl_config)
|
| 284 |
+
|
| 285 |
+
with pytest.raises(ValueError, match="Model not trained"):
|
| 286 |
+
agent.predict(sample_data)
|
| 287 |
+
|
| 288 |
+
def test_evaluate_without_training(self, finrl_config, sample_data):
|
| 289 |
+
"""Test evaluation without training"""
|
| 290 |
+
agent = FinRLAgent(finrl_config)
|
| 291 |
+
|
| 292 |
+
with pytest.raises(ValueError, match="Model not trained"):
|
| 293 |
+
agent.evaluate(sample_data)
|
| 294 |
+
|
| 295 |
+
@patch('agentic_ai_system.finrl_agent.PPO')
|
| 296 |
+
def test_save_and_load_model(self, mock_ppo, finrl_config, sample_data):
|
| 297 |
+
"""Test model saving and loading"""
|
| 298 |
+
# Mock the PPO model
|
| 299 |
+
mock_model = Mock()
|
| 300 |
+
mock_ppo.return_value = mock_model
|
| 301 |
+
mock_ppo.load.return_value = mock_model
|
| 302 |
+
|
| 303 |
+
agent = FinRLAgent(finrl_config)
|
| 304 |
+
|
| 305 |
+
# Train the agent
|
| 306 |
+
agent.train(sample_data, total_timesteps=100)
|
| 307 |
+
|
| 308 |
+
# Test saving
|
| 309 |
+
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as tmp_file:
|
| 310 |
+
agent.save_model(tmp_file.name)
|
| 311 |
+
mock_model.save.assert_called_once_with(tmp_file.name)
|
| 312 |
+
|
| 313 |
+
# Test loading
|
| 314 |
+
agent.load_model(tmp_file.name)
|
| 315 |
+
mock_ppo.load.assert_called_once_with(tmp_file.name)
|
| 316 |
+
|
| 317 |
+
# Clean up
|
| 318 |
+
os.unlink(tmp_file.name)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class TestFinRLIntegration:
|
| 322 |
+
"""Test FinRL integration with configuration"""
|
| 323 |
+
|
| 324 |
+
def test_create_agent_from_config(self):
|
| 325 |
+
"""Test creating agent from configuration file"""
|
| 326 |
+
config_data = {
|
| 327 |
+
'finrl': {
|
| 328 |
+
'algorithm': 'PPO',
|
| 329 |
+
'learning_rate': 0.001,
|
| 330 |
+
'batch_size': 128,
|
| 331 |
+
'gamma': 0.95
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as tmp_file:
|
| 336 |
+
yaml.dump(config_data, tmp_file)
|
| 337 |
+
tmp_file_path = tmp_file.name
|
| 338 |
+
|
| 339 |
+
try:
|
| 340 |
+
agent = create_finrl_agent_from_config(tmp_file_path)
|
| 341 |
+
|
| 342 |
+
assert agent.config.algorithm == 'PPO'
|
| 343 |
+
assert agent.config.learning_rate == 0.001
|
| 344 |
+
assert agent.config.batch_size == 128
|
| 345 |
+
assert agent.config.gamma == 0.95
|
| 346 |
+
finally:
|
| 347 |
+
os.unlink(tmp_file_path)
|
| 348 |
+
|
| 349 |
+
def test_create_agent_from_config_missing_finrl(self):
|
| 350 |
+
"""Test creating agent from config without finrl section"""
|
| 351 |
+
config_data = {
|
| 352 |
+
'trading': {
|
| 353 |
+
'symbol': 'AAPL',
|
| 354 |
+
'capital': 100000
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as tmp_file:
|
| 359 |
+
yaml.dump(config_data, tmp_file)
|
| 360 |
+
tmp_file_path = tmp_file.name
|
| 361 |
+
|
| 362 |
+
try:
|
| 363 |
+
agent = create_finrl_agent_from_config(tmp_file_path)
|
| 364 |
+
|
| 365 |
+
# Should use default values
|
| 366 |
+
assert agent.config.algorithm == 'PPO'
|
| 367 |
+
assert agent.config.learning_rate == 0.0003
|
| 368 |
+
finally:
|
| 369 |
+
os.unlink(tmp_file_path)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
if __name__ == "__main__":
|
| 373 |
+
pytest.main([__file__])
|