import streamlit as st import pandas as pd import numpy as np import torch from transformers import AutoTokenizer, AutoModelForCausalLM import matplotlib.pyplot as plt import time import json import re import os import asyncio # ------------------------------- # Utility Functions # ------------------------------- token = st.secrets["HF_TOKEN"] os.environ['CURL_CA_BUNDLE'] = '' @st.cache_resource def load_model(model_id: str, token: str): """ Loads and caches the Gemma model and tokenizer with authentication token. """ try: # Create and run an event loop explicitly asyncio.run(async_load(model_id, token)) # Ensure torch classes path is valid (optional) if not hasattr(torch, "classes") or not torch.classes: torch.classes = torch._C._get_python_module("torch.classes") tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) model = AutoModelForCausalLM.from_pretrained(model_id, token=token) return tokenizer, model except Exception as e: print(f"An error occurred: {e}") st.error(f"Model loading failed: {e}") return None, None async def async_load(model_id, token): """ Dummy async function to initialize the event loop. """ await asyncio.sleep(0.1) # Dummy async operation def preprocess_data(uploaded_file, file_extension): """ Reads the uploaded file and returns a processed version. Supports CSV, JSONL, and TXT. """ data = None try: if file_extension == "csv": data = pd.read_csv(uploaded_file) elif file_extension == "jsonl": # Each line is a JSON object. data = [json.loads(line) for line in uploaded_file.readlines()] try: data = pd.DataFrame(data) except Exception: st.warning("Unable to convert JSONL to a table. Previewing raw JSON objects.") elif file_extension == "txt": text_data = uploaded_file.read().decode("utf-8") data = text_data.splitlines() except Exception as e: st.error(f"Error processing file: {e}") return data def clean_text(text, lowercase=True, remove_punctuation=True): """ Cleans text data by applying basic normalization. """ if lowercase: text = text.lower() if remove_punctuation: text = re.sub(r'[^\w\s]', '', text) return text def plot_training_metrics(epochs, loss_values, accuracy_values): """ Returns a matplotlib figure plotting training loss and accuracy. """ fig, ax = plt.subplots(1, 2, figsize=(12, 4)) ax[0].plot(range(1, epochs+1), loss_values, marker='o', color='red') ax[0].set_title("Training Loss") ax[0].set_xlabel("Epoch") ax[0].set_ylabel("Loss") ax[1].plot(range(1, epochs+1), accuracy_values, marker='o', color='green') ax[1].set_title("Training Accuracy") ax[1].set_xlabel("Epoch") ax[1].set_ylabel("Accuracy") return fig def simulate_training(num_epochs): """ Simulates a training loop for demonstration. Yields current epoch, loss values, and accuracy values. Replace this with your actual fine-tuning loop. """ loss_values = [] accuracy_values = [] for epoch in range(1, num_epochs + 1): loss = np.exp(-epoch) + np.random.random() * 0.1 acc = 0.5 + (epoch / num_epochs) * 0.5 + np.random.random() * 0.05 loss_values.append(loss) accuracy_values.append(acc) yield epoch, loss_values, accuracy_values time.sleep(1) # Simulate computation time def quantize_model(model): """ Applies dynamic quantization for demonstration. In practice, adjust this based on your model and target hardware. """ quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) return quantized_model def convert_to_torchscript(model): """ Converts the model to TorchScript format. """ example_input = torch.randint(0, 100, (1, 10)) traced_model = torch.jit.trace(model, example_input) return traced_model def convert_to_onnx(model, output_path="model.onnx"): """ Converts the model to ONNX format. """ dummy_input = torch.randint(0, 100, (1, 10)) torch.onnx.export(model, dummy_input, output_path, input_names=["input"], output_names=["output"]) return output_path def load_finetuned_model(model, checkpoint_path="fine_tuned_model.pt"): """ Loads the fine-tuned model from the checkpoint. """ if os.path.exists(checkpoint_path): model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))) model.eval() st.success("Fine-tuned model loaded successfully!") else: st.error(f"Checkpoint not found: {checkpoint_path}") return model def generate_response(prompt, model, tokenizer, max_length=200): """ Generates a response using the fine-tuned model. """ # Tokenize the prompt inputs = tokenizer(prompt, return_tensors="pt").input_ids # Generate text with torch.no_grad(): outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1, temperature=0.7) # Decode the output response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response # ------------------------------- # Application Layout # ------------------------------- st.title("One-Stop Gemma Model Fine-tuning, Quantization & Conversion UI") st.markdown(""" This application is designed for beginners in generative AI. It allows you to fine-tune, quantize, and convert Gemma models with an intuitive UI. You can upload your dataset, clean and preview your data, configure training parameters, and export your model in different formats. """) # Sidebar: Model selection and data upload st.sidebar.header("Configuration") # Model Selection selected_model = st.sidebar.selectbox("Select Gemma Model", options=["Gemma-Small", "Gemma-Medium", "Gemma-Large"]) if selected_model == "google/gemma-3-1b-it": model_id = "google/gemma-3-1b-it" elif selected_model == "google/gemma-3-4b-it": model_id = "google/gemma-3-4b-it" else: model_id = "google/gemma-3-1b-it" loading_placeholder = st.sidebar.empty() loading_placeholder.info("Loading model...") tokenizer, model = load_model(model_id, token) loading_placeholder.success("Model loaded.") # Dataset Upload uploaded_file = st.sidebar.file_uploader("Upload Dataset (CSV, JSONL, TXT)", type=["csv", "jsonl", "txt"]) data = None if uploaded_file is not None: file_ext = uploaded_file.name.split('.')[-1].lower() data = preprocess_data(uploaded_file, file_ext) st.sidebar.subheader("Dataset Preview:") if isinstance(data, pd.DataFrame): st.sidebar.dataframe(data.head()) elif isinstance(data, list): st.sidebar.write(data[:5]) else: st.sidebar.write(data) else: st.sidebar.info("Awaiting dataset upload.") # Data Cleaning Options (for TXT files) if uploaded_file is not None and file_ext == "txt": st.sidebar.subheader("Data Cleaning Options") lowercase_option = st.sidebar.checkbox("Convert to lowercase", value=True) remove_punct = st.sidebar.checkbox("Remove punctuation", value=True) cleaned_data = [clean_text(line, lowercase=lowercase_option, remove_punctuation=remove_punct) for line in data] st.sidebar.text_area("Cleaned Data Preview", value="\n".join(cleaned_data[:5]), height=150) # Main Tabs for Different Operations tabs = st.tabs(["Fine-tuning", "Quantization", "Model Conversion"]) # ------------------------------- # Fine-tuning Tab # ------------------------------- with tabs[0]: st.header("Fine-tuning") st.markdown("Configure hyperparameters and start fine-tuning your Gemma model.") col1, col2, col3 = st.columns(3) with col1: learning_rate = st.number_input("Learning Rate", value=1e-4, format="%.5f") with col2: batch_size = st.number_input("Batch Size", value=16, step=1) with col3: epochs = st.number_input("Epochs", value=3, step=1) if st.button("Start Fine-tuning"): if data is None: st.error("Please upload a dataset first!") else: st.info("Starting fine-tuning...") progress_bar = st.progress(0) training_placeholder = st.empty() loss_values = [] accuracy_values = [] # Simulate training loop (replace with your actual training code) for epoch, losses, accs in simulate_training(epochs): fig = plot_training_metrics(epoch, losses, accs) training_placeholder.pyplot(fig) progress_bar.progress(epoch/epochs) st.success("Fine-tuning completed!") # Save the fine-tuned model (for demonstration, saving state_dict) if model: torch.save(model.state_dict(), "fine_tuned_model.pt") with open("fine_tuned_model.pt", "rb") as f: st.download_button("Download Fine-tuned Model", data=f, file_name="fine_tuned_model.pt", mime="application/octet-stream") else: st.error("Model not loaded. Cannot save.") # ------------------------------- # Quantization Tab # ------------------------------- with tabs[1]: st.header("Model Quantization") st.markdown("Quantize your model to optimize for inference performance.") quantize_choice = st.radio("Select Quantization Type", options=["Dynamic Quantization"], index=0) if st.button("Apply Quantization"): with st.spinner("Applying quantization..."): quantized_model = quantize_model(model) st.success("Model quantized successfully!") torch.save(quantized_model.state_dict(), "quantized_model.pt") with open("quantized_model.pt", "rb") as f: st.download_button("Download Quantized Model", data=f, file_name="quantized_model.pt", mime="application/octet-stream") # ------------------------------- # Model Conversion Tab # ------------------------------- with tabs[2]: st.header("Model Conversion") st.markdown("Convert your model to a different format for deployment or optimization.") conversion_option = st.selectbox("Select Conversion Format", options=["TorchScript", "ONNX"]) if st.button("Convert Model"): if conversion_option == "TorchScript": with st.spinner("Converting to TorchScript..."): ts_model = convert_to_torchscript(model) ts_model.save("model_ts.pt") st.success("Converted to TorchScript!") with open("model_ts.pt", "rb") as f: st.download_button("Download TorchScript Model", data=f, file_name="model_ts.pt", mime="application/octet-stream") elif conversion_option == "ONNX": with st.spinner("Converting to ONNX..."): onnx_path = convert_to_onnx(model, "model.onnx") st.success("Converted to ONNX!") with open(onnx_path, "rb") as f: st.download_button("Download ONNX Model", data=f, file_name="model.onnx", mime="application/octet-stream") # ------------------------------- # Response Generation Section # ------------------------------- st.header("Generate Responses with Fine-Tuned Model") st.markdown("Use the fine-tuned model to generate text responses based on your prompts.") # Check if the fine-tuned model exists if os.path.exists("fine_tuned_model.pt"): # Load the fine-tuned model model = load_finetuned_model(model, "fine_tuned_model.pt") # Input prompt for generating responses prompt = st.text_area("Enter a prompt:", "Once upon a time...") # Max length slider max_length = st.slider("Max Response Length", min_value=50, max_value=500, value=200, step=10) if st.button("Generate Response"): with st.spinner("Generating response..."): response = generate_response(prompt, model, tokenizer, max_length) st.success("Generated Response:") st.write(response) else: st.warning("Fine-tuned model not found. Please fine-tune the model first.") # ------------------------------- # Optional: Cloud Integration Snippet # ------------------------------- st.header("Cloud Integration") st.markdown(""" For large-scale training or model storage, consider integrating with Google Cloud Storage or Vertex AI. Below is an example snippet for uploading your model to GCS: """) st.code(""" from google.cloud import storage def upload_to_gcs(bucket_name, source_file_name, destination_blob_name): storage_client = storage.Client() bucket = storage_client.bucket(bucket_name) blob = bucket.blob(destination_blob_name) blob.upload_from_filename(source_file_name) print(f"Uploaded {source_file_name} to {destination_blob_name}") # Example usage: # upload_to_gcs("your-bucket-name", "fine_tuned_model.pt", "models/fine_tuned_model.pt") """, language="python")