Tech-Meld's picture
Update app.py
fa136e4 verified
raw
history blame
1.12 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
model_cache = {}
def load_model():
model_id = "Tech-Meld/Hajax_Chat_1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
return model, tokenizer
def get_response(input_text, model, tokenizer):
inputs = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
outputs = model.generate(inputs, max_length=1000, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True)
return response
def chat(input_text):
global model_cache
if "model" not in model_cache:
model_cache["model"], model_cache["tokenizer"] = load_model()
model = model_cache["model"]
tokenizer = model_cache["tokenizer"]
response = get_response(input_text, model, tokenizer)
return response
iface = gr.Interface(
chat,
"text",
"text",
title="Chat with AI",
description="Type your message and press Enter to chat with the AI.",
)
iface.launch()