|
import os |
|
|
|
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit" |
|
os.environ["STREAMLIT_CACHE_DIR"] = "/tmp/.cache" |
|
os.makedirs("/tmp/.streamlit", exist_ok=True) |
|
os.makedirs("/tmp/.cache", exist_ok=True) |
|
|
|
import streamlit as st |
|
import os |
|
import json |
|
from PIL import Image |
|
from transformers import LlavaForConditionalGeneration, AutoProcessor |
|
import torch |
|
import base64 |
|
from io import BytesIO |
|
|
|
|
|
DEFAULT_KEYWORD_COUNT = 5 |
|
DEFAULT_MODEL = "llava-hf/llava-1.5-7b-hf" |
|
DEFAULT_TONE = "witty,curious" |
|
DEFAULT_TEMP = 0.5 |
|
|
|
|
|
def convert_to_base64(pil_image): |
|
buffered = BytesIO() |
|
pil_image.save(buffered, format="JPEG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
return img_str |
|
|
|
|
|
def extract_keywords(keywords_string): |
|
if keywords_string.startswith("Keywords: "): |
|
keywords = keywords_string.replace("Keywords: ", "").strip().split(",") |
|
return [keyword.strip() for keyword in keywords] |
|
else: |
|
return [] |
|
|
|
|
|
@st.cache_resource |
|
def load_llava_model(model_name): |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = LlavaForConditionalGeneration.from_pretrained(model_name) |
|
return processor, model |
|
|
|
def generate_metadata(image, prompt_template, model_name, temperature): |
|
processor, model = load_llava_model(model_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = processor(text=prompt_template, images=image, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.generate(**inputs, max_new_tokens=100, temperature=temperature, do_sample=True, top_p=0.9) |
|
|
|
generated_text = processor.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prompt_template in generated_text: |
|
model_response = generated_text.split(prompt_template)[-1].strip() |
|
else: |
|
model_response = generated_text |
|
|
|
return model_response |
|
|
|
|
|
st.set_page_config(layout="wide", page_title="Image Metadata Generator") |
|
|
|
st.title("๐ธ AI-Powered Image Metadata Generator") |
|
st.markdown("Upload an image and let the AI generate a catchy title, description, and keywords!") |
|
|
|
|
|
st.sidebar.header("Configuration") |
|
selected_model = st.sidebar.selectbox( |
|
"Choose a Llava Model", |
|
["llava-hf/llava-1.5-7b-hf", "llava-hf/baklava-hf"], |
|
index=0 |
|
) |
|
temperature = st.sidebar.slider("Creativity (Temperature)", 0.0, 1.0, DEFAULT_TEMP, 0.05) |
|
keyword_count = st.sidebar.number_input("Number of Keywords", 1, 10, DEFAULT_KEYWORD_COUNT) |
|
tone_input = st.sidebar.text_input("Tone (e.g., witty, curious)", DEFAULT_TONE) |
|
tone = [t.strip() for t in tone_input.split(',')] |
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file).convert("RGB") |
|
|
|
st.subheader("Uploaded Image") |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
if st.button("Generate Metadata"): |
|
with st.spinner("Generating metadata... This might take a moment."): |
|
prompt_template = f""" |
|
As a photojournalist analyze the following image and provide it in a {tone[0]} and {tone[1] if len(tone) > 1 else tone[0]} tone: |
|
- Image Headline: A short, impactful title |
|
- Image Description: A brief, informative summary |
|
- {keyword_count} Image Keywords, separated by commas |
|
- Return the Image Headline, Image Description, and Image Keywords in the following format: Headline: ..., Description: ..., Keywords: ...". |
|
""" |
|
|
|
|
|
ollama_response = generate_metadata(image, prompt_template, selected_model, temperature) |
|
|
|
if ollama_response: |
|
st.subheader("Generated Metadata") |
|
|
|
|
|
lines = ollama_response.split('\n') |
|
|
|
headline = "" |
|
description = "" |
|
keywords = [] |
|
|
|
for line in lines: |
|
if line.startswith("Headline:"): |
|
headline = line.replace("Headline:", "").strip() |
|
elif line.startswith("Description:"): |
|
description = line.replace("Description:", "").strip() |
|
elif line.startswith("Keywords:"): |
|
keywords = extract_keywords(line) |
|
|
|
|
|
headline = headline.strip('"') |
|
description = description.strip('"') |
|
lstkeywords = [x.strip('"') for x in keywords] |
|
|
|
st.info(f"**Headline:** {headline}") |
|
st.info(f"**Description:** {description}") |
|
st.info(f"**Keywords:** {', '.join(lstkeywords)}") |
|
else: |
|
st.error("Failed to generate metadata. Please try again.") |
|
|
|
st.markdown(""" |
|
--- |
|
*This app utilizes Hugging Face's Transformers library and Llava models to generate image metadata. |
|
The quality of the generated metadata depends on the chosen model and the complexity of the image.* |
|
""") |