EYEDOL commited on
Commit
babb493
·
verified ·
1 Parent(s): e1a9f6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -96
app.py CHANGED
@@ -1,24 +1,23 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Fixed and self-contained Swahili multimodal assistant for Hugging Face Spaces.
4
-
5
- Key fixes / improvements over original:
6
- - Robust loading of an LLM repo that may lack `model_type` in config.json by
7
- loading the model object directly and using `trust_remote_code=True` as a
8
- fallback. Avoids `pipeline(... )` raising ValueError on AutoConfig.
9
- - Correct handling of `pipeline(..., device=...)` which expects an int GPU
10
- index or -1 for CPU (previously passed a string like "cpu").
11
- - Streaming generation implemented by calling `model.generate(..., streamer=TextIteratorStreamer(...))`
12
- in a background thread so the main thread can iterate over the streamer.
13
- - Use standard HF env var `HF_TOKEN` and graceful error message if not set.
14
- - Minor robustness improvements (resampling audio, handling mono/stereo, temp
15
- filenames, etc.).
16
-
17
- Drop this file into your Space and replace the old app.py contents.
18
  """
19
 
20
  import os
21
- import re
 
22
  import threading
23
  import numpy as np
24
  import gradio as gr
@@ -38,26 +37,29 @@ from transformers import (
38
  TextIteratorStreamer,
39
  )
40
 
 
 
 
41
  # -------------------- Configuration --------------------
42
  STT_MODEL_ID = "EYEDOL/SALAMA_C3"
43
- LLM_MODEL_ID = "EYEDOL/Llama-3.2-1B_ON_ALPACA5"
 
44
  TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
45
  TTS_ONNX_MODEL_PATH = "swahili_tts.onnx"
46
 
47
  TEMP_DIR = "temp"
48
  os.makedirs(TEMP_DIR, exist_ok=True)
49
 
50
- # Use the standard environment variable name used by Spaces
51
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
52
  if not HF_TOKEN:
53
- raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.")
54
-
55
- # Attempt login to HF hub (Spaces typically already provides token, but this keeps parity)
56
- try:
57
- login(token=HF_TOKEN)
58
- print("Successfully logged into Hugging Face Hub!")
59
- except Exception as e:
60
- print("Warning: could not call huggingface_hub.login(). Proceeding — ensure your token is valid in the environment. Error:", e)
61
 
62
 
63
  class WeeboAssistant:
@@ -79,7 +81,6 @@ class WeeboAssistant:
79
  # ---------------- STT ----------------
80
  print(f"Loading STT model: {STT_MODEL_ID}")
81
  self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
82
- # Speech seq2seq model (e.g. Whisper-like)
83
  self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
84
  STT_MODEL_ID,
85
  torch_dtype=self.torch_dtype,
@@ -93,56 +94,53 @@ class WeeboAssistant:
93
  pass
94
  print("STT model loaded successfully.")
95
 
96
- # ---------------- LLM ----------------
97
- print(f"Loading LLM: {LLM_MODEL_ID}")
98
- self.llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True)
99
 
100
- # Attempt robust loading. If the repo lacks a model_type in config.json,
101
- # try loading with trust_remote_code=True (this allows custom model code in repo).
102
  try:
103
- config = AutoConfig.from_pretrained(LLM_MODEL_ID)
104
- # If config loaded but missing model_type, continue to try direct load
105
- if not getattr(config, "model_type", None):
106
- raise ValueError("config missing model_type - forcing trusted load")
107
 
108
- # Try to load into a causal LM class (works for many standard model types)
 
 
109
  self.llm_model = AutoModelForCausalLM.from_pretrained(
110
- LLM_MODEL_ID,
111
- config=config,
112
  torch_dtype=self.torch_dtype,
113
  low_cpu_mem_usage=True,
 
 
 
 
 
 
 
 
114
  )
115
- except Exception as first_err:
116
- print("Standard AutoConfig/AutoModel load failed or model_type missing. Trying trust_remote_code=True. Error:", first_err)
117
- # Try using trust_remote_code which will import repo-specific model code if present
118
- try:
119
- config = AutoConfig.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
120
- self.llm_model = AutoModelForCausalLM.from_pretrained(
121
- LLM_MODEL_ID,
122
- config=config,
123
- torch_dtype=self.torch_dtype,
124
- trust_remote_code=True,
125
- low_cpu_mem_usage=True,
126
- device_map="auto" if torch.cuda.is_available() else None,
127
- )
128
- except Exception as second_err:
129
- # Final fallback: try to load without special configs — may still fail for custom repos
130
- print("Fallback load also failed:", second_err)
131
- raise RuntimeError(
132
- "Unable to load LLM model. Check the model repo, ensure config.json contains a model_type or that trust_remote_code is allowed."
133
- )
134
-
135
- # If device_map wasn't used and model is on CPU, ensure model is moved to CPU
136
- if self.device == "cpu":
137
- try:
138
- # Many Hugging Face helpers use device_map; if not used, move model
139
- self.llm_model = self.llm_model.to("cpu")
140
- except Exception:
141
- pass
142
 
143
- # For convenience, create a pipeline for non-streaming quick calls (device expects int or -1)
144
- device_index = 0 if torch.cuda.is_available() else -1
145
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  self.llm_pipeline = pipeline(
147
  "text-generation",
148
  model=self.llm_model,
@@ -150,15 +148,14 @@ class WeeboAssistant:
150
  device=device_index,
151
  model_kwargs={"torch_dtype": self.torch_dtype},
152
  )
153
- except Exception:
154
- # pipeline is optional; if it fails we still support the streaming flow via model.generate
155
  self.llm_pipeline = None
156
 
157
- print("LLM loaded successfully.")
158
 
159
  # ---------------- TTS ----------------
160
  print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
161
- # ONNX runtime session; providers include CUDA if available
162
  providers = ["CPUExecutionProvider"]
163
  if torch.cuda.is_available():
164
  providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
@@ -171,22 +168,17 @@ class WeeboAssistant:
171
 
172
  # ---------------- Utility methods ----------------
173
  def transcribe_audio(self, audio_tuple):
174
- """Take a Gradio audio tuple (sample_rate, np_audio) and return transcription string."""
175
  if audio_tuple is None:
176
  return ""
177
  sample_rate, audio_data = audio_tuple
178
- # Convert to mono
179
  if audio_data.ndim > 1:
180
  audio_data = audio_data.mean(axis=1)
181
- # Normalize to float32
182
  if audio_data.dtype != np.float32:
183
- # handle common integer audio dtypes
184
  if np.issubdtype(audio_data.dtype, np.integer):
185
  max_val = np.iinfo(audio_data.dtype).max
186
  audio_data = audio_data.astype(np.float32) / float(max_val)
187
  else:
188
  audio_data = audio_data.astype(np.float32)
189
- # Resample if needed
190
  if sample_rate != self.STT_SAMPLE_RATE:
191
  audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
192
  if len(audio_data) < 1000:
@@ -200,20 +192,15 @@ class WeeboAssistant:
200
  return transcription.strip()
201
 
202
  def generate_speech(self, text):
203
- """Synthesize speech using the ONNX TTS model and return a filepath to a WAV file."""
204
  if not text:
205
  return None
206
  text = text.strip()
207
- # Tokenize with numpy arrays for ONNX
208
  inputs = self.tts_tokenizer(text, return_tensors="np")
209
  input_name = self.tts_session.get_inputs()[0].name
210
  ort_inputs = {input_name: inputs["input_ids"]}
211
  audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
212
 
213
- # ONNX model might produce float audio in range [-1,1] or int16 depending on model. We'll safe-guard.
214
- # Normalize to int16 WAV
215
  if np.issubdtype(audio_waveform.dtype, np.floating):
216
- # Clip and convert
217
  audio_clip = np.clip(audio_waveform, -1.0, 1.0)
218
  audio_int16 = (audio_clip * 32767).astype(np.int16)
219
  else:
@@ -224,23 +211,17 @@ class WeeboAssistant:
224
  return output_path
225
 
226
  def get_llm_response(self, chat_history):
227
- """Return a TextIteratorStreamer that yields generated text pieces as the model produces them.
228
-
229
- This implementation uses self.llm_model.generate(...) with a TextIteratorStreamer and
230
- runs generate in a background thread so the caller can iterate over streamer.
231
- """
232
- # Build prompt from system + conversation. Adjust this template to match your LLM's preferred format.
233
- prompt_lines = [self.SYSTEM_PROMPT.strip(), "\n"]
234
  for user_msg, assistant_msg in chat_history:
235
  if user_msg:
236
- # tag user messages clearly so model understands dialogue turns
237
  prompt_lines.append("User: " + user_msg)
238
  if assistant_msg:
239
  prompt_lines.append("Assistant: " + assistant_msg)
240
  prompt_lines.append("Assistant: ")
241
- prompt = "\n".join(prompt_lines)
 
242
 
243
- # Tokenize and prepare inputs on the same device as the model
244
  inputs = self.llm_tokenizer(prompt, return_tensors="pt")
245
  try:
246
  model_device = next(self.llm_model.parameters()).device
@@ -261,7 +242,6 @@ class WeeboAssistant:
261
  eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
262
  )
263
 
264
- # Launch generation in a thread so we can yield from the streamer in the main thread
265
  gen_thread = threading.Thread(target=self.llm_model.generate, kwargs=generation_kwargs, daemon=True)
266
  gen_thread.start()
267
 
@@ -275,7 +255,6 @@ assistant = WeeboAssistant()
275
  # -------------------- Gradio pipelines --------------------
276
 
277
  def s2s_pipeline(audio_input, chat_history):
278
- # `chat_history` is expected to be a list of (user_text, assistant_text) tuples
279
  user_text = assistant.transcribe_audio(audio_input)
280
  if not user_text or user_text.startswith("("):
281
  chat_history.append((user_text or "(No valid speech detected)", None))
@@ -289,11 +268,9 @@ def s2s_pipeline(audio_input, chat_history):
289
  llm_response_text = ""
290
  for text_chunk in response_stream:
291
  llm_response_text += text_chunk
292
- # Update last turn in chat history
293
  chat_history[-1] = (user_text, llm_response_text)
294
  yield chat_history, None, llm_response_text
295
 
296
- # Once finished, synthesize audio
297
  final_audio_path = assistant.generate_speech(llm_response_text)
298
  yield chat_history, final_audio_path, llm_response_text
299
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Salama Assistant fixed full app.py with PEFT adapter loading (base + adapter)
4
+
5
+ Drop this file into your Hugging Face Space (replace your existing app.py).
6
+
7
+ Requirements:
8
+ - transformers
9
+ - peft
10
+ - onnxruntime
11
+ - librosa
12
+ - huggingface_hub
13
+ - gradio
14
+
15
+ Note: install `peft` (e.g. add to requirements.txt: "peft>=0.4.0") or pip install in your environment.
 
 
16
  """
17
 
18
  import os
19
+ import json
20
+ import tempfile
21
  import threading
22
  import numpy as np
23
  import gradio as gr
 
37
  TextIteratorStreamer,
38
  )
39
 
40
+ # PEFT imports
41
+ from peft import PeftModel, PeftConfig
42
+
43
  # -------------------- Configuration --------------------
44
  STT_MODEL_ID = "EYEDOL/SALAMA_C3"
45
+ ADAPTER_REPO_ID = "EYEDOL/Llama-3.2-1B_ON_ALPACA5" # adapter-only repo
46
+ BASE_MODEL_ID = "unsloth/Llama-3.2-1B-Instruct" # full base model referenced by adapter
47
  TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
48
  TTS_ONNX_MODEL_PATH = "swahili_tts.onnx"
49
 
50
  TEMP_DIR = "temp"
51
  os.makedirs(TEMP_DIR, exist_ok=True)
52
 
53
+ # Use HF token from env; Spaces normally provide HF_TOKEN
54
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
55
  if not HF_TOKEN:
56
+ print("Warning: HF_TOKEN not found in env. Public models may still load, but private repos require a token.")
57
+ else:
58
+ try:
59
+ login(token=HF_TOKEN)
60
+ print("Successfully logged into Hugging Face Hub!")
61
+ except Exception as e:
62
+ print("Warning: huggingface_hub.login() failed:", e)
 
63
 
64
 
65
  class WeeboAssistant:
 
81
  # ---------------- STT ----------------
82
  print(f"Loading STT model: {STT_MODEL_ID}")
83
  self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
 
84
  self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
85
  STT_MODEL_ID,
86
  torch_dtype=self.torch_dtype,
 
94
  pass
95
  print("STT model loaded successfully.")
96
 
97
+ # ---------------- LLM (base + PEFT adapter) ----------------
98
+ print(f"Loading base LLM: {BASE_MODEL_ID} and applying adapter: {ADAPTER_REPO_ID}")
 
99
 
100
+ # 1) Tokenizer: prefer base tokenizer
 
101
  try:
102
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
103
+ except Exception as e:
104
+ print("Warning: could not load base tokenizer, falling back to adapter tokenizer. Error:", e)
105
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
106
 
107
+ # 2) Load base model
108
+ device_map = "auto" if torch.cuda.is_available() else None
109
+ try:
110
  self.llm_model = AutoModelForCausalLM.from_pretrained(
111
+ BASE_MODEL_ID,
 
112
  torch_dtype=self.torch_dtype,
113
  low_cpu_mem_usage=True,
114
+ device_map=device_map,
115
+ trust_remote_code=True,
116
+ )
117
+ except Exception as e:
118
+ # Helpful error info and hint
119
+ raise RuntimeError(
120
+ "Failed to load base model. Ensure the base model ID is correct and the HF_TOKEN has access if private. Error: "
121
+ + str(e)
122
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ # 3) Load and apply PEFT adapter (adapter-only repo)
 
125
  try:
126
+ # This discovers adapter config (adapter_config.json) and applies weights
127
+ peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
128
+ self.llm_model = PeftModel.from_pretrained(
129
+ self.llm_model,
130
+ ADAPTER_REPO_ID,
131
+ device_map=device_map,
132
+ torch_dtype=self.torch_dtype,
133
+ low_cpu_mem_usage=True,
134
+ )
135
+ except Exception as e:
136
+ raise RuntimeError(
137
+ "Failed to load/apply PEFT adapter from adapter repo. Make sure adapter files (adapter_config.json and adapter_model.safetensors) are present and HF_TOKEN has access if private. Error: "
138
+ + str(e)
139
+ )
140
+
141
+ # 4) Optionally create a non-streaming pipeline for quick tests
142
+ try:
143
+ device_index = 0 if torch.cuda.is_available() else -1
144
  self.llm_pipeline = pipeline(
145
  "text-generation",
146
  model=self.llm_model,
 
148
  device=device_index,
149
  model_kwargs={"torch_dtype": self.torch_dtype},
150
  )
151
+ except Exception as e:
152
+ print("Warning: could not create text-generation pipeline. Streaming generate will still work. Error:", e)
153
  self.llm_pipeline = None
154
 
155
+ print("LLM base + adapter loaded successfully.")
156
 
157
  # ---------------- TTS ----------------
158
  print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
 
159
  providers = ["CPUExecutionProvider"]
160
  if torch.cuda.is_available():
161
  providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
 
168
 
169
  # ---------------- Utility methods ----------------
170
  def transcribe_audio(self, audio_tuple):
 
171
  if audio_tuple is None:
172
  return ""
173
  sample_rate, audio_data = audio_tuple
 
174
  if audio_data.ndim > 1:
175
  audio_data = audio_data.mean(axis=1)
 
176
  if audio_data.dtype != np.float32:
 
177
  if np.issubdtype(audio_data.dtype, np.integer):
178
  max_val = np.iinfo(audio_data.dtype).max
179
  audio_data = audio_data.astype(np.float32) / float(max_val)
180
  else:
181
  audio_data = audio_data.astype(np.float32)
 
182
  if sample_rate != self.STT_SAMPLE_RATE:
183
  audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
184
  if len(audio_data) < 1000:
 
192
  return transcription.strip()
193
 
194
  def generate_speech(self, text):
 
195
  if not text:
196
  return None
197
  text = text.strip()
 
198
  inputs = self.tts_tokenizer(text, return_tensors="np")
199
  input_name = self.tts_session.get_inputs()[0].name
200
  ort_inputs = {input_name: inputs["input_ids"]}
201
  audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
202
 
 
 
203
  if np.issubdtype(audio_waveform.dtype, np.floating):
 
204
  audio_clip = np.clip(audio_waveform, -1.0, 1.0)
205
  audio_int16 = (audio_clip * 32767).astype(np.int16)
206
  else:
 
211
  return output_path
212
 
213
  def get_llm_response(self, chat_history):
214
+ prompt_lines = [self.SYSTEM_PROMPT.strip(), "
215
+ "]
 
 
 
 
 
216
  for user_msg, assistant_msg in chat_history:
217
  if user_msg:
 
218
  prompt_lines.append("User: " + user_msg)
219
  if assistant_msg:
220
  prompt_lines.append("Assistant: " + assistant_msg)
221
  prompt_lines.append("Assistant: ")
222
+ prompt = "
223
+ ".join(prompt_lines)
224
 
 
225
  inputs = self.llm_tokenizer(prompt, return_tensors="pt")
226
  try:
227
  model_device = next(self.llm_model.parameters()).device
 
242
  eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
243
  )
244
 
 
245
  gen_thread = threading.Thread(target=self.llm_model.generate, kwargs=generation_kwargs, daemon=True)
246
  gen_thread.start()
247
 
 
255
  # -------------------- Gradio pipelines --------------------
256
 
257
  def s2s_pipeline(audio_input, chat_history):
 
258
  user_text = assistant.transcribe_audio(audio_input)
259
  if not user_text or user_text.startswith("("):
260
  chat_history.append((user_text or "(No valid speech detected)", None))
 
268
  llm_response_text = ""
269
  for text_chunk in response_stream:
270
  llm_response_text += text_chunk
 
271
  chat_history[-1] = (user_text, llm_response_text)
272
  yield chat_history, None, llm_response_text
273
 
 
274
  final_audio_path = assistant.generate_speech(llm_response_text)
275
  yield chat_history, final_audio_path, llm_response_text
276