Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import PyPDF2 | |
import torch | |
import os | |
from huggingface_hub import login | |
import warnings | |
warnings.filterwarnings("ignore") | |
st.set_page_config(page_title="AI Study Assistant - Mistral 7B", layout="wide") | |
st.title("π§ AI Study Assistant using Mistral 7B") | |
# Enhanced token validation and authentication | |
def validate_hf_token(): | |
"""Validate and authenticate Hugging Face token""" | |
hf_token = None | |
# Try multiple sources for the token | |
token_sources = [ | |
("Environment Variable", os.getenv("HF_TOKEN")), | |
("Streamlit Secrets", st.secrets.get("HF_TOKEN", None) if hasattr(st, 'secrets') else None), | |
("Manual Input", None) # Will be handled below | |
] | |
for source, token in token_sources: | |
if token: | |
st.success(f"β Token found from: {source}") | |
hf_token = token | |
break | |
if not hf_token: | |
st.warning("π No token found in environment or secrets. Please enter manually:") | |
hf_token = st.text_input( | |
"Enter your Hugging Face Token:", | |
type="password", | |
help="Get your token from https://huggingface.co/settings/tokens" | |
) | |
if hf_token: | |
try: | |
# Test token validity | |
api = HfApi() | |
user_info = api.whoami(token=hf_token) | |
st.success(f"β Authenticated as: {user_info['name']}") | |
# Attempt to login | |
login(token=hf_token, add_to_git_credential=False) | |
return hf_token | |
except Exception as e: | |
st.error(f"β Token validation failed: {str(e)}") | |
st.info("Please check your token and ensure you have access to Mistral 7B model") | |
return None | |
return None | |
def check_model_access(token): | |
"""Check if user has access to the Mistral model""" | |
try: | |
api = HfApi() | |
model_info = api.model_info("mistralai/Mistral-7B-Instruct-v0.1", token=token) | |
st.success("β Model access confirmed") | |
return True | |
except Exception as e: | |
st.error("β Cannot access Mistral 7B model") | |
st.info(""" | |
**To fix this:** | |
1. Visit: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 | |
2. Click "Request Access" | |
3. Wait for approval (usually instant for most users) | |
4. Refresh this page | |
""") | |
return False | |
def load_model(hf_token): | |
"""Load the Mistral model with proper error handling""" | |
try: | |
st.info("π Loading Mistral 7B model... This may take a few minutes on first run.") | |
# Load tokenizer first | |
tokenizer = AutoTokenizer.from_pretrained( | |
"mistralai/Mistral-7B-Instruct-v0.1", | |
token=hf_token, | |
trust_remote_code=True | |
) | |
# Add padding token if missing | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load model with optimizations | |
model = AutoModelForCausalLM.from_pretrained( | |
"mistralai/Mistral-7B-Instruct-v0.1", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None, | |
token=hf_token, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
# Create pipeline | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
st.success("β Model loaded successfully!") | |
return pipe | |
except Exception as e: | |
st.error(f"β Model loading failed: {str(e)}") | |
st.info("Try refreshing the page or check your internet connection") | |
return None | |
def extract_text_from_pdf(file): | |
"""Extract text from uploaded PDF with error handling""" | |
try: | |
reader = PyPDF2.PdfReader(file) | |
text = "" | |
for page_num, page in enumerate(reader.pages): | |
page_text = page.extract_text() | |
if page_text.strip(): | |
text += f"\n--- Page {page_num + 1} ---\n{page_text}\n" | |
if not text.strip(): | |
st.warning("β οΈ No text extracted from PDF. It might be image-based.") | |
return "" | |
return text | |
except Exception as e: | |
st.error(f"β PDF processing failed: {str(e)}") | |
return "" | |
def format_prompt(context, query): | |
"""Create properly formatted Mistral prompt""" | |
if context.strip(): | |
prompt = f"<s>[INST] Use the following context to answer the question comprehensively:\n\nContext:\n{context[:3000]}...\n\nQuestion: {query}\n\nProvide a detailed, accurate answer based on the context. [/INST]" | |
else: | |
prompt = f"<s>[INST] {query} [/INST]" | |
return prompt | |
# Main Application Flow | |
def main(): | |
# Step 1: Validate token | |
hf_token = validate_hf_token() | |
if not hf_token: | |
st.stop() | |
# Step 2: Check model access | |
if not check_model_access(hf_token): | |
st.stop() | |
# Step 3: Load model | |
textgen = load_model(hf_token) | |
if not textgen: | |
st.stop() | |
# Step 4: User Interface | |
st.markdown("---") | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
query = st.text_area( | |
"π Ask your question:", | |
height=100, | |
placeholder="e.g., Explain machine learning concepts, summarize this document, etc." | |
) | |
with col2: | |
uploaded_file = st.file_uploader( | |
"π Upload PDF Context (Optional):", | |
type=["pdf"], | |
help="Upload a PDF to provide context for your question" | |
) | |
# Process uploaded file | |
context = "" | |
if uploaded_file: | |
with st.spinner("π Extracting text from PDF..."): | |
context = extract_text_from_pdf(uploaded_file) | |
if context: | |
with st.expander("π View Extracted Text", expanded=False): | |
st.text_area("PDF Content Preview:", context[:1000] + "..." if len(context) > 1000 else context, height=200) | |
st.success(f"β Extracted {len(context)} characters from PDF") | |
# Generate answer | |
if st.button("π Generate Answer", type="primary"): | |
if not query.strip(): | |
st.warning("β οΈ Please enter a question") | |
return | |
with st.spinner("π€ Generating answer..."): | |
try: | |
prompt = format_prompt(context, query) | |
# Generate response | |
result = textgen(prompt, max_new_tokens=512, temperature=0.7) | |
generated_text = result[0]["generated_text"] | |
# Extract only the generated part | |
answer = generated_text.split("[/INST]")[-1].strip() | |
# Display result | |
st.markdown("### π― Answer:") | |
st.markdown(answer) | |
# Show token usage info | |
with st.expander("π Generation Details", expanded=False): | |
st.write(f"**Prompt length:** {len(prompt)} characters") | |
st.write(f"**Response length:** {len(answer)} characters") | |
st.write(f"**Context used:** {'Yes' if context else 'No'}") | |
except Exception as e: | |
st.error(f"β Generation failed: {str(e)}") | |
st.info("Try with a shorter question or refresh the page") | |
# Run the application | |
if __name__ == "__main__": | |
main() |