prithivMLmods commited on
Commit
8a0565e
Β·
verified Β·
1 Parent(s): e244517

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -113
app.py CHANGED
@@ -1,145 +1,121 @@
1
- import spaces
2
  import os
3
- import json
4
- import subprocess
5
- from llama_cpp import Llama
6
- from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
7
- from llama_cpp_agent.providers import LlamaCppPythonProvider
8
- from llama_cpp_agent.chat_history import BasicChatHistory
9
- from llama_cpp_agent.chat_history.messages import Roles
10
  import gradio as gr
11
- from huggingface_hub import hf_hub_download
 
 
12
 
13
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
 
14
 
15
- hf_hub_download(
16
- repo_id="prithivMLmods/GWQ-9B-Preview-Q5_K_M-GGUF",
17
- filename="gwq-9b-preview-q5_k_m.gguf",
18
- local_dir="./models"
19
- )
20
 
21
- hf_hub_download(
22
- repo_id="prithivMLmods/GWQ-9B-Preview2-Q5_K_M-GGUF",
23
- filename="gwq-9b-preview2-q5_k_m.gguf",
24
- local_dir="./models"
 
 
25
  )
 
 
26
 
27
- llm = None
28
- llm_model = None
29
-
30
- @spaces.GPU(duration=120)
31
- def respond(
32
- message,
33
- history: list[tuple[str, str]],
34
- model,
35
- system_message,
36
- max_tokens,
37
- temperature,
38
- top_p,
39
- top_k,
40
- repeat_penalty,
41
- ):
42
- chat_template = MessagesFormatterType.GEMMA_2
43
 
44
- global llm
45
- global llm_model
46
-
47
- if llm is None or llm_model != model:
48
- llm = Llama(
49
- model_path=f"models/{model}",
50
- flash_attn=True,
51
- n_gpu_layers=81,
52
- n_batch=1024,
53
- n_ctx=8192,
54
- )
55
- llm_model = model
56
 
57
- provider = LlamaCppPythonProvider(llm)
 
 
 
 
58
 
59
- agent = LlamaCppAgent(
60
- provider,
61
- system_prompt=f"{system_message}",
62
- predefined_messages_formatter_type=chat_template,
63
- debug_output=True
 
 
 
 
 
 
64
  )
65
-
66
- settings = provider.get_provider_default_settings()
67
- settings.temperature = temperature
68
- settings.top_k = top_k
69
- settings.top_p = top_p
70
- settings.max_tokens = max_tokens
71
- settings.repeat_penalty = repeat_penalty
72
- settings.stream = True
73
 
74
- messages = BasicChatHistory()
 
 
 
75
 
76
- for msn in history:
77
- user = {
78
- 'role': Roles.user,
79
- 'content': msn[0]
80
- }
81
- assistant = {
82
- 'role': Roles.assistant,
83
- 'content': msn[1]
84
- }
85
- messages.add_message(user)
86
- messages.add_message(assistant)
87
-
88
- stream = agent.get_chat_response(
89
- message,
90
- llm_sampling_settings=settings,
91
- chat_history=messages,
92
- returns_streaming_generator=True,
93
- print_output=False
94
- )
95
-
96
- outputs = ""
97
- for output in stream:
98
- outputs += output
99
- yield outputs
100
 
101
  demo = gr.ChatInterface(
102
- respond,
103
  additional_inputs=[
104
- gr.Dropdown([
105
- 'gwq-9b-preview-q5_k_m.gguf',
106
- 'gwq-9b-preview2-q5_k_m.gguf'
107
- ],
108
- value="gwq-9b-preview-q5_k_m.gguf",
109
- label="Model"
110
  ),
111
- gr.Textbox(value="You are a helpful assistant.", label="System message"),
112
- gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
113
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
114
  gr.Slider(
 
115
  minimum=0.1,
 
 
 
 
 
 
 
116
  maximum=1.0,
117
- value=0.95,
118
  step=0.05,
119
- label="Top-p",
120
  ),
121
  gr.Slider(
122
- minimum=0,
123
- maximum=100,
124
- value=40,
125
- step=1,
126
  label="Top-k",
 
 
 
 
127
  ),
128
  gr.Slider(
129
- minimum=0.0,
130
- maximum=2.0,
131
- value=1.1,
132
- step=0.1,
133
  label="Repetition penalty",
 
 
 
 
134
  ),
135
  ],
136
- title="GWQ Prev",
137
- chatbot=gr.Chatbot(
138
- scale=1,
139
- show_copy_button=True,
140
- type="messages"
141
- )
 
 
 
 
 
 
142
  )
143
 
144
  if __name__ == "__main__":
145
- demo.launch()
 
 
1
  import os
2
+ from collections.abc import Iterator
3
+ from threading import Thread
4
+
 
 
 
 
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
+ MAX_MAX_NEW_TOKENS = 2048
11
+ DEFAULT_MAX_NEW_TOKENS = 1024
12
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
15
 
16
+ model_id = "prithivMLmods/GWQ-9B-Preview"
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ model_id,
20
+ device_map="auto",
21
+ torch_dtype=torch.bfloat16,
22
  )
23
+ model.config.sliding_window = 4096
24
+ model.eval()
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ @spaces.GPU
28
+ def generate(
29
+ message: str,
30
+ chat_history: list[dict],
31
+ max_new_tokens: int = 1024,
32
+ temperature: float = 0.6,
33
+ top_p: float = 0.9,
34
+ top_k: int = 50,
35
+ repetition_penalty: float = 1.2,
36
+ ) -> Iterator[str]:
37
+ conversation = chat_history.copy()
38
+ conversation.append({"role": "user", "content": message})
39
 
40
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
41
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
42
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
43
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
44
+ input_ids = input_ids.to(model.device)
45
 
46
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
47
+ generate_kwargs = dict(
48
+ {"input_ids": input_ids},
49
+ streamer=streamer,
50
+ max_new_tokens=max_new_tokens,
51
+ do_sample=True,
52
+ top_p=top_p,
53
+ top_k=top_k,
54
+ temperature=temperature,
55
+ num_beams=1,
56
+ repetition_penalty=repetition_penalty,
57
  )
58
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
59
+ t.start()
 
 
 
 
 
 
60
 
61
+ outputs = []
62
+ for text in streamer:
63
+ outputs.append(text)
64
+ yield "".join(outputs)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  demo = gr.ChatInterface(
68
+ fn=generate,
69
  additional_inputs=[
70
+ gr.Slider(
71
+ label="Max new tokens",
72
+ minimum=1,
73
+ maximum=MAX_MAX_NEW_TOKENS,
74
+ step=1,
75
+ value=DEFAULT_MAX_NEW_TOKENS,
76
  ),
 
 
 
77
  gr.Slider(
78
+ label="Temperature",
79
  minimum=0.1,
80
+ maximum=4.0,
81
+ step=0.1,
82
+ value=0.6,
83
+ ),
84
+ gr.Slider(
85
+ label="Top-p (nucleus sampling)",
86
+ minimum=0.05,
87
  maximum=1.0,
 
88
  step=0.05,
89
+ value=0.9,
90
  ),
91
  gr.Slider(
 
 
 
 
92
  label="Top-k",
93
+ minimum=1,
94
+ maximum=1000,
95
+ step=1,
96
+ value=50,
97
  ),
98
  gr.Slider(
 
 
 
 
99
  label="Repetition penalty",
100
+ minimum=1.0,
101
+ maximum=2.0,
102
+ step=0.05,
103
+ value=1.2,
104
  ),
105
  ],
106
+ stop_btn=None,
107
+ examples=[
108
+ ["Hello there! How are you doing?"],
109
+ ["Can you explain briefly to me what is the Python programming language?"],
110
+ ["Explain the plot of Cinderella in a sentence."],
111
+ ["How many hours does it take a man to eat a Helicopter?"],
112
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
113
+ ],
114
+ cache_examples=False,
115
+ type="messages",
116
+ css_paths="style.css",
117
+ fill_height=True,
118
  )
119
 
120
  if __name__ == "__main__":
121
+ demo.queue(max_size=20).launch()