Spaces:
Sleeping
Sleeping
import streamlit as st | |
import requests | |
import os | |
import time | |
# Load Hugging Face API key | |
HF_API_KEY = os.getenv("HF_API_KEY") | |
# Define API URLs | |
IMG2TEXT_API = "https://api-inference.huggingface.co/models/nlpconnect/vit-gpt2-image-captioning" | |
CHAT_API = "https://api-inference.huggingface.co/models/facebook/blenderbot-3B" | |
HEADERS = {"Authorization": f"Bearer {HF_API_KEY}"} | |
# App Title | |
st.title("Multimodal Chatbot") | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Add initial bot welcome message | |
initial_message = "Hello! I'm a chatbot. You can upload an image or ask me anything to get started!" | |
st.session_state.messages.append({"role": "assistant", "content": initial_message}) | |
# Display chat history | |
for msg in st.session_state.messages: | |
with st.chat_message(msg["role"]): | |
st.write(msg["content"]) | |
# Image upload | |
uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"]) | |
# User input | |
user_input = st.chat_input("Ask about this image or anything...") | |
image_caption = None | |
# Process image if uploaded | |
if uploaded_file: | |
# Check image type | |
if uploaded_file.type not in ["image/jpeg", "image/png"]: | |
st.error("⚠️ Please upload a valid JPG or PNG image.") | |
else: | |
# Send image to Hugging Face image-to-text API with retries | |
img_bytes = uploaded_file.read() | |
st.session_state.messages.append({"role": "user", "content": "[Image Uploaded]"}) | |
with st.chat_message("user"): | |
st.image(img_bytes, caption="Uploaded Image", use_column_width=True) | |
# st.write(f"**Image to text context generated:** {image_caption}") fix plz | |
max_retries = 3 | |
for i in range(max_retries): | |
response = requests.post( | |
IMG2TEXT_API, | |
headers={ | |
"Authorization": f"Bearer {HF_API_KEY}", | |
"Content-Type": "application/octet-stream", | |
}, | |
data=img_bytes # Send raw image data | |
) | |
if response.status_code == 200: | |
try: | |
res_json = response.json() | |
# Check for list format and dictionary format | |
if isinstance(res_json, list) and len(res_json) > 0: | |
image_caption = res_json[0].get("generated_text", "⚠️ No caption generated.") | |
elif isinstance(res_json, dict) and "generated_text" in res_json: | |
image_caption = res_json["generated_text"] | |
if image_caption: | |
st.session_state.image_caption = image_caption | |
bot_context = ( | |
f"Consider this image: {image_caption}. Please provide a relevant and engaging response to the image." | |
) | |
payload = {"inputs": bot_context} | |
# Send context to chatbot | |
bot_response = requests.post(CHAT_API, headers=HEADERS, json=payload) | |
if bot_response.status_code == 200: | |
res_json = bot_response.json() | |
# Check if the response is a list or dictionary | |
if isinstance(res_json, list) and len(res_json) > 0: | |
bot_reply = res_json[0].get("generated_text", "I received your image. What would you like to ask about it?") | |
elif isinstance(res_json, dict) and "generated_text" in res_json: | |
bot_reply = res_json["generated_text"] | |
else: | |
bot_reply = "I received your image. What would you like to ask about it?" | |
else: | |
bot_reply = "I received your image. What would you like to ask about it?" | |
# Append chatbot's generated response | |
st.session_state.messages.append({"role": "assistant", "content": bot_reply}) | |
with st.chat_message("assistant"): | |
st.write(bot_reply) | |
uploaded_file = None # Clear image after processing | |
break # Successful, no need to retry | |
else: | |
st.error("⚠️ Unexpected response format from image captioning API.") | |
break | |
except (KeyError, IndexError, TypeError) as e: | |
st.error(f"⚠️ Error: Unable to generate caption. Details: {e}") | |
break | |
elif response.status_code == 503: | |
st.warning(f"⏳ Model warming up... Retrying in 5 seconds. Attempt {i+1}/{max_retries}") | |
time.sleep(5) # Wait before retrying | |
else: | |
st.error(f"⚠️ Image API Error: {response.status_code} - {response.text}") | |
break | |
# Process user input if provided | |
if user_input: | |
combined_input = user_input | |
# Merge image caption with user query if an image was uploaded | |
if "image_caption" in st.session_state and st.session_state.image_caption: | |
combined_input = f"Image context: {st.session_state.image_caption}. {user_input}" | |
# Append user message | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
with st.chat_message("user"): | |
st.write(user_input) | |
# Send combined input to chatbot with retries | |
payload = {"inputs": combined_input} | |
max_retries = 3 | |
for i in range(max_retries): | |
response = requests.post(CHAT_API, headers=HEADERS, json=payload) | |
if response.status_code == 200: | |
try: | |
res_json = response.json() | |
# If it's a dictionary and contains 'generated_text' | |
if isinstance(res_json, dict) and "generated_text" in res_json: | |
bot_reply = res_json["generated_text"] | |
break # Successful, no need to retry | |
# If response is a list (some models return list format) | |
elif isinstance(res_json, list) and len(res_json) > 0 and "generated_text" in res_json[0]: | |
bot_reply = res_json[0]["generated_text"] | |
break | |
else: | |
st.error("⚠️ Unexpected response format from chatbot API.") | |
bot_reply = "⚠️ Unable to generate a response." | |
break | |
except (KeyError, TypeError, IndexError): | |
bot_reply = "⚠️ Error: Unable to generate response." | |
break | |
elif response.status_code == 503: | |
st.warning(f"⏳ Model warming up... Retrying in 5 seconds. Attempt {i+1}/{max_retries}") | |
time.sleep(5) # Wait before retrying | |
else: | |
bot_reply = f"⚠️ Chatbot Error {response.status_code}: {response.text}" | |
break | |
# Append bot response | |
st.session_state.messages.append({"role": "assistant", "content": bot_reply}) | |
with st.chat_message("assistant"): | |
st.write(bot_reply) | |
# Clear button to reset chat | |
if st.button("Clear Chat"): | |
st.session_state.messages = [] | |
st.experimental_rerun() |