FlameF0X commited on
Commit
3604e16
·
verified ·
1 Parent(s): 9882ae6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -17
app.py CHANGED
@@ -2,18 +2,27 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Load model and tokenizer
6
- print("Loading model and tokenizer...")
7
- model_id = "PingVortex/VLM-1"
8
- tokenizer = AutoTokenizer.from_pretrained(model_id)
9
- model = AutoModelForCausalLM.from_pretrained(model_id)
10
- print("Model loaded successfully!")
11
 
12
- def generate_response(message, history):
 
 
 
 
 
 
 
 
 
 
 
13
  input_ids = tokenizer(message, return_tensors="pt").input_ids
14
- # Truncate to last 1024 tokens if needed
15
- input_ids = input_ids[:, -1024:]
16
-
17
  with torch.no_grad():
18
  output = model.generate(
19
  input_ids,
@@ -23,18 +32,19 @@ def generate_response(message, history):
23
  top_p=0.9,
24
  pad_token_id=tokenizer.eos_token_id
25
  )
26
-
27
  new_tokens = output[0][input_ids.shape[1]:]
28
  response = tokenizer.decode(new_tokens, skip_special_tokens=True)
29
-
30
  return response.strip()
31
 
32
  # Create the Gradio interface
33
- demo = gr.ChatInterface(
34
- generate_response,
35
- theme="soft",
36
- examples=["Hello, who are you?", "What can you do?", "Tell me a short story"],
37
- )
 
 
 
38
 
39
  if __name__ == "__main__":
40
  demo.launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Define available models
6
+ model_options = {
7
+ "VLM-1-K1": "PingVortex/VLM-1-K1",
8
+ "VLM-1-K2": "PingVortex/VLM-1-K2",
9
+ "VLM-1-K3": "PingVortex/VLM-1-K3"
10
+ }
11
 
12
+ # Load models and tokenizers
13
+ models = {}
14
+ tokenizers = {}
15
+ for name, model_id in model_options.items():
16
+ print(f"Loading {name}...")
17
+ tokenizers[name] = AutoTokenizer.from_pretrained(model_id)
18
+ models[name] = AutoModelForCausalLM.from_pretrained(model_id)
19
+ print(f"{name} loaded successfully!")
20
+
21
+ def generate_response(message, history, model_choice):
22
+ tokenizer = tokenizers[model_choice]
23
+ model = models[model_choice]
24
  input_ids = tokenizer(message, return_tensors="pt").input_ids
25
+ input_ids = input_ids[:, -1024:] # Truncate to last 1024 tokens if needed
 
 
26
  with torch.no_grad():
27
  output = model.generate(
28
  input_ids,
 
32
  top_p=0.9,
33
  pad_token_id=tokenizer.eos_token_id
34
  )
 
35
  new_tokens = output[0][input_ids.shape[1]:]
36
  response = tokenizer.decode(new_tokens, skip_special_tokens=True)
 
37
  return response.strip()
38
 
39
  # Create the Gradio interface
40
+ with gr.Blocks() as demo:
41
+ model_choice = gr.Dropdown(choices=list(model_options.keys()), label="Select Model", value="VLM-1-K1")
42
+ chatbot = gr.ChatInterface(
43
+ lambda message, history: generate_response(message, history, model_choice.value),
44
+ theme="soft",
45
+ examples=["Hello, who are you?", "What can you do?", "Tell me a short story"],
46
+ )
47
+ model_choice.change(fn=lambda x: None, inputs=model_choice, outputs=[])
48
 
49
  if __name__ == "__main__":
50
  demo.launch()