File size: 32,670 Bytes
a23082c
 
b8f6b7f
a23082c
 
 
 
 
b8f6b7f
 
a23082c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8f6b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a23082c
 
 
 
 
 
 
114747f
a23082c
 
 
 
 
 
 
 
 
 
114747f
a23082c
 
 
 
b8f6b7f
a23082c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68bd1d5
 
a23082c
 
 
 
 
 
 
 
 
 
 
b8f6b7f
a23082c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
import os
import logging
from typing import List, Dict

import sympy as sp
import numpy as np
import scipy.linalg as la
import scipy.special as special
from llama_index.tools.code_interpreter import CodeInterpreterToolSpec
from scipy.integrate import quad
from scipy.stats import binom, norm, poisson
import numpy.fft as fft

from llama_index.core.agent.workflow import ReActAgent
from llama_index.core.tools import FunctionTool
from llama_index.llms.google_genai import GoogleGenAI
from llama_index.tools.wolfram_alpha import WolframAlphaToolSpec

# Setup logging
logger = logging.getLogger(__name__)

# --- Math Tool Functions (with enhanced logging and error handling) ---

# Helper decorator for error handling and logging
def math_tool_handler(func):
    def wrapper(*args, **kwargs):
        func_name = func.__name__
        logger.info(f"Executing math tool: {func_name} with args: {args}, kwargs: {kwargs}")
        try:
            result = func(*args, **kwargs)
            logger.info(f"Tool {func_name} executed successfully. Result: {str(result)[:200]}...")
            # Ensure result is serializable (convert numpy types if necessary)
            if isinstance(result, np.ndarray):
                return result.tolist()
            if isinstance(result, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64)):
                 return int(result)
            if isinstance(result, (np.float_, np.float16, np.float32, np.float64)):
                 return float(result)
            if isinstance(result, (np.complex_, np.complex64, np.complex128)):
                 return complex(result)
            if isinstance(result, np.bool_):
                 return bool(result)
            if isinstance(result, dict):
                 return {k: wrapper(v) if isinstance(v, (np.ndarray, np.number, np.bool_)) else v for k, v in result.items()} # Recursively handle dicts
            return result
        except (sp.SympifyError, TypeError, ValueError, np.linalg.LinAlgError, ZeroDivisionError) as e:
            logger.warning(f"Math error in {func_name}: {e}")
            return f"Error in {func_name}: {e}"
        except Exception as e:
            logger.error(f"Unexpected error in {func_name}: {e}", exc_info=True)
            return f"Unexpected error in {func_name}: {e}"
    return wrapper

# --- Symbolic math functions ---
@math_tool_handler
def solve_symbolic_equation(equation: str, variable: str = "x") -> str:
    """Solve a symbolic equation (e.g., 'x**2 - 4') for the given variable."""
    symbol = sp.symbols(variable)
    # Ensure equation is treated as expression == 0 if no equality sign
    if "=" not in equation:
        expr = sp.sympify(equation)
    else:
        lhs, rhs = equation.split("=", 1)
        expr = sp.Eq(sp.sympify(lhs.strip()), sp.sympify(rhs.strip()))
    solutions = sp.solve(expr, symbol)
    return f"Solutions: {solutions}"

@math_tool_handler
def compute_derivative(expression: str, variable: str = "x") -> str:
    """Compute the symbolic derivative of an expression (e.g., 'sin(x)*x**2')."""
    symbol = sp.symbols(variable)
    expr = sp.sympify(expression)
    deriv = sp.diff(expr, symbol)
    return f"Derivative: {deriv}"

@math_tool_handler
def compute_integral(expression: str, variable: str = "x") -> str:
    """Compute the symbolic indefinite integral of an expression (e.g., '1/x')."""
    symbol = sp.symbols(variable)
    expr = sp.sympify(expression)
    integ = sp.integrate(expr, symbol)
    return f"Integral: {integ} + C"

@math_tool_handler
def compute_limit(
    expression: str, variable: str = "x", point: str = "oo"
) -> str:
    """Compute the limit of an expression (e.g., 'sin(x)/x') as variable approaches point (e.g., '0', 'oo')."""
    symbol = sp.symbols(variable)
    expr = sp.sympify(expression)
    # Handle 'oo', '-oo', 'zoo' for infinity, or numerical points
    if point.lower() == "oo":
        pt = sp.oo
    elif point.lower() == "-oo":
        pt = -sp.oo
    elif point.lower() == "zoo":
        pt = sp.zoo # Complex infinity
    else:
        pt = sp.sympify(point)
    lim = sp.limit(expr, symbol, pt)
    return f"Limit at {point}: {lim}"

@math_tool_handler
def simplify_expression(expression: str) -> str:
    """Simplify a symbolic expression (e.g., 'sin(x)**2 + cos(x)**2')."""
    expr = sp.sympify(expression)
    simp = sp.simplify(expr)
    return f"Simplified expression: {simp}"

@math_tool_handler
def expand_expression(expression: str) -> str:
    """Expand a symbolic expression (e.g., '(x+y)**2')."""
    expr = sp.sympify(expression)
    exp = sp.expand(expr)
    return f"Expanded expression: {exp}"

@math_tool_handler
def factor_expression(expression: str) -> str:
    """Factor a symbolic expression (e.g., 'x**2 - y**2')."""
    expr = sp.sympify(expression)
    fact = sp.factor(expr)
    return f"Factored expression: {fact}"


# --- Matrix math functions ---
@math_tool_handler
def matrix_addition(a: List[List[float]], b: List[List[float]]) -> List[List[float]]:
    """Add two matrices element-wise. Input: [[1, 2], [3, 4]], [[5, 6], [7, 8]]."""
    A = np.array(a)
    B = np.array(b)
    if A.shape != B.shape:
        raise ValueError("Matrices must have the same shape for addition.")
    return (A + B)

@math_tool_handler
def matrix_subtraction(a: List[List[float]], b: List[List[float]]) -> List[List[float]]:
    """Subtract matrix B from matrix A element-wise. Input: [[5, 6], [7, 8]], [[1, 2], [3, 4]]."""
    A = np.array(a)
    B = np.array(b)
    if A.shape != B.shape:
        raise ValueError("Matrices must have the same shape for subtraction.")
    return (A - B)

@math_tool_handler
def matrix_multiplication(a: List[List[float]], b: List[List[float]]) -> List[List[float]]:
    """Multiply two matrices. Input: [[1, 2], [3, 4]], [[5, 6], [7, 8]]."""
    A = np.array(a)
    B = np.array(b)
    if A.shape[1] != B.shape[0]:
        raise ValueError("Inner dimensions must match for matrix multiplication.")
    return np.matmul(A, B)

@math_tool_handler
def matrix_inverse(matrix: List[List[float]]) -> List[List[float]]:
    """Compute the inverse of a square matrix. Input: [[1, 2], [3, 4]]."""
    M = np.array(matrix)
    if M.shape[0] != M.shape[1]:
        raise ValueError("Matrix must be square to compute inverse.")
    return np.linalg.inv(M)

@math_tool_handler
def matrix_determinant(matrix: List[List[float]]) -> float:
    """Compute the determinant of a square matrix. Input: [[1, 2], [3, 4]]."""
    M = np.array(matrix)
    if M.shape[0] != M.shape[1]:
        raise ValueError("Matrix must be square to compute determinant.")
    return np.linalg.det(M)

@math_tool_handler
def matrix_transpose(matrix: List[List[float]]) -> List[List[float]]:
    """Transpose a matrix. Input: [[1, 2, 3], [4, 5, 6]]."""
    M = np.array(matrix)
    return M.T

@math_tool_handler
def matrix_rank(matrix: List[List[float]]) -> int:
    """Compute the rank of a matrix. Input: [[1, 2], [2, 4]]."""
    M = np.array(matrix)
    return np.linalg.matrix_rank(M)

@math_tool_handler
def matrix_trace(matrix: List[List[float]]) -> float:
    """Compute the trace of a square matrix. Input: [[1, 2], [3, 4]]."""
    M = np.array(matrix)
    if M.shape[0] != M.shape[1]:
        raise ValueError("Matrix must be square to compute trace.")
    return np.trace(M)

@math_tool_handler
def matrix_norm(matrix: List[List[float]], ord_str: str = "fro") -> float:
    """Compute the norm of a matrix. ord_str can be 'fro' (Frobenius), 'nuc' (nuclear), inf, -inf, 1, -1, 2, -2. Input: [[1, 2], [3, 4]]."""
    M = np.array(matrix)
    ord_map = {"fro": "fro", "nuc": "nuc", "inf": np.inf, "-inf": -np.inf, "1": 1, "-1": -1, "2": 2, "-2": -2}
    ord_val = ord_map.get(ord_str)
    if ord_val is None:
        raise ValueError(f"Invalid ord_str: {ord_str}. Must be one of {list(ord_map.keys())}")
    return np.linalg.norm(M, ord=ord_val)

@math_tool_handler
def eigenvalues(matrix: List[List[float]]) -> List[complex]:
    """Compute eigenvalues of a square matrix. Input: [[1, -1], [1, 1]]."""
    M = np.array(matrix)
    if M.shape[0] != M.shape[1]:
        raise ValueError("Matrix must be square to compute eigenvalues.")
    vals = np.linalg.eigvals(M)
    return vals

@math_tool_handler
def eigenvectors(matrix: List[List[float]]) -> List[List[complex]]:
    """Compute eigenvectors of a square matrix. Returns list of eigenvectors. Input: [[1, -1], [1, 1]]."""
    M = np.array(matrix)
    if M.shape[0] != M.shape[1]:
        raise ValueError("Matrix must be square to compute eigenvectors.")
    vals, vecs = np.linalg.eig(M)
    # Return eigenvectors as rows or columns? Let's return as list of column vectors
    return vecs.T # Transpose to get eigenvectors as list items

@math_tool_handler
def svd_decompose(matrix: List[List[float]]) -> Dict[str, List]:
    """Compute the singular value decomposition (U, S, Vh) of a matrix. Input: [[1, 2], [3, 4], [5, 6]]."""
    M = np.array(matrix)
    U, S, Vh = np.linalg.svd(M)
    return {"U": U, "S": S, "Vh": Vh}

@math_tool_handler
def lu_decompose(matrix: List[List[float]]) -> Dict[str, List]:
    """Compute the LU decomposition (P, L, U) of a matrix. Input: [[1, 2], [3, 4]]."""
    M = np.array(matrix)
    P, L, U = la.lu(M)
    return {"P": P, "L": L, "U": U}

@math_tool_handler
def qr_decompose(matrix: List[List[float]]) -> Dict[str, List]:
    """Compute the QR decomposition (Q, R) of a matrix. Input: [[1, 2], [3, 4]]."""
    M = np.array(matrix)
    Q, R = np.linalg.qr(M)
    return {"Q": Q, "R": R}

# --- Statistics functions ---
@math_tool_handler
def mean(values: List[float]) -> float:
    """Compute the mean of a list of numbers. Input: [1, 2, 3, 4, 5]."""
    if not values:
        raise ValueError("Input list cannot be empty for mean calculation.")
    return np.mean(np.array(values))

@math_tool_handler
def median(values: List[float]) -> float:
    """Compute the median of a list of numbers. Input: [1, 3, 2, 4, 5]."""
    if not values:
        raise ValueError("Input list cannot be empty for median calculation.")
    return np.median(np.array(values))

@math_tool_handler
def std_dev(values: List[float], ddof: int = 1) -> float:
    """Compute the sample standard deviation (ddof=1) or population (ddof=0) of a list. Input: [1, 2, 3, 4, 5]."""
    if not values or len(values) < ddof:
         raise ValueError(f"Input list must have at least {ddof} elements for std dev with ddof={ddof}.")
    return np.std(np.array(values), ddof=ddof)

@math_tool_handler
def variance(values: List[float], ddof: int = 1) -> float:
    """Compute the sample variance (ddof=1) or population (ddof=0) of a list. Input: [1, 2, 3, 4, 5]."""
    if not values or len(values) < ddof:
         raise ValueError(f"Input list must have at least {ddof} elements for variance with ddof={ddof}.")
    return np.var(np.array(values), ddof=ddof)

@math_tool_handler
def percentile(values: List[float], percent: float) -> float:
    """Compute the q-th percentile (0<=q<=100) of a list. Input: [1, 2, 3, 4, 5], 75."""
    if not values:
        raise ValueError("Input list cannot be empty for percentile calculation.")
    if not (0 <= percent <= 100):
        raise ValueError("Percent must be between 0 and 100.")
    return np.percentile(np.array(values), percent)

@math_tool_handler
def covariance(x: List[float], y: List[float], ddof: int = 1) -> float:
    """Compute sample covariance (ddof=1) or population (ddof=0) between two lists. Input: [1, 2, 3], [4, 5, 6]."""
    X = np.array(x)
    Y = np.array(y)
    if X.size != Y.size:
        raise ValueError("Input lists must have the same length for covariance.")
    if X.size == 0 or X.size < ddof:
         raise ValueError(f"Input lists must have at least {ddof} elements for covariance with ddof={ddof}.")
    # np.cov returns the covariance matrix, we want the off-diagonal element
    return np.cov(X, Y, ddof=ddof)[0, 1]

@math_tool_handler
def correlation(x: List[float], y: List[float]) -> float:
    """Compute Pearson correlation coefficient between two lists. Input: [1, 2, 3], [1, 2, 3.1]."""
    X = np.array(x)
    Y = np.array(y)
    if X.size != Y.size:
        raise ValueError("Input lists must have the same length for correlation.")
    if X.size < 2:
        raise ValueError("Need at least 2 data points for correlation.")
    # np.corrcoef returns the correlation matrix
    corr_matrix = np.corrcoef(X, Y)
    # Handle case where std dev is zero (results in nan)
    if np.isnan(corr_matrix[0, 1]):
        logger.warning("Correlation resulted in NaN, likely due to zero standard deviation in one or both inputs.")
        # Return 0 or raise error? Let's return 0 for now.
        return 0.0
    return corr_matrix[0, 1]

@math_tool_handler
def linear_regression(x: List[float], y: List[float]) -> Dict[str, float]:
    """Perform simple linear regression (y = mx + c). Returns slope (m) and intercept (c). Input: [1, 2, 3], [2, 4.1, 5.9]."""
    X = np.array(x)
    Y = np.array(y)
    if X.size != Y.size:
        raise ValueError("Input lists must have the same length for linear regression.")
    if X.size < 2:
        raise ValueError("Need at least 2 data points for linear regression.")
    slope, intercept = np.polyfit(X, Y, 1)
    return {"slope": slope, "intercept": intercept}

# --- Numerical functions ---
@math_tool_handler
def find_polynomial_roots(coefficients: List[float]) -> List[complex]:
    """Find roots of a polynomial given coefficients [a_n, a_n-1, ..., a_0]. Input: [1, -3, 2] for x^2-3x+2."""
    if not coefficients:
        raise ValueError("Coefficient list cannot be empty.")
    return np.roots(coefficients)

@math_tool_handler
def interpolate_value(x_vals: List[float], y_vals: List[float], x: float) -> float:
    """Linear interpolate a value at x given data points (x_vals, y_vals). Input: [0, 1, 2], [0, 1, 4], 1.5."""
    if len(x_vals) != len(y_vals):
        raise ValueError("x_vals and y_vals must have the same length.")
    if len(x_vals) < 2:
        raise ValueError("Need at least 2 data points for interpolation.")
    # Ensure x_vals are sorted for np.interp
    sorted_indices = np.argsort(x_vals)
    x_sorted = np.array(x_vals)[sorted_indices]
    y_sorted = np.array(y_vals)[sorted_indices]
    return np.interp(x, x_sorted, y_sorted)

@math_tool_handler
def numerical_integration(
    func_str: str, a: float, b: float, variable: str = "x"
) -> float:
    """Numerically integrate func_str (e.g., 'x**2 * sin(x)') from a to b. Input: 'x**2', 0, 1."""
    symbol = sp.symbols(variable)
    # Security Note: Using sympify/lambdify can be risky if func_str is untrusted.
    # Consider using a safer evaluation method if input is external.
    try:
        func = sp.sympify(func_str)
        f_lambdified = sp.lambdify(symbol, func, modules=["numpy"])
    except (sp.SympifyError, SyntaxError) as sym_err:
        raise ValueError(f"Invalid function string: {func_str}. Error: {sym_err}")
    
    result, abserr = quad(f_lambdified, a, b)
    logger.info(f"Numerical integration estimated absolute error: {abserr}")
    return result

@math_tool_handler
def solve_ode(
    func_str: str, y0: float, t_eval: List[float], args: tuple = ()
) -> List[float]:
    """Solve a first-order ODE dy/dt = f(t, y) using scipy.integrate.solve_ivp.
       func_str should define f(t, y), e.g., '-y + sin(t)'.
       y0 is the initial condition y(t_eval[0]).
       t_eval is the list of time points to evaluate the solution at.
       args are optional additional arguments passed to f(t, y, *args).
       Input: func_str='-y', y0=1, t_eval=[0, 1, 2, 3, 4]."""
    from scipy.integrate import solve_ivp
    import math # Make math functions available
    
    # Security Note: Using eval is dangerous with untrusted input.
    # A safer approach would parse the expression or use a restricted environment.
    def ode_func(t, y, *args):
        try:
            # Provide t, y, args, and safe math functions in the eval context
            local_vars = {"t": t, "y": y, "math": math, "np": np}
            # Add args if provided
            if args:
                # Assuming args correspond to p1, p2, ... in the func_str
                for i, arg_val in enumerate(args):
                    local_vars[f"p{i+1}"] = arg_val 
            return eval(func_str, {"__builtins__": {}}, local_vars)
        except Exception as e:
            # Log the error and raise it to be caught by the handler
            logger.error(f"Error evaluating ODE function {func_str} at t={t}, y={y}: {e}")
            raise ValueError(f"Error in ODE function definition: {e}")

    if not t_eval:
        raise ValueError("t_eval list cannot be empty.")
    t_span = (min(t_eval), max(t_eval))
    
    sol = solve_ivp(ode_func, t_span, [y0], t_eval=t_eval, args=args)
    
    if not sol.success:
        raise RuntimeError(f"ODE solver failed: {sol.message}")
        
    return sol.y[0] # Return the solution for y

# --- Vector functions ---
@math_tool_handler
def dot_product(a: List[float], b: List[float]) -> float:
    """Compute dot product of two vectors. Input: [1, 2, 3], [4, 5, 6]."""
    A = np.array(a)
    B = np.array(b)
    if A.shape != B.shape:
        raise ValueError("Vectors must have the same dimension for dot product.")
    return np.dot(A, B)

@math_tool_handler
def cross_product(a: List[float], b: List[float]) -> List[float]:
    """Compute cross product of two 3D vectors. Input: [1, 0, 0], [0, 1, 0]."""
    A = np.array(a)
    B = np.array(b)
    if A.size != 3 or B.size != 3:
        raise ValueError("Cross product is only defined for 3D vectors.")
    return np.cross(A, B)

@math_tool_handler
def vector_magnitude(a: List[float]) -> float:
    """Compute magnitude (Euclidean norm) of a vector. Input: [3, 4]."""
    if not a:
        raise ValueError("Input vector cannot be empty.")
    return np.linalg.norm(np.array(a))

@math_tool_handler
def vector_normalize(a: List[float]) -> List[float]:
    """Normalize a vector to unit length. Input: [3, 4]."""
    A = np.array(a)
    norm = np.linalg.norm(A)
    if norm == 0:
        raise ValueError("Cannot normalize a zero vector.")
    return (A / norm)

@math_tool_handler
def vector_angle(a: List[float], b: List[float], degrees: bool = False) -> float:
    """Compute the angle (in radians or degrees) between two vectors. Input: [1, 0], [0, 1]."""
    dot = dot_product(a, b) # Use our handled dot_product
    norm_a = vector_magnitude(a)
    norm_b = vector_magnitude(b)
    if norm_a == 0 or norm_b == 0:
        raise ValueError("Cannot compute angle with zero vector(s).")
    # Clip argument to arccos to avoid domain errors due to floating point inaccuracies
    cos_theta = np.clip(dot / (norm_a * norm_b), -1.0, 1.0)
    angle_rad = np.arccos(cos_theta)
    return np.degrees(angle_rad) if degrees else angle_rad

# --- Probability functions ---
@math_tool_handler
def binomial_pmf(k: int, n: int, p: float) -> float:
    """Compute binomial probability mass function P(X=k | n, p). Input: k=2, n=5, p=0.5."""
    if not (0 <= p <= 1):
        raise ValueError("Probability p must be between 0 and 1.")
    if not (0 <= k <= n):
        raise ValueError("k must be between 0 and n (inclusive).")
    return binom.pmf(k, n, p)

@math_tool_handler
def normal_pdf(x: float, mu: float = 0, sigma: float = 1) -> float:
    """Compute normal distribution probability density function N(x | mu, sigma). Input: x=0, mu=0, sigma=1."""
    if sigma <= 0:
        raise ValueError("Standard deviation sigma must be positive.")
    return norm.pdf(x, mu, sigma)

@math_tool_handler
def normal_cdf(x: float, mu: float = 0, sigma: float = 1) -> float:
    """Compute normal distribution cumulative distribution function P(X<=x | mu, sigma). Input: x=0, mu=0, sigma=1."""
    if sigma <= 0:
        raise ValueError("Standard deviation sigma must be positive.")
    return norm.cdf(x, mu, sigma)

@math_tool_handler
def poisson_pmf(k: int, lam: float) -> float:
    """Compute Poisson probability mass function P(X=k | lambda). Input: k=2, lam=3."""
    if lam < 0:
        raise ValueError("Rate parameter lambda must be non-negative.")
    if k < 0 or not isinstance(k, int):
        raise ValueError("k must be a non-negative integer.")
    return poisson.pmf(k, lam)

# --- Special functions ---
@math_tool_handler
def gamma_function(x: float) -> float:
    """Compute the gamma function Gamma(x). Input: 5."""
    return special.gamma(x)

@math_tool_handler
def beta_function(x: float, y: float) -> float:
    """Compute the beta function B(x, y). Input: 2, 3."""
    return special.beta(x, y)

@math_tool_handler
def erf_function(x: float) -> float:
    """Compute the error function erf(x). Input: 1."""
    return special.erf(x)

# --- Fourier Transform functions ---
@math_tool_handler
def fft_transform(y: List[float]) -> List[complex]:
    """Compute the Fast Fourier Transform (FFT) of a real sequence y. Input: [0, 1, 0, -1]."""
    if not y:
        raise ValueError("Input list cannot be empty for FFT.")
    return fft.fft(np.array(y))

@math_tool_handler
def ifft_transform(y_complex: List[complex]) -> List[complex]:
    """Compute the inverse Fast Fourier Transform (IFFT) of a complex sequence. Input: result from fft_transform."""
    if not y_complex:
        raise ValueError("Input list cannot be empty for IFFT.")
    return fft.ifft(np.array(y_complex))

# --- Tool List Creation ---

def get_python_math_tools() -> List[FunctionTool]:
    """Returns a list of FunctionTools for the Python math functions."""
    py_tools = [
        # Symbolic
        FunctionTool.from_defaults(fn=solve_symbolic_equation),
        FunctionTool.from_defaults(fn=compute_derivative),
        FunctionTool.from_defaults(fn=compute_integral),
        FunctionTool.from_defaults(fn=compute_limit),
        FunctionTool.from_defaults(fn=simplify_expression),
        FunctionTool.from_defaults(fn=expand_expression),
        FunctionTool.from_defaults(fn=factor_expression),
        # Matrix
        FunctionTool.from_defaults(fn=matrix_addition),
        FunctionTool.from_defaults(fn=matrix_subtraction),
        FunctionTool.from_defaults(fn=matrix_multiplication),
        FunctionTool.from_defaults(fn=matrix_inverse),
        FunctionTool.from_defaults(fn=matrix_determinant),
        FunctionTool.from_defaults(fn=matrix_transpose),
        FunctionTool.from_defaults(fn=matrix_rank),
        FunctionTool.from_defaults(fn=matrix_trace),
        FunctionTool.from_defaults(fn=matrix_norm),
        FunctionTool.from_defaults(fn=eigenvalues),
        FunctionTool.from_defaults(fn=eigenvectors),
        FunctionTool.from_defaults(fn=svd_decompose),
        FunctionTool.from_defaults(fn=lu_decompose),
        FunctionTool.from_defaults(fn=qr_decompose),
        # Statistics
        FunctionTool.from_defaults(fn=mean),
        FunctionTool.from_defaults(fn=median),
        FunctionTool.from_defaults(fn=std_dev),
        FunctionTool.from_defaults(fn=variance),
        FunctionTool.from_defaults(fn=percentile),
        FunctionTool.from_defaults(fn=covariance),
        FunctionTool.from_defaults(fn=correlation),
        FunctionTool.from_defaults(fn=linear_regression),
        # Numerical
        FunctionTool.from_defaults(fn=find_polynomial_roots),
        FunctionTool.from_defaults(fn=interpolate_value),
        FunctionTool.from_defaults(fn=numerical_integration),
        FunctionTool.from_defaults(fn=solve_ode),
        # Vector
        FunctionTool.from_defaults(fn=dot_product),
        FunctionTool.from_defaults(fn=cross_product),
        FunctionTool.from_defaults(fn=vector_magnitude),
        FunctionTool.from_defaults(fn=vector_normalize),
        FunctionTool.from_defaults(fn=vector_angle),
        # Probability
        FunctionTool.from_defaults(fn=binomial_pmf),
        FunctionTool.from_defaults(fn=normal_pdf),
        FunctionTool.from_defaults(fn=normal_cdf),
        FunctionTool.from_defaults(fn=poisson_pmf),
        # Special Functions
        FunctionTool.from_defaults(fn=gamma_function),
        FunctionTool.from_defaults(fn=beta_function),
        FunctionTool.from_defaults(fn=erf_function),
        # Fourier
        FunctionTool.from_defaults(fn=fft_transform),
        FunctionTool.from_defaults(fn=ifft_transform),
    ]
    # Update descriptions for clarity if needed (optional)
    for tool in py_tools:
        tool.metadata.description = f"(Python) {tool.metadata.description}"
    logger.info(f"Created {len(py_tools)} Python math tools.")
    return py_tools

# --- Wolfram Alpha Tool --- 
_wolfram_alpha_tools = None

def get_wolfram_alpha_tools() -> List[FunctionTool]:
    """Initializes and returns Wolfram Alpha tools (singleton)."""
    global _wolfram_alpha_tools
    if _wolfram_alpha_tools is None:
        logger.info("Initializing WolframAlphaToolSpec...")
        wolfram_alpha_app_id = os.getenv("WOLFRAM_ALPHA_APP_ID")
        if not wolfram_alpha_app_id:
            logger.warning("WOLFRAM_ALPHA_APP_ID not set. Wolfram Alpha tools will be unavailable.")
            _wolfram_alpha_tools = []
        else:
            try:
                spec = WolframAlphaToolSpec(app_id=wolfram_alpha_app_id)
                _wolfram_alpha_tools = spec.to_tool_list()
                # Add prefix to description for clarity
                for tool in _wolfram_alpha_tools:
                     tool.metadata.description = f"(WolframAlpha) {tool.metadata.description}"
                logger.info(f"WolframAlpha tools initialized: {len(_wolfram_alpha_tools)} tools.")
            except Exception as e:
                logger.error(f"Failed to initialize WolframAlpha tools: {e}", exc_info=True)
                _wolfram_alpha_tools = []
    return _wolfram_alpha_tools


# Use LlamaIndex's built-in Code Interpreter Tool Spec for safe execution
# This assumes the necessary environment (e.g., docker) for the spec is available
try:
    code_interpreter_spec = CodeInterpreterToolSpec()
    # Get the tool(s) from the spec. It might return multiple tools.
    code_interpreter_tools = code_interpreter_spec.to_tool_list()
    if not code_interpreter_tools:
        raise RuntimeError("CodeInterpreterToolSpec did not return any tools.")
    # Assuming the primary tool is the first one, or find by name if necessary
    code_interpreter_tool = next((t for t in code_interpreter_tools if t.metadata.name == "code_interpreter"), None)
    if code_interpreter_tool is None:
         raise RuntimeError("Could not find 'code_interpreter' tool in CodeInterpreterToolSpec results.")
    logger.info("CodeInterpreterToolSpec initialized successfully.")
except Exception as e:
    logger.error(f"Failed to initialize CodeInterpreterToolSpec: {e}", exc_info=True)
    # Fallback: Define a dummy tool or raise error to prevent agent start?
    # For now, let initialization fail if the safe interpreter isn't available.
    raise RuntimeError("CodeInterpreterToolSpec failed to initialize. Cannot create code_agent.") from e

# --- Agent Initialization ---

def initialize_math_agent() -> ReActAgent:
    """Initializes the Math Agent with Python and Wolfram Alpha tools."""
    logger.info("Initializing MathAgent...")

    # Configuration
    agent_llm_model = os.getenv("MATH_AGENT_LLM_MODEL", "gemini-2.5-pro-preview-03-25")
    gemini_api_key = os.getenv("GEMINI_API_KEY")

    if not gemini_api_key:
        logger.error("GEMINI_API_KEY not found in environment variables for MathAgent.")
        raise ValueError("GEMINI_API_KEY must be set for MathAgent")

    try:
        llm = GoogleGenAI(
            api_key=gemini_api_key,
            model=agent_llm_model,
            temperature=0.05
        )
        logger.info(f"Using agent LLM: {agent_llm_model}")

        # Combine Python tools and Wolfram Alpha tools
        all_tools = get_python_math_tools() + get_wolfram_alpha_tools() + [code_interpreter_tool]
        if not all_tools:
             logger.warning("No math tools available (Python or WolframAlpha). MathAgent may be ineffective.")

        # System prompt (consider loading from file)
        system_prompt = """\
        You are MathAgent, a powerful mathematical problem solver. Your goal is to accurately answer mathematical questions using the available tools.
        
        Available Tools:
        - Python Tools: A comprehensive suite for symbolic math (SymPy), numerical computation (NumPy/SciPy), statistics, linear algebra, calculus, ODEs, and transforms. Prefixed with '(Python)'. Use these for precise calculations when the method is clear.
        - WolframAlpha Tool: Accesses Wolfram Alpha for complex queries, natural language math questions, data, and real-world facts. Prefixed with '(WolframAlpha)'. Use this for broader questions, knowledge-based math, or when Python tools are insufficient.
        
        Workflow:
        1. **Thought**: Analyze the question. Determine the mathematical concepts involved. Decide the best tool or sequence of tools to use. Prefer Python tools for specific, well-defined calculations. Use WolframAlpha for complex, ambiguous, or knowledge-based queries.
        2. **Action**: Call the chosen tool with the correct arguments. Ensure inputs match the tool's requirements (e.g., list of lists for matrices, strings for symbolic expressions).
        3. **Observation**: Examine the tool's output. Check for errors or unexpected results.
        4. **Iteration**: If the result is incorrect or incomplete, rethink the approach. Try a different tool, adjust parameters, or break the problem down further. If a Python tool fails, consider rephrasing for WolframAlpha.
        5. **Final Answer**: Once the correct answer is obtained, state it clearly and concisely. Provide the numerical result, symbolic expression, or explanation as requested.
        6. **Hand-Off**: Pass the final mathematical result or analysis to **planner_agent** for integration into the overall response.
        
        Constraints:
        - Always use a tool for calculations; do not perform calculations yourself.
        - Clearly state which tool you are using and why.
        - Handle potential errors gracefully and report them if they prevent finding a solution.
        - Pay close attention to input formats required by each tool (e.g., lists for vectors/matrices, strings for symbolic expressions).
        
        If your response exceeds the maximum token limit and cannot be completed in a single reply, please conclude your output with the marker [CONTINUE]. In subsequent interactions, I will prompt you with “continue” to receive the next portion of the response.
        """

        agent = ReActAgent(
            name="math_agent",
            description=(
                "MathAgent solves mathematical problems using a suite of Python tools (SymPy, NumPy, SciPy) and WolframAlpha. "
                "It handles symbolic math, numerical computation, statistics, linear algebra, calculus, and more."
            ),
            tools=all_tools,
            llm=llm,
            system_prompt=system_prompt,
            can_handoff_to=["planner_agent", "reasoning_agent"],
        )
        logger.info("MathAgent initialized successfully.")
        return agent

    except Exception as e:
        logger.error(f"Error during MathAgent initialization: {e}", exc_info=True)
        raise

# Example usage (for testing if run directly)
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    logger.info("Running math_agent.py directly for testing...")

    # Ensure API keys are set for testing
    required_keys = ["GEMINI_API_KEY"] # WOLFRAM_ALPHA_APP_ID is optional
    missing_keys = [key for key in required_keys if not os.getenv(key)]
    if missing_keys:
        print(f"Error: Required environment variable(s) not set: {', '.join(missing_keys)}. Cannot run test.")
    else:
        if not os.getenv("WOLFRAM_ALPHA_APP_ID"):
            print("Warning: WOLFRAM_ALPHA_APP_ID not set. WolframAlpha tools will be unavailable for testing.")
        try:
            test_agent = initialize_math_agent()
            print("Math Agent initialized successfully for testing.")
            # Example test
            # result = test_agent.chat("What is the integral of x**2 from 0 to 1?")
            # print(f"Test query result: {result}")
            # result2 = test_agent.chat("what is the population of france?") # Test WolframAlpha
            # print(f"Test query 2 result: {result2}")
        except Exception as e:
            print(f"Error during testing: {e}")