dolphinium commited on
Commit
bba916d
·
1 Parent(s): f3249d4

enhance vis generation code prompt.

Browse files
Files changed (1) hide show
  1. llm_prompts.py +66 -127
llm_prompts.py CHANGED
@@ -214,7 +214,7 @@ This is the most critical part of your task. A bad choice leads to a useless, bo
214
  "deal_values_by_route": {{
215
  "type": "terms",
216
  "field": "route_branch",
217
- "limit": 10,
218
  "sort": "total_deal_value desc",
219
  "facet": {{
220
  "total_deal_value": "sum(total_deal_value_in_million)"
@@ -239,7 +239,6 @@ Convert the following user query into a single, raw JSON "Analysis Plan" object.
239
  **Current User Query:** `{natural_language_query}`
240
  """
241
 
242
- # The other prompt functions remain unchanged.
243
  def get_synthesis_report_prompt(query, quantitative_data, qualitative_data, plan):
244
  """
245
  Generates the prompt for synthesizing a final report from the query results.
@@ -277,7 +276,7 @@ This data shows the high-level aggregates.
277
  {json.dumps(quantitative_data, indent=2)}
278
  ```
279
 
280
- **3. Qualitative Data (The 'Why'):
281
  These are the single most significant documents driving the numbers for each category.
282
  {qualitative_prompt_str}
283
 
@@ -307,117 +306,53 @@ Your report must be in clean, professional Markdown and follow this structure pr
307
 
308
  def get_visualization_code_prompt(query_context, facet_data):
309
  """
310
- Generates the prompt for creating Python visualization code.
311
  """
312
  return f"""
313
- You are a Python Data Visualization expert specializing in Matplotlib and Seaborn.
314
- Your task is to generate robust, error-free Python code to create a single, insightful visualization based on the user's query and the provided Solr facet data.
315
 
316
- **User's Analytical Goal:**
317
- \"{query_context}\"
318
 
319
- **Aggregated Data (from Solr Facets):**
320
  ```json
321
  {json.dumps(facet_data, indent=2)}
322
  ```
323
 
324
  ---
325
  ### **CRITICAL INSTRUCTIONS: CODE GENERATION RULES**
326
- You MUST follow these rules to avoid errors.
327
-
328
- **1. Identify the Data Structure FIRST:**
329
- Before writing any code, analyze the `facet_data` JSON to determine its structure. There are three common patterns. Choose the correct template below.
330
-
331
- * **Pattern A: Simple `terms` Facet.** The JSON has ONE main key (besides "count") which contains a list of "buckets". Each bucket has a "val" and a "count". Use this for standard bar charts.
332
- * **Pattern B: Multiple `query` Facets.** The JSON has MULTIPLE keys (besides "count"), and each key is an object containing metrics like "count" or "sum(...)". Use this for comparing a few distinct items (e.g., "oral vs injection").
333
- * **Pattern C: Nested `terms` Facet.** The JSON has one main key with a list of "buckets", but inside EACH bucket, there are nested metric objects. This is used for grouped comparisons (e.g., "compare 2024 vs 2025 across categories"). This almost always requires `pandas`.
334
-
335
- **2. Use the Correct Parsing Template:**
336
-
337
- ---
338
- **TEMPLATE FOR PATTERN A (Simple Bar Chart from `terms` facet):**
339
- ```python
340
- import matplotlib.pyplot as plt
341
- import seaborn as sns
342
- import pandas as pd
343
-
344
- plt.style.use('seaborn-v0_8-whitegrid')
345
- fig, ax = plt.subplots(figsize=(12, 8))
346
-
347
- # Dynamically find the main facet key (the one with 'buckets')
348
- facet_key = None
349
- for key, value in facet_data.items():
350
- if isinstance(value, dict) and 'buckets' in value:
351
- facet_key = key
352
- break
353
 
354
- if facet_key:
355
- buckets = facet_data[facet_key].get('buckets', [])
356
- # Check if buckets contain data
357
- if buckets:
358
- df = pd.DataFrame(buckets)
359
- # Check for a nested metric or use 'count'
360
- if 'total_deal_value' in df.columns and pd.api.types.is_dict_like(df['total_deal_value'].iloc):
361
- # Example for nested sum metric
362
- df['value'] = df['total_deal_value'].apply(lambda x: x.get('sum', 0))
363
- y_axis_label = 'Sum of Total Deal Value'
364
- else:
365
- df.rename(columns={{'count': 'value'}}, inplace=True)
366
- y_axis_label = 'Count'
367
-
368
- sns.barplot(data=df, x='val', y='value', ax=ax, palette='viridis')
369
- ax.set_xlabel('Category')
370
- ax.set_ylabel(y_axis_label)
371
- else:
372
- ax.text(0.5, 0.5, 'No data in buckets to plot.', ha='center')
373
-
374
-
375
- ax.set_title('Your Insightful Title Here')
376
- # Correct way to rotate labels to prevent errors
377
- plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
378
- plt.tight_layout()
379
- ```
380
  ---
381
- **TEMPLATE FOR PATTERN B (Comparison Bar Chart from `query` facets):**
382
- ```python
383
- import matplotlib.pyplot as plt
384
- import seaborn as sns
385
- import pandas as pd
386
-
387
- plt.style.use('seaborn-v0_8-whitegrid')
388
- fig, ax = plt.subplots(figsize=(10, 6))
389
-
390
- labels = []
391
- values = []
392
- # Iterate through top-level keys, skipping the 'count'
393
- for key, data_dict in facet_data.items():
394
- if key == 'count' or not isinstance(data_dict, dict):
395
- continue
396
- # Extract the label (e.g., 'oral_deals' -> 'Oral')
397
- label = key.replace('_deals', '').replace('_', ' ').title()
398
- # Find the metric value, which is NOT 'count'
399
- metric_value = 0
400
- for sub_key, sub_value in data_dict.items():
401
- if sub_key != 'count':
402
- metric_value = sub_value
403
- break # Found the metric
404
- labels.append(label)
405
- values.append(metric_value)
406
-
407
- if labels:
408
- sns.barplot(x=labels, y=values, ax=ax, palette='mako')
409
- ax.set_ylabel('Total Deal Value') # Or other metric name
410
- ax.set_xlabel('Category')
411
- else:
412
- ax.text(0.5, 0.5, 'No query facet data to plot.', ha='center')
413
 
 
414
 
415
- ax.set_title('Your Insightful Title Here')
416
- plt.tight_layout()
417
- ```
418
- ---
419
- **TEMPLATE FOR PATTERN C (Grouped Bar Chart from nested `terms` facet):**
420
  ```python
 
421
  import matplotlib.pyplot as plt
422
  import seaborn as sns
423
  import pandas as pd
@@ -425,48 +360,52 @@ import pandas as pd
425
  plt.style.use('seaborn-v0_8-whitegrid')
426
  fig, ax = plt.subplots(figsize=(14, 8))
427
 
428
- # Find the key that has the buckets
 
429
  facet_key = None
430
  for key, value in facet_data.items():
431
  if isinstance(value, dict) and 'buckets' in value:
432
  facet_key = key
433
  break
434
 
 
 
435
  if facet_key and facet_data[facet_key].get('buckets'):
436
- # This list comprehension is robust for parsing nested metrics
437
- plot_data = []
438
  for bucket in facet_data[facet_key]['buckets']:
439
- category = bucket['val']
440
- # Find all nested metrics (e.g., total_deal_value_2025)
441
  for sub_key, sub_value in bucket.items():
442
  if isinstance(sub_value, dict) and 'sum' in sub_value:
443
- # Extracts year from 'total_deal_value_2025' -> '2025'
444
- year = sub_key.split('_')[-1]
445
  value = sub_value['sum']
446
- plot_data.append({{'Category': category, 'Year': year, 'Value': value}})
447
-
448
- if plot_data:
449
- df = pd.DataFrame(plot_data)
450
- sns.barplot(data=df, x='Category', y='Value', hue='Year', ax=ax)
451
- ax.set_ylabel('Total Deal Value')
452
- ax.set_xlabel('Business Model')
453
- # Correct way to rotate labels to prevent errors
454
- plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
455
- else:
456
- ax.text(0.5, 0.5, 'No nested data found to plot.', ha='center')
 
 
 
457
  else:
458
- ax.text(0.5, 0.5, 'No data in buckets to plot.', ha='center')
 
 
459
 
460
- ax.set_title('Your Insightful Title Here')
461
  plt.tight_layout()
462
  ```
 
463
  ---
464
- **3. Final Code Generation:**
465
- - **DO NOT** include `plt.show()`.
466
- - **DO** set a dynamic and descriptive `ax.set_title()`, `ax.set_xlabel()`, and `ax.set_ylabel()`.
467
- - **DO NOT** wrap the code in ```python ... ```. Output only the raw Python code.
468
- - Adapt the chosen template to the specific keys and metrics in the provided `facet_data`.
469
-
470
- **Your Task:**
471
- Now, generate the Python code.
472
  """
 
214
  "deal_values_by_route": {{
215
  "type": "terms",
216
  "field": "route_branch",
217
+ "limit": 2,
218
  "sort": "total_deal_value desc",
219
  "facet": {{
220
  "total_deal_value": "sum(total_deal_value_in_million)"
 
239
  **Current User Query:** `{natural_language_query}`
240
  """
241
 
 
242
  def get_synthesis_report_prompt(query, quantitative_data, qualitative_data, plan):
243
  """
244
  Generates the prompt for synthesizing a final report from the query results.
 
276
  {json.dumps(quantitative_data, indent=2)}
277
  ```
278
 
279
+ **3. Qualitative Data (The 'Why'):**
280
  These are the single most significant documents driving the numbers for each category.
281
  {qualitative_prompt_str}
282
 
 
306
 
307
  def get_visualization_code_prompt(query_context, facet_data):
308
  """
309
+ Generates a flexible prompt for creating Python visualization code.
310
  """
311
  return f"""
312
+ You are a world-class Python data visualization expert specializing in Matplotlib and Seaborn.
313
+ Your primary task is to generate a single, insightful, and robust Python script to visualize the provided data. The visualization should directly answer the user's analytical goal.
314
 
315
+ **1. User's Analytical Goal:**
316
+ "{query_context}"
317
 
318
+ **2. Aggregated Data (from Solr Facets):**
319
  ```json
320
  {json.dumps(facet_data, indent=2)}
321
  ```
322
 
323
  ---
324
  ### **CRITICAL INSTRUCTIONS: CODE GENERATION RULES**
325
+ You MUST follow these rules meticulously to ensure the code runs without errors in a server environment.
326
+
327
+ **A. Analyze the Data & Choose the Right Chart:**
328
+ - **Inspect the Data:** Before writing any code, carefully examine the structure of the `facet_data` JSON. Is it a simple list of categories and counts? Is it a nested structure comparing metrics across categories? Is it a time-series?
329
+ - **Select the Best Chart Type:** Based on the data and the user's goal, choose the most effective chart.
330
+ - **Bar Chart:** Ideal for comparing quantities across different categories (e.g., top companies by deal value).
331
+ - **Grouped Bar Chart:** Use when comparing a metric across categories for a few groups (e.g., deal values for 2023 vs. 2024 by company).
332
+ - **Line Chart:** Best for showing a trend over time (e.g., number of approvals per year).
333
+ - **Pie Chart:** Use ONLY for showing parts of a whole, and only with a few (2-5) categories. Generally, bar charts are better.
334
+ - **Tell a Story:** Your visualization should be more than just a plot; it should reveal the key insight from the data.
335
+ - **Direct Answer** If user ask for like this: compare x with y there should be a comparison visualization between x and y nothing more.
336
+
337
+ **B. Non-Negotiable Code Requirements:**
338
+ 1. **Imports:** You must import `matplotlib.pyplot as plt`, `seaborn as sns`, and `pandas as pd`.
339
+ 2. **Use Pandas:** ALWAYS parse the `facet_data` into a pandas DataFrame. This is more robust and flexible than iterating through dictionaries directly.
340
+ 3. **Figure and Axes:** Use `fig, ax = plt.subplots()` to create the figure and axes objects. This gives you better control.
341
+ 4. **Styling:** Apply a clean and professional style, for example: `plt.style.use('seaborn-v0_8-whitegrid')` and use a suitable Seaborn palette (e.g., `palette='viridis'`).
342
+ 5. **NO `plt.show()`:** Your code will be run on a server. **DO NOT** include `plt.show()`.
343
+ 6. **Save the Figure:** The execution environment expects a Matplotlib figure object named `fig`. Your code does not need to handle the saving path directly, but it **MUST** produce the final `fig` object correctly. The calling function will handle saving it.
344
+ 7. **Titles and Labels:** You MUST set a clear and descriptive title and labels for the x and y axes. The title should reflect the user's query.
345
+ 8. **Axis Label Readability:** If x-axis labels are long, you MUST rotate them to prevent overlap. Use this robust method: `plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")`.
346
+ 9. **Layout:** Use `plt.tight_layout()` at the end to ensure all elements fit nicely.
347
+ 10. **Error Handling:** Your code should be robust. If the `facet_data` contains no "buckets" or data to plot, the code should not crash. It should instead produce a plot with a message like "No data available to plot."
 
 
 
 
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  ---
350
+ ### **High-Quality Example (Grouped Bar Chart)**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
+ This example shows how to parse a nested facet structure into a DataFrame and create an insightful grouped bar chart. Adapt its principles to your specific task.
353
 
 
 
 
 
 
354
  ```python
355
+ # --- Imports and Style ---
356
  import matplotlib.pyplot as plt
357
  import seaborn as sns
358
  import pandas as pd
 
360
  plt.style.use('seaborn-v0_8-whitegrid')
361
  fig, ax = plt.subplots(figsize=(14, 8))
362
 
363
+ # --- Data Parsing ---
364
+ # Dynamically find the main facet key (the one with 'buckets')
365
  facet_key = None
366
  for key, value in facet_data.items():
367
  if isinstance(value, dict) and 'buckets' in value:
368
  facet_key = key
369
  break
370
 
371
+ plot_data = []
372
+ # Check if a valid key and buckets were found
373
  if facet_key and facet_data[facet_key].get('buckets'):
374
+ # This robustly parses nested metrics (e.g., a sum for each year)
 
375
  for bucket in facet_data[facet_key]['buckets']:
376
+ category = bucket.get('val', 'N/A')
377
+ # Find all nested metrics inside the bucket
378
  for sub_key, sub_value in bucket.items():
379
  if isinstance(sub_value, dict) and 'sum' in sub_value:
380
+ # Extracts '2025' from a key like 'total_value_2025'
381
+ group = sub_key.split('_')[-1]
382
  value = sub_value['sum']
383
+ plot_data.append({{'Category': category, 'Group': group, 'Value': value}})
384
+
385
+ # --- Plotting ---
386
+ if plot_data:
387
+ df = pd.DataFrame(plot_data)
388
+ sns.barplot(data=df, x='Category', y='Value', hue='Group', ax=ax, palette='viridis')
389
+
390
+ # --- Labels and Titles ---
391
+ ax.set_title('Comparison of Total Value by Category and Group')
392
+ ax.set_xlabel('Category')
393
+ ax.set_ylabel('Total Value')
394
+
395
+ # --- Formatting ---
396
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
397
  else:
398
+ # --- Handle No Data ---
399
+ ax.text(0.5, 0.5, 'No data available to plot.', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
400
+ ax.set_title('Data Visualization')
401
 
402
+ # --- Final Layout ---
403
  plt.tight_layout()
404
  ```
405
+
406
  ---
407
+ ### **Your Task:**
408
+
409
+ Now, generate the raw Python code to create the best possible visualization for the user's goal based on the provided data.
410
+ Do not wrap the code in ```python ... ```.
 
 
 
 
411
  """