File size: 5,710 Bytes
e8c0a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
# import ctranslate2
# from transformers import AutoTokenizer
# from huggingface_hub import snapshot_download
from codeexecutor import get_majority_vote, type_check, postprocess_completion, draw_polynomial_plot
import re
import os

# Define the model and tokenizer loading
model_prompt = "Explain and solve the following mathematical problem step by step, showing all work: "
# tokenizer = AutoTokenizer.from_pretrained("AI-MO/NuminaMath-7B-TIR")
# model_path = snapshot_download(repo_id="Makima57/deepseek-math-Numina")
# generator = ctranslate2.Generator(model_path, device="cpu", compute_type="int8")
iterations = 4

# # Function to generate predictions using the model
# def get_prediction(question):
#     input_text = model_prompt + question
#     input_tokens = tokenizer.tokenize(input_text)
#     results = generator.generate_batch(
#         [input_tokens],
#         max_length=512,
#         sampling_temperature=0.7,
#         sampling_topk=40,
#     )
#     output_tokens = results[0].sequences[0]
#     predicted_answer = tokenizer.convert_tokens_to_string(output_tokens)
#     return predicted_answer

def get_prediction(question):
    return "Solve the following mathematical problem: what is  sum of polynomial 2x+3 and 3x?\n### Solution: To solve the problem of summing the polynomials \\(2x + 3\\) and \\(3x\\), we can follow these steps:\n\n1. Define the polynomials.\n2. Sum the polynomials.\n3. Simplify the resulting polynomial expression.\n\nLet's implement this in Python using the sympy library.\n\n```python\nimport sympy as sp\n\n# Define the variable\nx = sp.symbols('x')\n\n# Define the polynomials\npoly1 = 2*x + 3\npoly2 = 3*x\n\n# Sum the polynomials\nsum_poly = poly1 + poly2\n\n# Simplify the resulting polynomial\nsimplified_sum_poly = sp.simplify(sum_poly)\n\n# Print the simplified polynomial\nprint(simplified_sum_poly)\n```\n```output\n5*x + 3\n```\nThe sum of the polynomials \\(2x + 3\\) and \\(3x\\) is \\(\\boxed{5x + 3}\\).\n"

# Function to parse the prediction to extract the answer and steps
def parse_prediction(prediction):
    lines = prediction.strip().split('\n')
    answer = None
    steps = []
    for line in lines:
        # Check for "Answer:" or "answer:"
        match = re.match(r'^\s*(?:Answer|answer)\s*[:=]\s*(.*)', line)
        if match:
            answer = match.group(1).strip()
        else:
            steps.append(line)
    if answer is None:
        # If no "Answer:" found, assume last line is the answer
        answer = lines[-1].strip()
        steps = lines
    steps_text = '\n'.join(steps).strip()
    return answer, steps_text

def extract_boxed_answer(text):
    # Regular expression to find the content inside \\boxed{}
    match = re.search(r'\\boxed\{(.*?)\}', text)
    if match:
        return match.group(1)  # Return the content inside the \\boxed{}
    return None


# Function to perform majority voting and get steps
def majority_vote_with_steps(question, num_iterations=10):
    all_predictions = []
    all_answers = []
    steps_list = []

    for _ in range(num_iterations):
        prediction = get_prediction(question)
        answer, success = postprocess_completion(prediction, return_status=True, last_code_block=True)
        
        if success:
            all_predictions.append(prediction)
            all_answers.append(answer)
            steps_list.append(prediction)
           
            
        else:
            answer, steps = parse_prediction(prediction)
            all_predictions.append(prediction)
            all_answers.append(answer)
            steps_list.append(steps)
            
    if success:
            majority_voted_ans = get_majority_vote(all_answers)
            expression=majority_voted_ans
            print(type_check(expression))
            if type_check(expression) == "Polynomial":
                plotfile = draw_polynomial_plot(expression) 
    else:
        plotfile = None

            
        
         # Draw plot of polynomial

    # Find the steps corresponding to the majority voted answer
    for i, ans in enumerate(all_answers):
        if ans == majority_voted_ans:
            steps_solution = steps_list[i]
            answer = parse_prediction(steps_solution)
            break
    else:
        answer = majority_voted_ans
        steps_solution = "No steps found"

    return answer, steps_solution, plotfile

# Function to handle chat-like interaction
def chat_interface(history, question):
    # Get the answer and steps from the majority voting method
    final_answer, steps_solution, plotfile = majority_vote_with_steps(question, iterations)
    
    # Append the question and answer to the chat history
    history.append(("User", question))
    history.append(("MathBot", f"Answer: {final_answer}\nSteps:\n{steps_solution}"))
    
    return history, plotfile

# Gradio app setup with chat UI
interface = gr.Interface(
    fn=chat_interface,
    inputs=[
        gr.Chatbot(label="Chat with MathBot", elem_id="chat_history"),
        gr.Textbox(label="Your Question", placeholder="Ask a math question...", elem_id="math_question"),
    ],
    outputs=[
        gr.Chatbot(label="Chat History"),  # Chat-like display of conversation
        gr.Image(label="Polynomial Plot")
    ],
    title="🔢 Math Question Solver - Chat Mode",
    description="Chat with MathBot and ask any math-related question. It will explain the solution step by step and provide a majority-voted answer.",
    allow_flagging="auto",
    flagging_dir="./flagged_data",
)

if __name__ == "__main__":
    interface.launch()
    # history, plotfile=chat_interface(["hello"], ["what is the sum of 2x+3 and 3x"])
    # print(history, plotfile)