Spaces:
Sleeping
Sleeping
| import sys | |
| import logging | |
| import gradio as gr | |
| import faiss | |
| import numpy as np | |
| import pandas as pd | |
| import requests | |
| from geopy.geocoders import Nominatim | |
| from sentence_transformers import SentenceTransformer | |
| from typing import Tuple, Optional | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| import geonamescache | |
| logging.basicConfig(level=logging.INFO) | |
| from huggingface_hub import login | |
| token = os.getenv('HF_TOKEN') | |
| df_path = hf_hub_download( | |
| repo_id='MrSimple07/raggg', | |
| filename='15_rag_data.csv', | |
| repo_type='dataset', | |
| token = token | |
| ) | |
| embeddings_path = hf_hub_download( | |
| repo_id='MrSimple07/raggg', | |
| filename='rag_embeddings.npy', | |
| repo_type='dataset', | |
| token = token | |
| ) | |
| df = pd.read_csv(df_path) | |
| embeddings = np.load(embeddings_path) | |
| MISTRAL_API_KEY = "TeX7Cs30zMCAi0A90w4pGhPbOGrYzQkj" | |
| MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions" | |
| category_synonyms = { | |
| "museum": [ | |
| "museums", "art galleries", "natural museums", "modern art museums" | |
| ], | |
| "cafe": [ | |
| "coffee shops", "" | |
| ], | |
| "restaurant": [ | |
| "local dining spots", "fine dining", "casual eateries", | |
| "family-friendly restaurants", "street food places" | |
| ], | |
| "parks": [ | |
| "national parks", "urban green spaces", "botanical gardens", | |
| "recreational parks", "wildlife reserves" | |
| ], | |
| "park": [ | |
| "national parks", "urban green spaces", "botanical gardens", | |
| "recreational parks", "wildlife reserves" | |
| ], | |
| "spa": ['bath', 'swimming', 'pool'] | |
| } | |
| def extract_location_geonames(query: str) -> dict: | |
| gc = geonamescache.GeonamesCache() | |
| countries = {c['name'].lower(): c['name'] for c in gc.get_countries().values()} | |
| cities = {c['name'].lower(): c['name'] for c in gc.get_cities().values()} | |
| words = query.split() | |
| for i in range(len(words)): | |
| for j in range(i+1, len(words)+1): | |
| potential_location = ' '.join(words[i:j]).lower() | |
| # Check if it's a city first | |
| if potential_location in cities: | |
| return { | |
| 'city': cities[potential_location], | |
| } | |
| # Then check if it's a country | |
| if potential_location in countries: | |
| return { | |
| 'city': ' '.join(words[:i] + words[j:]) if i+j < len(words) else None, | |
| 'country': countries[potential_location] | |
| } | |
| return {'city': query} | |
| def expand_category_once(query, target_category): | |
| """ | |
| Expand the target category term in absthe query only once with synonyms and related phrases. | |
| """ | |
| target_lower = target_category.lower() | |
| if target_lower in query.lower(): | |
| synonyms = category_synonyms.get(target_lower, []) | |
| if synonyms: | |
| expanded_term = f"{target_category} ({', '.join(synonyms)})" | |
| query = query.replace(target_category, expanded_term, 1) # Replace only the first occurrence | |
| return query | |
| CATEGORY_FILTER_WORDS = [ | |
| 'museum', 'art', 'gallery', 'tourism', 'historical', | |
| 'bar', 'cafe', 'restaurant', 'park', 'landmark', | |
| 'beach', 'mountain', 'theater', 'church', 'monument', | |
| 'garden', 'library', 'university', 'shopping', 'market', | |
| 'hotel', 'resort', 'cultural', 'natural', 'science', | |
| 'educational', 'entertainment', 'sports', 'memorial', 'historic', | |
| 'spa', 'landmarks', 'sleep', 'coffee shops', 'shops', 'buildings', | |
| 'gothic', 'castle', 'fortress', 'aquarium', 'zoo', 'wildlife', | |
| 'adventure', 'hiking', 'lighthouse', 'vineyard', 'brewery', | |
| 'winery', 'pub', 'nightclub', 'observatory', 'theme park', | |
| 'botanical', 'sanctuary', 'heritage', 'island', 'waterfall', | |
| 'canyon', 'valley', 'desert', 'artisans', 'crafts', 'music hall', | |
| 'dance clubs', 'opera house', 'skyscraper', 'bridge', 'fountain', | |
| 'temple', 'shrine', 'archaeological', 'planetarium', 'marketplace', | |
| 'street art', 'local cuisine', 'eco-tourism', 'carnival', 'festival', 'film' | |
| ] | |
| def extract_category_from_query(query: str) -> Optional[str]: | |
| query_lower = query.lower() | |
| for word in CATEGORY_FILTER_WORDS: | |
| if word in query_lower: | |
| return word | |
| return None | |
| def get_location_details(min_lat, max_lat, min_lon, max_lon): | |
| """Get detailed location information for a bounding box with improved city detection and error handling""" | |
| geolocator = Nominatim(user_agent="location_finder", timeout=10) | |
| try: | |
| # Strategy 1: Try multiple points within the bounding box | |
| sample_points = [ | |
| ((float(min_lat) + float(max_lat)) / 2, | |
| (float(min_lon) + float(max_lon)) / 2), | |
| (float(min_lat), float(min_lon)), | |
| (float(max_lat), float(min_lon)), | |
| (float(min_lat), float(max_lon)), | |
| (float(max_lat), float(max_lon)) | |
| ] | |
| # Collect unique cities from all points | |
| cities = set() | |
| full_addresses = [] | |
| for lat, lon in sample_points: | |
| try: | |
| # Add multiple retry attempts with exponential backoff | |
| for attempt in range(3): | |
| try: | |
| location = geolocator.reverse(f"{lat}, {lon}", language='en') | |
| break | |
| except Exception as retry_error: | |
| if attempt == 2: # Last attempt | |
| print(f"Failed to retrieve location for {lat}, {lon} after 3 attempts") | |
| continue | |
| time.sleep(2 ** attempt) # Exponential backoff | |
| if location: | |
| address = location.raw.get('address', {}) | |
| # Extract city with multiple fallback options | |
| city = ( | |
| address.get('city') or | |
| address.get('town') or | |
| address.get('municipality') or | |
| address.get('county') or | |
| address.get('state') | |
| ) | |
| if city: | |
| cities.add(city) | |
| full_addresses.append(location.address) | |
| except Exception as point_error: | |
| print(f"Error processing point {lat}, {lon}: {point_error}") | |
| continue | |
| # If no cities found, try alternative geocoding service or return default | |
| if not cities: | |
| print("No cities detected. Returning default location information.") | |
| return { | |
| 'location_parts': [], | |
| 'full_address_parts': '', | |
| 'full_address': '', | |
| 'city': [], | |
| 'state': '', | |
| 'country': '', | |
| 'cities_or_query': '' | |
| } | |
| # Prioritize cities, keeping all detected cities | |
| city_list = list(cities) | |
| # Use the last processed address for state and country | |
| state = address.get('state', '') | |
| country = address.get('country', '') | |
| # Create a formatted list of cities for query | |
| cities_or_query = " or ".join(city_list) | |
| location_parts = [part for part in [cities_or_query, state, country] if part] | |
| full_address_parts = ', '.join(location_parts) | |
| print(f"Detected Cities: {cities}") | |
| print(f"Cities for Query: {cities_or_query}") | |
| print(f"Full Address Parts: {full_address_parts}") | |
| return { | |
| 'location_parts': city_list, | |
| 'full_address_parts': full_address_parts, | |
| 'full_address': full_addresses[0] if full_addresses else '', | |
| 'city': city_list, | |
| 'state': state, | |
| 'country': country, | |
| 'cities_or_query': cities_or_query | |
| } | |
| except Exception as e: | |
| print(f"Comprehensive error in location details retrieval: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def rag_query( | |
| query: str, | |
| df: pd.DataFrame, | |
| model: SentenceTransformer, | |
| precomputed_embeddings: np.ndarray, | |
| index: faiss.IndexFlatL2, | |
| min_lat: str = None, | |
| max_lat: str = None, | |
| min_lon: str = None, | |
| max_lon: str = None, | |
| category: str = None, | |
| city: str = None, | |
| ) -> Tuple[str, str]: | |
| """Enhanced RAG function with prioritized location extraction""" | |
| print("\n=== Starting RAG Query ===") | |
| print(f"Initial DataFrame size: {len(df)}") | |
| # Prioritized location extraction | |
| location_info = None | |
| location_names = [] | |
| # Priority 1: Explicitly provided city name | |
| if city: | |
| location_names = [city] | |
| print(f"Using explicitly provided city: {city}") | |
| # Priority 2: Coordinates (Nominatim) | |
| elif all(coord is not None and coord != "" for coord in [min_lat, max_lat, min_lon, max_lon]): | |
| try: | |
| location_info = get_location_details( | |
| float(min_lat), | |
| float(max_lat), | |
| float(min_lon), | |
| float(max_lon) | |
| ) | |
| # Extract location names from Nominatim result | |
| if location_info: | |
| if location_info.get('city'): | |
| location_names.extend(location_info['city'] if isinstance(location_info['city'], list) else [location_info['city']]) | |
| if location_info.get('state'): | |
| location_names.append(location_info['state']) | |
| if location_info.get('country'): | |
| location_names.append(location_info['country']) | |
| print(f"Using coordinates-based location: {location_names}") | |
| except Exception as e: | |
| print(f"Location details error: {e}") | |
| # Priority 3: Extract from query using GeoNames only if no previous methods worked | |
| if not location_names: | |
| geonames_info = extract_location_geonames(query) | |
| if geonames_info.get('city'): | |
| location_names = [geonames_info['city']] | |
| print(f"Using GeoNames-extracted city: {location_names}") | |
| # Start with a copy of the original DataFrame | |
| filtered_df = df.copy() | |
| # Filter DataFrame by location names | |
| if location_names: | |
| # Create a case-insensitive filter | |
| location_filter = ( | |
| filtered_df['city'].str.lower().isin([name.lower() for name in location_names]) | | |
| filtered_df['city'].apply(lambda x: any(name.lower() in str(x).lower() for name in location_names)) | | |
| filtered_df['combined_field'].apply(lambda x: any(name.lower() in str(x).lower() for name in location_names)) | |
| ) | |
| filtered_df = filtered_df[location_filter] | |
| print(f"Location Names Used for Filtering: {location_names}") | |
| print(f"Results after location filtering: {len(filtered_df)}") | |
| enhanced_query_parts = [] | |
| if query: | |
| enhanced_query_parts.append(query) | |
| if category: | |
| enhanced_query_parts.append(f"{category} category") | |
| if city: | |
| enhanced_query_parts.append(f" in {city}") | |
| if min_lat is not None and max_lat is not None and min_lon is not None and max_lon is not None: | |
| enhanced_query_parts.append(f"within latitudes {min_lat} to {max_lat} and longitudes {min_lon} to {max_lon}") | |
| # Add location context | |
| if location_info: | |
| location_context = " ".join(filter(None, [ | |
| ", ".join(location_info.get('city', [])), | |
| location_info.get('state', ''), | |
| # location_info.get('country', '') | |
| ])) | |
| if location_context: | |
| enhanced_query_parts.append(f"in {location_context}") | |
| enhanced_query = " ".join(enhanced_query_parts) | |
| if enhanced_query: | |
| enhanced_query = expand_category_once(enhanced_query, category) | |
| print(f"Filtered by city '{city}': {len(filtered_df)} results") | |
| print(f"Enhanced Query: {enhanced_query}") | |
| detected_category = extract_category_from_query(enhanced_query) | |
| if detected_category: | |
| category_filter = ( | |
| filtered_df['category'].str.contains(detected_category, case=False, na=False) | | |
| filtered_df['combined_field'].str.contains(detected_category, case=False, na=False) | |
| ) | |
| filtered_df = filtered_df[category_filter] | |
| print(f"Filtered by query words '{detected_category}': {len(filtered_df)} results") | |
| try: | |
| query_vector = model.encode([enhanced_query])[0] | |
| # Compute embeddings for the filtered DataFrame | |
| filtered_embeddings = precomputed_embeddings[filtered_df.index] | |
| # Create FAISS index with filtered embeddings | |
| filtered_index = faiss.IndexFlatL2(filtered_embeddings.shape[1]) | |
| filtered_index.add(filtered_embeddings.astype(np.float32)) | |
| # Perform semantic search on filtered results | |
| k = min(20, len(filtered_df)) | |
| distances, local_indices = filtered_index.search( | |
| np.array([query_vector]).astype(np.float32), | |
| k | |
| ) | |
| # Get the top results | |
| results_df = filtered_df.iloc[local_indices[0]] | |
| # Format results | |
| formatted_results = [] | |
| for i, (_, row) in enumerate(results_df.iterrows(), 1): | |
| formatted_results.append( | |
| f"\n=== Result {i} ===\n" | |
| f"Name: {row['name']}\n" | |
| f"Category: {row['category']}\n" | |
| f"City: {row['city']}\n" | |
| f"Address: {row['address']}\n" | |
| f"Description: {row['description']}\n" | |
| f"Latitude: {row['latitude']}\n" | |
| f"Longitude: {row['longitude']}\n" | |
| ) | |
| search_results = "\n".join(formatted_results) if formatted_results else "No matching locations found." | |
| # Optional: Use Mistral for further refinement | |
| try: | |
| answer = query_mistral(enhanced_query, search_results) | |
| except Exception as e: | |
| print(f"Error in Mistral query: {e}") | |
| answer = "Unable to generate additional insights." | |
| return search_results, answer | |
| except Exception as e: | |
| print(f"Error in semantic search: {e}") | |
| return f"Error performing search: {str(e)}", "" | |
| def query_mistral(prompt: str, context: str, max_retries: int = 3) -> str: | |
| """ | |
| Robust Mistral verification with exponential backoff | |
| """ | |
| import time | |
| # Early return if no context | |
| if not context or context.strip() == "No matching locations found.": | |
| return context | |
| verification_prompt = f"""Precise Location Curation Task: | |
| REQUIREMENTS: | |
| - Source Query: {prompt} | |
| - Current Context: {context} | |
| DETAILED INSTRUCTIONS: | |
| 1. Write the min, max latitude and min, max longitude in the beginning taking from the query | |
| 2. Curate a comprehensive list of 15 locations inside of these coordinates and strictly relevant to place. | |
| 3. Take STRICTLY ONLY relevant places to Source Query. | |
| 4. Add a short description about the place (2-3 sentences) | |
| 5. Add coordinates (lat and long). | |
| 6. Add address for the place | |
| 7. Remove any duplicate entries in the list | |
| 8. If places > 10, quick generation a new places relevant to Source Query and inside of the coordinates | |
| CRITICAL: Do NOT use placeholder. Quick and fast response required | |
| """ | |
| for attempt in range(max_retries): | |
| try: | |
| # Robust API configuration | |
| response = requests.post( | |
| MISTRAL_API_URL, | |
| headers={ | |
| "Authorization": f"Bearer {MISTRAL_API_KEY}", | |
| "Content-Type": "application/json" | |
| }, | |
| json={ | |
| "model": "mistral-large-latest", | |
| "messages": [ | |
| {"role": "system", "content": "You are a precise location curator specializing in comprehensive travel information."}, | |
| {"role": "user", "content": verification_prompt} | |
| ], | |
| "temperature": 0.1, | |
| "max_tokens": 5000 | |
| }, | |
| timeout=100 # Increased timeout | |
| ) | |
| # Enhanced error handling | |
| response.raise_for_status() | |
| # Extract verified response | |
| verified_response = response.json()['choices'][0]['message']['content'] | |
| # Validate response length and complexity | |
| if len(verified_response.strip()) < 100: | |
| if attempt == max_retries - 1: | |
| return context | |
| time.sleep(2 ** attempt) # Exponential backoff | |
| continue | |
| return verified_response | |
| except requests.Timeout: | |
| logging.warning(f"Mistral API timeout (Attempt {attempt + 1}/{max_retries})") | |
| if attempt < max_retries - 1: | |
| time.sleep(2 ** attempt) # Exponential backoff | |
| else: | |
| logging.error("Mistral API consistently timing out") | |
| return context | |
| except requests.RequestException as e: | |
| logging.error(f"Mistral API request error: {e}") | |
| if attempt < max_retries - 1: | |
| time.sleep(2 ** attempt) | |
| else: | |
| return context | |
| except Exception as e: | |
| logging.error(f"Unexpected error in Mistral verification: {e}") | |
| if attempt < max_retries - 1: | |
| time.sleep(2 ** attempt) | |
| else: | |
| return context | |
| return context | |
| def create_interface( | |
| df: pd.DataFrame, | |
| model: SentenceTransformer, | |
| precomputed_embeddings: np.ndarray, | |
| index: faiss.IndexFlatL2 | |
| ): | |
| """Create Gradio interface with 4 bounding box inputs""" | |
| return gr.Interface( | |
| fn=lambda q, min_lat, max_lat, min_lon, max_lon, city, cat: rag_query( | |
| query=q, | |
| df=df, | |
| model=model, | |
| precomputed_embeddings=precomputed_embeddings, | |
| index=index, | |
| min_lat=min_lat, | |
| max_lat=max_lat, | |
| min_lon=min_lon, | |
| max_lon=max_lon, | |
| city=city, | |
| category=cat | |
| )[1], | |
| inputs=[ | |
| gr.Textbox(lines=2, label="Question"), | |
| gr.Textbox(label="Min Latitude"), | |
| gr.Textbox(label="Max Latitude"), | |
| gr.Textbox(label="Min Longitude"), | |
| gr.Textbox(label="Max Longitude"), | |
| gr.Textbox(label="City"), | |
| gr.Textbox(label="Category") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Locations Found"), | |
| ], | |
| title="Tourist Information System with Bounding Box Search", | |
| examples=[ | |
| ["Museums in area", "40.71", "40.86", "-74.0", "-74.1", "", "museum"], | |
| ["Restaurants", "48.8575", "48.9", "2.3514", "2.4", "Paris", "restaurant"], | |
| ["Coffee shops", "51.5", "51.6", "-0.2", "-0.1", "London", "cafe"], | |
| ["Spa places", "", "", "", "", "Budapest", ""], | |
| ["Lambic brewery", "50.84211068618749", "50.849274898691244","4.339536387173865", "4.361188801802462", "", ""], | |
| ["Art nouveau architecture buildings", "44.42563381188614", "44.43347927669681","26.008709832230608", "26.181744493414488", "", ""], | |
| ["Harry Potter filming locations", "51.52428877891333", "51.54738884423489", "-0.1955164690977472", "-0.05082973945560466", "", ""] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| precomputed_embeddings = embeddings | |
| index = faiss.IndexFlatL2(precomputed_embeddings.shape[1]) | |
| index.add(precomputed_embeddings.astype(np.float32)) | |
| iface = create_interface(df, model, precomputed_embeddings, index) | |
| iface.launch(share=True, debug=True) | |
| except Exception as e: | |
| logging.error(f"Startup error: {e}") | |
| sys.exit(1) |