Enhance explanation via prompt engineering

#5
Files changed (1) hide show
  1. app.py +64 -42
app.py CHANGED
@@ -8,50 +8,51 @@ import re
8
  import pandas as pd # type: ignore
9
  from dotenv import load_dotenv # type: ignore # Para cambios locales
10
  from supabase import create_client, Client # type: ignore
11
- from pandasai import Agent
12
 
13
  # from pandasai import SmartDataframe # type: ignore
14
- from pandasai.llm.local_llm import LocalLLM
 
15
  from pandasai import Agent
16
  import matplotlib.pyplot as plt
 
17
 
18
  # ---------------------------------------------------------------------------------
19
  # Funciones auxiliares
20
  # ---------------------------------------------------------------------------------
21
 
22
 
23
- # Ejemplo de prompt generado:
24
- # generate_graph_prompt("Germany", "France", "fertility rate", 2020, 2030)
25
  def generate_graph_prompt(user_query):
26
  prompt = f"""
27
- You are a highly skilled data scientist working with European demographic data.
28
 
29
- Given the user's request: "{user_query}"
30
 
31
- 1. Plot the relevant data according to the user's request.
32
- 2. After generating the plot, write a clear, human-readable explanation of the plot (no code).
33
- 3. Save the explanation in a variable called "explanation".
 
 
34
 
35
- VERY IMPORTANT:
36
- - Declare a result variable as a dictionary that includes:
37
- - type = "plot"
38
- - value = the path to the saved plot
39
- - explanation = the explanation text you wrote
40
 
41
- Example of expected result dictionary:
42
- result = {{
43
- "type": "plot",
44
- "value": "temp_chart.png",
45
- "explanation": explanation
46
- }}
47
 
48
- Only respond with valid Python code.
 
 
 
 
 
49
 
50
- IMPORTANT: Stick strictly to using the data available in the database.
51
- """
52
  return prompt
53
 
54
- # TODO: Mejorar prompt
55
 
56
  # ---------------------------------------------------------------------------------
57
  # Configuración de conexión a Supabase
@@ -101,20 +102,18 @@ def load_data(table):
101
  # Cargar datos iniciales
102
  # ---------------------------------------------------------------------------------
103
 
104
- # # Cargar datos desde la tabla "labor"
105
- data = load_data("labor")
106
-
107
  # TODO: La idea es luego usar todas las tablas, cuando ya funcione.
108
- # Se puede si el modelo funciona con las gráficas, sino que toca mejorarlo
109
- # porque serían consultas más complejas.
110
- # labor_data = load_data("labor")
111
- # fertility_data = load_data("fertility")
112
  # population_data = load_data("population")
113
- # predictions_data = load_data("predictions")
114
 
 
115
 
116
  # ---------------------------------------------------------------------------------
117
- # Inicializar modelo
118
  # ---------------------------------------------------------------------------------
119
 
120
  # ollama_llm = LocalLLM(api_base="http://localhost:11434/v1",
@@ -124,43 +123,66 @@ data = load_data("labor")
124
 
125
  lm_studio_llm = LocalLLM(api_base="http://localhost:1234/v1") # el modelo es gemma-3-12b-it-qat
126
 
127
- agent = Agent([labor_data], config={"llm": lm_studio_llm}) # Inicializar agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  # ---------------------------------------------------------------------------------
130
  # Configuración de la app en Streamlit
131
  # ---------------------------------------------------------------------------------
132
 
133
  # Título de la app
134
- st.title("_Europe GraphGen_ :blue[Graph generator] :flag-eu:")
135
 
136
  # TODO: Poner instrucciones al usuario sobre cómo hacer un muy buen prompt (sin tecnisismos, pensando en el usuario final)
137
 
138
-
139
  # Entrada de usuario para describir el gráfico
140
  user_input = st.text_input("What graphics do you have in mind")
141
  generate_button = st.button("Generate")
142
 
143
- # Procesar el input del usuario con PandasAI
144
  if generate_button and user_input:
145
  with st.spinner('Generating answer...'):
146
  try:
 
147
  prompt = generate_graph_prompt(user_input)
 
 
 
 
148
  answer = agent.chat(prompt)
149
- explanation = agent.explain()
150
  print(f"\nAnswer type: {type(answer)}\n") # Verificar tipo de objeto
151
  print(f"\nAnswer content: {answer}\n") # Inspeccionar contenido de la respuesta
152
- print(f"\n explanation type: {type(explanation)}\n") # Verificar tipo de objeto
153
- print(f"\n explanation content: {explanation}\n")
 
 
 
 
 
154
 
155
  if isinstance(answer, str) and os.path.isfile(answer):
156
  # Si el output es una ruta válida a imagen
157
  im = plt.imread(answer)
158
  st.image(im)
159
  os.remove(answer) # Limpiar archivo temporal
160
- st.markdown(str(explanation))
 
 
161
  else:
162
  # Si no es una ruta válida, mostrar como texto
163
- st.markdown(str(answer))
164
 
165
  except Exception as e:
166
  st.error(f"Error generating answer: {e}")
 
8
  import pandas as pd # type: ignore
9
  from dotenv import load_dotenv # type: ignore # Para cambios locales
10
  from supabase import create_client, Client # type: ignore
 
11
 
12
  # from pandasai import SmartDataframe # type: ignore
13
+ from pandasai import SmartDatalake # type: ignore # Porque ya usamos más de un df (más de una tabla de nuestra db)
14
+ from pandasai.llm.local_llm import LocalLLM # type: ignore
15
  from pandasai import Agent
16
  import matplotlib.pyplot as plt
17
+ import time
18
 
19
  # ---------------------------------------------------------------------------------
20
  # Funciones auxiliares
21
  # ---------------------------------------------------------------------------------
22
 
23
 
 
 
24
  def generate_graph_prompt(user_query):
25
  prompt = f"""
26
+ You are a senior data scientist analyzing European labor force data.
27
 
28
+ Given the user's request: "{user_query}"
29
 
30
+ 1. Plot the relevant data using matplotlib:
31
+ - Use `df.query("geo == 'X'")` to filter the country, instead of chained comparisons.
32
+ - Avoid using filters like `df[df['geo'] == 'Germany']`.
33
+ - Include clear axis labels and a descriptive title.
34
+ - Save the plot as an image file (e.g., temp_chart.png).
35
 
36
+ 2. After plotting, write a **concise analytical summary** of the trend based on those 5 years. The summary should:
37
+ - Identify the **year with the largest increase** and the percent change.
38
+ - Identify the **year with the largest decrease** and the percent change.
39
+ - Provide a **brief overall trend interpretation** (e.g., steady growth, fluctuating, recovery, etc.).
40
+ - Avoid listing every year individually, summarize intelligently.
41
 
42
+ 3. Store the summary in a variable named `explanation`.
 
 
 
 
 
43
 
44
+ 4. Return a result dictionary structured as follows:
45
+ result = {{
46
+ "type": "plot",
47
+ "value": "temp_chart.png",
48
+ "explanation": explanation
49
+ }}
50
 
51
+ IMPORTANT: Use only the data available in the input DataFrame.
52
+ """
53
  return prompt
54
 
55
+ #TODO: Continuar mejorando el prompt
56
 
57
  # ---------------------------------------------------------------------------------
58
  # Configuración de conexión a Supabase
 
102
  # Cargar datos iniciales
103
  # ---------------------------------------------------------------------------------
104
 
 
 
 
105
  # TODO: La idea es luego usar todas las tablas, cuando ya funcione.
106
+ # Se puede si el modelo funciona con las gráficas, sino que toca mejorarlo porque serían consultas más complejas.
107
+
108
+ labor_data = load_data("labor")
109
+ fertility_data = load_data("fertility")
110
  # population_data = load_data("population")
111
+ # predictions_data = load_data("predictions")
112
 
113
+ # TODO: Buscar la forma de disminuir la latencia (muchos datos = mucha latencia)
114
 
115
  # ---------------------------------------------------------------------------------
116
+ # Inicializar LLM desde Ollama con PandasAI
117
  # ---------------------------------------------------------------------------------
118
 
119
  # ollama_llm = LocalLLM(api_base="http://localhost:11434/v1",
 
123
 
124
  lm_studio_llm = LocalLLM(api_base="http://localhost:1234/v1") # el modelo es gemma-3-12b-it-qat
125
 
126
+ # sdl = SmartDatalake([labor_data, fertility_data, population_data, predictions_data], config={"llm": ollama_llm}) # DataFrame PandasAI-ready.
127
+ # sdl = SmartDatalake([labor_data, fertility_data], config={"llm": ollama_llm})
128
+
129
+ # agent = Agent([labor_data], config={"llm": lm_studio_llm}) # TODO: Probar Agent con multiples dfs
130
+ agent = Agent(
131
+ [
132
+ labor_data,
133
+ fertility_data
134
+ ],
135
+ config={
136
+ "llm": lm_studio_llm,
137
+ "enable_cache": False,
138
+ "enable_filter_extraction": False # evita errores de parseo
139
+ }
140
+ )
141
 
142
  # ---------------------------------------------------------------------------------
143
  # Configuración de la app en Streamlit
144
  # ---------------------------------------------------------------------------------
145
 
146
  # Título de la app
147
+ st.title("Europe GraphGen :blue[Graph generator] :flag-eu:")
148
 
149
  # TODO: Poner instrucciones al usuario sobre cómo hacer un muy buen prompt (sin tecnisismos, pensando en el usuario final)
150
 
 
151
  # Entrada de usuario para describir el gráfico
152
  user_input = st.text_input("What graphics do you have in mind")
153
  generate_button = st.button("Generate")
154
 
 
155
  if generate_button and user_input:
156
  with st.spinner('Generating answer...'):
157
  try:
158
+ print(f"\nGenerating prompt...\n")
159
  prompt = generate_graph_prompt(user_input)
160
+ print(f"\nPrompt generated\n")
161
+
162
+ start_time = time.time()
163
+
164
  answer = agent.chat(prompt)
 
165
  print(f"\nAnswer type: {type(answer)}\n") # Verificar tipo de objeto
166
  print(f"\nAnswer content: {answer}\n") # Inspeccionar contenido de la respuesta
167
+ print(f"\nFull result: {agent.last_result}\n")
168
+
169
+ full_result = agent.last_result
170
+ explanation = full_result.get("explanation", "")
171
+
172
+ elapsed_time = time.time() - start_time
173
+ print(f"\nExecution time: {elapsed_time:.2f} seconds\n")
174
 
175
  if isinstance(answer, str) and os.path.isfile(answer):
176
  # Si el output es una ruta válida a imagen
177
  im = plt.imread(answer)
178
  st.image(im)
179
  os.remove(answer) # Limpiar archivo temporal
180
+
181
+ if explanation:
182
+ st.markdown(f"**Explanation:** {explanation}")
183
  else:
184
  # Si no es una ruta válida, mostrar como texto
185
+ st.markdown(str(answer))
186
 
187
  except Exception as e:
188
  st.error(f"Error generating answer: {e}")