Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 5,172 Bytes
018b8c8 8cba7c8 018b8c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import os
import base64
from collections.abc import Iterator
import gradio as gr
from cohere import ClientV2
model_id = "command-a-vision-07-2025"
# Initialize Cohere client
api_key = os.getenv("COHERE_API_KEY")
if not api_key:
raise ValueError("COHERE_API_KEY environment variable is required")
client = ClientV2(api_key=api_key, client_name="hf-command-a-vision-07-2025")
IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
def count_files_in_new_message(paths: list[str]) -> int:
image_count = 0
for path in paths:
if path.endswith(IMAGE_FILE_TYPES):
image_count += 1
return image_count
def validate_media_constraints(message: dict) -> bool:
image_count = count_files_in_new_message(message["files"])
if image_count > 10:
gr.Warning("Maximum 10 images are supported.")
return False
return True
def encode_image_to_base64(image_path: str) -> str:
"""Encode an image file to base64 data URL format."""
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
# Determine file extension for MIME type
if image_path.lower().endswith('.png'):
mime_type = "image/png"
elif image_path.lower().endswith('.jpg') or image_path.lower().endswith('.jpeg'):
mime_type = "image/jpeg"
elif image_path.lower().endswith('.webp'):
mime_type = "image/webp"
else:
mime_type = "image/jpeg" # default
return f"data:{mime_type};base64,{encoded_string}"
def generate(message: dict, history: list[dict], max_new_tokens: int = 512) -> Iterator[str]:
if not validate_media_constraints(message):
yield ""
return
# Build messages for Cohere API
messages = []
# Add conversation history
for item in history:
if item["role"] == "assistant":
messages.append({"role": "assistant", "content": item["content"]})
else:
content = item["content"]
if isinstance(content, str):
messages.append({"role": "user", "content": [{"type": "text", "text": content}]})
else:
filepath = content[0]
# For file-only messages, don't include empty text
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": encode_image_to_base64(filepath)}}
]
})
# Add current message
current_content = []
if message["text"]:
current_content.append({"type": "text", "text": message["text"]})
for file_path in message["files"]:
current_content.append({
"type": "image_url",
"image_url": {"url": encode_image_to_base64(file_path)}
})
# Only add the message if there's content
if current_content:
messages.append({"role": "user", "content": current_content})
try:
# Call Cohere API using the correct event type and delta access
response = client.chat_stream(
model=model_id,
messages=messages,
temperature=0.3,
max_tokens=max_new_tokens,
)
output = ""
for event in response:
if getattr(event, "type", None) == "content-delta":
# event.delta.message.content.text is the streamed text
text = getattr(event.delta.message.content, "text", "")
output += text
yield output
except Exception as e:
gr.Warning(f"Error calling Cohere API: {str(e)}")
yield ""
examples = [
[
{
"text": "Write a COBOL function to reverse a string",
"files": [],
}
],
[
{
"text": "Como sair de um helicóptero que caiu na água?",
"files": [],
}
],
[
{
"text": "What is the total amount of the invoice with and without tax?",
"files": ["assets/invoice-1.jpg"],
}
],
[
{
"text": "¿Contra qué modelo gana más Aya Vision 8B?",
"files": ["assets/aya-vision-win-rates.png"],
}
],
[
{
"text": "Erläutern Sie die Ergebnisse in der Tabelle",
"files": ["assets/command-a-longbech-v2.png"],
}
],
[
{
"text": "Explique la théorie de la relativité en français",
"files": [],
}
],
]
demo = gr.ChatInterface(
fn=generate,
type="messages",
textbox=gr.MultimodalTextbox(
file_types=list(IMAGE_FILE_TYPES),
file_count="multiple",
autofocus=True,
),
multimodal=True,
additional_inputs=[
gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
],
stop_btn=False,
title="Command A Vision",
examples=examples,
run_examples_on_click=False,
cache_examples=False,
css_paths="style.css",
delete_cache=(1800, 1800),
)
if __name__ == "__main__":
demo.launch()
|