MrSimple07 commited on
Commit
25c732c
·
verified ·
1 Parent(s): 9761a91

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +541 -0
app.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import gradio as gr
4
+ import faiss
5
+ import numpy as np
6
+ import pandas as pd
7
+ import requests
8
+ from geopy.geocoders import Nominatim
9
+ from sentence_transformers import SentenceTransformer
10
+ from typing import Tuple, Optional
11
+ import os
12
+ from huggingface_hub import hf_hub_download
13
+ import geonamescache
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+
17
+ from huggingface_hub import login
18
+
19
+ token = os.getenv('HF_TOKEN')
20
+
21
+ df_path = hf_hub_download(
22
+ repo_id='MrSimple07/raggg',
23
+ filename='17_rag_data.csv',
24
+ repo_type='dataset',
25
+ token = token
26
+ )
27
+ embeddings_path = hf_hub_download(
28
+ repo_id='MrSimple07/raggg',
29
+ filename='rag_embeddings_3.npy',
30
+ repo_type='dataset',
31
+ token = token
32
+ )
33
+
34
+ df = pd.read_csv(df_path)
35
+ embeddings = np.load(embeddings_path, mmap_mode='r')
36
+
37
+
38
+ MISTRAL_API_KEY = "TeX7Cs30zMCAi0A90w4pGhPbOGrYzQkj"
39
+ MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
40
+
41
+ category_synonyms = {
42
+ "museum": [
43
+ "museums", "art galleries", "natural museums", "modern art museums"
44
+ ],
45
+ "cafe": [
46
+ "coffee shops", ""
47
+ ],
48
+ "restaurant": [
49
+ "local dining spots", "fine dining", "casual eateries",
50
+ "family-friendly restaurants", "street food places"
51
+ ],
52
+ "parks": [
53
+ "national parks", "urban green spaces", "botanical gardens",
54
+ "recreational parks", "wildlife reserves"
55
+ ],
56
+ "park": [
57
+ "national parks", "urban green spaces", "botanical gardens",
58
+ "recreational parks", "wildlife reserves"
59
+ ],
60
+ "spa": ['bath', 'swimming', 'pool']
61
+ }
62
+
63
+ def extract_location_geonames(query: str) -> dict:
64
+ gc = geonamescache.GeonamesCache()
65
+ countries = {c['name'].lower(): c['name'] for c in gc.get_countries().values()}
66
+ cities = {c['name'].lower(): c['name'] for c in gc.get_cities().values()}
67
+
68
+ words = query.split()
69
+
70
+ for i in range(len(words)):
71
+ for j in range(i+1, len(words)+1):
72
+ potential_location = ' '.join(words[i:j]).lower()
73
+
74
+ # Check if it's a city first
75
+ if potential_location in cities:
76
+ return {
77
+ 'city': cities[potential_location],
78
+ }
79
+
80
+ # Then check if it's a country
81
+ if potential_location in countries:
82
+ return {
83
+ 'city': ' '.join(words[:i] + words[j:]) if i+j < len(words) else None,
84
+ 'country': countries[potential_location]
85
+ }
86
+
87
+ return {'city': query}
88
+
89
+
90
+
91
+ def expand_category_once(query, target_category):
92
+ """
93
+ Expand the target category term in absthe query only once with synonyms and related phrases.
94
+ """
95
+ target_lower = target_category.lower()
96
+ if target_lower in query.lower():
97
+ synonyms = category_synonyms.get(target_lower, [])
98
+ if synonyms:
99
+ expanded_term = f"{target_category} ({', '.join(synonyms)})"
100
+ query = query.replace(target_category, expanded_term, 1) # Replace only the first occurrence
101
+ return query
102
+
103
+ CATEGORY_FILTER_WORDS = [
104
+ 'museum', 'art', 'gallery', 'tourism', 'historical',
105
+ 'bar', 'cafe', 'restaurant', 'park', 'landmark',
106
+ 'beach', 'mountain', 'theater', 'church', 'monument',
107
+ 'garden', 'library', 'university', 'shopping', 'market',
108
+ 'hotel', 'resort', 'cultural', 'natural', 'science',
109
+ 'educational', 'entertainment', 'sports', 'memorial', 'historic',
110
+ 'spa', 'landmarks', 'sleep', 'coffee shops', 'shops', 'buildings',
111
+ 'gothic', 'castle', 'fortress', 'aquarium', 'zoo', 'wildlife',
112
+ 'adventure', 'hiking', 'lighthouse', 'vineyard', 'brewery',
113
+ 'winery', 'pub', 'nightclub', 'observatory', 'theme park',
114
+ 'botanical', 'sanctuary', 'heritage', 'island', 'waterfall',
115
+ 'canyon', 'valley', 'desert', 'artisans', 'crafts', 'music hall',
116
+ 'dance clubs', 'opera house', 'skyscraper', 'bridge', 'fountain',
117
+ 'temple', 'shrine', 'archaeological', 'planetarium', 'marketplace',
118
+ 'street art', 'local cuisine', 'eco-tourism', 'carnival', 'festival', 'film'
119
+ ]
120
+
121
+
122
+ def extract_category_from_query(query: str) -> Optional[str]:
123
+ query_lower = query.lower()
124
+ for word in CATEGORY_FILTER_WORDS:
125
+ if word in query_lower:
126
+ return word
127
+
128
+ return None
129
+
130
+ def get_location_details(min_lat, max_lat, min_lon, max_lon):
131
+ """Get detailed location information for a bounding box with improved city detection and error handling"""
132
+ geolocator = Nominatim(user_agent="location_finder", timeout=10)
133
+
134
+ try:
135
+ # Strategy 1: Try multiple points within the bounding box
136
+ sample_points = [
137
+ ((float(min_lat) + float(max_lat)) / 2,
138
+ (float(min_lon) + float(max_lon)) / 2),
139
+ (float(min_lat), float(min_lon)),
140
+ (float(max_lat), float(min_lon)),
141
+ (float(min_lat), float(max_lon)),
142
+ (float(max_lat), float(max_lon))
143
+ ]
144
+
145
+ # Collect unique cities from all points
146
+ cities = set()
147
+ full_addresses = []
148
+
149
+ for lat, lon in sample_points:
150
+ try:
151
+ # Add multiple retry attempts with exponential backoff
152
+ for attempt in range(3):
153
+ try:
154
+ location = geolocator.reverse(f"{lat}, {lon}", language='en')
155
+ break
156
+ except Exception as retry_error:
157
+ if attempt == 2: # Last attempt
158
+ print(f"Failed to retrieve location for {lat}, {lon} after 3 attempts")
159
+ continue
160
+ time.sleep(2 ** attempt) # Exponential backoff
161
+
162
+ if location:
163
+ address = location.raw.get('address', {})
164
+
165
+ # Extract city with multiple fallback options
166
+ city = (
167
+ address.get('city') or
168
+ address.get('town') or
169
+ address.get('municipality') or
170
+ address.get('county') or
171
+ address.get('state')
172
+ )
173
+
174
+ if city:
175
+ cities.add(city)
176
+ full_addresses.append(location.address)
177
+
178
+ except Exception as point_error:
179
+ print(f"Error processing point {lat}, {lon}: {point_error}")
180
+ continue
181
+
182
+ # If no cities found, try alternative geocoding service or return default
183
+ if not cities:
184
+ print("No cities detected. Returning default location information.")
185
+ return {
186
+ 'location_parts': [],
187
+ 'full_address_parts': '',
188
+ 'full_address': '',
189
+ 'city': [],
190
+ 'state': '',
191
+ 'country': '',
192
+ 'cities_or_query': ''
193
+ }
194
+
195
+ # Prioritize cities, keeping all detected cities
196
+ city_list = list(cities)
197
+
198
+ # Use the last processed address for state and country
199
+ state = address.get('state', '')
200
+ country = address.get('country', '')
201
+
202
+ # Create a formatted list of cities for query
203
+ cities_or_query = " or ".join(city_list)
204
+
205
+ location_parts = [part for part in [cities_or_query, state, country] if part]
206
+ full_address_parts = ', '.join(location_parts)
207
+
208
+ print(f"Detected Cities: {cities}")
209
+ print(f"Cities for Query: {cities_or_query}")
210
+ print(f"Full Address Parts: {full_address_parts}")
211
+
212
+ return {
213
+ 'location_parts': city_list,
214
+ 'full_address_parts': full_address_parts,
215
+ 'full_address': full_addresses[0] if full_addresses else '',
216
+ 'city': city_list,
217
+ 'state': state,
218
+ 'country': country,
219
+ 'cities_or_query': cities_or_query
220
+ }
221
+
222
+ except Exception as e:
223
+ print(f"Comprehensive error in location details retrieval: {e}")
224
+ import traceback
225
+ traceback.print_exc()
226
+
227
+ return None
228
+
229
+ def rag_query(
230
+ query: str,
231
+ df: pd.DataFrame,
232
+ model: SentenceTransformer,
233
+ precomputed_embeddings: np.ndarray,
234
+ index: faiss.IndexFlatL2,
235
+ min_lat: str = None,
236
+ max_lat: str = None,
237
+ min_lon: str = None,
238
+ max_lon: str = None,
239
+ category: str = None,
240
+ city: str = None,
241
+ ) -> Tuple[str, str]:
242
+ """Enhanced RAG function with prioritized location extraction"""
243
+ print("\n=== Starting RAG Query ===")
244
+ print(f"Initial DataFrame size: {len(df)}")
245
+
246
+ # Prioritized location extraction
247
+ location_info = None
248
+ location_names = []
249
+
250
+ # Priority 1: Explicitly provided city name
251
+ if city:
252
+ location_names = [city]
253
+ print(f"Using explicitly provided city: {city}")
254
+
255
+ # Priority 2: Coordinates (Nominatim)
256
+ elif all(coord is not None and coord != "" for coord in [min_lat, max_lat, min_lon, max_lon]):
257
+ try:
258
+ location_info = get_location_details(
259
+ float(min_lat),
260
+ float(max_lat),
261
+ float(min_lon),
262
+ float(max_lon)
263
+ )
264
+
265
+ # Extract location names from Nominatim result
266
+ if location_info:
267
+ if location_info.get('city'):
268
+ location_names.extend(location_info['city'] if isinstance(location_info['city'], list) else [location_info['city']])
269
+ if location_info.get('state'):
270
+ location_names.append(location_info['state'])
271
+ if location_info.get('country'):
272
+ location_names.append(location_info['country'])
273
+
274
+ print(f"Using coordinates-based location: {location_names}")
275
+ except Exception as e:
276
+ print(f"Location details error: {e}")
277
+
278
+ # Priority 3: Extract from query using GeoNames only if no previous methods worked
279
+ if not location_names:
280
+ geonames_info = extract_location_geonames(query)
281
+ if geonames_info.get('city'):
282
+ location_names = [geonames_info['city']]
283
+ print(f"Using GeoNames-extracted city: {location_names}")
284
+
285
+ # Start with a copy of the original DataFrame
286
+ filtered_df = df.copy()
287
+
288
+ # Filter DataFrame by location names
289
+ if location_names:
290
+ # Create a case-insensitive filter
291
+ location_filter = (
292
+ filtered_df['city'].str.lower().isin([name.lower() for name in location_names]) |
293
+ filtered_df['city'].apply(lambda x: any(name.lower() in str(x).lower() for name in location_names)) |
294
+ filtered_df['combined_field'].apply(lambda x: any(name.lower() in str(x).lower() for name in location_names))
295
+ )
296
+
297
+ filtered_df = filtered_df[location_filter]
298
+
299
+ print(f"Location Names Used for Filtering: {location_names}")
300
+ print(f"Results after location filtering: {len(filtered_df)}")
301
+
302
+
303
+
304
+ enhanced_query_parts = []
305
+ if query:
306
+ enhanced_query_parts.append(query)
307
+ if category:
308
+ enhanced_query_parts.append(f"{category} category")
309
+ if city:
310
+ enhanced_query_parts.append(f" in {city}")
311
+
312
+ if min_lat is not None and max_lat is not None and min_lon is not None and max_lon is not None:
313
+ enhanced_query_parts.append(f"within latitudes {min_lat} to {max_lat} and longitudes {min_lon} to {max_lon}")
314
+
315
+ # Add location context
316
+ if location_info:
317
+ location_context = " ".join(filter(None, [
318
+ ", ".join(location_info.get('city', [])),
319
+ location_info.get('state', ''),
320
+ # location_info.get('country', '')
321
+ ]))
322
+ if location_context:
323
+ enhanced_query_parts.append(f"in {location_context}")
324
+
325
+
326
+
327
+ enhanced_query = " ".join(enhanced_query_parts)
328
+
329
+ if enhanced_query:
330
+ enhanced_query = expand_category_once(enhanced_query, category)
331
+ print(f"Filtered by city '{city}': {len(filtered_df)} results")
332
+
333
+ print(f"Enhanced Query: {enhanced_query}")
334
+
335
+ detected_category = extract_category_from_query(enhanced_query)
336
+ if detected_category:
337
+ category_filter = (
338
+ filtered_df['category'].str.contains(detected_category, case=False, na=False) |
339
+ filtered_df['combined_field'].str.contains(detected_category, case=False, na=False)
340
+ )
341
+ filtered_df = filtered_df[category_filter]
342
+
343
+ print(f"Filtered by query words '{detected_category}': {len(filtered_df)} results")
344
+
345
+
346
+ try:
347
+ query_vector = model.encode([enhanced_query])[0]
348
+
349
+ # Compute embeddings for the filtered DataFrame
350
+ filtered_embeddings = precomputed_embeddings[filtered_df.index]
351
+
352
+ # Create FAISS index with filtered embeddings
353
+ filtered_index = faiss.IndexFlatL2(filtered_embeddings.shape[1])
354
+ filtered_index.add(filtered_embeddings.astype(np.float32))
355
+
356
+ # Perform semantic search on filtered results
357
+ k = min(20, len(filtered_df))
358
+ distances, local_indices = filtered_index.search(
359
+ np.array([query_vector]).astype(np.float32),
360
+ k
361
+ )
362
+
363
+ # Get the top results
364
+ results_df = filtered_df.iloc[local_indices[0]]
365
+
366
+ # Format results
367
+ formatted_results = []
368
+ for i, (_, row) in enumerate(results_df.iterrows(), 1):
369
+ formatted_results.append(
370
+ f"\n=== Result {i} ===\n"
371
+ f"Name: {row['name']}\n"
372
+ f"Category: {row['category']}\n"
373
+ f"City: {row['city']}\n"
374
+ f"Address: {row['address']}\n"
375
+ f"Description: {row['description']}\n"
376
+ f"Latitude: {row['latitude']}\n"
377
+ f"Longitude: {row['longitude']}\n"
378
+ )
379
+
380
+ search_results = "\n".join(formatted_results) if formatted_results else "No matching locations found."
381
+
382
+ # Optional: Use Mistral for further refinement
383
+ try:
384
+ answer = query_mistral(enhanced_query, search_results)
385
+ except Exception as e:
386
+ print(f"Error in Mistral query: {e}")
387
+ answer = "Unable to generate additional insights."
388
+
389
+ return search_results, answer
390
+
391
+ except Exception as e:
392
+ print(f"Error in semantic search: {e}")
393
+ return f"Error performing search: {str(e)}", ""
394
+
395
+
396
+ def query_mistral(prompt: str, context: str, max_retries: int = 3) -> str:
397
+ """
398
+ Robust Mistral verification with exponential backoff
399
+ """
400
+ import time
401
+
402
+ # Early return if no context
403
+ if not context or context.strip() == "No matching locations found.":
404
+ return context
405
+
406
+ verification_prompt = f"""Precise Location Curation Task:
407
+ REQUIREMENTS:
408
+ - Source Query: {prompt}
409
+ - Current Context: {context}
410
+ DETAILED INSTRUCTIONS:
411
+ 1. Write the min, max latitude and min, max longitude in the beginning taking from the query
412
+ 2. Curate a comprehensive list of 15 locations inside of these coordinates and strictly relevant to place.
413
+ 3. Take STRICTLY ONLY relevant places to Source Query.
414
+ 4. Add a short description about the place (2-3 sentences)
415
+ 5. Add coordinates (lat and long) if there are in the Current Context.
416
+ 6. If no coordinates in the Current Context, then give only name and description
417
+ 7. Add address for the place
418
+ 8. Remove any duplicate entries in the list
419
+ 9. If places > 10, quick generation a new places relevant to Source Query and inside of the coordinates
420
+ CRITICAL: Do NOT use placeholder. Quick and fast response required
421
+ """
422
+
423
+ for attempt in range(max_retries):
424
+ try:
425
+ # Robust API configuration
426
+ response = requests.post(
427
+ MISTRAL_API_URL,
428
+ headers={
429
+ "Authorization": f"Bearer {MISTRAL_API_KEY}",
430
+ "Content-Type": "application/json"
431
+ },
432
+ json={
433
+ "model": "mistral-large-latest",
434
+ "messages": [
435
+ {"role": "system", "content": "You are a precise location curator specializing in comprehensive travel information."},
436
+ {"role": "user", "content": verification_prompt}
437
+ ],
438
+ "temperature": 0.1,
439
+ "max_tokens": 5000
440
+ },
441
+ timeout=100 # Increased timeout
442
+ )
443
+
444
+ # Enhanced error handling
445
+ response.raise_for_status()
446
+
447
+ # Extract verified response
448
+ verified_response = response.json()['choices'][0]['message']['content']
449
+
450
+ # Validate response length and complexity
451
+ if len(verified_response.strip()) < 100:
452
+ if attempt == max_retries - 1:
453
+ return context
454
+ time.sleep(2 ** attempt) # Exponential backoff
455
+ continue
456
+
457
+ return verified_response
458
+
459
+ except requests.Timeout:
460
+ logging.warning(f"Mistral API timeout (Attempt {attempt + 1}/{max_retries})")
461
+ if attempt < max_retries - 1:
462
+ time.sleep(2 ** attempt) # Exponential backoff
463
+ else:
464
+ logging.error("Mistral API consistently timing out")
465
+ return context
466
+
467
+ except requests.RequestException as e:
468
+ logging.error(f"Mistral API request error: {e}")
469
+ if attempt < max_retries - 1:
470
+ time.sleep(2 ** attempt)
471
+ else:
472
+ return context
473
+
474
+ except Exception as e:
475
+ logging.error(f"Unexpected error in Mistral verification: {e}")
476
+ if attempt < max_retries - 1:
477
+ time.sleep(2 ** attempt)
478
+ else:
479
+ return context
480
+
481
+ return context
482
+
483
+
484
+
485
+ def create_interface(
486
+ df: pd.DataFrame,
487
+ model: SentenceTransformer,
488
+ precomputed_embeddings: np.ndarray,
489
+ index: faiss.IndexFlatL2
490
+ ):
491
+ """Create Gradio interface with 4 bounding box inputs"""
492
+ return gr.Interface(
493
+ fn=lambda q, min_lat, max_lat, min_lon, max_lon, city, cat: rag_query(
494
+ query=q,
495
+ df=df,
496
+ model=model,
497
+ precomputed_embeddings=precomputed_embeddings,
498
+ index=index,
499
+ min_lat=min_lat,
500
+ max_lat=max_lat,
501
+ min_lon=min_lon,
502
+ max_lon=max_lon,
503
+ city=city,
504
+ category=cat
505
+ )[1],
506
+ inputs=[
507
+ gr.Textbox(lines=2, label="Question"),
508
+ gr.Textbox(label="Min Latitude"),
509
+ gr.Textbox(label="Max Latitude"),
510
+ gr.Textbox(label="Min Longitude"),
511
+ gr.Textbox(label="Max Longitude"),
512
+ gr.Textbox(label="City"),
513
+ gr.Textbox(label="Category")
514
+ ],
515
+ outputs=[
516
+ gr.Textbox(label="Locations Found"),
517
+ ],
518
+ title="Tourist Information System with Bounding Box Search",
519
+ examples=[
520
+ ["Museums in area", "40.71", "40.86", "-74.0", "-74.1", "", "museum"],
521
+ ["Restaurants", "48.8575", "48.9", "2.3514", "2.4", "Paris", "restaurant"],
522
+ ["Coffee shops", "51.5", "51.6", "-0.2", "-0.1", "London", "cafe"],
523
+ ["Spa places", "", "", "", "", "Budapest", ""],
524
+ ["Lambic brewery", "50.84211068618749", "50.849274898691244","4.339536387173865", "4.361188801802462", "", ""],
525
+ ["Art nouveau architecture buildings", "44.42563381188614", "44.43347927669681","26.008709832230608", "26.181744493414488", "", ""],
526
+ ["Harry Potter filming locations", "51.52428877891333", "51.54738884423489", "-0.1955164690977472", "-0.05082973945560466", "", ""]
527
+
528
+ ]
529
+ )
530
+ if __name__ == "__main__":
531
+ try:
532
+ model = SentenceTransformer('all-MiniLM-L6-v2')
533
+ precomputed_embeddings = embeddings
534
+ index = faiss.IndexFlatL2(precomputed_embeddings.shape[1])
535
+ index.add(precomputed_embeddings.astype(np.float32))
536
+
537
+ iface = create_interface(df, model, precomputed_embeddings, index)
538
+ iface.launch(share=True, debug=True)
539
+ except Exception as e:
540
+ logging.error(f"Startup error: {e}")
541
+ sys.exit(1)