yukimama commited on
Commit
9fc2a76
·
verified ·
1 Parent(s): 3b2ce0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -4,13 +4,26 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
  model = GPT2LMHeadModel.from_pretrained("gpt2")
5
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
6
 
 
 
 
7
  def generate_code(prompt, max_length=200):
8
  full_prompt = f"Generate Python code for {prompt}:```python\n"
9
- input_ids = tokenizer.encode(full_prompt, return_tensors="pt")
10
- output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, temperature=0.7)
 
 
 
 
 
 
 
 
 
 
11
  generated_text = tokenizer.decode(output[0], skip_special_tokens=False)
12
 
13
-
14
  start = generated_text.find("```python") + len("```python")
15
  end = generated_text.find("```", start)
16
  if end == -1:
@@ -19,7 +32,7 @@ def generate_code(prompt, max_length=200):
19
 
20
  return code
21
 
22
-
23
  prompt = "a function to calculate the factorial of a number"
24
  malicious_code = generate_code(prompt)
25
  print(malicious_code)
 
4
  model = GPT2LMHeadModel.from_pretrained("gpt2")
5
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
6
 
7
+
8
+ tokenizer.pad_token = tokenizer.eos_token
9
+
10
  def generate_code(prompt, max_length=200):
11
  full_prompt = f"Generate Python code for {prompt}:```python\n"
12
+ inputs = tokenizer.encode(full_prompt, return_tensors="pt")
13
+
14
+ attention_mask = inputs.ne(tokenizer.pad_token_id).long()
15
+
16
+ output = model.generate(
17
+ inputs,
18
+ max_length=max_length,
19
+ num_return_sequences=1,
20
+ temperature=0.7,
21
+ attention_mask=attention_mask,
22
+ pad_token_id=tokenizer.pad_token_id
23
+ )
24
  generated_text = tokenizer.decode(output[0], skip_special_tokens=False)
25
 
26
+
27
  start = generated_text.find("```python") + len("```python")
28
  end = generated_text.find("```", start)
29
  if end == -1:
 
32
 
33
  return code
34
 
35
+ # Example usage
36
  prompt = "a function to calculate the factorial of a number"
37
  malicious_code = generate_code(prompt)
38
  print(malicious_code)