Staticaliza commited on
Commit
03f6f58
·
verified ·
1 Parent(s): 938f862

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -104,8 +104,7 @@ def generate(input, instruction=DEFAULT_INPUT, sampling=False, temperature=0.7,
104
  omni_content = build_omni_chunks(input) + [instruction]
105
  sys_msg = repo.get_sys_prompt(mode="omni", language="en")
106
  msgs = [sys_msg, {"role": "user", "content": omni_content}]
107
- params = dict(msgs=msgs, tokenizer=tokenizer, omni_input=True, **kw)
108
- return repo.chat(**params)
109
  elif filetype == "Audio":
110
  audio_np, sample_rate = librosa.load(input, sr=16000, mono=True)
111
  chunk_tensor = torch.from_numpy(audio_np).float().to(DEVICE)
@@ -130,7 +129,7 @@ def generate(input, instruction=DEFAULT_INPUT, sampling=False, temperature=0.7,
130
  inputs_payload = [{"role": "user", "content": content}]
131
 
132
  params = {
133
- "msgs": inputs_payload,
134
  "tokenizer": tokenizer,
135
  "sampling": sampling,
136
  "temperature": temperature,
@@ -138,6 +137,7 @@ def generate(input, instruction=DEFAULT_INPUT, sampling=False, temperature=0.7,
138
  "top_k": top_k,
139
  "repetition_penalty": repetition_penalty,
140
  "max_new_tokens": max_tokens,
 
141
  }
142
 
143
  output = repo.chat(**params)
 
104
  omni_content = build_omni_chunks(input) + [instruction]
105
  sys_msg = repo.get_sys_prompt(mode="omni", language="en")
106
  msgs = [sys_msg, {"role": "user", "content": omni_content}]
107
+ print(msgs)
 
108
  elif filetype == "Audio":
109
  audio_np, sample_rate = librosa.load(input, sr=16000, mono=True)
110
  chunk_tensor = torch.from_numpy(audio_np).float().to(DEVICE)
 
129
  inputs_payload = [{"role": "user", "content": content}]
130
 
131
  params = {
132
+ "msgs": msgs or inputs_payload,
133
  "tokenizer": tokenizer,
134
  "sampling": sampling,
135
  "temperature": temperature,
 
137
  "top_k": top_k,
138
  "repetition_penalty": repetition_penalty,
139
  "max_new_tokens": max_tokens,
140
+ "omni_input": filetype == "Video",
141
  }
142
 
143
  output = repo.chat(**params)