Studymaker2 / app.py
g0th's picture
Update app.py
7db3429 verified
raw
history blame
7.91 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import PyPDF2
import torch
import os
from huggingface_hub import login
import warnings
warnings.filterwarnings("ignore")
st.set_page_config(page_title="AI Study Assistant - Mistral 7B", layout="wide")
st.title("🧠 AI Study Assistant using Mistral 7B")
# Enhanced token validation and authentication
def validate_hf_token():
"""Validate and authenticate Hugging Face token"""
hf_token = None
# Try multiple sources for the token
token_sources = [
("Environment Variable", os.getenv("HF_TOKEN")),
("Streamlit Secrets", st.secrets.get("HF_TOKEN", None) if hasattr(st, 'secrets') else None),
("Manual Input", None) # Will be handled below
]
for source, token in token_sources:
if token:
st.success(f"βœ… Token found from: {source}")
hf_token = token
break
if not hf_token:
st.warning("πŸ”‘ No token found in environment or secrets. Please enter manually:")
hf_token = st.text_input(
"Enter your Hugging Face Token:",
type="password",
help="Get your token from https://huggingface.co/settings/tokens"
)
if hf_token:
try:
# Test token validity
api = HfApi()
user_info = api.whoami(token=hf_token)
st.success(f"βœ… Authenticated as: {user_info['name']}")
# Attempt to login
login(token=hf_token, add_to_git_credential=False)
return hf_token
except Exception as e:
st.error(f"❌ Token validation failed: {str(e)}")
st.info("Please check your token and ensure you have access to Mistral 7B model")
return None
return None
def check_model_access(token):
"""Check if user has access to the Mistral model"""
try:
api = HfApi()
model_info = api.model_info("mistralai/Mistral-7B-Instruct-v0.1", token=token)
st.success("βœ… Model access confirmed")
return True
except Exception as e:
st.error("❌ Cannot access Mistral 7B model")
st.info("""
**To fix this:**
1. Visit: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
2. Click "Request Access"
3. Wait for approval (usually instant for most users)
4. Refresh this page
""")
return False
@st.cache_resource
def load_model(hf_token):
"""Load the Mistral model with proper error handling"""
try:
st.info("πŸ”„ Loading Mistral 7B model... This may take a few minutes on first run.")
# Load tokenizer first
tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
token=hf_token,
trust_remote_code=True
)
# Add padding token if missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model with optimizations
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
token=hf_token,
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Create pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
st.success("βœ… Model loaded successfully!")
return pipe
except Exception as e:
st.error(f"❌ Model loading failed: {str(e)}")
st.info("Try refreshing the page or check your internet connection")
return None
def extract_text_from_pdf(file):
"""Extract text from uploaded PDF with error handling"""
try:
reader = PyPDF2.PdfReader(file)
text = ""
for page_num, page in enumerate(reader.pages):
page_text = page.extract_text()
if page_text.strip():
text += f"\n--- Page {page_num + 1} ---\n{page_text}\n"
if not text.strip():
st.warning("⚠️ No text extracted from PDF. It might be image-based.")
return ""
return text
except Exception as e:
st.error(f"❌ PDF processing failed: {str(e)}")
return ""
def format_prompt(context, query):
"""Create properly formatted Mistral prompt"""
if context.strip():
prompt = f"<s>[INST] Use the following context to answer the question comprehensively:\n\nContext:\n{context[:3000]}...\n\nQuestion: {query}\n\nProvide a detailed, accurate answer based on the context. [/INST]"
else:
prompt = f"<s>[INST] {query} [/INST]"
return prompt
# Main Application Flow
def main():
# Step 1: Validate token
hf_token = validate_hf_token()
if not hf_token:
st.stop()
# Step 2: Check model access
if not check_model_access(hf_token):
st.stop()
# Step 3: Load model
textgen = load_model(hf_token)
if not textgen:
st.stop()
# Step 4: User Interface
st.markdown("---")
col1, col2 = st.columns([2, 1])
with col1:
query = st.text_area(
"πŸ’­ Ask your question:",
height=100,
placeholder="e.g., Explain machine learning concepts, summarize this document, etc."
)
with col2:
uploaded_file = st.file_uploader(
"πŸ“Ž Upload PDF Context (Optional):",
type=["pdf"],
help="Upload a PDF to provide context for your question"
)
# Process uploaded file
context = ""
if uploaded_file:
with st.spinner("πŸ“– Extracting text from PDF..."):
context = extract_text_from_pdf(uploaded_file)
if context:
with st.expander("πŸ“„ View Extracted Text", expanded=False):
st.text_area("PDF Content Preview:", context[:1000] + "..." if len(context) > 1000 else context, height=200)
st.success(f"βœ… Extracted {len(context)} characters from PDF")
# Generate answer
if st.button("πŸš€ Generate Answer", type="primary"):
if not query.strip():
st.warning("⚠️ Please enter a question")
return
with st.spinner("πŸ€” Generating answer..."):
try:
prompt = format_prompt(context, query)
# Generate response
result = textgen(prompt, max_new_tokens=512, temperature=0.7)
generated_text = result[0]["generated_text"]
# Extract only the generated part
answer = generated_text.split("[/INST]")[-1].strip()
# Display result
st.markdown("### 🎯 Answer:")
st.markdown(answer)
# Show token usage info
with st.expander("πŸ“Š Generation Details", expanded=False):
st.write(f"**Prompt length:** {len(prompt)} characters")
st.write(f"**Response length:** {len(answer)} characters")
st.write(f"**Context used:** {'Yes' if context else 'No'}")
except Exception as e:
st.error(f"❌ Generation failed: {str(e)}")
st.info("Try with a shorter question or refresh the page")
# Run the application
if __name__ == "__main__":
main()