File size: 3,609 Bytes
ea62d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
from collections.abc import Iterator

import gradio as gr
from cohere import ClientV2

model_id = "command-a-translate-08-2025"

# Initialize Cohere client
api_key = os.getenv("COHERE_API_KEY")
if not api_key:
    raise ValueError("COHERE_API_KEY environment variable is required")
client = ClientV2(api_key=api_key, client_name="hf-command-a-translate-08-2025")


def generate(message: str, history: list[dict], max_new_tokens: int = 512) -> Iterator[str]:
    """
    Gradio passes the conversation as a list of message dicts, not as (message, history).
    The last message is the current user message, previous are the history.
    """
    # Build messages for Cohere API (text-only)
    messages = []

    # Add conversation history (text-only)
    for item in history:
        role = item.get("role")
        content = item.get("content", "")
        if content is None:
            content = ""
        if not isinstance(content, str):
            content = str(content)
        if role in ("assistant", "user"):
            messages.append({"role": role, "content": content})


    # Add current user message (text-only)
    current_text = message or ""

    if current_text is None:
        current_text = ""
    if not isinstance(current_text, str):
        current_text = str(current_text)
    if current_text:
        messages.append({"role": "user", "content": current_text})

    try:
        # Call Cohere API using the correct event type and delta access
        response = client.chat_stream(
            model=model_id,
            messages=messages,
            temperature=0.3,
            max_tokens=max_new_tokens,
        )

        output = ""
        for event in response:
            if getattr(event, "type", None) == "content-delta":
                # event.delta.message.content.text is the streamed text
                text = getattr(event.delta.message.content, "text", "")
                output += text
                yield output

    except Exception as e:
        gr.Warning(f"Error calling Cohere API: {str(e)}")
        yield ""


examples = [
    [
        "Translate everything that follows into Spanish:\n\nEnterprises rely on translation for some of their most sensitive and business-critical documents and cannot risk data leakage, compliance violations, or misunderstandings. Mistranslated documents can reduce trust and have strategic implications."
    ],
    [
        "Take the English text that follows and translate it into German. Only respond with the translated text.\n\nCommand A Translate is available today on the Cohere platform and for research use on Hugging Face. If you are interested in private or on-prem deployments, please contact our sales team for bespoke pricing."
    ],
    [
        "Can you rewrite that in French please?\n\nTo meet the needs of global enterprises, the model supports translation across 23 widely used business languages: English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Arabic, Chinese, Russian, Polish, Turkish, Vietnamese, Dutch, Czech, Indonesian, Ukrainian, Romanian, Greek, Hindi, Hebrew, and Persian."
    ],
]

demo = gr.ChatInterface(
    fn=generate,
    type="messages",
    textbox=gr.Textbox(autofocus=True),
    additional_inputs=[
        gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
    ],
    stop_btn=False,
    title="Command A Translate",
    examples=examples,
    run_examples_on_click=True,
    cache_examples=False,
    css_paths="style.css",
    delete_cache=(1800, 1800),
)

if __name__ == "__main__":
    demo.launch()