import streamlit as st import pandas as pd import numpy as np import tensorflow as tf from sklearn.preprocessing import MinMaxScaler import plotly.express as px import os from groq import Groq # Initialize session state if 'model' not in st.session_state: st.session_state.model = None if 'threshold' not in st.session_state: st.session_state.threshold = None # Autoencoder model definition def build_autoencoder(input_dim): model = tf.keras.Sequential([ tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)), tf.keras.layers.Dense(32, activation='relu'), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(input_dim) ]) model.compile(optimizer='adam', loss='mse') return model # Sidebar controls st.sidebar.title("Configuration") fine_tune = st.sidebar.button("Fine-tune Model") groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password") # Main interface st.title("🛰️ AI Network Anomaly Detector") st.write("Upload your network data (CSV) to detect anomalies") # File uploader uploaded_file = st.file_uploader("Choose a CSV file", type="csv") # Load or generate sample data if uploaded_file is not None: data = pd.read_csv(uploaded_file) else: st.info("Using sample data. Upload a file to use your own dataset.") data = pd.read_csv("sample_wifi_data.csv") # You should provide this sample file # Preprocessing features = ['device_count', 'connection_attempts', 'packet_loss', 'latency'] scaler = MinMaxScaler() data_scaled = scaler.fit_transform(data[features]) # Model training/fine-tuning if fine_tune or st.session_state.model is None: with st.spinner("Training model..."): autoencoder = build_autoencoder(data_scaled.shape[1]) autoencoder.fit(data_scaled, data_scaled, epochs=100, batch_size=32, verbose=0, validation_split=0.1) st.session_state.model = autoencoder # Calculate threshold reconstructions = autoencoder.predict(data_scaled) mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1) st.session_state.threshold = np.percentile(mse, 95) # Anomaly detection if st.session_state.model and st.button("Detect Anomalies"): reconstructions = st.session_state.model.predict(data_scaled) mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1) anomalies = mse > st.session_state.threshold # Visualization fig = px.line(data, x=data.index, y='connection_attempts', title='Network Traffic with Anomalies') fig.add_scatter(x=data[anomalies].index, y=data[anomalies]['connection_attempts'], mode='markers', name='Anomalies') st.plotly_chart(fig) # Generate alert with Groq/Llama3 if groq_api_key: try: client = Groq(api_key=groq_api_key) response = client.chat.completions.create( model="llama3-70b-8192", messages=[{ "role": "user", "content": f"Generate a network security alert for {sum(anomalies)} anomalies detected. Max connection attempts: {data['connection_attempts'].max()}" }] ) st.warning(response.choices[0].message.content) except Exception as e: st.error(f"Groq API Error: {str(e)}") else: st.warning(f"Detected {sum(anomalies)} anomalies! Consider adding Groq API key for detailed analysis.") # Download button for results if st.session_state.threshold: st.download_button( label="Download Anomaly Report", data=data[anomalies].to_csv().encode('utf-8'), file_name='anomalies_report.csv', mime='text/csv' )