Spaces:
Running
on
Zero
Running
on
Zero
revert batching
Browse files- generate.py +7 -29
- gradio_app.py +4 -4
generate.py
CHANGED
|
@@ -3,7 +3,6 @@ import json
|
|
| 3 |
import logging
|
| 4 |
import regex
|
| 5 |
import time
|
| 6 |
-
from itertools import chain, islice
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Annotated, Iterator
|
| 9 |
|
|
@@ -23,16 +22,14 @@ logger = logging.getLogger(__name__)
|
|
| 23 |
|
| 24 |
|
| 25 |
logger.warning("Loading model...")
|
|
|
|
|
|
|
| 26 |
if torch.backends.mps.is_available():
|
| 27 |
device = "mps"
|
| 28 |
-
|
| 29 |
-
batch_size = 1 # batching generates duplicates
|
| 30 |
else:
|
| 31 |
device = "cuda"
|
| 32 |
-
model_id =
|
| 33 |
-
batch_size = 1 # batching generates duplicates
|
| 34 |
-
|
| 35 |
-
model = models.transformers(model_id, device=device)
|
| 36 |
|
| 37 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 38 |
sampler = PenalizedMultinomialSampler()
|
|
@@ -98,24 +95,6 @@ def samples_prommpt(filename: str, prompt: str, columns: str):
|
|
| 98 |
{{ prompt }}
|
| 99 |
"""
|
| 100 |
|
| 101 |
-
|
| 102 |
-
def stream_json_objects_from_batched_tokens_generator(batched_tokens_generator: Iterator[list[str]], json_field: str) -> Iterator[dict]:
|
| 103 |
-
first_batch = next(batched_tokens_generator)
|
| 104 |
-
batch_size = len(first_batch)
|
| 105 |
-
streams = [""] * batch_size
|
| 106 |
-
skips = [0] * batch_size
|
| 107 |
-
for tokens_batch in chain([first_batch], batched_tokens_generator):
|
| 108 |
-
for stream_idx, token in enumerate(tokens_batch):
|
| 109 |
-
streams[stream_idx] += token
|
| 110 |
-
if '"' in token or "}" in token:
|
| 111 |
-
try:
|
| 112 |
-
for stream_sample in islice(ijson.items(StringIteratorIO(streams[stream_idx].__iter__()), json_field + ".item", buf_size=1), skips[stream_idx], None):
|
| 113 |
-
yield stream_sample
|
| 114 |
-
skips[stream_idx] = +1
|
| 115 |
-
except ijson.IncompleteJSONError:
|
| 116 |
-
pass
|
| 117 |
-
|
| 118 |
-
|
| 119 |
def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
|
| 120 |
filename = Path(filename).stem
|
| 121 |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
|
|
@@ -155,8 +134,7 @@ def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int,
|
|
| 155 |
tokenize=False,
|
| 156 |
add_generation_prompt=True
|
| 157 |
)
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
for _, sample in zip(range(size), stream_json_objects_from_batched_tokens_generator(batched_samples_generator_tokens, json_field=json_field)):
|
| 161 |
yield json.dumps(sample, ensure_ascii=False) + "\n"
|
| 162 |
-
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
|
|
|
|
| 3 |
import logging
|
| 4 |
import regex
|
| 5 |
import time
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import Annotated, Iterator
|
| 8 |
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
logger.warning("Loading model...")
|
| 25 |
+
model_id = "google/gemma-2b-it"
|
| 26 |
+
# model_id = "Qwen/Qwen1.5-0.5B-Chat"
|
| 27 |
if torch.backends.mps.is_available():
|
| 28 |
device = "mps"
|
| 29 |
+
model = models.transformers(model_id, device=device)
|
|
|
|
| 30 |
else:
|
| 31 |
device = "cuda"
|
| 32 |
+
model = models.transformers(model_id, device=device)
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 35 |
sampler = PenalizedMultinomialSampler()
|
|
|
|
| 95 |
{{ prompt }}
|
| 96 |
"""
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
|
| 99 |
filename = Path(filename).stem
|
| 100 |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
|
|
|
|
| 134 |
tokenize=False,
|
| 135 |
add_generation_prompt=True
|
| 136 |
)
|
| 137 |
+
samples_generator_tokens = samples_generator.stream(text, rng=rng)
|
| 138 |
+
for _, sample in zip(range(size), ijson.items(StringIteratorIO(samples_generator_tokens), "data.item", buf_size=4)):
|
|
|
|
| 139 |
yield json.dumps(sample, ensure_ascii=False) + "\n"
|
| 140 |
+
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
|
gradio_app.py
CHANGED
|
@@ -6,11 +6,11 @@ import io
|
|
| 6 |
import pandas as pd
|
| 7 |
import spaces
|
| 8 |
|
| 9 |
-
from generate import model_id, stream_jsonl_file
|
| 10 |
|
| 11 |
-
MAX_SIZE = 20
|
| 12 |
DEFAULT_SEED = 42
|
| 13 |
-
DEFAULT_SIZE =
|
| 14 |
|
| 15 |
@spaces.GPU(duration=120)
|
| 16 |
def stream_output(query: str, continue_content: str = ""):
|
|
@@ -87,4 +87,4 @@ with gr.Blocks() as demo:
|
|
| 87 |
generate_more_button.click(stream_more_output, filename_comp, outputs)
|
| 88 |
|
| 89 |
|
| 90 |
-
demo.launch()
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
import spaces
|
| 8 |
|
| 9 |
+
from generate import model_id, stream_jsonl_file
|
| 10 |
|
| 11 |
+
MAX_SIZE = 20
|
| 12 |
DEFAULT_SEED = 42
|
| 13 |
+
DEFAULT_SIZE = 3
|
| 14 |
|
| 15 |
@spaces.GPU(duration=120)
|
| 16 |
def stream_output(query: str, continue_content: str = ""):
|
|
|
|
| 87 |
generate_more_button.click(stream_more_output, filename_comp, outputs)
|
| 88 |
|
| 89 |
|
| 90 |
+
demo.launch()
|