Spaces:
Running
Running
File size: 6,682 Bytes
a26e606 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import textwrap
class CodeGenerator:
def __init__(self):
print("Initializing Code Generator...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Load model and tokenizer
self.model_name = "microsoft/CodeGPT-small-py-adaptedGPT2"
print(f"Loading model {self.model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
print(f"Model loaded and moved to {self.device}")
def generate_code(self, prompt, max_length=150, temperature=0.7, top_p=0.95):
"""
Generate code based on the given prompt
Args:
prompt (str): The prompt describing the code to generate
max_length (int): Maximum length of the generated code
temperature (float): Controls randomness in generation
top_p (float): Controls diversity of generation
Returns:
str: Generated code
"""
try:
print(f"Generating code on {self.device}...")
# Format prompt for better code generation
formatted_prompt = f"# Python\n# Task: {prompt}\n# Solution:\n"
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length + len(inputs["input_ids"][0]),
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=True,
repetition_penalty=1.1,
no_repeat_ngram_size=3
)
generated_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from the generated code
generated_code = generated_code[len(formatted_prompt):]
# Format the code
formatted_code = self._format_code(generated_code)
return formatted_code
except Exception as e:
return f"Error generating code: {str(e)}"
def _format_code(self, code):
"""
Format the generated code for better readability
Args:
code (str): The code to format
Returns:
str: Formatted code
"""
# Remove any trailing whitespace
code = code.strip()
# Split into lines and remove duplicates
lines = code.split('\n')
unique_lines = []
seen_lines = set()
for line in lines:
stripped_line = line.strip()
if stripped_line and stripped_line not in seen_lines:
seen_lines.add(stripped_line)
unique_lines.append(line)
# Fix common indentation issues
formatted_lines = []
# Track indentation level
indent_level = 0
for line in unique_lines:
# Skip empty lines
if not line.strip():
formatted_lines.append('')
continue
# Calculate current indentation
current_indent = len(line) - len(line.lstrip())
# Handle indentation changes
if line.strip().endswith(':'):
# Increase indent after colons
indent_level = current_indent + 4
elif current_indent > indent_level:
# Decrease indent if too deep
indent_level = max(0, indent_level - 4)
# Apply proper indentation
formatted_line = ' ' * indent_level + line.lstrip()
formatted_lines.append(formatted_line)
# Join lines with proper spacing
formatted_code = '\n'.join(formatted_lines)
# Add docstrings if missing
if 'def ' in formatted_code and '"""' not in formatted_code:
formatted_code = self._add_docstrings(formatted_code)
# Ensure proper spacing between functions/classes
formatted_code = re.sub(r'\n{3,}', '\n\n', formatted_code)
# Remove any duplicate code blocks
formatted_code = self._remove_duplicate_blocks(formatted_code)
return formatted_code
def _remove_duplicate_blocks(self, code):
"""
Remove duplicate code blocks
Args:
code (str): The code to clean
Returns:
str: Code with duplicates removed
"""
# Split into blocks (functions/classes)
blocks = re.split(r'(?=\n\s*(?:def|class)\s)', code)
unique_blocks = []
seen_blocks = set()
for block in blocks:
# Normalize block by removing whitespace
normalized = re.sub(r'\s+', ' ', block.strip())
if normalized and normalized not in seen_blocks:
seen_blocks.add(normalized)
unique_blocks.append(block)
return ''.join(unique_blocks).strip()
def _add_docstrings(self, code):
"""
Add docstrings to functions if missing
Args:
code (str): The code to add docstrings to
Returns:
str: Code with docstrings
"""
lines = code.split('\n')
formatted_lines = []
i = 0
while i < len(lines):
line = lines[i]
formatted_lines.append(line)
# Check for function definition
if line.strip().startswith('def '):
# Add docstring if next line doesn't have one
if i + 1 < len(lines) and '"""' not in lines[i + 1]:
indent = len(line) - len(line.lstrip())
docstring = f'{indent * " "} """\n{indent * " "} Docstring\n{indent * " "} """'
formatted_lines.append(docstring)
i += 1
return '\n'.join(formatted_lines) |