cptsubtext
commited on
Commit
·
b8a909c
1
Parent(s):
1640c7d
update files with app
Browse files- requirements.txt +4 -3
- src/streamlit_app.py +141 -38
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
1 |
+
streamlit
|
2 |
+
Pillow
|
3 |
+
transformers
|
4 |
+
torch
|
src/streamlit_app.py
CHANGED
@@ -1,40 +1,143 @@
|
|
1 |
-
import altair as alt
|
2 |
-
import numpy as np
|
3 |
-
import pandas as pd
|
4 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import LlavaForConditionalGeneration, AutoProcessor
|
6 |
+
import torch
|
7 |
+
import base64
|
8 |
+
from io import BytesIO
|
9 |
|
10 |
+
# Configuration (similar to aim.py, but adapted for Streamlit)
|
11 |
+
DEFAULT_KEYWORD_COUNT = 5
|
12 |
+
DEFAULT_MODEL = "llava-hf/llava-1.5-7b-hf" # A common Llava model on Hugging Face
|
13 |
+
DEFAULT_TONE = "witty,curious"
|
14 |
+
DEFAULT_TEMP = 0.5
|
15 |
+
|
16 |
+
# Function to convert PIL Image to base64 for display
|
17 |
+
def convert_to_base64(pil_image):
|
18 |
+
buffered = BytesIO()
|
19 |
+
pil_image.save(buffered, format="JPEG")
|
20 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
21 |
+
return img_str
|
22 |
+
|
23 |
+
# Function to extract keywords (from aim.py)
|
24 |
+
def extract_keywords(keywords_string):
|
25 |
+
if keywords_string.startswith("Keywords: "):
|
26 |
+
keywords = keywords_string.replace("Keywords: ", "").strip().split(",")
|
27 |
+
return [keyword.strip() for keyword in keywords]
|
28 |
+
else:
|
29 |
+
return []
|
30 |
+
|
31 |
+
# Function to generate metadata using Transformers Llava model
|
32 |
+
@st.cache_resource
|
33 |
+
def load_llava_model(model_name):
|
34 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
35 |
+
model = LlavaForConditionalGeneration.from_pretrained(model_name)
|
36 |
+
return processor, model
|
37 |
+
|
38 |
+
def generate_metadata(image, prompt_template, model_name, temperature):
|
39 |
+
processor, model = load_llava_model(model_name)
|
40 |
+
|
41 |
+
# Prepare image and prompt for the model
|
42 |
+
# Llava models typically take a conversation-like prompt
|
43 |
+
# Example prompt format for Llava: "USER: <image>\nWhat is this?\nASSISTANT:"
|
44 |
+
|
45 |
+
# We'll need to adapt the prompt to fit the Llava model's expected input
|
46 |
+
# For now, let's keep it simple and pass the image and the direct prompt.
|
47 |
+
# A more robust solution might involve a chat template from the processor.
|
48 |
+
|
49 |
+
inputs = processor(text=prompt_template, images=image, return_tensors="pt")
|
50 |
+
|
51 |
+
# Generate response
|
52 |
+
with torch.no_grad():
|
53 |
+
output = model.generate(**inputs, max_new_tokens=100, temperature=temperature, do_sample=True, top_p=0.9)
|
54 |
+
|
55 |
+
generated_text = processor.decode(output[0], skip_special_tokens=True)
|
56 |
+
|
57 |
+
# The generated_text will contain the prompt itself and then the model's response.
|
58 |
+
# We need to extract only the part that is the model's answer.
|
59 |
+
# This might require some string manipulation depending on the exact output format of the model.
|
60 |
+
# For now, let's assume the model's response starts after the prompt.
|
61 |
+
|
62 |
+
# Find the end of the prompt in the generated text
|
63 |
+
# This is a simplification and might need adjustment based on actual model output
|
64 |
+
if prompt_template in generated_text:
|
65 |
+
model_response = generated_text.split(prompt_template)[-1].strip()
|
66 |
+
else:
|
67 |
+
model_response = generated_text # Fallback if prompt is not found as a prefix
|
68 |
+
|
69 |
+
return model_response
|
70 |
+
|
71 |
+
# Streamlit App
|
72 |
+
st.set_page_config(layout="wide", page_title="Image Metadata Generator")
|
73 |
+
|
74 |
+
st.title("📸 AI-Powered Image Metadata Generator")
|
75 |
+
st.markdown("Upload an image and let the AI generate a catchy title, description, and keywords!")
|
76 |
+
|
77 |
+
# Sidebar for configuration
|
78 |
+
st.sidebar.header("Configuration")
|
79 |
+
selected_model = st.sidebar.selectbox(
|
80 |
+
"Choose a Llava Model",
|
81 |
+
["llava-hf/llava-1.5-7b-hf", "llava-hf/baklava-hf"], # Add more Llava models as needed
|
82 |
+
index=0
|
83 |
+
)
|
84 |
+
temperature = st.sidebar.slider("Creativity (Temperature)", 0.0, 1.0, DEFAULT_TEMP, 0.05)
|
85 |
+
keyword_count = st.sidebar.number_input("Number of Keywords", 1, 10, DEFAULT_KEYWORD_COUNT)
|
86 |
+
tone_input = st.sidebar.text_input("Tone (e.g., witty, curious)", DEFAULT_TONE)
|
87 |
+
tone = [t.strip() for t in tone_input.split(',')]
|
88 |
+
|
89 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
90 |
+
|
91 |
+
if uploaded_file is not None:
|
92 |
+
image = Image.open(uploaded_file).convert("RGB")
|
93 |
+
|
94 |
+
st.subheader("Uploaded Image")
|
95 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
96 |
+
|
97 |
+
if st.button("Generate Metadata"):
|
98 |
+
with st.spinner("Generating metadata... This might take a moment."):
|
99 |
+
prompt_template = f"""
|
100 |
+
As a photojournalist analyze the following image and provide it in a {tone[0]} and {tone[1] if len(tone) > 1 else tone[0]} tone:
|
101 |
+
- Image Headline: A short, impactful title
|
102 |
+
- Image Description: A brief, informative summary
|
103 |
+
- {keyword_count} Image Keywords, separated by commas
|
104 |
+
- Return the Image Headline, Image Description, and Image Keywords in the following format: Headline: ..., Description: ..., Keywords: ...".
|
105 |
+
"""
|
106 |
+
|
107 |
+
# Generate metadata using the selected Llava model
|
108 |
+
ollama_response = generate_metadata(image, prompt_template, selected_model, temperature)
|
109 |
+
|
110 |
+
if ollama_response:
|
111 |
+
st.subheader("Generated Metadata")
|
112 |
+
|
113 |
+
# Parse the response similar to aim.py
|
114 |
+
lines = ollama_response.split('\n')
|
115 |
+
|
116 |
+
headline = ""
|
117 |
+
description = ""
|
118 |
+
keywords = []
|
119 |
+
|
120 |
+
for line in lines:
|
121 |
+
if line.startswith("Headline:"):
|
122 |
+
headline = line.replace("Headline:", "").strip()
|
123 |
+
elif line.startswith("Description:"):
|
124 |
+
description = line.replace("Description:", "").strip()
|
125 |
+
elif line.startswith("Keywords:"):
|
126 |
+
keywords = extract_keywords(line)
|
127 |
+
|
128 |
+
# Remove quotation marks
|
129 |
+
headline = headline.strip('"')
|
130 |
+
description = description.strip('"')
|
131 |
+
lstkeywords = [x.strip('"') for x in keywords]
|
132 |
+
|
133 |
+
st.info(f"**Headline:** {headline}")
|
134 |
+
st.info(f"**Description:** {description}")
|
135 |
+
st.info(f"**Keywords:** {', '.join(lstkeywords)}")
|
136 |
+
else:
|
137 |
+
st.error("Failed to generate metadata. Please try again.")
|
138 |
+
|
139 |
+
st.markdown("""
|
140 |
+
---
|
141 |
+
*This app utilizes Hugging Face's Transformers library and Llava models to generate image metadata.
|
142 |
+
The quality of the generated metadata depends on the chosen model and the complexity of the image.*
|
143 |
+
""")
|