import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import PyPDF2 import torch import os from huggingface_hub import login import warnings warnings.filterwarnings("ignore") st.set_page_config(page_title="AI Study Assistant - Mistral 7B", layout="wide") st.title("🧠 AI Study Assistant using Mistral 7B") # Enhanced token validation and authentication def validate_hf_token(): """Validate and authenticate Hugging Face token""" hf_token = None # Try multiple sources for the token token_sources = [ ("Environment Variable", os.getenv("HF_TOKEN")), ("Streamlit Secrets", st.secrets.get("HF_TOKEN", None) if hasattr(st, 'secrets') else None), ("Manual Input", None) # Will be handled below ] for source, token in token_sources: if token: st.success(f"✅ Token found from: {source}") hf_token = token break if not hf_token: st.warning("🔑 No token found in environment or secrets. Please enter manually:") hf_token = st.text_input( "Enter your Hugging Face Token:", type="password", help="Get your token from https://huggingface.co/settings/tokens" ) if hf_token: try: # Test token validity api = HfApi() user_info = api.whoami(token=hf_token) st.success(f"✅ Authenticated as: {user_info['name']}") # Attempt to login login(token=hf_token, add_to_git_credential=False) return hf_token except Exception as e: st.error(f"❌ Token validation failed: {str(e)}") st.info("Please check your token and ensure you have access to Mistral 7B model") return None return None def check_model_access(token): """Check if user has access to the Mistral model""" try: api = HfApi() model_info = api.model_info("mistralai/Mistral-7B-Instruct-v0.1", token=token) st.success("✅ Model access confirmed") return True except Exception as e: st.error("❌ Cannot access Mistral 7B model") st.info(""" **To fix this:** 1. Visit: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 2. Click "Request Access" 3. Wait for approval (usually instant for most users) 4. Refresh this page """) return False @st.cache_resource def load_model(hf_token): """Load the Mistral model with proper error handling""" try: st.info("🔄 Loading Mistral 7B model... This may take a few minutes on first run.") # Load tokenizer first tokenizer = AutoTokenizer.from_pretrained( "mistralai/Mistral-7B-Instruct-v0.1", token=hf_token, trust_remote_code=True ) # Add padding token if missing if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model with optimizations model = AutoModelForCausalLM.from_pretrained( "mistralai/Mistral-7B-Instruct-v0.1", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, token=hf_token, trust_remote_code=True, low_cpu_mem_usage=True ) # Create pipeline pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id ) st.success("✅ Model loaded successfully!") return pipe except Exception as e: st.error(f"❌ Model loading failed: {str(e)}") st.info("Try refreshing the page or check your internet connection") return None def extract_text_from_pdf(file): """Extract text from uploaded PDF with error handling""" try: reader = PyPDF2.PdfReader(file) text = "" for page_num, page in enumerate(reader.pages): page_text = page.extract_text() if page_text.strip(): text += f"\n--- Page {page_num + 1} ---\n{page_text}\n" if not text.strip(): st.warning("⚠️ No text extracted from PDF. It might be image-based.") return "" return text except Exception as e: st.error(f"❌ PDF processing failed: {str(e)}") return "" def format_prompt(context, query): """Create properly formatted Mistral prompt""" if context.strip(): prompt = f"[INST] Use the following context to answer the question comprehensively:\n\nContext:\n{context[:3000]}...\n\nQuestion: {query}\n\nProvide a detailed, accurate answer based on the context. [/INST]" else: prompt = f"[INST] {query} [/INST]" return prompt # Main Application Flow def main(): # Step 1: Validate token hf_token = validate_hf_token() if not hf_token: st.stop() # Step 2: Check model access if not check_model_access(hf_token): st.stop() # Step 3: Load model textgen = load_model(hf_token) if not textgen: st.stop() # Step 4: User Interface st.markdown("---") col1, col2 = st.columns([2, 1]) with col1: query = st.text_area( "💭 Ask your question:", height=100, placeholder="e.g., Explain machine learning concepts, summarize this document, etc." ) with col2: uploaded_file = st.file_uploader( "📎 Upload PDF Context (Optional):", type=["pdf"], help="Upload a PDF to provide context for your question" ) # Process uploaded file context = "" if uploaded_file: with st.spinner("📖 Extracting text from PDF..."): context = extract_text_from_pdf(uploaded_file) if context: with st.expander("📄 View Extracted Text", expanded=False): st.text_area("PDF Content Preview:", context[:1000] + "..." if len(context) > 1000 else context, height=200) st.success(f"✅ Extracted {len(context)} characters from PDF") # Generate answer if st.button("🚀 Generate Answer", type="primary"): if not query.strip(): st.warning("⚠️ Please enter a question") return with st.spinner("🤔 Generating answer..."): try: prompt = format_prompt(context, query) # Generate response result = textgen(prompt, max_new_tokens=512, temperature=0.7) generated_text = result[0]["generated_text"] # Extract only the generated part answer = generated_text.split("[/INST]")[-1].strip() # Display result st.markdown("### 🎯 Answer:") st.markdown(answer) # Show token usage info with st.expander("📊 Generation Details", expanded=False): st.write(f"**Prompt length:** {len(prompt)} characters") st.write(f"**Response length:** {len(answer)} characters") st.write(f"**Context used:** {'Yes' if context else 'No'}") except Exception as e: st.error(f"❌ Generation failed: {str(e)}") st.info("Try with a shorter question or refresh the page") # Run the application if __name__ == "__main__": main()