Spaces:
Sleeping
Sleeping
File size: 2,803 Bytes
2817176 814a015 9604a21 2817176 814a015 63adfac 814a015 7a7cde8 2817176 e56158c 2817176 e56158c 2817176 814a015 e56158c 2817176 e56158c 2817176 63adfac e56158c 2817176 c50c58a 2817176 d123fea 63adfac 2817176 814a015 2817176 814a015 2817176 63adfac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
# Replace 'your_huggingface_token' with your actual Hugging Face access token
access_token = os.getenv('token')
# Initialize the tokenizer and model with the Hugging Face access token
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", use_auth_token=access_token)
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b-it",
torch_dtype=torch.bfloat16,
use_auth_token=access_token
)
model.eval() # Set the model to evaluation mode
# Initialize the inference client (if needed for other API-based tasks)
client = InferenceClient(provider="together", token=access_token)
def conversation_predict(input_text):
"""Generate a response for single-turn input using the model."""
# Tokenize the input text
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
# Generate a response with the model
outputs = model.generate(input_ids, max_new_tokens=2048)
# Decode and return the generated response
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
"""Generate a response for a multi-turn chat conversation."""
# Prepare the messages in the correct format for the API
messages = [{"role": "system", "content": system_message}]
for user_input, assistant_reply in history:
if user_input:
messages.append({"role": "user", "content": user_input})
if assistant_reply:
messages.append({"role": "assistant", "content": assistant_reply})
messages.append({"role": "user", "content": message})
# Get the complete response at once (no streaming)
response = client.chat_completion(
model="google/gemma-2b-it",
messages=messages,
max_tokens=max_tokens,
stream=False,
temperature=temperature,
top_p=top_p,
)
# Extract and return the full response
return response["choices"][0]["message"]["content"]
# Create a Gradio ChatInterface demo
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch(share=True) |