File size: 7,908 Bytes
587fb3d
 
 
 
03836f6
7db3429
 
 
587fb3d
7db3429
 
03836f6
7db3429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587fb3d
 
7db3429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587fb3d
 
7db3429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import PyPDF2
import torch
import os
from huggingface_hub import login
import warnings
warnings.filterwarnings("ignore")

st.set_page_config(page_title="AI Study Assistant - Mistral 7B", layout="wide")
st.title("🧠 AI Study Assistant using Mistral 7B")

# Enhanced token validation and authentication
def validate_hf_token():
    """Validate and authenticate Hugging Face token"""
    hf_token = None
    
    # Try multiple sources for the token
    token_sources = [
        ("Environment Variable", os.getenv("HF_TOKEN")),
        ("Streamlit Secrets", st.secrets.get("HF_TOKEN", None) if hasattr(st, 'secrets') else None),
        ("Manual Input", None)  # Will be handled below
    ]
    
    for source, token in token_sources:
        if token:
            st.success(f"βœ… Token found from: {source}")
            hf_token = token
            break
    
    if not hf_token:
        st.warning("πŸ”‘ No token found in environment or secrets. Please enter manually:")
        hf_token = st.text_input(
            "Enter your Hugging Face Token:", 
            type="password",
            help="Get your token from https://huggingface.co/settings/tokens"
        )
    
    if hf_token:
        try:
            # Test token validity
            api = HfApi()
            user_info = api.whoami(token=hf_token)
            st.success(f"βœ… Authenticated as: {user_info['name']}")
            
            # Attempt to login
            login(token=hf_token, add_to_git_credential=False)
            return hf_token
            
        except Exception as e:
            st.error(f"❌ Token validation failed: {str(e)}")
            st.info("Please check your token and ensure you have access to Mistral 7B model")
            return None
    
    return None

def check_model_access(token):
    """Check if user has access to the Mistral model"""
    try:
        api = HfApi()
        model_info = api.model_info("mistralai/Mistral-7B-Instruct-v0.1", token=token)
        st.success("βœ… Model access confirmed")
        return True
    except Exception as e:
        st.error("❌ Cannot access Mistral 7B model")
        st.info("""
        **To fix this:**
        1. Visit: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
        2. Click "Request Access" 
        3. Wait for approval (usually instant for most users)
        4. Refresh this page
        """)
        return False

@st.cache_resource
def load_model(hf_token):
    """Load the Mistral model with proper error handling"""
    try:
        st.info("πŸ”„ Loading Mistral 7B model... This may take a few minutes on first run.")
        
        # Load tokenizer first
        tokenizer = AutoTokenizer.from_pretrained(
            "mistralai/Mistral-7B-Instruct-v0.1",
            token=hf_token,
            trust_remote_code=True
        )
        
        # Add padding token if missing
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Load model with optimizations
        model = AutoModelForCausalLM.from_pretrained(
            "mistralai/Mistral-7B-Instruct-v0.1",
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None,
            token=hf_token,
            trust_remote_code=True,
            low_cpu_mem_usage=True
        )
        
        # Create pipeline
        pipe = pipeline(
            "text-generation", 
            model=model, 
            tokenizer=tokenizer, 
            max_new_tokens=512,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
        
        st.success("βœ… Model loaded successfully!")
        return pipe
        
    except Exception as e:
        st.error(f"❌ Model loading failed: {str(e)}")
        st.info("Try refreshing the page or check your internet connection")
        return None

def extract_text_from_pdf(file):
    """Extract text from uploaded PDF with error handling"""
    try:
        reader = PyPDF2.PdfReader(file)
        text = ""
        for page_num, page in enumerate(reader.pages):
            page_text = page.extract_text()
            if page_text.strip():
                text += f"\n--- Page {page_num + 1} ---\n{page_text}\n"
        
        if not text.strip():
            st.warning("⚠️ No text extracted from PDF. It might be image-based.")
            return ""
            
        return text
    except Exception as e:
        st.error(f"❌ PDF processing failed: {str(e)}")
        return ""

def format_prompt(context, query):
    """Create properly formatted Mistral prompt"""
    if context.strip():
        prompt = f"<s>[INST] Use the following context to answer the question comprehensively:\n\nContext:\n{context[:3000]}...\n\nQuestion: {query}\n\nProvide a detailed, accurate answer based on the context. [/INST]"
    else:
        prompt = f"<s>[INST] {query} [/INST]"
    
    return prompt

# Main Application Flow
def main():
    # Step 1: Validate token
    hf_token = validate_hf_token()
    
    if not hf_token:
        st.stop()
    
    # Step 2: Check model access
    if not check_model_access(hf_token):
        st.stop()
    
    # Step 3: Load model
    textgen = load_model(hf_token)
    
    if not textgen:
        st.stop()
    
    # Step 4: User Interface
    st.markdown("---")
    
    col1, col2 = st.columns([2, 1])
    
    with col1:
        query = st.text_area(
            "πŸ’­ Ask your question:", 
            height=100,
            placeholder="e.g., Explain machine learning concepts, summarize this document, etc."
        )
    
    with col2:
        uploaded_file = st.file_uploader(
            "πŸ“Ž Upload PDF Context (Optional):", 
            type=["pdf"],
            help="Upload a PDF to provide context for your question"
        )
    
    # Process uploaded file
    context = ""
    if uploaded_file:
        with st.spinner("πŸ“– Extracting text from PDF..."):
            context = extract_text_from_pdf(uploaded_file)
            
        if context:
            with st.expander("πŸ“„ View Extracted Text", expanded=False):
                st.text_area("PDF Content Preview:", context[:1000] + "..." if len(context) > 1000 else context, height=200)
            st.success(f"βœ… Extracted {len(context)} characters from PDF")
    
    # Generate answer
    if st.button("πŸš€ Generate Answer", type="primary"):
        if not query.strip():
            st.warning("⚠️ Please enter a question")
            return
        
        with st.spinner("πŸ€” Generating answer..."):
            try:
                prompt = format_prompt(context, query)
                
                # Generate response
                result = textgen(prompt, max_new_tokens=512, temperature=0.7)
                generated_text = result[0]["generated_text"]
                
                # Extract only the generated part
                answer = generated_text.split("[/INST]")[-1].strip()
                
                # Display result
                st.markdown("### 🎯 Answer:")
                st.markdown(answer)
                
                # Show token usage info
                with st.expander("πŸ“Š Generation Details", expanded=False):
                    st.write(f"**Prompt length:** {len(prompt)} characters")
                    st.write(f"**Response length:** {len(answer)} characters")
                    st.write(f"**Context used:** {'Yes' if context else 'No'}")
                
            except Exception as e:
                st.error(f"❌ Generation failed: {str(e)}")
                st.info("Try with a shorter question or refresh the page")

# Run the application
if __name__ == "__main__":
    main()