Tonic commited on
Commit
757241b
·
1 Parent(s): 46248f1

attempts lora adapter and streaming

Browse files
Files changed (2) hide show
  1. app.py +49 -74
  2. app_alternative.py +159 -0
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
2
  import torch
3
  from threading import Thread
4
  import gradio as gr
@@ -29,42 +29,20 @@ except Exception as e:
29
  print(f"❌ Error loading model: {e}")
30
  raise e
31
 
32
- class LoRAPipeline:
33
- def __init__(self, model, tokenizer):
34
- self.model = model
35
- self.tokenizer = tokenizer
36
-
37
- def __call__(self, messages, **kwargs):
38
- prompt = self.format_messages(messages)
39
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
40
-
41
- with torch.no_grad():
42
- outputs = self.model.generate(
43
- **inputs,
44
- **kwargs
45
- )
46
-
47
- generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
48
- response = generated_text[len(prompt):]
49
- return response
50
-
51
- def format_messages(self, messages):
52
- """Format messages into a prompt string"""
53
- formatted = ""
54
- for message in messages:
55
- role = message["role"]
56
- content = message["content"]
57
- if role == "system":
58
- formatted += f"System: {content}\n"
59
- elif role == "user":
60
- formatted += f"User: {content}\n"
61
- elif role == "assistant":
62
- formatted += f"Assistant: {content}\n"
63
- formatted += "Assistant: "
64
- return formatted
65
-
66
- # Create the pipeline
67
- pipe = LoRAPipeline(model, tokenizer)
68
 
69
  def format_conversation_history(chat_history):
70
  messages = []
@@ -83,7 +61,13 @@ def generate_response(input_data, chat_history, max_new_tokens, system_prompt, t
83
  processed_history = format_conversation_history(chat_history)
84
  messages = system_message + processed_history + [new_message]
85
 
86
- # Generate response using the LoRA pipeline
 
 
 
 
 
 
87
  generation_kwargs = {
88
  "max_new_tokens": max_new_tokens,
89
  "do_sample": True,
@@ -92,47 +76,38 @@ def generate_response(input_data, chat_history, max_new_tokens, system_prompt, t
92
  "top_k": top_k,
93
  "repetition_penalty": repetition_penalty,
94
  "pad_token_id": tokenizer.eos_token_id,
 
 
95
  }
96
 
97
- # For streaming, we'll generate token by token
98
- prompt = pipe.format_messages(messages)
99
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
100
 
101
- # Generate with streaming
102
- full_response = ""
103
- current_length = inputs["input_ids"].shape[1]
104
 
105
- with torch.no_grad():
106
- for i in range(max_new_tokens):
107
- # Generate one token at a time
108
- outputs = model.generate(
109
- **inputs,
110
- max_new_tokens=1,
111
- do_sample=True,
112
- temperature=temperature,
113
- top_p=top_p,
114
- top_k=top_k,
115
- repetition_penalty=repetition_penalty,
116
- pad_token_id=tokenizer.eos_token_id,
117
- use_cache=True
118
- )
119
-
120
- # Get the new token
121
- new_token = outputs[0][-1].unsqueeze(0).unsqueeze(0)
122
-
123
- # Decode the new token
124
- new_text = tokenizer.decode(new_token[0], skip_special_tokens=True)
125
-
126
- if new_text:
127
- full_response += new_text
128
- yield full_response
129
-
130
- # Update inputs for next iteration
131
- inputs = {"input_ids": torch.cat([inputs["input_ids"], new_token], dim=1)}
132
-
133
- # Check for end of generation
134
- if new_token.item() == tokenizer.eos_token_id:
135
- break
136
 
137
  demo = gr.ChatInterface(
138
  fn=generate_response,
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
2
  import torch
3
  from threading import Thread
4
  import gradio as gr
 
29
  print(f"❌ Error loading model: {e}")
30
  raise e
31
 
32
+ def format_messages(messages):
33
+ """Format messages into a prompt string"""
34
+ formatted = ""
35
+ for message in messages:
36
+ role = message["role"]
37
+ content = message["content"]
38
+ if role == "system":
39
+ formatted += f"System: {content}\n"
40
+ elif role == "user":
41
+ formatted += f"User: {content}\n"
42
+ elif role == "assistant":
43
+ formatted += f"Assistant: {content}\n"
44
+ formatted += "Assistant: "
45
+ return formatted
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def format_conversation_history(chat_history):
48
  messages = []
 
61
  processed_history = format_conversation_history(chat_history)
62
  messages = system_message + processed_history + [new_message]
63
 
64
+ # Format the prompt
65
+ prompt = format_messages(messages)
66
+
67
+ # Create streamer for proper streaming
68
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
69
+
70
+ # Prepare generation kwargs
71
  generation_kwargs = {
72
  "max_new_tokens": max_new_tokens,
73
  "do_sample": True,
 
76
  "top_k": top_k,
77
  "repetition_penalty": repetition_penalty,
78
  "pad_token_id": tokenizer.eos_token_id,
79
+ "streamer": streamer,
80
+ "use_cache": True
81
  }
82
 
83
+ # Tokenize input
 
84
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
85
 
86
+ # Start generation in a separate thread
87
+ thread = Thread(target=model.generate, kwargs={**inputs, **generation_kwargs})
88
+ thread.start()
89
 
90
+ # Stream the response
91
+ thinking = ""
92
+ final = ""
93
+ started_final = False
94
+
95
+ for chunk in streamer:
96
+ if not started_final:
97
+ if "assistantfinal" in chunk.lower():
98
+ split_parts = re.split(r'assistantfinal', chunk, maxsplit=1)
99
+ thinking += split_parts[0]
100
+ final += split_parts[1]
101
+ started_final = True
102
+ else:
103
+ thinking += chunk
104
+ else:
105
+ final += chunk
106
+
107
+ clean_thinking = re.sub(r'^analysis\s*', '', thinking).strip()
108
+ clean_final = final.strip()
109
+ formatted = f"<details open><summary>Click to view Thinking Process</summary>\n\n{clean_thinking}\n\n</details>\n\n{clean_final}"
110
+ yield formatted
 
 
 
 
 
 
 
 
 
 
111
 
112
  demo = gr.ChatInterface(
113
  fn=generate_response,
app_alternative.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
2
+ import torch
3
+ from threading import Thread
4
+ import gradio as gr
5
+ import spaces
6
+ import re
7
+ from peft import PeftModel
8
+
9
+ # Load the base model
10
+ try:
11
+ base_model = AutoModelForCausalLM.from_pretrained(
12
+ "openai/gpt-oss-20b",
13
+ torch_dtype="auto",
14
+ device_map="auto",
15
+ attn_implementation="kernel-community/vllm-flash-attention3"
16
+ )
17
+ tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
18
+
19
+ # Load the LoRA adapter
20
+ try:
21
+ model = PeftModel.from_pretrained(base_model, "Tonic/gpt-oss-20b-multilingual-reasoner")
22
+ print("✅ LoRA model loaded successfully!")
23
+ except Exception as lora_error:
24
+ print(f"⚠️ LoRA adapter failed to load: {lora_error}")
25
+ print("🔄 Falling back to base model...")
26
+ model = base_model
27
+
28
+ except Exception as e:
29
+ print(f"❌ Error loading model: {e}")
30
+ raise e
31
+
32
+ def format_messages(messages):
33
+ """Format messages into a prompt string"""
34
+ formatted = ""
35
+ for message in messages:
36
+ role = message["role"]
37
+ content = message["content"]
38
+ if role == "system":
39
+ formatted += f"System: {content}\n"
40
+ elif role == "user":
41
+ formatted += f"User: {content}\n"
42
+ elif role == "assistant":
43
+ formatted += f"Assistant: {content}\n"
44
+ formatted += "Assistant: "
45
+ return formatted
46
+
47
+ def format_conversation_history(chat_history):
48
+ messages = []
49
+ for item in chat_history:
50
+ role = item["role"]
51
+ content = item["content"]
52
+ if isinstance(content, list):
53
+ content = content[0]["text"] if content and "text" in content[0] else str(content)
54
+ messages.append({"role": role, "content": content})
55
+ return messages
56
+
57
+ @spaces.GPU(duration=60)
58
+ def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
59
+ new_message = {"role": "user", "content": input_data}
60
+ system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
61
+ processed_history = format_conversation_history(chat_history)
62
+ messages = system_message + processed_history + [new_message]
63
+
64
+ # Format the prompt
65
+ prompt = format_messages(messages)
66
+
67
+ # Alternative streaming approach with manual chunking
68
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
69
+
70
+ # Generate in smaller chunks for better streaming
71
+ chunk_size = 50 # Generate 50 tokens at a time
72
+ full_response = ""
73
+
74
+ with torch.no_grad():
75
+ for i in range(0, max_new_tokens, chunk_size):
76
+ current_max_tokens = min(chunk_size, max_new_tokens - i)
77
+
78
+ outputs = model.generate(
79
+ **inputs,
80
+ max_new_tokens=current_max_tokens,
81
+ do_sample=True,
82
+ temperature=temperature,
83
+ top_p=top_p,
84
+ top_k=top_k,
85
+ repetition_penalty=repetition_penalty,
86
+ pad_token_id=tokenizer.eos_token_id,
87
+ use_cache=True
88
+ )
89
+
90
+ # Decode the new tokens
91
+ new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
92
+ new_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
93
+
94
+ if new_text:
95
+ full_response += new_text
96
+
97
+ # Process for thinking/final split
98
+ thinking = ""
99
+ final = ""
100
+ started_final = False
101
+
102
+ if "assistantfinal" in full_response.lower():
103
+ split_parts = re.split(r'assistantfinal', full_response, maxsplit=1)
104
+ thinking = split_parts[0]
105
+ final = split_parts[1] if len(split_parts) > 1 else ""
106
+ started_final = True
107
+ else:
108
+ thinking = full_response
109
+
110
+ clean_thinking = re.sub(r'^analysis\s*', '', thinking).strip()
111
+ clean_final = final.strip()
112
+ formatted = f"<details open><summary>Click to view Thinking Process</summary>\n\n{clean_thinking}\n\n</details>\n\n{clean_final}"
113
+ yield formatted
114
+
115
+ # Update inputs for next iteration
116
+ inputs = {"input_ids": outputs}
117
+
118
+ # Check for end of generation
119
+ if outputs[0][-1].item() == tokenizer.eos_token_id:
120
+ break
121
+
122
+ demo = gr.ChatInterface(
123
+ fn=generate_response,
124
+ additional_inputs=[
125
+ gr.Slider(label="Max new tokens", minimum=64, maximum=4096, step=1, value=2048),
126
+ gr.Textbox(
127
+ label="System Prompt",
128
+ value="You are a helpful assistant. Reasoning: medium",
129
+ lines=4,
130
+ placeholder="Change system prompt"
131
+ ),
132
+ gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
133
+ gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
134
+ gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
135
+ gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0)
136
+ ],
137
+ examples=[
138
+ [{"text": "Explain Newton laws clearly and concisely"}],
139
+ [{"text": "Write a Python function to calculate the Fibonacci sequence"}],
140
+ [{"text": "What are the benefits of open weight AI models"}],
141
+ ],
142
+ cache_examples=False,
143
+ type="messages",
144
+ description="""
145
+ # 🙋🏻‍♂️Welcome to 🌟Tonic's gpt-oss-20b Multilingual Reasoner Demo !
146
+ Wait couple of seconds initially. You can adjust reasoning level in the system prompt like "Reasoning: high.
147
+ """,
148
+ fill_height=True,
149
+ textbox=gr.Textbox(
150
+ label="Query Input",
151
+ placeholder="Type your prompt"
152
+ ),
153
+ stop_btn="Stop Generation",
154
+ multimodal=False,
155
+ theme=gr.themes.Soft()
156
+ )
157
+
158
+ if __name__ == "__main__":
159
+ demo.launch(share=True)