jomasego commited on
Commit
235071b
·
verified ·
1 Parent(s): 5636a13

Upload llm_assistant.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. llm_assistant.py +340 -0
llm_assistant.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trade Data Assistant using Google Gemma-2b from Hugging Face
3
+ """
4
+ import os
5
+ import requests
6
+ import json
7
+ from typing import Dict, List, Any, Optional
8
+
9
+ class TradeAssistant:
10
+ """
11
+ Assistant powered by Google Gemma-2b to help users with trade data analysis
12
+ """
13
+
14
+ def __init__(self, api_token: Optional[str] = None):
15
+ """Initialize the Trade Assistant with HuggingFace API token"""
16
+ self.api_token = api_token or os.environ.get("HUGGINGFACE_API_TOKEN")
17
+ if not self.api_token:
18
+ print("Warning: No HuggingFace API token provided. Please set HUGGINGFACE_API_TOKEN environment variable.")
19
+
20
+ # Model ID for Google Gemma-2b - efficient with strong reasoning
21
+ self.model_id = "google/gemma-2b-it"
22
+
23
+ # API endpoint
24
+ self.api_url = f"https://api-inference.huggingface.co/models/{self.model_id}"
25
+
26
+ # Headers for API requests
27
+ self.headers = {
28
+ "Authorization": f"Bearer {self.api_token}",
29
+ "Content-Type": "application/json"
30
+ }
31
+
32
+ # System prompt defining the assistant's role
33
+ self.system_prompt = """
34
+ You are Trade Flow Assistant, an AI helper specializing in international trade data analysis.
35
+ You assist users with:
36
+ 1. Finding and interpreting trade data
37
+ 2. Explaining economic trends and trade flows
38
+ 3. Helping users navigate the Trade Flow Predictor application
39
+ 4. Suggesting relevant visualizations and analysis approaches
40
+ 5. Explaining trade terminology and concepts
41
+
42
+ Focus on providing clear, concise responses with actionable insights.
43
+ When appropriate, suggest specific countries, commodities, or time periods to explore.
44
+ Do not make up data - if you don't know something, say so.
45
+ """
46
+
47
+ # Context about the application
48
+ self.app_context = """
49
+ The Trade Flow Predictor application has the following features:
50
+ - Data viewing and visualization for international trade (imports/exports)
51
+ - Filtering by country, product code, year, and flow type
52
+ - Machine learning prediction of future trade values
53
+ - Various chart types (bar, pie, line, treemap)
54
+ - Data download capabilities
55
+
56
+ Available tabs:
57
+ - Basics: Simple data lookup by country pairs
58
+ - Exports by Country: View top export destinations
59
+ - Imports by Country: View top import sources
60
+ - Exports by Product: View top exported products
61
+ - Imports by Product: View top imported products
62
+ - Rankings: Compare countries by trade volume
63
+ - Bilateral Trade: Examine trade between specific country pairs
64
+ - Data Download: Download custom datasets
65
+ - Prediction: ML forecasting of future trade values
66
+ - Data Cache: Manage previously retrieved data
67
+ """
68
+
69
+ def query(self,
70
+ user_question: str,
71
+ chat_history: List[Dict[str, str]] = None,
72
+ include_app_context: bool = True) -> Dict[str, Any]:
73
+ """
74
+ Send a query to the LLM and get a response
75
+
76
+ Args:
77
+ user_question: The user's question
78
+ chat_history: Previous conversation history
79
+ include_app_context: Whether to include app context in the prompt
80
+
81
+ Returns:
82
+ Dict containing the LLM response
83
+ """
84
+ if chat_history is None:
85
+ chat_history = []
86
+
87
+ # Construct the messages for the LLM
88
+ messages = [
89
+ {"role": "system", "content": self.system_prompt}
90
+ ]
91
+
92
+ # Add application context if requested
93
+ if include_app_context and not chat_history:
94
+ messages.append({"role": "system", "content": self.app_context})
95
+
96
+ # Add chat history
97
+ for message in chat_history:
98
+ messages.append(message)
99
+
100
+ # Add the current question
101
+ messages.append({"role": "user", "content": user_question})
102
+
103
+ try:
104
+ # Send the request to the HuggingFace API
105
+ payload = {
106
+ "inputs": messages,
107
+ "parameters": {
108
+ "max_new_tokens": 500,
109
+ "temperature": 0.7,
110
+ "top_p": 0.9,
111
+ "do_sample": True
112
+ }
113
+ }
114
+
115
+ # Implement retry mechanism for model loading
116
+ max_retries = 2
117
+ retry_delay = 1 # seconds
118
+
119
+ for attempt in range(max_retries):
120
+ response = requests.post(
121
+ self.api_url,
122
+ headers=self.headers,
123
+ json=payload,
124
+ timeout=10 # Add timeout to prevent hanging requests
125
+ )
126
+
127
+ # If request succeeded, process the response
128
+ if response.status_code == 200:
129
+ try:
130
+ result = response.json()
131
+ if isinstance(result, list) and len(result) > 0:
132
+ # Extract the assistant's response
133
+ generated_text = result[0].get("generated_text", "")
134
+
135
+ # Format for return
136
+ return {
137
+ "success": True,
138
+ "response": generated_text,
139
+ "message": "Successfully generated response"
140
+ }
141
+ else:
142
+ return {
143
+ "success": False,
144
+ "response": self.get_fallback_response(user_question),
145
+ "message": f"Unexpected API response format: {result}"
146
+ }
147
+ except (json.JSONDecodeError, KeyError, IndexError) as e:
148
+ print(f"Error processing response: {str(e)}, Response: {response.text}")
149
+ return {
150
+ "success": True, # Return as success but with fallback response
151
+ "response": self.get_fallback_response(user_question),
152
+ "message": f"Error processing response: {str(e)}"
153
+ }
154
+
155
+ # If model is loading (status code 503), wait and retry
156
+ elif response.status_code == 503:
157
+ print(f"Model is loading or temporarily unavailable. Attempt {attempt+1}/{max_retries}.")
158
+ if attempt < max_retries - 1: # Don't wait after the last attempt
159
+ import time
160
+ time.sleep(retry_delay)
161
+ else:
162
+ # If we've exhausted all retries, use fallback
163
+ return {
164
+ "success": True, # Mark as successful but using fallback
165
+ "response": self.get_fallback_response(user_question),
166
+ "message": f"Model unavailable (status: {response.status_code}). Using fallback response."
167
+ }
168
+ else:
169
+ # Other errors - try fallback immediately
170
+ error_message = f"API request failed with status code {response.status_code}"
171
+ try:
172
+ error_detail = response.json()
173
+ error_message += f": {json.dumps(error_detail)}"
174
+ except:
175
+ error_message += f": {response.text}"
176
+
177
+ print(error_message) # Log the error for debugging
178
+
179
+ # Return fallback response instead of error
180
+ return {
181
+ "success": True, # Mark as successful but using fallback
182
+ "response": self.get_fallback_response(user_question),
183
+ "message": error_message
184
+ }
185
+
186
+ except Exception as e:
187
+ print(f"Exception during API request: {str(e)}")
188
+ return {
189
+ "success": True, # Return as success but with fallback
190
+ "response": self.get_fallback_response(user_question),
191
+ "message": f"Error querying LLM: {str(e)}"
192
+ }
193
+
194
+ def get_fallback_response(self, query: str) -> str:
195
+ """
196
+ Provide a fallback response when the model is unavailable or loading
197
+
198
+ Args:
199
+ query: The user's question
200
+
201
+ Returns:
202
+ A useful fallback response based on the query
203
+ """
204
+ query_lower = query.lower()
205
+
206
+ # Common trade-related questions and answers
207
+ if "hs code" in query_lower or "hscode" in query_lower:
208
+ return "HS Codes (Harmonized System Codes) are standardized numerical codes developed by the World Customs Organization (WCO) to classify traded products. Each code represents a specific category of goods, with the first 2 digits identifying the chapter, the next 2 identifying the heading, and so on. For example, HS code 8471 represents 'Automatic data-processing machines and units thereof; magnetic or optical readers, machines for transcribing data onto data media in coded form and machines for processing such data'."
209
+
210
+ elif "imports" in query_lower and "exports" in query_lower and ("difference" in query_lower or "vs" in query_lower):
211
+ return "Imports represent goods and services purchased from other countries and brought into the reporting country. Exports represent goods and services produced domestically and sold to buyers in other countries. The difference between exports and imports is called the trade balance. A trade surplus occurs when exports exceed imports, while a trade deficit occurs when imports exceed exports."
212
+
213
+ elif "recommend" in query_lower or "interesting" in query_lower or "pattern" in query_lower:
214
+ return "While the model is temporarily unavailable, here are some interesting trade patterns to explore:\n\n1. **China-US Trade Tensions**: Examine how trade flows between China and the US have changed since 2018\n\n2. **COVID-19 Impact**: Look at the dramatic shifts in medical supply trade in 2020-2021\n\n3. **Green Technology Trade**: Explore the growing exports of renewable energy equipment, particularly solar panels and wind turbines\n\n4. **Semiconductor Supply Chain**: Investigate the complex global trade network for microchips and electronic components\n\n5. **Changing Agricultural Patterns**: Review how climate change has affected agricultural trade flows globally\n\nYou can explore these patterns using the data visualization tools in the application."
215
+
216
+ elif "interpret" in query_lower or "understand" in query_lower or "analyze" in query_lower:
217
+ return "To interpret trade data effectively:\n\n1. **Consider Context**: Look at multiple years to identify trends vs. one-time anomalies\n\n2. **Compare Related Metrics**: Examine both value and volume to distinguish price effects from quantity changes\n\n3. **Check Seasonality**: Many products have seasonal trade patterns that repeat annually\n\n4. **Account for Re-exports**: Some countries serve as trade hubs, importing and then re-exporting goods\n\n5. **Use Visualization**: Charts and graphs can reveal patterns that aren't obvious in tables\n\nThe Trade Flow Predictor application provides multiple visualization options to help with this analysis."
218
+
219
+ elif "8471" in query_lower:
220
+ return "HS Code 8471: Automatic data processing machines and units thereof; magnetic or optical readers, machines for transcribing data onto data media in coded form and machines for processing such data.\n\nThis includes computers, laptops, servers, and related equipment. Major exporters include China, Mexico, the Netherlands, and the United States. This is a high-value category in international trade with complex supply chains spanning multiple countries."
221
+
222
+ else:
223
+ return "I'm sorry, but I can't provide a specific answer right now as the AI model is temporarily unavailable. Please try again in a few minutes. In the meantime, you can explore the trade data visualization tools in the application, or try one of these specific questions:\n\n- What are HS codes?\n- Explain the difference between imports and exports\n- Recommend interesting trade patterns to explore\n- How can I interpret trade data?"
224
+
225
+ def get_trade_recommendation(self,
226
+ country: str = None,
227
+ product: str = None,
228
+ year: str = None) -> Dict[str, Any]:
229
+ """
230
+ Get a specific recommendation for trade data exploration
231
+
232
+ Args:
233
+ country: Country name or code (optional)
234
+ product: Product name or HS code (optional)
235
+ year: Year for analysis (optional)
236
+
237
+ Returns:
238
+ Dict containing the LLM recommendation
239
+ """
240
+ # Construct a specific prompt for recommendations
241
+ recommendation_prompt = f"Please recommend interesting trade patterns to explore"
242
+
243
+ if country:
244
+ recommendation_prompt += f" for {country}"
245
+ if product:
246
+ recommendation_prompt += f" related to {product}"
247
+ if year:
248
+ recommendation_prompt += f" in {year}"
249
+
250
+ recommendation_prompt += ". Suggest specific data queries and visualizations that would be insightful."
251
+
252
+ return self.query(recommendation_prompt)
253
+
254
+ def explain_hs_code(self, code: str) -> Dict[str, Any]:
255
+ """
256
+ Explain what a specific HS code represents
257
+
258
+ Args:
259
+ code: HS code to explain
260
+
261
+ Returns:
262
+ Dict containing the explanation
263
+ """
264
+ prompt = f"Please explain what the HS code {code} represents in international trade classification. Include information about what products are classified under this code, any notable trade patterns, and major exporting countries if you know them."
265
+
266
+ return self.query(prompt, include_app_context=False)
267
+
268
+ def format_chat_history(self, chat_history_raw: List[Dict[str, Any]]) -> List[Dict[str, str]]:
269
+ """
270
+ Format chat history to match the expected format for the LLM API
271
+
272
+ Args:
273
+ chat_history_raw: Raw chat history from the frontend
274
+
275
+ Returns:
276
+ Formatted chat history compatible with the API
277
+ """
278
+ formatted_history = []
279
+
280
+ for message in chat_history_raw:
281
+ if not isinstance(message, dict) or 'role' not in message or 'content' not in message:
282
+ continue
283
+
284
+ role = message.get('role', '').lower()
285
+ content = message.get('content', '')
286
+
287
+ # Ensure role is either 'user' or 'assistant'
288
+ if role not in ['user', 'assistant']:
289
+ continue
290
+
291
+ formatted_history.append({
292
+ "role": role,
293
+ "content": content
294
+ })
295
+
296
+ return formatted_history
297
+
298
+ def enhance_query_with_context(self, query: str) -> str:
299
+ """
300
+ Enhance a user query with additional context about trade data
301
+
302
+ Args:
303
+ query: Original user query
304
+
305
+ Returns:
306
+ Enhanced query
307
+ """
308
+ # Add context to HS code questions
309
+ if "hs code" in query.lower() or "hscode" in query.lower() or "hs-code" in query.lower():
310
+ return f"{query} Please explain in the context of international trade classification."
311
+
312
+ # Add context to country questions
313
+ if "country" in query.lower() or "countries" in query.lower():
314
+ return f"{query} Focus on trade-related information and statistics if available."
315
+
316
+ # Add context to trend questions
317
+ if "trend" in query.lower() or "trends" in query.lower():
318
+ return f"{query} Consider both recent trends and historical context where relevant."
319
+
320
+ # Default case - return original query
321
+ return query
322
+
323
+
324
+ # Simple test function
325
+ def test_assistant():
326
+ """Test the Trade Assistant functionality"""
327
+ assistant = TradeAssistant()
328
+ test_query = "What are HS codes and how are they used in trade analysis?"
329
+
330
+ print(f"Test query: {test_query}")
331
+ response = assistant.query(test_query)
332
+
333
+ if response["success"]:
334
+ print("\nResponse:")
335
+ print(response["response"])
336
+ else:
337
+ print(f"\nError: {response['message']}")
338
+
339
+ if __name__ == "__main__":
340
+ test_assistant()