|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import yfinance as yf |
|
import talib |
|
import requests |
|
import matplotlib.pyplot as plt |
|
import logging |
|
from tensorflow.keras.models import Sequential |
|
from tensorflow.keras.layers import LSTM, Dense, Dropout |
|
from stable_baselines3 import PPO, DQN |
|
from gym import Env, spaces |
|
from selenium import webdriver |
|
from selenium.webdriver.chrome.service import Service |
|
from selenium.webdriver.chrome.options import Options |
|
from selenium.webdriver.common.by import By |
|
from selenium.webdriver.support.ui import WebDriverWait |
|
from selenium.webdriver.support import expected_conditions as EC |
|
from bs4 import BeautifulSoup |
|
import time |
|
from webdriver_manager.chrome import ChromeDriverManager |
|
import threading |
|
import smtplib |
|
from email.mime.text import MIMEText |
|
from email.mime.multipart import MIMEMultipart |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
BASE_URL = "https://query1.finance.yahoo.com/v8/finance/chart/" |
|
|
|
def fetch_live_price(symbol): |
|
try: |
|
stock = yf.Ticker(symbol) |
|
return stock.history(period="1d")["Close"].iloc[-1] |
|
except Exception as e: |
|
logging.error(f"Error fetching live price: {e}") |
|
return None |
|
|
|
|
|
def scrape_groww(symbol): |
|
options = Options() |
|
options.add_argument("--headless") |
|
driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=options) |
|
url = f"https://groww.in/stocks/{symbol.lower()}" |
|
driver.get(url) |
|
|
|
try: |
|
price_element = WebDriverWait(driver, 10).until( |
|
EC.presence_of_element_located((By.CLASS_NAME, "stock-price")) |
|
) |
|
price = float(price_element.text.replace(',', '')) |
|
except Exception: |
|
price = None |
|
|
|
driver.quit() |
|
return price |
|
|
|
|
|
def fetch_data(symbol, interval='1m', period='5d'): |
|
df = yf.download(symbol, interval=interval, period=period) |
|
df['SMA_10'] = talib.SMA(df['Close'], timeperiod=10) |
|
df['RSI'] = talib.RSI(df['Close'], timeperiod=14) |
|
df['MACD'], df['MACD_signal'], _ = talib.MACD(df['Close']) |
|
return df |
|
|
|
|
|
def train_lstm_model(data): |
|
X, y = np.array([[data[i-10:i].values] for i in range(10, len(data))]), data['Close'][10:].values |
|
model = Sequential([ |
|
LSTM(100, return_sequences=True, input_shape=(10, 1)), |
|
Dropout(0.2), |
|
LSTM(100), |
|
Dense(50, activation='relu'), |
|
Dense(1) |
|
]) |
|
model.compile(optimizer='adam', loss='mse') |
|
model.fit(X, y, epochs=20, batch_size=32) |
|
return model |
|
|
|
|
|
class TradingEnv(Env): |
|
def __init__(self): |
|
self.action_space = spaces.Discrete(3) |
|
self.observation_space = spaces.Box(low=0, high=1, shape=(10,), dtype=np.float32) |
|
self.current_step = 0 |
|
self.balance = 10000 |
|
self.position = 0 |
|
self.history = [] |
|
|
|
def step(self, action): |
|
reward = 0 |
|
done = False |
|
if action == 0: |
|
self.position += 1 |
|
reward -= 0.5 |
|
elif action == 1: |
|
if self.position > 0: |
|
self.position -= 1 |
|
reward += 10 |
|
elif action == 2: |
|
reward += 0.2 |
|
self.current_step += 1 |
|
self.history.append((self.current_step, self.balance, self.position)) |
|
if self.current_step >= 200: |
|
done = True |
|
return np.random.random(10), reward, done, {} |
|
|
|
env = TradingEnv() |
|
model = PPO("MlpPolicy", env, verbose=1) |
|
model.learn(total_timesteps=500000) |
|
|
|
|
|
def place_trade(symbol, action): |
|
return {"status": "success", "symbol": symbol, "action": action} |
|
|
|
|
|
def send_email_alert(subject, body): |
|
sender_email = "[email protected]" |
|
receiver_email = "[email protected]" |
|
password = "your_password" |
|
|
|
msg = MIMEMultipart() |
|
msg['From'] = sender_email |
|
msg['To'] = receiver_email |
|
msg['Subject'] = subject |
|
msg.attach(MIMEText(body, 'plain')) |
|
|
|
with smtplib.SMTP('smtp.gmail.com', 587) as server: |
|
server.starttls() |
|
server.login(sender_email, password) |
|
server.sendmail(sender_email, receiver_email, msg.as_string()) |
|
|
|
|
|
def monitor_price(symbol, threshold): |
|
while True: |
|
price = fetch_live_price(symbol) |
|
if price and price >= threshold: |
|
send_email_alert("Stock Price Alert", f"{symbol} has reached {price}!") |
|
place_trade(symbol, "sell") |
|
break |
|
time.sleep(60) |
|
|
|
|
|
def stock_dashboard(symbol, threshold_price): |
|
data = fetch_data(symbol) |
|
fig, ax = plt.subplots() |
|
ax.plot(data.index, data['Close'], label='Close Price') |
|
ax.plot(data.index, data['SMA_10'], label='SMA 10', linestyle='dashed') |
|
ax.legend() |
|
|
|
live_price = fetch_live_price(symbol) |
|
action = np.random.choice(["buy", "sell", "hold"]) |
|
place_trade(symbol, action) |
|
|
|
return fig, f"Live Price: {live_price}", f"Trade Executed: {action}" |
|
|
|
demo = gr.Interface( |
|
fn=stock_dashboard, |
|
inputs=["text", "number"], |
|
outputs=["plot", "text", "text"], |
|
title="AI-Powered Intraday Trading Agent", |
|
description="Enter a stock symbol and set a price threshold to start trading." |
|
) |
|
|
|
demo.launch() |