g0th commited on
Commit
7db3429
Β·
verified Β·
1 Parent(s): 03836f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -36
app.py CHANGED
@@ -3,45 +3,228 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import PyPDF2
4
  import torch
5
  import os
 
 
 
6
 
7
- st.set_page_config(page_title="Perplexity-style Q&A (Mistral Auth)", layout="wide")
8
- st.title("🧠 AI Study Assistant using Mistral 7B (Authenticated)")
9
 
10
- # βœ… Load Hugging Face token from secrets
11
- hf_token = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  @st.cache_resource
14
- def load_model():
15
- tokenizer = AutoTokenizer.from_pretrained(
16
- "mistralai/Mistral-7B-Instruct-v0.1",
17
- token=hf_token
18
- )
19
- model = AutoModelForCausalLM.from_pretrained(
20
- "mistralai/Mistral-7B-Instruct-v0.1",
21
- torch_dtype=torch.float16,
22
- device_map="auto",
23
- token=hf_token
24
- )
25
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
26
- return pipe
27
-
28
- textgen = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def extract_text_from_pdf(file):
31
- reader = PyPDF2.PdfReader(file)
32
- return "\n".join([p.extract_text() for p in reader.pages if p.extract_text()])
33
-
34
- query = st.text_input("Ask a question or enter a topic:")
35
- uploaded_file = st.file_uploader("Or upload a PDF to use as context:", type=["pdf"])
36
-
37
- context = ""
38
- if uploaded_file:
39
- context = extract_text_from_pdf(uploaded_file)
40
- st.text_area("πŸ“„ Extracted PDF Text", context, height=200)
41
-
42
- if st.button("Generate Answer"):
43
- with st.spinner("Generating answer..."):
44
- prompt = f"[INST] Use the following context to answer the question:\n\n{context}\n\nQuestion: {query} [/INST]"
45
- result = textgen(prompt)[0]["generated_text"]
46
- st.success("Answer:")
47
- st.write(result.replace(prompt, "").strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import PyPDF2
4
  import torch
5
  import os
6
+ from huggingface_hub import login
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
 
10
+ st.set_page_config(page_title="AI Study Assistant - Mistral 7B", layout="wide")
11
+ st.title("🧠 AI Study Assistant using Mistral 7B")
12
 
13
+ # Enhanced token validation and authentication
14
+ def validate_hf_token():
15
+ """Validate and authenticate Hugging Face token"""
16
+ hf_token = None
17
+
18
+ # Try multiple sources for the token
19
+ token_sources = [
20
+ ("Environment Variable", os.getenv("HF_TOKEN")),
21
+ ("Streamlit Secrets", st.secrets.get("HF_TOKEN", None) if hasattr(st, 'secrets') else None),
22
+ ("Manual Input", None) # Will be handled below
23
+ ]
24
+
25
+ for source, token in token_sources:
26
+ if token:
27
+ st.success(f"βœ… Token found from: {source}")
28
+ hf_token = token
29
+ break
30
+
31
+ if not hf_token:
32
+ st.warning("πŸ”‘ No token found in environment or secrets. Please enter manually:")
33
+ hf_token = st.text_input(
34
+ "Enter your Hugging Face Token:",
35
+ type="password",
36
+ help="Get your token from https://huggingface.co/settings/tokens"
37
+ )
38
+
39
+ if hf_token:
40
+ try:
41
+ # Test token validity
42
+ api = HfApi()
43
+ user_info = api.whoami(token=hf_token)
44
+ st.success(f"βœ… Authenticated as: {user_info['name']}")
45
+
46
+ # Attempt to login
47
+ login(token=hf_token, add_to_git_credential=False)
48
+ return hf_token
49
+
50
+ except Exception as e:
51
+ st.error(f"❌ Token validation failed: {str(e)}")
52
+ st.info("Please check your token and ensure you have access to Mistral 7B model")
53
+ return None
54
+
55
+ return None
56
+
57
+ def check_model_access(token):
58
+ """Check if user has access to the Mistral model"""
59
+ try:
60
+ api = HfApi()
61
+ model_info = api.model_info("mistralai/Mistral-7B-Instruct-v0.1", token=token)
62
+ st.success("βœ… Model access confirmed")
63
+ return True
64
+ except Exception as e:
65
+ st.error("❌ Cannot access Mistral 7B model")
66
+ st.info("""
67
+ **To fix this:**
68
+ 1. Visit: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
69
+ 2. Click "Request Access"
70
+ 3. Wait for approval (usually instant for most users)
71
+ 4. Refresh this page
72
+ """)
73
+ return False
74
 
75
  @st.cache_resource
76
+ def load_model(hf_token):
77
+ """Load the Mistral model with proper error handling"""
78
+ try:
79
+ st.info("πŸ”„ Loading Mistral 7B model... This may take a few minutes on first run.")
80
+
81
+ # Load tokenizer first
82
+ tokenizer = AutoTokenizer.from_pretrained(
83
+ "mistralai/Mistral-7B-Instruct-v0.1",
84
+ token=hf_token,
85
+ trust_remote_code=True
86
+ )
87
+
88
+ # Add padding token if missing
89
+ if tokenizer.pad_token is None:
90
+ tokenizer.pad_token = tokenizer.eos_token
91
+
92
+ # Load model with optimizations
93
+ model = AutoModelForCausalLM.from_pretrained(
94
+ "mistralai/Mistral-7B-Instruct-v0.1",
95
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
96
+ device_map="auto" if torch.cuda.is_available() else None,
97
+ token=hf_token,
98
+ trust_remote_code=True,
99
+ low_cpu_mem_usage=True
100
+ )
101
+
102
+ # Create pipeline
103
+ pipe = pipeline(
104
+ "text-generation",
105
+ model=model,
106
+ tokenizer=tokenizer,
107
+ max_new_tokens=512,
108
+ temperature=0.7,
109
+ do_sample=True,
110
+ pad_token_id=tokenizer.eos_token_id
111
+ )
112
+
113
+ st.success("βœ… Model loaded successfully!")
114
+ return pipe
115
+
116
+ except Exception as e:
117
+ st.error(f"❌ Model loading failed: {str(e)}")
118
+ st.info("Try refreshing the page or check your internet connection")
119
+ return None
120
 
121
  def extract_text_from_pdf(file):
122
+ """Extract text from uploaded PDF with error handling"""
123
+ try:
124
+ reader = PyPDF2.PdfReader(file)
125
+ text = ""
126
+ for page_num, page in enumerate(reader.pages):
127
+ page_text = page.extract_text()
128
+ if page_text.strip():
129
+ text += f"\n--- Page {page_num + 1} ---\n{page_text}\n"
130
+
131
+ if not text.strip():
132
+ st.warning("⚠️ No text extracted from PDF. It might be image-based.")
133
+ return ""
134
+
135
+ return text
136
+ except Exception as e:
137
+ st.error(f"❌ PDF processing failed: {str(e)}")
138
+ return ""
139
+
140
+ def format_prompt(context, query):
141
+ """Create properly formatted Mistral prompt"""
142
+ if context.strip():
143
+ prompt = f"<s>[INST] Use the following context to answer the question comprehensively:\n\nContext:\n{context[:3000]}...\n\nQuestion: {query}\n\nProvide a detailed, accurate answer based on the context. [/INST]"
144
+ else:
145
+ prompt = f"<s>[INST] {query} [/INST]"
146
+
147
+ return prompt
148
+
149
+ # Main Application Flow
150
+ def main():
151
+ # Step 1: Validate token
152
+ hf_token = validate_hf_token()
153
+
154
+ if not hf_token:
155
+ st.stop()
156
+
157
+ # Step 2: Check model access
158
+ if not check_model_access(hf_token):
159
+ st.stop()
160
+
161
+ # Step 3: Load model
162
+ textgen = load_model(hf_token)
163
+
164
+ if not textgen:
165
+ st.stop()
166
+
167
+ # Step 4: User Interface
168
+ st.markdown("---")
169
+
170
+ col1, col2 = st.columns([2, 1])
171
+
172
+ with col1:
173
+ query = st.text_area(
174
+ "πŸ’­ Ask your question:",
175
+ height=100,
176
+ placeholder="e.g., Explain machine learning concepts, summarize this document, etc."
177
+ )
178
+
179
+ with col2:
180
+ uploaded_file = st.file_uploader(
181
+ "πŸ“Ž Upload PDF Context (Optional):",
182
+ type=["pdf"],
183
+ help="Upload a PDF to provide context for your question"
184
+ )
185
+
186
+ # Process uploaded file
187
+ context = ""
188
+ if uploaded_file:
189
+ with st.spinner("πŸ“– Extracting text from PDF..."):
190
+ context = extract_text_from_pdf(uploaded_file)
191
+
192
+ if context:
193
+ with st.expander("πŸ“„ View Extracted Text", expanded=False):
194
+ st.text_area("PDF Content Preview:", context[:1000] + "..." if len(context) > 1000 else context, height=200)
195
+ st.success(f"βœ… Extracted {len(context)} characters from PDF")
196
+
197
+ # Generate answer
198
+ if st.button("πŸš€ Generate Answer", type="primary"):
199
+ if not query.strip():
200
+ st.warning("⚠️ Please enter a question")
201
+ return
202
+
203
+ with st.spinner("πŸ€” Generating answer..."):
204
+ try:
205
+ prompt = format_prompt(context, query)
206
+
207
+ # Generate response
208
+ result = textgen(prompt, max_new_tokens=512, temperature=0.7)
209
+ generated_text = result[0]["generated_text"]
210
+
211
+ # Extract only the generated part
212
+ answer = generated_text.split("[/INST]")[-1].strip()
213
+
214
+ # Display result
215
+ st.markdown("### 🎯 Answer:")
216
+ st.markdown(answer)
217
+
218
+ # Show token usage info
219
+ with st.expander("πŸ“Š Generation Details", expanded=False):
220
+ st.write(f"**Prompt length:** {len(prompt)} characters")
221
+ st.write(f"**Response length:** {len(answer)} characters")
222
+ st.write(f"**Context used:** {'Yes' if context else 'No'}")
223
+
224
+ except Exception as e:
225
+ st.error(f"❌ Generation failed: {str(e)}")
226
+ st.info("Try with a shorter question or refresh the page")
227
+
228
+ # Run the application
229
+ if __name__ == "__main__":
230
+ main()