Update agents/debugger.py
Browse files- agents/debugger.py +9 -15
agents/debugger.py
CHANGED
@@ -7,35 +7,29 @@ import torch
|
|
7 |
class DebuggerAgent(BaseAgent):
|
8 |
def __init__(self):
|
9 |
super().__init__(name="BugBot", role="Debugger")
|
10 |
-
self.tokenizer = AutoTokenizer.from_pretrained("
|
11 |
-
self.model = AutoModelForCausalLM.from_pretrained("
|
12 |
|
13 |
-
def
|
14 |
-
prompt = f"Review this Python
|
15 |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
|
16 |
outputs = self.model.generate(
|
17 |
inputs["input_ids"],
|
18 |
-
max_length=
|
19 |
do_sample=True,
|
20 |
-
temperature=0.
|
21 |
pad_token_id=self.tokenizer.eos_token_id
|
22 |
)
|
23 |
reply = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
24 |
return reply[len(prompt):].strip()
|
25 |
|
26 |
def receive_message(self, message: ACPMessage) -> ACPMessage:
|
27 |
-
if message.performative
|
28 |
-
|
29 |
return self.create_message(
|
30 |
receiver=message.sender,
|
31 |
performative="request",
|
32 |
-
content=
|
33 |
-
)
|
34 |
-
elif message.performative == "acknowledge":
|
35 |
-
return self.create_message(
|
36 |
-
receiver=message.sender,
|
37 |
-
performative="inform",
|
38 |
-
content="Got it. Let me know when you have more code."
|
39 |
)
|
40 |
else:
|
41 |
return self.create_message(
|
|
|
7 |
class DebuggerAgent(BaseAgent):
|
8 |
def __init__(self):
|
9 |
super().__init__(name="BugBot", role="Debugger")
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-multi")
|
11 |
+
self.model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-multi")
|
12 |
|
13 |
+
def generate_debug_comment(self, code: str) -> str:
|
14 |
+
prompt = f"# Review this Python code and ask a smart debugging or testing question:\n{code}\n# Question:"
|
15 |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
|
16 |
outputs = self.model.generate(
|
17 |
inputs["input_ids"],
|
18 |
+
max_length=128,
|
19 |
do_sample=True,
|
20 |
+
temperature=0.7,
|
21 |
pad_token_id=self.tokenizer.eos_token_id
|
22 |
)
|
23 |
reply = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
24 |
return reply[len(prompt):].strip()
|
25 |
|
26 |
def receive_message(self, message: ACPMessage) -> ACPMessage:
|
27 |
+
if message.performative in ["inform", "request"]:
|
28 |
+
comment = self.generate_debug_comment(message.content)
|
29 |
return self.create_message(
|
30 |
receiver=message.sender,
|
31 |
performative="request",
|
32 |
+
content=comment or "Can you improve error handling?"
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
)
|
34 |
else:
|
35 |
return self.create_message(
|