File size: 4,193 Bytes
1278b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e00a0cb
ccded5c
 
 
 
 
 
 
e00a0cb
ccded5c
 
 
 
 
 
e00a0cb
 
ccded5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e00a0cb
ccded5c
 
 
 
 
 
 
e00a0cb
1278b3f
ccded5c
 
 
 
 
 
e00a0cb
ccded5c
 
1278b3f
ccded5c
 
 
 
 
1278b3f
ccded5c
e00a0cb
 
1278b3f
ccded5c
e00a0cb
ccded5c
 
 
 
 
 
 
 
 
 
e00a0cb
ccded5c
 
e00a0cb
1278b3f
e00a0cb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
from mistralai import Mistral
from agent.agent_config import prompts
from agent.agent_config import tool_schema
from config import MISTRAL_API_KEY
from tools.code_index import retrieve_context
from tools.github_tools import fetch_github_issue, get_issue_details, post_comment

tools = tool_schema.tools
names_to_functions = {
    "fetch_github_issue": fetch_github_issue,
    "get_issue_details": get_issue_details,
    "retrieve_context": retrieve_context,
    "post_comment": post_comment,
}

allowed_tools = set(names_to_functions.keys())

system_message = prompts.system_message

api_key = MISTRAL_API_KEY
model = "devstral-small-latest"
client = Mistral(api_key=api_key)

async def run_agent(issue_url: str, branch_name: str = "main"):
    """
    Run the agent workflow on a given GitHub issue URL.
    """

    MAX_STEPS = 5
    tool_calls = 0
    issue_description_cache = None

    user_message = {
        "role": "user",
        "content": f"Please suggest a fix on this issue {issue_url} and use {branch_name} branch for retrieving code context."
    }
    messages = [system_message, user_message]

    yield "⚡️ OpenSorus agent started..."

    while True:
        response = client.chat.complete(
            model=model,
            messages=messages,
            tools=tools,
            tool_choice="any",
        )
        msg = response.choices[0].message
        messages.append(msg)

        if hasattr(msg, "tool_calls") and msg.tool_calls:
            for tool_call in msg.tool_calls:
                function_name = tool_call.function.name
                function_params = json.loads(tool_call.function.arguments)
                if function_name in allowed_tools:
                    yield f"🔧 Agent is calling tool: `{function_name}`"
                    function_result = names_to_functions[function_name](**function_params)
                    tool_calls += 1

                    if function_name == "get_issue_details" and isinstance(function_result, dict):
                        issue_title = function_result.get("title")
                        issue_body = function_result.get("body")
                        issue_description_cache = issue_title + "\n" + issue_body if issue_title or issue_body else None
                        yield "📝 Issue description cached."

                    if function_name == "retrieve_context":
                        if "issue_description" in function_params:
                            if (
                                issue_description_cache
                                and (function_params["issue_description"] != issue_description_cache)
                            ):
                                yield "⚠️ Overriding incorrect issue_description with correct one from cache."
                                function_params["issue_description"] = issue_description_cache
                                function_result = names_to_functions[function_name](**function_params)

                    messages.append({
                        "role": "tool",
                        "tool_call_id": tool_call.id,
                        "content": str(function_result)
                    })

                    if function_name == "post_comment":
                        yield "✅ Comment posted. Task complete."
                        return

                else:
                    yield f"Agent tried to call unknown tool: {function_name}"
                    tool_error_msg = (
                        f"Error: Tool '{function_name}' is not available. "
                        "You can only use the following tools: fetch_github_issue, get_issue_details, post_comment."
                    )
                    messages.append({
                        "role": "tool",
                        "tool_call_id": tool_call.id,
                        "content": tool_error_msg
                    })
            if tool_calls >= MAX_STEPS:
                yield f"Agent stopped after {MAX_STEPS} tool calls to protect against rate limiting."
                break
        else:
            yield f"OpenSorus (final): {msg.content}"
            break

    yield "Task Completed"