|
import gradio as gr |
|
from google import genai |
|
from google.genai import types |
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
import os |
|
import json |
|
|
|
|
|
try: |
|
api_key = os.environ['GEMINI_API_KEY'] |
|
except KeyError: |
|
raise ValueError("Please set the GEMINI_API_KEY environment variable.") |
|
client = genai.Client(api_key=api_key) |
|
|
|
def generate_item(tag): |
|
""" |
|
Generate a single feed item consisting of text from Gemini LLM and an image from Imagen. |
|
|
|
Args: |
|
tag (str): The tag to base the content on. |
|
|
|
Returns: |
|
dict: A dictionary with 'text' (str) and 'image_base64' (str). |
|
""" |
|
|
|
prompt = f""" |
|
Generate a short, engaging TikTok-style caption about {tag}. |
|
Return the response as a JSON object with a single key 'caption' containing the caption text. |
|
Example: {{"caption": "Craving this yummy treat! π #foodie"}} |
|
Do not include additional commentary or options. |
|
""" |
|
text_response = client.models.generate_content( |
|
model='gemini-2.5-flash-preview-04-17', |
|
contents=[prompt] |
|
) |
|
|
|
try: |
|
response_json = json.loads(text_response.text.strip()) |
|
text = response_json['caption'] |
|
except (json.JSONDecodeError, KeyError): |
|
text = f"Wow, {tag} is amazing! π #{tag}" |
|
|
|
|
|
image_prompt = f""" |
|
A vivid, high-quality visual scene representing {tag}, designed for a TikTok video. |
|
The image should be colorful and engaging, with no text or letters included. |
|
""" |
|
image_response = client.models.generate_images( |
|
model='imagen-3.0-generate-002', |
|
prompt=image_prompt, |
|
config=types.GenerateImagesConfig( |
|
number_of_images=1, |
|
aspect_ratio="9:16", |
|
person_generation="DONT_ALLOW" |
|
) |
|
) |
|
|
|
|
|
if image_response.generated_images and len(image_response.generated_images) > 0: |
|
generated_image = image_response.generated_images[0] |
|
image = Image.open(BytesIO(generated_image.image.image_bytes)) |
|
else: |
|
|
|
image = Image.new('RGB', (300, 533), color='gray') |
|
|
|
|
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
return {'text': text, 'image_base64': img_str} |
|
|
|
def start_feed(tag): |
|
""" |
|
Start a new feed with the given tag by generating one initial item. |
|
|
|
Args: |
|
tag (str): The tag to generate content for. |
|
|
|
Returns: |
|
tuple: (current_tag, feed_items, html_content) |
|
""" |
|
if not tag.strip(): |
|
tag = "trending" |
|
item = generate_item(tag) |
|
feed_items = [item] |
|
html_content = generate_html(feed_items) |
|
return tag, feed_items, html_content |
|
|
|
def load_more(current_tag, feed_items): |
|
""" |
|
Append a new item to the existing feed using the current tag. |
|
|
|
Args: |
|
current_tag (str): The tag currently being used for the feed. |
|
feed_items (list): The current list of feed items. |
|
|
|
Returns: |
|
tuple: (current_tag, updated_feed_items, updated_html_content) |
|
""" |
|
new_item = generate_item(current_tag) |
|
feed_items.append(new_item) |
|
html_content = generate_html(feed_items) |
|
return current_tag, feed_items, html_content |
|
|
|
def generate_html(feed_items): |
|
""" |
|
Generate an HTML string to display the feed items in a TikTok-like vertical layout. |
|
|
|
Args: |
|
feed_items (list): List of dictionaries containing 'text' and 'image_base64'. |
|
|
|
Returns: |
|
str: HTML string representing the feed. |
|
""" |
|
html_str = """ |
|
<div style=" |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
max-width: 360px; |
|
margin: 0 auto; |
|
background-color: #000; |
|
height: 640px; |
|
overflow-y: auto; |
|
scrollbar-width: none; |
|
-ms-overflow-style: none; |
|
border: 1px solid #333; |
|
border-radius: 10px; |
|
"> |
|
""" |
|
|
|
html_str += """ |
|
<style> |
|
div::-webkit-scrollbar { |
|
display: none; |
|
} |
|
</style> |
|
""" |
|
for item in feed_items: |
|
html_str += f""" |
|
<div style=" |
|
width: 100%; |
|
height: 640px; |
|
position: relative; |
|
display: flex; |
|
flex-direction: column; |
|
justify-content: flex-end; |
|
overflow: hidden; |
|
"> |
|
<img src="data:image/png;base64,{item['image_base64']}" style=" |
|
width: 100%; |
|
height: 100%; |
|
object-fit: cover; |
|
position: absolute; |
|
top: 0; |
|
left: 0; |
|
z-index: 1; |
|
"> |
|
<div style=" |
|
position: relative; |
|
z-index: 2; |
|
background: linear-gradient(to top, rgba(0,0,0,0.7), transparent); |
|
padding: 20px; |
|
color: white; |
|
font-family: Arial, sans-serif; |
|
font-size: 18px; |
|
font-weight: bold; |
|
text-shadow: 1px 1px 2px rgba(0,0,0,0.5); |
|
"> |
|
{item['text']} |
|
</div> |
|
</div> |
|
""" |
|
html_str += "</div>" |
|
return html_str |
|
|
|
|
|
with gr.Blocks( |
|
css=""" |
|
body { background-color: #000; color: #fff; font-family: Arial, sans-serif; } |
|
.gradio-container { max-width: 400px; margin: 0 auto; padding: 10px; } |
|
input, select, button { border-radius: 5px; background-color: #222; color: #fff; border: 1px solid #444; } |
|
button { background-color: #ff2d55; border: none; } |
|
button:hover { background-color: #e0264b; } |
|
.gr-button { width: 100%; margin-top: 10px; } |
|
.gr-form { background-color: #111; padding: 15px; border-radius: 10px; } |
|
""", |
|
title="TikTok-Style Infinite Feed" |
|
) as demo: |
|
|
|
with gr.Column(elem_classes="gr-form"): |
|
gr.Markdown("### Create Your TikTok Feed") |
|
with gr.Row(): |
|
suggested_tags = gr.Dropdown( |
|
choices=["food", "travel", "fashion", "tech"], |
|
label="Pick a Tag", |
|
value="food" |
|
) |
|
tag_input = gr.Textbox( |
|
label="Or Enter a Custom Tag", |
|
value="food", |
|
placeholder="e.g., sushi, adventure" |
|
) |
|
with gr.Row(): |
|
start_button = gr.Button("Start Feed") |
|
load_more_button = gr.Button("Load More") |
|
|
|
|
|
feed_html = gr.HTML() |
|
|
|
|
|
current_tag = gr.State(value="") |
|
feed_items = gr.State(value=[]) |
|
|
|
|
|
def set_tag(selected_tag): |
|
"""Update the tag input when a suggested tag is selected.""" |
|
return selected_tag |
|
|
|
suggested_tags.change(fn=set_tag, inputs=suggested_tags, outputs=tag_input) |
|
start_button.click( |
|
fn=start_feed, |
|
inputs=tag_input, |
|
outputs=[current_tag, feed_items, feed_html] |
|
) |
|
load_more_button.click( |
|
fn=load_more, |
|
inputs=[current_tag, feed_items], |
|
outputs=[current_tag, feed_items, feed_html] |
|
) |
|
|
|
|
|
demo.launch() |