Tech-Meld commited on
Commit
fa136e4
·
verified ·
1 Parent(s): 6bc76bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
-
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
- @gr.cache(allow_output_mutation=True)
 
6
  def load_model():
7
  model_id = "Tech-Meld/Hajax_Chat_1.0"
8
  tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -15,17 +15,20 @@ def get_response(input_text, model, tokenizer):
15
  response = tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True)
16
  return response
17
 
18
- model, tokenizer = load_model()
19
-
20
  def chat(input_text):
21
- response = get_response(input_text, model, tokenizer)
22
- return response
 
 
 
 
 
23
 
24
  iface = gr.Interface(
25
  chat,
26
  "text",
27
  "text",
28
- title="Chat with Hajax",
29
  description="Type your message and press Enter to chat with the AI.",
30
  )
31
  iface.launch()
 
1
  import gradio as gr
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ model_cache = {}
5
+
6
  def load_model():
7
  model_id = "Tech-Meld/Hajax_Chat_1.0"
8
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
15
  response = tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True)
16
  return response
17
 
 
 
18
  def chat(input_text):
19
+ global model_cache
20
+ if "model" not in model_cache:
21
+ model_cache["model"], model_cache["tokenizer"] = load_model()
22
+ model = model_cache["model"]
23
+ tokenizer = model_cache["tokenizer"]
24
+ response = get_response(input_text, model, tokenizer)
25
+ return response
26
 
27
  iface = gr.Interface(
28
  chat,
29
  "text",
30
  "text",
31
+ title="Chat with AI",
32
  description="Type your message and press Enter to chat with the AI.",
33
  )
34
  iface.launch()