Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import torch | |
| import random | |
| from transformers import ( | |
| GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling, | |
| TrainerCallback # Import TrainerCallback here | |
| ) | |
| from datasets import Dataset | |
| from huggingface_hub import HfApi | |
| import plotly.graph_objects as go | |
| import time | |
| from datetime import datetime | |
| import threading | |
| # Cyberpunk and Loading Animation Styling | |
| def setup_cyberpunk_style(): | |
| st.markdown(""" | |
| <style> | |
| body, button, input, select, textarea { | |
| font-family: 'Orbitron', sans-serif !important; | |
| color: #00ff9d !important; | |
| } | |
| .stApp { | |
| background: radial-gradient(circle, rgba(0, 0, 0, 0.95) 20%, rgba(0, 50, 80, 0.95) 90%); | |
| color: #00ff9d; | |
| font-family: 'Orbitron', sans-serif; | |
| font-size: 16px; | |
| line-height: 1.6; | |
| padding: 20px; | |
| box-sizing: border-box; | |
| } | |
| .main-title { | |
| text-align: center; | |
| font-size: 4em; | |
| color: #00ff9d; | |
| letter-spacing: 4px; | |
| animation: glow 2s ease-in-out infinite alternate; | |
| } | |
| @keyframes glow { | |
| from {text-shadow: 0 0 5px #00ff9d, 0 0 10px #00ff9d;} | |
| to {text-shadow: 0 0 15px #00b8ff, 0 0 20px #00b8ff;} | |
| } | |
| .stButton > button { | |
| font-family: 'Orbitron', sans-serif; | |
| background: linear-gradient(45deg, #00ff9d, #00b8ff); | |
| color: #000; | |
| font-size: 1.1em; | |
| padding: 10px 20px; | |
| border: none; | |
| border-radius: 8px; | |
| transition: all 0.3s ease; | |
| } | |
| .stButton > button:hover { | |
| transform: scale(1.1); | |
| box-shadow: 0 0 20px rgba(0, 255, 157, 0.5); | |
| } | |
| .progress-bar-container { | |
| background: rgba(0, 0, 0, 0.5); | |
| border-radius: 15px; | |
| overflow: hidden; | |
| width: 100%; | |
| height: 30px; | |
| position: relative; | |
| margin: 10px 0; | |
| } | |
| .progress-bar { | |
| height: 100%; | |
| width: 0%; | |
| background: linear-gradient(45deg, #00ff9d, #00b8ff); | |
| transition: width 0.5s ease; | |
| } | |
| .go-button { | |
| font-family: 'Orbitron', sans-serif; | |
| background: linear-gradient(45deg, #00ff9d, #00b8ff); | |
| color: #000; | |
| font-size: 1.1em; | |
| padding: 10px 20px; | |
| border: none; | |
| border-radius: 8px; | |
| transition: all 0.3s ease; | |
| cursor: pointer; | |
| } | |
| .go-button:hover { | |
| transform: scale(1.1); | |
| box-shadow: 0 0 20px rgba(0, 255, 157, 0.5); | |
| } | |
| .loading-animation { | |
| display: inline-block; | |
| width: 20px; | |
| height: 20px; | |
| border: 3px solid #00ff9d; | |
| border-radius: 50%; | |
| border-top-color: transparent; | |
| animation: spin 1s ease-in-out infinite; | |
| } | |
| @keyframes spin { | |
| to {transform: rotate(360deg);} | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Prepare Dataset Function with Padding Token Fix | |
| def prepare_dataset(data, tokenizer, block_size=128): | |
| tokenizer.pad_token = tokenizer.eos_token | |
| def tokenize_function(examples): | |
| return tokenizer(examples['text'], truncation=True, max_length=block_size, padding='max_length') | |
| raw_dataset = Dataset.from_dict({'text': data}) | |
| tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=['text']) | |
| tokenized_dataset = tokenized_dataset.map(lambda examples: {'labels': examples['input_ids']}, batched=True) | |
| tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) | |
| return tokenized_dataset | |
| # Training Dashboard Class with Enhanced Display | |
| class TrainingDashboard: | |
| def __init__(self): | |
| self.metrics = { | |
| 'current_loss': 0, | |
| 'best_loss': float('inf'), | |
| 'generation': 0, | |
| 'individual': 0, | |
| 'start_time': time.time(), | |
| 'training_speed': 0 | |
| } | |
| self.history = [] | |
| def update(self, loss, generation, individual): | |
| self.metrics['current_loss'] = loss | |
| self.metrics['generation'] = generation | |
| self.metrics['individual'] = individual | |
| if loss < self.metrics['best_loss']: | |
| self.metrics['best_loss'] = loss | |
| elapsed_time = time.time() - self.metrics['start_time'] | |
| self.metrics['training_speed'] = (generation * individual) / elapsed_time | |
| self.history.append({'loss': loss, 'timestamp': datetime.now().strftime('%H:%M:%S')}) | |
| # Define Model Initialization | |
| def initialize_model(model_name="gpt2"): | |
| model = GPT2LMHeadModel.from_pretrained(model_name) | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return model, tokenizer | |
| # Load Dataset Function with Uploaded File Option | |
| def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None): | |
| if data_source == "demo": | |
| data = ["In the neon-lit streets of Neo-Tokyo, a lone hacker fights against the oppressive megacorporations.", | |
| "The rain falls in sheets, washing away the bloodstains from the alleyways.", | |
| "She plugs into the matrix, seeking answers to questions that have haunted her for years."] | |
| elif uploaded_file is not None: | |
| if uploaded_file.name.endswith(".txt"): | |
| data = [uploaded_file.read().decode("utf-8")] | |
| elif uploaded_file.name.endswith(".csv"): | |
| import pandas as pd | |
| df = pd.read_csv(uploaded_file) | |
| data = df[df.columns[0]].tolist() # assuming first column is text data | |
| else: | |
| data = ["No file uploaded. Please upload a dataset."] | |
| dataset = prepare_dataset(data, tokenizer) | |
| return dataset | |
| # Train Model Function with Customized Progress Bar | |
| def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4, progress_callback=None): | |
| training_args = TrainingArguments( | |
| output_dir="./results", | |
| overwrite_output_dir=True, | |
| num_train_epochs=epochs, | |
| per_device_train_batch_size=batch_size, | |
| save_steps=10_000, | |
| save_total_limit=2, | |
| logging_dir="./logs", | |
| logging_steps=100, | |
| ) | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| data_collator=data_collator, | |
| train_dataset=train_dataset, | |
| callbacks=[ProgressCallback(progress_callback)] | |
| ) | |
| trainer.train() | |
| class ProgressCallback(TrainerCallback): | |
| def __init__(self, progress_callback): | |
| super().__init__() | |
| self.progress_callback = progress_callback | |
| def on_epoch_end(self, args, state, control, **kwargs): | |
| loss = state.log_history[-1]['loss'] | |
| generation = state.global_step // args.gradient_accumulation_steps + 1 | |
| individual = args.gradient_accumulation_steps | |
| self.progress_callback(loss, generation, individual) | |
| # Main App Logic | |
| def main(): | |
| setup_cyberpunk_style() | |
| st.markdown('<h1 class="main-title">Neural Training Hub</h1>', unsafe_allow_html=True) | |
| # Initialize model and tokenizer | |
| model, tokenizer = initialize_model() | |
| # Sidebar Configuration with Additional Options | |
| with st.sidebar: | |
| st.markdown("### Configuration Panel") | |
| # Hugging Face API Token Input | |
| hf_token = st.text_input("Enter your Hugging Face Token", type="password") | |
| if hf_token: | |
| api = HfApi() | |
| api.set_access_token(hf_token) | |
| st.success("Hugging Face token added successfully!") | |
| # Training Parameters | |
| training_epochs = st.slider("Training Epochs", min_value=1, max_value=5, value=3) | |
| batch_size = st.slider("Batch Size", min_value=2, max_value=8, value=4) | |
| model_choice = st.selectbox("Model Selection", ("gpt2", "distilgpt2", "gpt2-medium")) | |
| # Dataset Source Selection | |
| data_source = st.selectbox("Data Source", ("demo", "uploaded file")) | |
| uploaded_file = st.file_uploader("Upload a text file", type=["txt", "csv"]) if data_source == "uploaded file" else None | |
| custom_learning_rate = st.slider("Learning Rate", min_value=1e-6, max_value=5e-4, value=3e-5, step=1e-6) | |
| # Advanced Settings Toggle | |
| advanced_toggle = st.checkbox("Advanced Training Settings") | |
| if advanced_toggle: | |
| warmup_steps = st.slider("Warmup Steps", min_value=0, max_value=500, value=100) | |
| weight_decay = st.slider("Weight Decay", min_value=0.0, max_value=0.1, step=0.01, value=0.01) | |
| else: | |
| warmup_steps = 100 | |
| weight_decay = 0.01 | |
| # Load Dataset | |
| train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file) | |
| # Chatbot Interaction | |
| if st.checkbox("Enable Chatbot"): | |
| user_input = st.text_input("You:", placeholder="Type your message here...") | |
| if user_input: | |
| inputs = tokenizer(user_input, return_tensors="pt") | |
| outputs = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| st.write("Bot:", response) | |
| # Go Button to Start Training | |
| if st.button("Go"): | |
| progress_placeholder = st.empty() | |
| loading_animation = st.empty() | |
| st.markdown("### Model Training Progress") | |
| dashboard = TrainingDashboard() | |
| def train_progress(loss, generation, individual): | |
| progress = (generation + 1) / dashboard.metrics['training_epochs'] * 100 | |
| progress_placeholder.markdown(f""" | |
| <div class="progress-bar-container"> | |
| <div class="progress-bar" style="width: {progress}%;"></div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| dashboard.update(loss=loss, generation=generation, individual=individual) | |
| thread = threading.Thread(target=train_model, args=(model, train_dataset, tokenizer, training_epochs, batch_size, train_progress)) | |
| thread.start() | |
| loading_animation.markdown(""" | |
| <div class="loading-animation"></div> | |
| """, unsafe_allow_html=True) | |
| thread.join() | |
| loading_animation.empty() | |
| st.success("Training Complete!") | |
| st.write("Training Metrics:") | |
| st.write(dashboard.metrics) | |
| if __name__ == "__main__": | |
| main() |