Spaces:
Running
Running
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
class TextGenerator: | |
def __init__(self): | |
print("Initializing Text Generator...") | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {self.device}") | |
# Load model and tokenizer | |
self.model_name = "facebook/opt-350m" | |
print(f"Loading model {self.model_name}...") | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
).to(self.device) | |
print(f"Model loaded and moved to {self.device}") | |
def generate_text(self, prompt, max_length=200, temperature=0.7, top_p=0.9): | |
""" | |
Generate text based on the given prompt | |
Args: | |
prompt (str): The text generation prompt | |
max_length (int): Maximum length of the generated text | |
temperature (float): Controls randomness in generation | |
top_p (float): Controls diversity of generation | |
Returns: | |
str: Generated text | |
""" | |
try: | |
print(f"Generating text on {self.device}...") | |
# Format prompt for better generation | |
formatted_prompt = f"Instruction: {prompt}\n\nResponse:" | |
inputs = self.tokenizer( | |
formatted_prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512 | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_length=max_length + len(inputs["input_ids"][0]), | |
temperature=temperature, | |
top_p=top_p, | |
num_return_sequences=1, | |
pad_token_id=self.tokenizer.eos_token_id, | |
do_sample=True, | |
repetition_penalty=1.2, | |
no_repeat_ngram_size=3, | |
num_beams=5, | |
early_stopping=True | |
) | |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Remove the prompt from the generated text | |
generated_text = generated_text[len(formatted_prompt):] | |
# Format the text | |
formatted_text = self._format_text(generated_text) | |
return formatted_text | |
except Exception as e: | |
return f"Error generating text: {str(e)}" | |
def _format_text(self, text): | |
""" | |
Format the generated text for better readability | |
Args: | |
text (str): The text to format | |
Returns: | |
str: Formatted text | |
""" | |
# Split into paragraphs | |
paragraphs = text.split('\n\n') | |
# Format each paragraph | |
formatted_paragraphs = [] | |
for para in paragraphs: | |
if para.strip(): | |
# Capitalize first letter | |
para = para.strip() | |
if para: | |
para = para[0].upper() + para[1:] | |
# Add proper spacing | |
para = ' '.join(para.split()) | |
formatted_paragraphs.append(para) | |
# Join paragraphs with proper spacing | |
formatted_text = '\n\n'.join(formatted_paragraphs) | |
# Ensure proper punctuation | |
if formatted_text and formatted_text[-1] not in '.!?': | |
formatted_text += '.' | |
return formatted_text |