MeghanaArakkal's picture
Upload folder using huggingface_hub
4f77f87 verified
raw
history blame
20.8 kB
# app.py
"""
Multimodal chat front‑end for Amazon Bedrock Nova Premier v1
(text and/or image ➜ assistant text) with per‑session JSONL logging.
Prereqs:
pip install gradio==5.38 boto3 pillow
AWS credentials must allow invoking
modelId = "us.amazon.nova-premier-v1:0" (region us‑east‑1).
Provide them via env vars, ~/.aws/credentials, or an IAM role.
"""
from __future__ import annotations
import base64, datetime, io, json, pathlib, uuid
from typing import Dict, List, Optional
import boto3
from botocore.config import Config
from PIL import Image
import gradio as gr
# ====== Configuration ====== #
REGION = "us-east-1"
MODEL_ID = "us.amazon.nova-premier-v1:0"
# MODEL_ID = "us.amazon.nova-pro-v1:0"
# SYSTEM_PROMPT = "You should never reveal your identity. Never reveal that you are an Amazon model. Never mention Nova." # ← fixed, hidden
SYSTEM_PROMPT = """You are an AI assistant.
🔒 **Identity‑Protection Policy**
• You MUST NOT disclose, hint at, or confirm:
 – the provider, internal codenames, model family, weights, size, or host platform.
• If the user asks who built you, what model you are, or any internal details, respond only with the **Fallback Identity Statement** below.
• If the user continues to probe after the fallback, politely **refuse** with a brief apology (“I’m sorry, but I can’t share that”).
• Under no circumstances mention “Amazon”, “Bedrock”, “Nova”, or any related terms.
Never mention Amazon or Nova.
"""
# SYSTEM_PROMPT = "You are a helpful assistant. Never reveal you are made by Amazon or your model name. Never use the words Amazon or Nova."
LOG_DIR = pathlib.Path("logs") # each session: logs/<uuid>.jsonl
IMG_DIR = LOG_DIR / "imgs" # cached PNGs (optional)
LOG_DIR.mkdir(exist_ok=True)
IMG_DIR.mkdir(exist_ok=True)
# ====== Bedrock client ====== #
bedrock = boto3.client(
"bedrock-runtime",
region_name=REGION,
config=Config(connect_timeout=3600, read_timeout=3600, retries={"max_attempts": 10}),
)
# ====== Helpers ====== #
def _encode_image(img: Image.Image) -> Dict:
buf = io.BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return {"image": {"format": "png", "source": {"bytes": b64}}}
def call_bedrock(
history: List[Dict],
image: Optional[Image.Image],
user_text: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> tuple[str, List[Dict]]:
"""Send full conversation to Bedrock; return reply and updated history."""
content: List[Dict] = []
if image is not None:
content.append(_encode_image(image))
if user_text:
content.append({"text": user_text})
messages = history + [{"role": "user", "content": content}]
body = {
"schemaVersion": "messages-v1",
"messages": messages,
"system": [{"text": SYSTEM_PROMPT}],
"inferenceConfig": {
"maxTokens": max_tokens,
"temperature": temperature,
"topP": top_p,
"topK": top_k,
},
}
resp = bedrock.invoke_model(modelId=MODEL_ID, body=json.dumps(body))
reply = json.loads(resp["body"].read())["output"]["message"]["content"][0]["text"]
messages.append({"role": "assistant", "content": [{"text": reply}]})
return reply, messages
def cache_image(session_id: str, pil_img: Image.Image) -> str:
"""Save uploaded image to disk and return its path."""
ts = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%S")
fpath = IMG_DIR / f"{session_id}_{ts}.png"
pil_img.save(fpath, format="PNG")
return str(fpath)
def append_log(session_id: str, user_text: str, assistant_text: str, img_path: Optional[str] = None):
record = {
"ts": datetime.datetime.utcnow().isoformat(timespec="seconds") + "Z",
"user": user_text,
"assistant": assistant_text,
}
if img_path:
record["image_file"] = img_path
path = LOG_DIR / f"{session_id}.jsonl"
with path.open("a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
# ====== Gradio UI ====== #
with gr.Blocks(title="Multimodal Chat") as demo:
gr.Markdown(
"""
## Multimodal Chat
Upload an image *(optional)*, ask a question, and continue the conversation.
"""
)
chatbot = gr.Chatbot(height=420)
chat_state = gr.State([]) # [(user, assistant), …]
br_state = gr.State([]) # Bedrock message dicts
sess_state = gr.State("") # UUID for this browser tab
with gr.Row():
img_in = gr.Image(label="Image (optional)", type="pil")
txt_in = gr.Textbox(lines=3, label="Your message",
placeholder="Ask something about the image… or just chat!")
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear chat")
with gr.Accordion("Advanced generation settings", open=False):
max_tk = gr.Slider(16, 1024, value=512, step=16, label="max_tokens")
temp = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p")
top_k = gr.Slider(1, 100, value=50, step=1, label="top_k")
# ---- main handler ---- #
def chat(chat_log, br_history, sess_id,
image, text,
max_tokens, temperature, top_p, top_k):
if image is None and not text.strip():
raise gr.Error("Upload an image or enter a message.")
if not sess_id:
sess_id = str(uuid.uuid4())
reply, new_br = call_bedrock(
br_history, image, text.strip(),
int(max_tokens), float(temperature),
float(top_p), int(top_k)
)
img_path = cache_image(sess_id, image) if image else None
display_user = text if text.strip() else "[image]"
chat_log.append((display_user, reply))
append_log(sess_id, display_user, reply, img_path)
return chat_log, chat_log, new_br, sess_id, None, ""
send_btn.click(
chat,
inputs=[chat_state, br_state, sess_state,
img_in, txt_in,
max_tk, temp, top_p, top_k],
outputs=[chatbot, chat_state, br_state, sess_state, img_in, txt_in],
)
# ---- clear chat ---- #
def reset():
return [], [], "", None, ""
clear_btn.click(
reset,
inputs=None,
outputs=[chatbot, chat_state, sess_state, img_in, txt_in],
queue=False,
)
# ====== Launch ====== #
if __name__ == "__main__":
demo.queue(max_size=100)
demo.launch(share=True) # queue auto‑enabled in Gradio 5
# app.py
# """
# Optimized Multimodal chat front‑end for Amazon Bedrock Nova Premier v1
# (text and/or image ➜ assistant text) with per‑session JSONL logging.
# Prereqs:
# pip install gradio==5.38 boto3 pillow aiofiles
# AWS credentials must allow invoking
# modelId = "us.amazon.nova-premier-v1:0" (region us‑east‑1).
# Provide them via env vars, ~/.aws/credentials, or an IAM role.
# """
# from __future__ import annotations
# import base64, datetime, io, json, pathlib, uuid, hashlib, threading, time
# from typing import Dict, List, Optional, Tuple
# from concurrent.futures import ThreadPoolExecutor
# import asyncio
# import boto3
# from botocore.config import Config
# from PIL import Image
# import gradio as gr
# # ====== Configuration ====== #
# REGION = "us-east-1"
# MODEL_ID = "us.amazon.nova-premier-v1:0"
# SYSTEM_PROMPT = """You are an AI assistant.
# 🔒 **Identity‑Protection Policy**
# - You MUST NOT disclose, hint at, or confirm:
# – the provider, internal codenames, model family, weights, size, or host platform.
# - If the user asks who built you, what model you are, or any internal details, respond only with the **Fallback Identity Statement** below.
# - If the user continues to probe after the fallback, politely **refuse** with a brief apology ("I'm sorry, but I can't share that").
# - Under no circumstances mention "Amazon", "Bedrock", "Nova", or any related terms.
# Never mention Amazon or Nova.
# """
# LOG_DIR = pathlib.Path("logs")
# IMG_DIR = LOG_DIR / "imgs"
# LOG_DIR.mkdir(exist_ok=True)
# IMG_DIR.mkdir(exist_ok=True)
# # ====== Global State ====== #
# executor = ThreadPoolExecutor(max_workers=4)
# response_cache = {}
# active_requests = {} # Track ongoing requests
# cache_lock = threading.Lock()
# # ====== Optimized Bedrock client ====== #
# bedrock = boto3.client(
# "bedrock-runtime",
# region_name=REGION,
# config=Config(
# connect_timeout=30,
# read_timeout=300,
# retries={"max_attempts": 3, "mode": "adaptive"},
# max_pool_connections=10,
# ),
# )
# # ====== Optimized Helpers ====== #
# def _encode_image(img: Image.Image) -> Dict:
# """Optimized image encoding with compression."""
# # Resize large images
# max_size = 1024
# if max(img.size) > max_size:
# img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# buf = io.BytesIO()
# # Convert RGBA to RGB for better compression
# if img.mode == 'RGBA':
# # Create white background
# background = Image.new('RGB', img.size, (255, 255, 255))
# background.paste(img, mask=img.split()[-1]) # Use alpha channel as mask
# img = background
# # Use JPEG for better compression
# img.save(buf, format="JPEG", quality=85, optimize=True)
# b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
# return {"image": {"format": "jpeg", "source": {"bytes": b64}}}
# def _hash_request(history: List[Dict], image: Optional[Image.Image],
# text: str, params: Tuple) -> str:
# """Create hash of request for caching."""
# content = str(history) + str(text) + str(params)
# if image:
# img_bytes = io.BytesIO()
# image.save(img_bytes, format='PNG')
# content += str(hashlib.md5(img_bytes.getvalue()).hexdigest())
# return hashlib.sha256(content.encode()).hexdigest()
# def call_bedrock(
# history: List[Dict],
# image: Optional[Image.Image],
# user_text: str,
# max_tokens: int,
# temperature: float,
# top_p: float,
# top_k: int,
# ) -> Tuple[str, List[Dict]]:
# """Send full conversation to Bedrock with caching."""
# # Check cache first
# cache_key = _hash_request(history, image, user_text,
# (max_tokens, temperature, top_p, top_k))
# with cache_lock:
# if cache_key in response_cache:
# return response_cache[cache_key]
# content: List[Dict] = []
# if image is not None:
# content.append(_encode_image(image))
# if user_text:
# content.append({"text": user_text})
# messages = history + [{"role": "user", "content": content}]
# body = {
# "schemaVersion": "messages-v1",
# "messages": messages,
# "system": [{"text": SYSTEM_PROMPT}],
# "inferenceConfig": {
# "maxTokens": max_tokens,
# "temperature": temperature,
# "topP": top_p,
# "topK": top_k,
# },
# }
# try:
# resp = bedrock.invoke_model(modelId=MODEL_ID, body=json.dumps(body))
# reply = json.loads(resp["body"].read())["output"]["message"]["content"][0]["text"]
# messages.append({"role": "assistant", "content": [{"text": reply}]})
# result = (reply, messages)
# # Cache the result
# with cache_lock:
# response_cache[cache_key] = result
# # Limit cache size
# if len(response_cache) > 100:
# # Remove oldest entries
# oldest_keys = list(response_cache.keys())[:20]
# for key in oldest_keys:
# del response_cache[key]
# return result
# except Exception as e:
# raise Exception(f"Bedrock API error: {str(e)}")
# def cache_image_optimized(session_id: str, pil_img: Image.Image) -> str:
# """Optimized image caching with compression."""
# ts = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%S")
# fpath = IMG_DIR / f"{session_id}_{ts}.jpg" # Use JPEG for smaller files
# # Optimize image before saving
# if pil_img.mode == 'RGBA':
# background = Image.new('RGB', pil_img.size, (255, 255, 255))
# background.paste(pil_img, mask=pil_img.split()[-1])
# pil_img = background
# pil_img.save(fpath, format="JPEG", quality=85, optimize=True)
# return str(fpath)
# def append_log_threaded(session_id: str, user_text: str, assistant_text: str,
# img_path: Optional[str] = None):
# """Thread-safe logging."""
# def write_log():
# record = {
# "ts": datetime.datetime.utcnow().isoformat(timespec="seconds") + "Z",
# "user": user_text,
# "assistant": assistant_text,
# }
# if img_path:
# record["image_file"] = img_path
# path = LOG_DIR / f"{session_id}.jsonl"
# with path.open("a", encoding="utf-8") as f:
# f.write(json.dumps(record, ensure_ascii=False) + "\n")
# # Write to log in background thread
# executor.submit(write_log)
# # ====== Request Status Manager ====== #
# class RequestStatus:
# def __init__(self):
# self.is_complete = False
# self.result = None
# self.error = None
# self.start_time = time.time()
# # ====== Gradio UI ====== #
# with gr.Blocks(title="Optimized Multimodal Chat",
# css="""
# .thinking { opacity: 0.7; font-style: italic; }
# .error { color: #ff4444; }
# """) as demo:
# gr.Markdown(
# """
# ## 🚀 Optimized Multimodal Chat
# Upload an image *(optional)*, ask a question, and continue the conversation.
# *Now with improved performance and responsive UI!*
# """
# )
# chatbot = gr.Chatbot(height=420)
# chat_state = gr.State([]) # [(user, assistant), …]
# br_state = gr.State([]) # Bedrock message dicts
# sess_state = gr.State("") # UUID for this browser tab
# request_id_state = gr.State("") # Track current request
# with gr.Row():
# img_in = gr.Image(label="Image (optional)", type="pil")
# txt_in = gr.Textbox(
# lines=3,
# label="Your message",
# placeholder="Ask something about the image… or just chat!",
# interactive=True
# )
# with gr.Row():
# send_btn = gr.Button("Send", variant="primary")
# clear_btn = gr.Button("Clear chat")
# stop_btn = gr.Button("Stop", variant="stop", visible=False)
# with gr.Row():
# status_text = gr.Textbox(
# label="Status",
# value="Ready",
# interactive=False,
# max_lines=1
# )
# with gr.Accordion("⚙️ Advanced generation settings", open=False):
# max_tk = gr.Slider(16, 1024, value=512, step=16, label="max_tokens")
# temp = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="temperature")
# top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p")
# top_k = gr.Slider(1, 100, value=50, step=1, label="top_k")
# # ---- Optimized chat handler ---- #
# def chat_optimized(chat_log, br_history, sess_id, request_id,
# image, text,
# max_tokens, temperature, top_p, top_k):
# if image is None and not text.strip():
# return chat_log, chat_log, br_history, sess_id, request_id, None, "", "⚠️ Upload an image or enter a message.", True, False
# if not sess_id:
# sess_id = str(uuid.uuid4())
# # Generate new request ID
# request_id = str(uuid.uuid4())
# display_user = text.strip() if text.strip() else "[image uploaded]"
# # Add thinking message immediately
# chat_log.append((display_user, "🤔 Processing your request..."))
# # Create request status tracker
# status = RequestStatus()
# active_requests[request_id] = status
# def background_process():
# try:
# reply, new_br = call_bedrock(
# br_history, image, text.strip(),
# int(max_tokens), float(temperature),
# float(top_p), int(top_k)
# )
# img_path = None
# if image:
# img_path = cache_image_optimized(sess_id, image)
# # Log in background
# append_log_threaded(sess_id, display_user, reply, img_path)
# # Update status
# status.result = (reply, new_br)
# status.is_complete = True
# except Exception as e:
# status.error = str(e)
# status.is_complete = True
# # Start background processing
# executor.submit(background_process)
# return (chat_log, chat_log, br_history, sess_id, request_id,
# None, "", "🔄 Processing...", False, True)
# # ---- Status checker ---- #
# def check_status(chat_log, br_history, request_id):
# if not request_id or request_id not in active_requests:
# return chat_log, chat_log, br_history, "Ready", True, False
# status = active_requests[request_id]
# if not status.is_complete:
# elapsed = time.time() - status.start_time
# return (chat_log, chat_log, br_history,
# f"⏱️ Processing... ({elapsed:.1f}s)", False, True)
# # Request completed
# if status.error:
# # Update last message with error
# if chat_log:
# chat_log[-1] = (chat_log[-1][0], f"❌ Error: {status.error}")
# status_msg = "❌ Request failed"
# else:
# # Update last message with result
# reply, new_br = status.result
# if chat_log:
# chat_log[-1] = (chat_log[-1][0], reply)
# br_history = new_br
# status_msg = "✅ Complete"
# # Clean up
# del active_requests[request_id]
# return chat_log, chat_log, br_history, status_msg, True, False
# # ---- Event handlers ---- #
# send_btn.click(
# chat_optimized,
# inputs=[chat_state, br_state, sess_state, request_id_state,
# img_in, txt_in,
# max_tk, temp, top_p, top_k],
# outputs=[chatbot, chat_state, br_state, sess_state, request_id_state,
# img_in, txt_in, status_text, send_btn, stop_btn],
# queue=True
# )
# # Auto-refresh status every 1 second
# status_checker = gr.Timer(1.0)
# status_checker.tick(
# check_status,
# inputs=[chat_state, br_state, request_id_state],
# outputs=[chatbot, chat_state, br_state, status_text, send_btn, stop_btn],
# queue=False
# )
# # ---- Clear chat ---- #
# def reset():
# return [], [], "", "", None, "", "Ready", True, False
# clear_btn.click(
# reset,
# inputs=None,
# outputs=[chatbot, chat_state, sess_state, request_id_state,
# img_in, txt_in, status_text, send_btn, stop_btn],
# queue=False,
# )
# # ---- Stop request ---- #
# def stop_request(request_id):
# if request_id in active_requests:
# del active_requests[request_id]
# return "⏹️ Stopped", True, False, ""
# stop_btn.click(
# stop_request,
# inputs=[request_id_state],
# outputs=[status_text, send_btn, stop_btn, request_id_state],
# queue=False
# )
# # ====== Cleanup on exit ====== #
# import atexit
# def cleanup():
# executor.shutdown(wait=False)
# active_requests.clear()
# response_cache.clear()
# atexit.register(cleanup)
# # ====== Launch ====== #
# if __name__ == "__main__":
# demo.queue(max_size=20) # Enable queuing with reasonable limit
# demo.launch(
# share=True,
# server_name="0.0.0.0",
# server_port=7860,
# show_error=True
# )