koey811 commited on
Commit
1f0f5c4
·
verified ·
1 Parent(s): 0f63834

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -15,7 +15,7 @@ except ImportError:
15
  import torch
16
 
17
  # Load the image captioning model
18
- caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
19
 
20
  # Load the text generation model
21
  text_generation_model = AutoModelForCausalLM.from_pretrained("gpt2")
@@ -28,8 +28,9 @@ def generate_caption(image):
28
 
29
  def generate_story(caption):
30
  # Generate the story based on the caption
31
- input_ids = tokenizer.encode(caption, return_tensors="pt")
32
- output = text_generation_model.generate(input_ids, max_length=100, num_return_sequences=1)
 
33
  story = tokenizer.decode(output[0], skip_special_tokens=True)
34
  return story
35
 
 
15
  import torch
16
 
17
  # Load the image captioning model
18
+ caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
19
 
20
  # Load the text generation model
21
  text_generation_model = AutoModelForCausalLM.from_pretrained("gpt2")
 
28
 
29
  def generate_story(caption):
30
  # Generate the story based on the caption
31
+ prompt = f"Imagine you are a storyteller for young children. Based on the image described as '{caption}', create a short and interesting story for children aged 3-10. Keep it positive and happy in tone."
32
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
33
+ output = text_generation_model.generate(input_ids, max_length=200, num_return_sequences=1)
34
  story = tokenizer.decode(output[0], skip_special_tokens=True)
35
  return story
36