aymnsk commited on
Commit
4a4659c
·
verified ·
1 Parent(s): 9fd2f28

Update agents/programmer.py

Browse files
Files changed (1) hide show
  1. agents/programmer.py +4 -6
agents/programmer.py CHANGED
@@ -2,26 +2,24 @@
2
 
3
  from agents.base_agent import BaseAgent, ACPMessage
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import torch
6
 
7
  class ProgrammerAgent(BaseAgent):
8
  def __init__(self):
9
  super().__init__(name="CodeBot", role="Programmer")
10
- self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
11
- self.model = AutoModelForCausalLM.from_pretrained("distilgpt2")
12
 
13
  def generate_code_reply(self, prompt: str) -> str:
14
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
15
  outputs = self.model.generate(
16
  inputs["input_ids"],
17
- max_length=100,
18
  do_sample=True,
19
- temperature=0.7,
20
  pad_token_id=self.tokenizer.eos_token_id
21
  )
22
  reply = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
23
  return reply[len(prompt):].strip()
24
-
25
  def receive_message(self, message: ACPMessage) -> ACPMessage:
26
  if message.performative == "request":
27
  prompt = f"Write Python code to: {message.content.strip()}\n\n"
 
2
 
3
  from agents.base_agent import BaseAgent, ACPMessage
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
 
6
  class ProgrammerAgent(BaseAgent):
7
  def __init__(self):
8
  super().__init__(name="CodeBot", role="Programmer")
9
+ self.tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-multi")
10
+ self.model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-multi")
11
 
12
  def generate_code_reply(self, prompt: str) -> str:
13
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
14
  outputs = self.model.generate(
15
  inputs["input_ids"],
16
+ max_length=128,
17
  do_sample=True,
18
+ temperature=0.6,
19
  pad_token_id=self.tokenizer.eos_token_id
20
  )
21
  reply = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
22
  return reply[len(prompt):].strip()
 
23
  def receive_message(self, message: ACPMessage) -> ACPMessage:
24
  if message.performative == "request":
25
  prompt = f"Write Python code to: {message.content.strip()}\n\n"