Spaces:
Sleeping
Sleeping
File size: 7,908 Bytes
587fb3d 03836f6 7db3429 587fb3d 7db3429 03836f6 7db3429 587fb3d 7db3429 587fb3d 7db3429 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import PyPDF2
import torch
import os
from huggingface_hub import login
import warnings
warnings.filterwarnings("ignore")
st.set_page_config(page_title="AI Study Assistant - Mistral 7B", layout="wide")
st.title("π§ AI Study Assistant using Mistral 7B")
# Enhanced token validation and authentication
def validate_hf_token():
"""Validate and authenticate Hugging Face token"""
hf_token = None
# Try multiple sources for the token
token_sources = [
("Environment Variable", os.getenv("HF_TOKEN")),
("Streamlit Secrets", st.secrets.get("HF_TOKEN", None) if hasattr(st, 'secrets') else None),
("Manual Input", None) # Will be handled below
]
for source, token in token_sources:
if token:
st.success(f"β
Token found from: {source}")
hf_token = token
break
if not hf_token:
st.warning("π No token found in environment or secrets. Please enter manually:")
hf_token = st.text_input(
"Enter your Hugging Face Token:",
type="password",
help="Get your token from https://huggingface.co/settings/tokens"
)
if hf_token:
try:
# Test token validity
api = HfApi()
user_info = api.whoami(token=hf_token)
st.success(f"β
Authenticated as: {user_info['name']}")
# Attempt to login
login(token=hf_token, add_to_git_credential=False)
return hf_token
except Exception as e:
st.error(f"β Token validation failed: {str(e)}")
st.info("Please check your token and ensure you have access to Mistral 7B model")
return None
return None
def check_model_access(token):
"""Check if user has access to the Mistral model"""
try:
api = HfApi()
model_info = api.model_info("mistralai/Mistral-7B-Instruct-v0.1", token=token)
st.success("β
Model access confirmed")
return True
except Exception as e:
st.error("β Cannot access Mistral 7B model")
st.info("""
**To fix this:**
1. Visit: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
2. Click "Request Access"
3. Wait for approval (usually instant for most users)
4. Refresh this page
""")
return False
@st.cache_resource
def load_model(hf_token):
"""Load the Mistral model with proper error handling"""
try:
st.info("π Loading Mistral 7B model... This may take a few minutes on first run.")
# Load tokenizer first
tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
token=hf_token,
trust_remote_code=True
)
# Add padding token if missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model with optimizations
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
token=hf_token,
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Create pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
st.success("β
Model loaded successfully!")
return pipe
except Exception as e:
st.error(f"β Model loading failed: {str(e)}")
st.info("Try refreshing the page or check your internet connection")
return None
def extract_text_from_pdf(file):
"""Extract text from uploaded PDF with error handling"""
try:
reader = PyPDF2.PdfReader(file)
text = ""
for page_num, page in enumerate(reader.pages):
page_text = page.extract_text()
if page_text.strip():
text += f"\n--- Page {page_num + 1} ---\n{page_text}\n"
if not text.strip():
st.warning("β οΈ No text extracted from PDF. It might be image-based.")
return ""
return text
except Exception as e:
st.error(f"β PDF processing failed: {str(e)}")
return ""
def format_prompt(context, query):
"""Create properly formatted Mistral prompt"""
if context.strip():
prompt = f"<s>[INST] Use the following context to answer the question comprehensively:\n\nContext:\n{context[:3000]}...\n\nQuestion: {query}\n\nProvide a detailed, accurate answer based on the context. [/INST]"
else:
prompt = f"<s>[INST] {query} [/INST]"
return prompt
# Main Application Flow
def main():
# Step 1: Validate token
hf_token = validate_hf_token()
if not hf_token:
st.stop()
# Step 2: Check model access
if not check_model_access(hf_token):
st.stop()
# Step 3: Load model
textgen = load_model(hf_token)
if not textgen:
st.stop()
# Step 4: User Interface
st.markdown("---")
col1, col2 = st.columns([2, 1])
with col1:
query = st.text_area(
"π Ask your question:",
height=100,
placeholder="e.g., Explain machine learning concepts, summarize this document, etc."
)
with col2:
uploaded_file = st.file_uploader(
"π Upload PDF Context (Optional):",
type=["pdf"],
help="Upload a PDF to provide context for your question"
)
# Process uploaded file
context = ""
if uploaded_file:
with st.spinner("π Extracting text from PDF..."):
context = extract_text_from_pdf(uploaded_file)
if context:
with st.expander("π View Extracted Text", expanded=False):
st.text_area("PDF Content Preview:", context[:1000] + "..." if len(context) > 1000 else context, height=200)
st.success(f"β
Extracted {len(context)} characters from PDF")
# Generate answer
if st.button("π Generate Answer", type="primary"):
if not query.strip():
st.warning("β οΈ Please enter a question")
return
with st.spinner("π€ Generating answer..."):
try:
prompt = format_prompt(context, query)
# Generate response
result = textgen(prompt, max_new_tokens=512, temperature=0.7)
generated_text = result[0]["generated_text"]
# Extract only the generated part
answer = generated_text.split("[/INST]")[-1].strip()
# Display result
st.markdown("### π― Answer:")
st.markdown(answer)
# Show token usage info
with st.expander("π Generation Details", expanded=False):
st.write(f"**Prompt length:** {len(prompt)} characters")
st.write(f"**Response length:** {len(answer)} characters")
st.write(f"**Context used:** {'Yes' if context else 'No'}")
except Exception as e:
st.error(f"β Generation failed: {str(e)}")
st.info("Try with a shorter question or refresh the page")
# Run the application
if __name__ == "__main__":
main() |