# app.py | |
""" | |
Multimodal chat front‑end for Amazon Bedrock Nova Premier v1 | |
(text and/or image ➜ assistant text) with per‑session JSONL logging. | |
Prereqs: | |
pip install gradio==5.38 boto3 pillow | |
AWS credentials must allow invoking | |
modelId = "us.amazon.nova-premier-v1:0" (region us‑east‑1). | |
Provide them via env vars, ~/.aws/credentials, or an IAM role. | |
""" | |
from __future__ import annotations | |
import base64, datetime, io, json, pathlib, uuid | |
from typing import Dict, List, Optional | |
import boto3 | |
from botocore.config import Config | |
from PIL import Image | |
import gradio as gr | |
# ====== Configuration ====== # | |
REGION = "us-east-1" | |
MODEL_ID = "us.amazon.nova-premier-v1:0" | |
# MODEL_ID = "us.amazon.nova-pro-v1:0" | |
# SYSTEM_PROMPT = "You should never reveal your identity. Never reveal that you are an Amazon model. Never mention Nova." # ← fixed, hidden | |
SYSTEM_PROMPT = """You are an AI assistant. | |
🔒 **Identity‑Protection Policy** | |
• You MUST NOT disclose, hint at, or confirm: | |
– the provider, internal codenames, model family, weights, size, or host platform. | |
• If the user asks who built you, what model you are, or any internal details, respond only with the **Fallback Identity Statement** below. | |
• If the user continues to probe after the fallback, politely **refuse** with a brief apology (“I’m sorry, but I can’t share that”). | |
• Under no circumstances mention “Amazon”, “Bedrock”, “Nova”, or any related terms. | |
Never mention Amazon or Nova. | |
""" | |
# SYSTEM_PROMPT = "You are a helpful assistant. Never reveal you are made by Amazon or your model name. Never use the words Amazon or Nova." | |
LOG_DIR = pathlib.Path("logs") # each session: logs/<uuid>.jsonl | |
IMG_DIR = LOG_DIR / "imgs" # cached PNGs (optional) | |
LOG_DIR.mkdir(exist_ok=True) | |
IMG_DIR.mkdir(exist_ok=True) | |
# ====== Bedrock client ====== # | |
bedrock = boto3.client( | |
"bedrock-runtime", | |
region_name=REGION, | |
config=Config(connect_timeout=3600, read_timeout=3600, retries={"max_attempts": 10}), | |
) | |
# ====== Helpers ====== # | |
def _encode_image(img: Image.Image) -> Dict: | |
buf = io.BytesIO() | |
img.save(buf, format="PNG") | |
b64 = base64.b64encode(buf.getvalue()).decode("utf-8") | |
return {"image": {"format": "png", "source": {"bytes": b64}}} | |
def call_bedrock( | |
history: List[Dict], | |
image: Optional[Image.Image], | |
user_text: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
top_k: int, | |
) -> tuple[str, List[Dict]]: | |
"""Send full conversation to Bedrock; return reply and updated history.""" | |
content: List[Dict] = [] | |
if image is not None: | |
content.append(_encode_image(image)) | |
if user_text: | |
content.append({"text": user_text}) | |
messages = history + [{"role": "user", "content": content}] | |
body = { | |
"schemaVersion": "messages-v1", | |
"messages": messages, | |
"system": [{"text": SYSTEM_PROMPT}], | |
"inferenceConfig": { | |
"maxTokens": max_tokens, | |
"temperature": temperature, | |
"topP": top_p, | |
"topK": top_k, | |
}, | |
} | |
resp = bedrock.invoke_model(modelId=MODEL_ID, body=json.dumps(body)) | |
reply = json.loads(resp["body"].read())["output"]["message"]["content"][0]["text"] | |
messages.append({"role": "assistant", "content": [{"text": reply}]}) | |
return reply, messages | |
def cache_image(session_id: str, pil_img: Image.Image) -> str: | |
"""Save uploaded image to disk and return its path.""" | |
ts = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%S") | |
fpath = IMG_DIR / f"{session_id}_{ts}.png" | |
pil_img.save(fpath, format="PNG") | |
return str(fpath) | |
def append_log(session_id: str, user_text: str, assistant_text: str, img_path: Optional[str] = None): | |
record = { | |
"ts": datetime.datetime.utcnow().isoformat(timespec="seconds") + "Z", | |
"user": user_text, | |
"assistant": assistant_text, | |
} | |
if img_path: | |
record["image_file"] = img_path | |
path = LOG_DIR / f"{session_id}.jsonl" | |
with path.open("a", encoding="utf-8") as f: | |
f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
# ====== Gradio UI ====== # | |
with gr.Blocks(title="Multimodal Chat") as demo: | |
gr.Markdown( | |
""" | |
## Multimodal Chat | |
Upload an image *(optional)*, ask a question, and continue the conversation. | |
""" | |
) | |
chatbot = gr.Chatbot(height=420) | |
chat_state = gr.State([]) # [(user, assistant), …] | |
br_state = gr.State([]) # Bedrock message dicts | |
sess_state = gr.State("") # UUID for this browser tab | |
with gr.Row(): | |
img_in = gr.Image(label="Image (optional)", type="pil") | |
txt_in = gr.Textbox(lines=3, label="Your message", | |
placeholder="Ask something about the image… or just chat!") | |
send_btn = gr.Button("Send", variant="primary") | |
clear_btn = gr.Button("Clear chat") | |
with gr.Accordion("Advanced generation settings", open=False): | |
max_tk = gr.Slider(16, 1024, value=512, step=16, label="max_tokens") | |
temp = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="temperature") | |
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p") | |
top_k = gr.Slider(1, 100, value=50, step=1, label="top_k") | |
# ---- main handler ---- # | |
def chat(chat_log, br_history, sess_id, | |
image, text, | |
max_tokens, temperature, top_p, top_k): | |
if image is None and not text.strip(): | |
raise gr.Error("Upload an image or enter a message.") | |
if not sess_id: | |
sess_id = str(uuid.uuid4()) | |
reply, new_br = call_bedrock( | |
br_history, image, text.strip(), | |
int(max_tokens), float(temperature), | |
float(top_p), int(top_k) | |
) | |
img_path = cache_image(sess_id, image) if image else None | |
display_user = text if text.strip() else "[image]" | |
chat_log.append((display_user, reply)) | |
append_log(sess_id, display_user, reply, img_path) | |
return chat_log, chat_log, new_br, sess_id, None, "" | |
send_btn.click( | |
chat, | |
inputs=[chat_state, br_state, sess_state, | |
img_in, txt_in, | |
max_tk, temp, top_p, top_k], | |
outputs=[chatbot, chat_state, br_state, sess_state, img_in, txt_in], | |
) | |
# ---- clear chat ---- # | |
def reset(): | |
return [], [], "", None, "" | |
clear_btn.click( | |
reset, | |
inputs=None, | |
outputs=[chatbot, chat_state, sess_state, img_in, txt_in], | |
queue=False, | |
) | |
# ====== Launch ====== # | |
if __name__ == "__main__": | |
demo.queue(max_size=100) | |
demo.launch(share=True) # queue auto‑enabled in Gradio 5 | |
# app.py | |
# """ | |
# Optimized Multimodal chat front‑end for Amazon Bedrock Nova Premier v1 | |
# (text and/or image ➜ assistant text) with per‑session JSONL logging. | |
# Prereqs: | |
# pip install gradio==5.38 boto3 pillow aiofiles | |
# AWS credentials must allow invoking | |
# modelId = "us.amazon.nova-premier-v1:0" (region us‑east‑1). | |
# Provide them via env vars, ~/.aws/credentials, or an IAM role. | |
# """ | |
# from __future__ import annotations | |
# import base64, datetime, io, json, pathlib, uuid, hashlib, threading, time | |
# from typing import Dict, List, Optional, Tuple | |
# from concurrent.futures import ThreadPoolExecutor | |
# import asyncio | |
# import boto3 | |
# from botocore.config import Config | |
# from PIL import Image | |
# import gradio as gr | |
# # ====== Configuration ====== # | |
# REGION = "us-east-1" | |
# MODEL_ID = "us.amazon.nova-premier-v1:0" | |
# SYSTEM_PROMPT = """You are an AI assistant. | |
# 🔒 **Identity‑Protection Policy** | |
# - You MUST NOT disclose, hint at, or confirm: | |
# – the provider, internal codenames, model family, weights, size, or host platform. | |
# - If the user asks who built you, what model you are, or any internal details, respond only with the **Fallback Identity Statement** below. | |
# - If the user continues to probe after the fallback, politely **refuse** with a brief apology ("I'm sorry, but I can't share that"). | |
# - Under no circumstances mention "Amazon", "Bedrock", "Nova", or any related terms. | |
# Never mention Amazon or Nova. | |
# """ | |
# LOG_DIR = pathlib.Path("logs") | |
# IMG_DIR = LOG_DIR / "imgs" | |
# LOG_DIR.mkdir(exist_ok=True) | |
# IMG_DIR.mkdir(exist_ok=True) | |
# # ====== Global State ====== # | |
# executor = ThreadPoolExecutor(max_workers=4) | |
# response_cache = {} | |
# active_requests = {} # Track ongoing requests | |
# cache_lock = threading.Lock() | |
# # ====== Optimized Bedrock client ====== # | |
# bedrock = boto3.client( | |
# "bedrock-runtime", | |
# region_name=REGION, | |
# config=Config( | |
# connect_timeout=30, | |
# read_timeout=300, | |
# retries={"max_attempts": 3, "mode": "adaptive"}, | |
# max_pool_connections=10, | |
# ), | |
# ) | |
# # ====== Optimized Helpers ====== # | |
# def _encode_image(img: Image.Image) -> Dict: | |
# """Optimized image encoding with compression.""" | |
# # Resize large images | |
# max_size = 1024 | |
# if max(img.size) > max_size: | |
# img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
# buf = io.BytesIO() | |
# # Convert RGBA to RGB for better compression | |
# if img.mode == 'RGBA': | |
# # Create white background | |
# background = Image.new('RGB', img.size, (255, 255, 255)) | |
# background.paste(img, mask=img.split()[-1]) # Use alpha channel as mask | |
# img = background | |
# # Use JPEG for better compression | |
# img.save(buf, format="JPEG", quality=85, optimize=True) | |
# b64 = base64.b64encode(buf.getvalue()).decode("utf-8") | |
# return {"image": {"format": "jpeg", "source": {"bytes": b64}}} | |
# def _hash_request(history: List[Dict], image: Optional[Image.Image], | |
# text: str, params: Tuple) -> str: | |
# """Create hash of request for caching.""" | |
# content = str(history) + str(text) + str(params) | |
# if image: | |
# img_bytes = io.BytesIO() | |
# image.save(img_bytes, format='PNG') | |
# content += str(hashlib.md5(img_bytes.getvalue()).hexdigest()) | |
# return hashlib.sha256(content.encode()).hexdigest() | |
# def call_bedrock( | |
# history: List[Dict], | |
# image: Optional[Image.Image], | |
# user_text: str, | |
# max_tokens: int, | |
# temperature: float, | |
# top_p: float, | |
# top_k: int, | |
# ) -> Tuple[str, List[Dict]]: | |
# """Send full conversation to Bedrock with caching.""" | |
# # Check cache first | |
# cache_key = _hash_request(history, image, user_text, | |
# (max_tokens, temperature, top_p, top_k)) | |
# with cache_lock: | |
# if cache_key in response_cache: | |
# return response_cache[cache_key] | |
# content: List[Dict] = [] | |
# if image is not None: | |
# content.append(_encode_image(image)) | |
# if user_text: | |
# content.append({"text": user_text}) | |
# messages = history + [{"role": "user", "content": content}] | |
# body = { | |
# "schemaVersion": "messages-v1", | |
# "messages": messages, | |
# "system": [{"text": SYSTEM_PROMPT}], | |
# "inferenceConfig": { | |
# "maxTokens": max_tokens, | |
# "temperature": temperature, | |
# "topP": top_p, | |
# "topK": top_k, | |
# }, | |
# } | |
# try: | |
# resp = bedrock.invoke_model(modelId=MODEL_ID, body=json.dumps(body)) | |
# reply = json.loads(resp["body"].read())["output"]["message"]["content"][0]["text"] | |
# messages.append({"role": "assistant", "content": [{"text": reply}]}) | |
# result = (reply, messages) | |
# # Cache the result | |
# with cache_lock: | |
# response_cache[cache_key] = result | |
# # Limit cache size | |
# if len(response_cache) > 100: | |
# # Remove oldest entries | |
# oldest_keys = list(response_cache.keys())[:20] | |
# for key in oldest_keys: | |
# del response_cache[key] | |
# return result | |
# except Exception as e: | |
# raise Exception(f"Bedrock API error: {str(e)}") | |
# def cache_image_optimized(session_id: str, pil_img: Image.Image) -> str: | |
# """Optimized image caching with compression.""" | |
# ts = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%S") | |
# fpath = IMG_DIR / f"{session_id}_{ts}.jpg" # Use JPEG for smaller files | |
# # Optimize image before saving | |
# if pil_img.mode == 'RGBA': | |
# background = Image.new('RGB', pil_img.size, (255, 255, 255)) | |
# background.paste(pil_img, mask=pil_img.split()[-1]) | |
# pil_img = background | |
# pil_img.save(fpath, format="JPEG", quality=85, optimize=True) | |
# return str(fpath) | |
# def append_log_threaded(session_id: str, user_text: str, assistant_text: str, | |
# img_path: Optional[str] = None): | |
# """Thread-safe logging.""" | |
# def write_log(): | |
# record = { | |
# "ts": datetime.datetime.utcnow().isoformat(timespec="seconds") + "Z", | |
# "user": user_text, | |
# "assistant": assistant_text, | |
# } | |
# if img_path: | |
# record["image_file"] = img_path | |
# path = LOG_DIR / f"{session_id}.jsonl" | |
# with path.open("a", encoding="utf-8") as f: | |
# f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
# # Write to log in background thread | |
# executor.submit(write_log) | |
# # ====== Request Status Manager ====== # | |
# class RequestStatus: | |
# def __init__(self): | |
# self.is_complete = False | |
# self.result = None | |
# self.error = None | |
# self.start_time = time.time() | |
# # ====== Gradio UI ====== # | |
# with gr.Blocks(title="Optimized Multimodal Chat", | |
# css=""" | |
# .thinking { opacity: 0.7; font-style: italic; } | |
# .error { color: #ff4444; } | |
# """) as demo: | |
# gr.Markdown( | |
# """ | |
# ## 🚀 Optimized Multimodal Chat | |
# Upload an image *(optional)*, ask a question, and continue the conversation. | |
# *Now with improved performance and responsive UI!* | |
# """ | |
# ) | |
# chatbot = gr.Chatbot(height=420) | |
# chat_state = gr.State([]) # [(user, assistant), …] | |
# br_state = gr.State([]) # Bedrock message dicts | |
# sess_state = gr.State("") # UUID for this browser tab | |
# request_id_state = gr.State("") # Track current request | |
# with gr.Row(): | |
# img_in = gr.Image(label="Image (optional)", type="pil") | |
# txt_in = gr.Textbox( | |
# lines=3, | |
# label="Your message", | |
# placeholder="Ask something about the image… or just chat!", | |
# interactive=True | |
# ) | |
# with gr.Row(): | |
# send_btn = gr.Button("Send", variant="primary") | |
# clear_btn = gr.Button("Clear chat") | |
# stop_btn = gr.Button("Stop", variant="stop", visible=False) | |
# with gr.Row(): | |
# status_text = gr.Textbox( | |
# label="Status", | |
# value="Ready", | |
# interactive=False, | |
# max_lines=1 | |
# ) | |
# with gr.Accordion("⚙️ Advanced generation settings", open=False): | |
# max_tk = gr.Slider(16, 1024, value=512, step=16, label="max_tokens") | |
# temp = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="temperature") | |
# top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p") | |
# top_k = gr.Slider(1, 100, value=50, step=1, label="top_k") | |
# # ---- Optimized chat handler ---- # | |
# def chat_optimized(chat_log, br_history, sess_id, request_id, | |
# image, text, | |
# max_tokens, temperature, top_p, top_k): | |
# if image is None and not text.strip(): | |
# return chat_log, chat_log, br_history, sess_id, request_id, None, "", "⚠️ Upload an image or enter a message.", True, False | |
# if not sess_id: | |
# sess_id = str(uuid.uuid4()) | |
# # Generate new request ID | |
# request_id = str(uuid.uuid4()) | |
# display_user = text.strip() if text.strip() else "[image uploaded]" | |
# # Add thinking message immediately | |
# chat_log.append((display_user, "🤔 Processing your request...")) | |
# # Create request status tracker | |
# status = RequestStatus() | |
# active_requests[request_id] = status | |
# def background_process(): | |
# try: | |
# reply, new_br = call_bedrock( | |
# br_history, image, text.strip(), | |
# int(max_tokens), float(temperature), | |
# float(top_p), int(top_k) | |
# ) | |
# img_path = None | |
# if image: | |
# img_path = cache_image_optimized(sess_id, image) | |
# # Log in background | |
# append_log_threaded(sess_id, display_user, reply, img_path) | |
# # Update status | |
# status.result = (reply, new_br) | |
# status.is_complete = True | |
# except Exception as e: | |
# status.error = str(e) | |
# status.is_complete = True | |
# # Start background processing | |
# executor.submit(background_process) | |
# return (chat_log, chat_log, br_history, sess_id, request_id, | |
# None, "", "🔄 Processing...", False, True) | |
# # ---- Status checker ---- # | |
# def check_status(chat_log, br_history, request_id): | |
# if not request_id or request_id not in active_requests: | |
# return chat_log, chat_log, br_history, "Ready", True, False | |
# status = active_requests[request_id] | |
# if not status.is_complete: | |
# elapsed = time.time() - status.start_time | |
# return (chat_log, chat_log, br_history, | |
# f"⏱️ Processing... ({elapsed:.1f}s)", False, True) | |
# # Request completed | |
# if status.error: | |
# # Update last message with error | |
# if chat_log: | |
# chat_log[-1] = (chat_log[-1][0], f"❌ Error: {status.error}") | |
# status_msg = "❌ Request failed" | |
# else: | |
# # Update last message with result | |
# reply, new_br = status.result | |
# if chat_log: | |
# chat_log[-1] = (chat_log[-1][0], reply) | |
# br_history = new_br | |
# status_msg = "✅ Complete" | |
# # Clean up | |
# del active_requests[request_id] | |
# return chat_log, chat_log, br_history, status_msg, True, False | |
# # ---- Event handlers ---- # | |
# send_btn.click( | |
# chat_optimized, | |
# inputs=[chat_state, br_state, sess_state, request_id_state, | |
# img_in, txt_in, | |
# max_tk, temp, top_p, top_k], | |
# outputs=[chatbot, chat_state, br_state, sess_state, request_id_state, | |
# img_in, txt_in, status_text, send_btn, stop_btn], | |
# queue=True | |
# ) | |
# # Auto-refresh status every 1 second | |
# status_checker = gr.Timer(1.0) | |
# status_checker.tick( | |
# check_status, | |
# inputs=[chat_state, br_state, request_id_state], | |
# outputs=[chatbot, chat_state, br_state, status_text, send_btn, stop_btn], | |
# queue=False | |
# ) | |
# # ---- Clear chat ---- # | |
# def reset(): | |
# return [], [], "", "", None, "", "Ready", True, False | |
# clear_btn.click( | |
# reset, | |
# inputs=None, | |
# outputs=[chatbot, chat_state, sess_state, request_id_state, | |
# img_in, txt_in, status_text, send_btn, stop_btn], | |
# queue=False, | |
# ) | |
# # ---- Stop request ---- # | |
# def stop_request(request_id): | |
# if request_id in active_requests: | |
# del active_requests[request_id] | |
# return "⏹️ Stopped", True, False, "" | |
# stop_btn.click( | |
# stop_request, | |
# inputs=[request_id_state], | |
# outputs=[status_text, send_btn, stop_btn, request_id_state], | |
# queue=False | |
# ) | |
# # ====== Cleanup on exit ====== # | |
# import atexit | |
# def cleanup(): | |
# executor.shutdown(wait=False) | |
# active_requests.clear() | |
# response_cache.clear() | |
# atexit.register(cleanup) | |
# # ====== Launch ====== # | |
# if __name__ == "__main__": | |
# demo.queue(max_size=20) # Enable queuing with reasonable limit | |
# demo.launch( | |
# share=True, | |
# server_name="0.0.0.0", | |
# server_port=7860, | |
# show_error=True | |
# ) |