File size: 2,803 Bytes
2817176
 
814a015
 
9604a21
2817176
814a015
 
 
 
 
 
 
 
 
 
 
 
 
63adfac
814a015
 
 
 
 
 
 
 
 
 
7a7cde8
2817176
 
e56158c
2817176
e56158c
 
 
 
2817176
814a015
e56158c
2817176
 
e56158c
 
 
 
 
2817176
 
 
63adfac
 
 
e56158c
2817176
c50c58a
2817176
 
d123fea
63adfac
 
 
2817176
814a015
2817176
814a015
2817176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63adfac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os

# Replace 'your_huggingface_token' with your actual Hugging Face access token
access_token = os.getenv('token')

# Initialize the tokenizer and model with the Hugging Face access token
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", use_auth_token=access_token)
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    torch_dtype=torch.bfloat16,
    use_auth_token=access_token
)
model.eval()  # Set the model to evaluation mode

# Initialize the inference client (if needed for other API-based tasks)
client = InferenceClient(provider="together", token=access_token)

def conversation_predict(input_text):
    """Generate a response for single-turn input using the model."""
    # Tokenize the input text
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids

    # Generate a response with the model
    outputs = model.generate(input_ids, max_new_tokens=2048)

    # Decode and return the generated response
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def respond(
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    """Generate a response for a multi-turn chat conversation."""
    # Prepare the messages in the correct format for the API
    messages = [{"role": "system", "content": system_message}]

    for user_input, assistant_reply in history:
        if user_input:
            messages.append({"role": "user", "content": user_input})
        if assistant_reply:
            messages.append({"role": "assistant", "content": assistant_reply})

    messages.append({"role": "user", "content": message})

    # Get the complete response at once (no streaming)
    response = client.chat_completion(
        model="google/gemma-2b-it",
        messages=messages,
        max_tokens=max_tokens,
        stream=False,
        temperature=temperature,
        top_p=top_p,
    )
    
    # Extract and return the full response
    return response["choices"][0]["message"]["content"]

# Create a Gradio ChatInterface demo
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    demo.launch(share=True)