import torch from transformers import AutoModelForCausalLM, AutoTokenizer class TextGenerator: def __init__(self): print("Initializing Text Generator...") self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") # Load model and tokenizer self.model_name = "facebook/opt-350m" print(f"Loading model {self.model_name}...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) print(f"Model loaded and moved to {self.device}") def generate_text(self, prompt, max_length=200, temperature=0.7, top_p=0.9): """ Generate text based on the given prompt Args: prompt (str): The text generation prompt max_length (int): Maximum length of the generated text temperature (float): Controls randomness in generation top_p (float): Controls diversity of generation Returns: str: Generated text """ try: print(f"Generating text on {self.device}...") # Format prompt for better generation formatted_prompt = f"Instruction: {prompt}\n\nResponse:" inputs = self.tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=512 ).to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=max_length + len(inputs["input_ids"][0]), temperature=temperature, top_p=top_p, num_return_sequences=1, pad_token_id=self.tokenizer.eos_token_id, do_sample=True, repetition_penalty=1.2, no_repeat_ngram_size=3, num_beams=5, early_stopping=True ) generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove the prompt from the generated text generated_text = generated_text[len(formatted_prompt):] # Format the text formatted_text = self._format_text(generated_text) return formatted_text except Exception as e: return f"Error generating text: {str(e)}" def _format_text(self, text): """ Format the generated text for better readability Args: text (str): The text to format Returns: str: Formatted text """ # Split into paragraphs paragraphs = text.split('\n\n') # Format each paragraph formatted_paragraphs = [] for para in paragraphs: if para.strip(): # Capitalize first letter para = para.strip() if para: para = para[0].upper() + para[1:] # Add proper spacing para = ' '.join(para.split()) formatted_paragraphs.append(para) # Join paragraphs with proper spacing formatted_text = '\n\n'.join(formatted_paragraphs) # Ensure proper punctuation if formatted_text and formatted_text[-1] not in '.!?': formatted_text += '.' return formatted_text