Spaces:
Runtime error
Runtime error
File size: 1,771 Bytes
56b98a2 f1db41f f7c5bbe f1db41f fddf668 0eb0409 4f68941 010bf7d 0eb0409 010bf7d 4f68941 0eb0409 0c198b3 0eb0409 239e77f 0eb0409 0c198b3 0eb0409 3d94e05 010bf7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import gradio as gr
import transformers
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from huggingface_hub import login
login(token ="HF_TOKEN")
def predict(input, history=[]):
"""Processes user input and potentially leverages history for improved predictions.
Args:
input (str): User's input text.
history (list, optional): List of previous inputs and outputs for context (default: []).
Returns:
tuple: A tuple containing the chatbot response and the updated history (optional).
"""
# Replace with your actual Gemma prediction logic here
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Assuming you're using Transformers
# Assuming you've loaded the Gemma model weights
model_name = "google/gemma-1.1-7b-it"
model = AutoModelForSeq2SeqLM.from_pretrained("google/gemma-1.1-7b-it")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-7b-it")
# Process user input using Gemma
inputs = tokenizer(input, return_tensors="pt")
generated_text = model.generate(**inputs)
chatbot_response = tokenizer.decode(generated_text[0], skip_special_tokens=True)
return chatbot_response, history # Return response and optionally updated history
# Create the Gradio interface
interface = gr.Interface(
fn=predict,
inputs=["textbox", "state"], # "state" input can be removed if not used
outputs=["chatbot", "state"] # Remove "state" output if history is not used
)
# Load the model within the Gradio interface context
try:
gr.load("models/google/gemma-1.1-7b-it") # Assuming model weights are available
except Exception as e:
print(f"An error occurred while loading the model: {e}") # Improved error handling
# Launch the Gradio interface
interface.launch()
|