File size: 1,771 Bytes
56b98a2
f1db41f
f7c5bbe
f1db41f
fddf668
 
0eb0409
4f68941
010bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0eb0409
 
 
010bf7d
4f68941
 
0eb0409
 
0c198b3
0eb0409
239e77f
0eb0409
0c198b3
0eb0409
3d94e05
 
010bf7d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import transformers
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from huggingface_hub import login
login(token ="HF_TOKEN")

def predict(input, history=[]):
  """Processes user input and potentially leverages history for improved predictions.

  Args:
      input (str): User's input text.
      history (list, optional): List of previous inputs and outputs for context (default: []).

  Returns:
      tuple: A tuple containing the chatbot response and the updated history (optional).
  """

  # Replace with your actual Gemma prediction logic here
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer  # Assuming you're using Transformers

  # Assuming you've loaded the Gemma model weights
  model_name = "google/gemma-1.1-7b-it"
  model = AutoModelForSeq2SeqLM.from_pretrained("google/gemma-1.1-7b-it")
  tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-7b-it")

  # Process user input using Gemma
  inputs = tokenizer(input, return_tensors="pt")
  generated_text = model.generate(**inputs)
  chatbot_response = tokenizer.decode(generated_text[0], skip_special_tokens=True)

  return chatbot_response, history  # Return response and optionally updated history

# Create the Gradio interface
interface = gr.Interface(
    fn=predict,
    inputs=["textbox", "state"],  # "state" input can be removed if not used
    outputs=["chatbot", "state"]  # Remove "state" output if history is not used
)

# Load the model within the Gradio interface context
try:
  gr.load("models/google/gemma-1.1-7b-it")  # Assuming model weights are available
except Exception as e:
  print(f"An error occurred while loading the model: {e}")  # Improved error handling

# Launch the Gradio interface
interface.launch()