Spaces:
Sleeping
Sleeping
| import requests | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| class TopicGenerator: | |
| def __init__(self): | |
| # Initialize Model and Tokenizer | |
| self.topic_generator_processor = AutoTokenizer.from_pretrained("google/flan-t5-large") | |
| self.topic_generator_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large") | |
| self.topic_generator_model.eval() | |
| def generate_topics(self, user_input, num_topics=3): | |
| base_prompt = "Generate short, creative titles or topics based on the detailed information provided:" | |
| # Construct the prompt based on whether additional context is provided | |
| full_prompt = (f"{base_prompt}\n\n" | |
| f"Context: {user_input}\n\n" | |
| f"Task: Create {num_topics} inventive titles or topics (2-5 words each) that blend the essence of the image with the additional context. " | |
| f"These titles should be imaginative and suitable for use as hashtags, image titles, or starting points for discussions." | |
| f"IMPORTANT: Be imaginative and concise in your responses. Avoid repeating the same ideas in different words." | |
| f"Also make sure to provide a title/topic that relates to every context provided while following the examples listed below as a way of being creative and intuitive." | |
| ) | |
| # Provide creative examples to inspire the model | |
| examples = """ | |
| Creative examples to inspire your titles/topics: | |
| - "Misty Peaks at Dawn" | |
| - "Graffiti Lanes of Urbania" | |
| - "Chef’s Secret Ingredients" | |
| - "Neon Future Skylines" | |
| - "Puppy’s First Snow" | |
| - "Edge of Adventure" | |
| """ | |
| full_prompt += examples | |
| # Generate Topics | |
| input_text = self.topic_generator_processor(full_prompt, return_tensors="pt") | |
| outputs = self.topic_generator_model.generate(input_ids=input_text["input_ids"], max_length=20, num_return_sequences=num_topics, num_beams=5, no_repeat_ngram_size=5, top_k=50, top_p=0.95, temperature=0.9) | |
| topics = [self.topic_generator_processor.decode(output, skip_special_tokens=True) for output in outputs] | |
| return topics | |