Spaces:
Running
on
L4
Running
on
L4
optimize compile by removing if branch
Browse files- app.py +16 -6
- tools/llama/generate.py +42 -45
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
from huggingface_hub import snapshot_download
|
| 3 |
import hydra
|
| 4 |
|
|
@@ -125,17 +126,26 @@ def inference(
|
|
| 125 |
)
|
| 126 |
|
| 127 |
payload = dict(
|
| 128 |
-
|
| 129 |
request=request,
|
| 130 |
)
|
| 131 |
llama_queue.put(payload)
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
# VQGAN Inference
|
| 141 |
feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
|
|
|
|
| 1 |
import os
|
| 2 |
+
import queue
|
| 3 |
from huggingface_hub import snapshot_download
|
| 4 |
import hydra
|
| 5 |
|
|
|
|
| 126 |
)
|
| 127 |
|
| 128 |
payload = dict(
|
| 129 |
+
response_queue=queue.Queue(),
|
| 130 |
request=request,
|
| 131 |
)
|
| 132 |
llama_queue.put(payload)
|
| 133 |
|
| 134 |
+
codes = []
|
| 135 |
+
while True:
|
| 136 |
+
result = payload["response_queue"].get()
|
| 137 |
+
if result == "next":
|
| 138 |
+
# TODO: handle next sentence
|
| 139 |
+
continue
|
| 140 |
|
| 141 |
+
if result == "done":
|
| 142 |
+
if payload["success"] is False:
|
| 143 |
+
raise payload["response"]
|
| 144 |
+
break
|
| 145 |
+
|
| 146 |
+
codes.append(result)
|
| 147 |
+
|
| 148 |
+
codes = torch.cat(codes, dim=1)
|
| 149 |
|
| 150 |
# VQGAN Inference
|
| 151 |
feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
|
tools/llama/generate.py
CHANGED
|
@@ -47,32 +47,32 @@ def logits_to_probs(
|
|
| 47 |
top_p: Optional[int] = None,
|
| 48 |
repetition_penalty: float = 1.0,
|
| 49 |
):
|
| 50 |
-
if previous_tokens is not None and repetition_penalty != 1.0:
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
|
| 58 |
-
if top_p is not None and top_p < 1.0:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
|
| 70 |
logits = logits / max(temperature, 1e-5)
|
| 71 |
|
| 72 |
-
if top_k is not None:
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
|
| 77 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 78 |
return probs
|
|
@@ -470,16 +470,14 @@ def generate_long(
|
|
| 470 |
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
| 471 |
|
| 472 |
if use_prompt:
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
num_codebooks=model.config.num_codebooks,
|
| 482 |
-
)
|
| 483 |
)
|
| 484 |
|
| 485 |
for idx, text in enumerate(texts):
|
|
@@ -501,10 +499,6 @@ def generate_long(
|
|
| 501 |
all_codes = []
|
| 502 |
seg_idx = 0
|
| 503 |
|
| 504 |
-
if use_prompt:
|
| 505 |
-
seg_idx = 1
|
| 506 |
-
global_encoded.append(encoded[0])
|
| 507 |
-
|
| 508 |
while seg_idx < len(encoded):
|
| 509 |
logger.info(
|
| 510 |
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
|
@@ -531,6 +525,9 @@ def generate_long(
|
|
| 531 |
else:
|
| 532 |
partial_encoded = global_encoded
|
| 533 |
|
|
|
|
|
|
|
|
|
|
| 534 |
cat_encoded = torch.cat(partial_encoded, dim=1)
|
| 535 |
prompt_length = cat_encoded.size(1)
|
| 536 |
|
|
@@ -593,14 +590,13 @@ def generate_long(
|
|
| 593 |
|
| 594 |
if is_streaming:
|
| 595 |
# This indicates the end of the current sample
|
| 596 |
-
yield
|
| 597 |
else:
|
| 598 |
all_codes = torch.cat(all_codes, dim=1)
|
| 599 |
assert (all_codes >= 0).all(), f"Negative code found: {codes}"
|
| 600 |
yield all_codes
|
| 601 |
|
| 602 |
|
| 603 |
-
|
| 604 |
def launch_thread_safe_queue(
|
| 605 |
config_name,
|
| 606 |
checkpoint_path,
|
|
@@ -624,20 +620,21 @@ def launch_thread_safe_queue(
|
|
| 624 |
break
|
| 625 |
|
| 626 |
kwargs = item["request"]
|
| 627 |
-
|
| 628 |
|
| 629 |
try:
|
| 630 |
item["success"] = True
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
)
|
| 635 |
-
|
|
|
|
| 636 |
except Exception as e:
|
| 637 |
item["success"] = False
|
| 638 |
item["response"] = e
|
| 639 |
|
| 640 |
-
|
| 641 |
|
| 642 |
threading.Thread(target=worker, daemon=True).start()
|
| 643 |
init_event.wait()
|
|
|
|
| 47 |
top_p: Optional[int] = None,
|
| 48 |
repetition_penalty: float = 1.0,
|
| 49 |
):
|
| 50 |
+
# if previous_tokens is not None and repetition_penalty != 1.0:
|
| 51 |
+
previous_tokens = previous_tokens.long()
|
| 52 |
+
score = torch.gather(logits, dim=0, index=previous_tokens)
|
| 53 |
+
score = torch.where(
|
| 54 |
+
score < 0, score * repetition_penalty, score / repetition_penalty
|
| 55 |
+
)
|
| 56 |
+
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
| 57 |
|
| 58 |
+
# if top_p is not None and top_p < 1.0:
|
| 59 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 60 |
+
cum_probs = torch.cumsum(
|
| 61 |
+
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
| 62 |
+
)
|
| 63 |
+
sorted_indices_to_remove = cum_probs > top_p
|
| 64 |
+
sorted_indices_to_remove[0] = False # keep at least one option
|
| 65 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 66 |
+
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
| 67 |
+
)
|
| 68 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
| 69 |
|
| 70 |
logits = logits / max(temperature, 1e-5)
|
| 71 |
|
| 72 |
+
# if top_k is not None:
|
| 73 |
+
# v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 74 |
+
# pivot = v.select(-1, -1).unsqueeze(-1)
|
| 75 |
+
# logits = torch.where(logits < pivot, -float("Inf"), logits)
|
| 76 |
|
| 77 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 78 |
return probs
|
|
|
|
| 470 |
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
| 471 |
|
| 472 |
if use_prompt:
|
| 473 |
+
encoded_prompts = encode_tokens(
|
| 474 |
+
tokenizer,
|
| 475 |
+
prompt_text,
|
| 476 |
+
prompt_tokens=prompt_tokens,
|
| 477 |
+
bos=True,
|
| 478 |
+
device=device,
|
| 479 |
+
speaker=speaker,
|
| 480 |
+
num_codebooks=model.config.num_codebooks,
|
|
|
|
|
|
|
| 481 |
)
|
| 482 |
|
| 483 |
for idx, text in enumerate(texts):
|
|
|
|
| 499 |
all_codes = []
|
| 500 |
seg_idx = 0
|
| 501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
while seg_idx < len(encoded):
|
| 503 |
logger.info(
|
| 504 |
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
|
|
|
| 525 |
else:
|
| 526 |
partial_encoded = global_encoded
|
| 527 |
|
| 528 |
+
if use_prompt:
|
| 529 |
+
partial_encoded = [encoded_prompts] + partial_encoded
|
| 530 |
+
|
| 531 |
cat_encoded = torch.cat(partial_encoded, dim=1)
|
| 532 |
prompt_length = cat_encoded.size(1)
|
| 533 |
|
|
|
|
| 590 |
|
| 591 |
if is_streaming:
|
| 592 |
# This indicates the end of the current sample
|
| 593 |
+
yield "next"
|
| 594 |
else:
|
| 595 |
all_codes = torch.cat(all_codes, dim=1)
|
| 596 |
assert (all_codes >= 0).all(), f"Negative code found: {codes}"
|
| 597 |
yield all_codes
|
| 598 |
|
| 599 |
|
|
|
|
| 600 |
def launch_thread_safe_queue(
|
| 601 |
config_name,
|
| 602 |
checkpoint_path,
|
|
|
|
| 620 |
break
|
| 621 |
|
| 622 |
kwargs = item["request"]
|
| 623 |
+
response_queue = item["response_queue"]
|
| 624 |
|
| 625 |
try:
|
| 626 |
item["success"] = True
|
| 627 |
+
for chunk in generate_long(
|
| 628 |
+
model=model, decode_one_token=decode_one_token, **kwargs
|
| 629 |
+
):
|
| 630 |
+
response_queue.put(chunk)
|
| 631 |
+
|
| 632 |
+
response_queue.put("done")
|
| 633 |
except Exception as e:
|
| 634 |
item["success"] = False
|
| 635 |
item["response"] = e
|
| 636 |
|
| 637 |
+
response_queue.put("done")
|
| 638 |
|
| 639 |
threading.Thread(target=worker, daemon=True).start()
|
| 640 |
init_event.wait()
|