import torch from transformers import AutoModelForCausalLM, AutoTokenizer import re import textwrap class CodeGenerator: def __init__(self): print("Initializing Code Generator...") self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") # Load model and tokenizer self.model_name = "microsoft/CodeGPT-small-py-adaptedGPT2" print(f"Loading model {self.model_name}...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) print(f"Model loaded and moved to {self.device}") def generate_code(self, prompt, max_length=150, temperature=0.7, top_p=0.95): """ Generate code based on the given prompt Args: prompt (str): The prompt describing the code to generate max_length (int): Maximum length of the generated code temperature (float): Controls randomness in generation top_p (float): Controls diversity of generation Returns: str: Generated code """ try: print(f"Generating code on {self.device}...") # Format prompt for better code generation formatted_prompt = f"# Python\n# Task: {prompt}\n# Solution:\n" inputs = self.tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=512 ).to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=max_length + len(inputs["input_ids"][0]), temperature=temperature, top_p=top_p, num_return_sequences=1, pad_token_id=self.tokenizer.eos_token_id, do_sample=True, repetition_penalty=1.1, no_repeat_ngram_size=3 ) generated_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove the prompt from the generated code generated_code = generated_code[len(formatted_prompt):] # Format the code formatted_code = self._format_code(generated_code) return formatted_code except Exception as e: return f"Error generating code: {str(e)}" def _format_code(self, code): """ Format the generated code for better readability Args: code (str): The code to format Returns: str: Formatted code """ # Remove any trailing whitespace code = code.strip() # Split into lines and remove duplicates lines = code.split('\n') unique_lines = [] seen_lines = set() for line in lines: stripped_line = line.strip() if stripped_line and stripped_line not in seen_lines: seen_lines.add(stripped_line) unique_lines.append(line) # Fix common indentation issues formatted_lines = [] # Track indentation level indent_level = 0 for line in unique_lines: # Skip empty lines if not line.strip(): formatted_lines.append('') continue # Calculate current indentation current_indent = len(line) - len(line.lstrip()) # Handle indentation changes if line.strip().endswith(':'): # Increase indent after colons indent_level = current_indent + 4 elif current_indent > indent_level: # Decrease indent if too deep indent_level = max(0, indent_level - 4) # Apply proper indentation formatted_line = ' ' * indent_level + line.lstrip() formatted_lines.append(formatted_line) # Join lines with proper spacing formatted_code = '\n'.join(formatted_lines) # Add docstrings if missing if 'def ' in formatted_code and '"""' not in formatted_code: formatted_code = self._add_docstrings(formatted_code) # Ensure proper spacing between functions/classes formatted_code = re.sub(r'\n{3,}', '\n\n', formatted_code) # Remove any duplicate code blocks formatted_code = self._remove_duplicate_blocks(formatted_code) return formatted_code def _remove_duplicate_blocks(self, code): """ Remove duplicate code blocks Args: code (str): The code to clean Returns: str: Code with duplicates removed """ # Split into blocks (functions/classes) blocks = re.split(r'(?=\n\s*(?:def|class)\s)', code) unique_blocks = [] seen_blocks = set() for block in blocks: # Normalize block by removing whitespace normalized = re.sub(r'\s+', ' ', block.strip()) if normalized and normalized not in seen_blocks: seen_blocks.add(normalized) unique_blocks.append(block) return ''.join(unique_blocks).strip() def _add_docstrings(self, code): """ Add docstrings to functions if missing Args: code (str): The code to add docstrings to Returns: str: Code with docstrings """ lines = code.split('\n') formatted_lines = [] i = 0 while i < len(lines): line = lines[i] formatted_lines.append(line) # Check for function definition if line.strip().startswith('def '): # Add docstring if next line doesn't have one if i + 1 < len(lines) and '"""' not in lines[i + 1]: indent = len(line) - len(line.lstrip()) docstring = f'{indent * " "} """\n{indent * " "} Docstring\n{indent * " "} """' formatted_lines.append(docstring) i += 1 return '\n'.join(formatted_lines)