Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import threading | |
| import arrow | |
| import time | |
| import argparse | |
| import logging | |
| from dataclasses import dataclass | |
| import torch | |
| import sentencepiece as spm | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers.generation.streamers import BaseStreamer | |
| from huggingface_hub import hf_hub_download, login | |
| logger = logging.getLogger() | |
| logger.setLevel("INFO") | |
| gr_interface = None | |
| VERSION = "0.1" | |
| class DefaultArgs: | |
| hf_model_name_or_path: str = "cyberagent/open-calm-1b" | |
| spm_model_path: str = None | |
| env: str = "dev" | |
| port: int = 7860 | |
| make_public: bool = False | |
| if not os.getenv("RUNNING_ON_HF_SPACE"): | |
| parser = argparse.ArgumentParser(description="") | |
| parser.add_argument("--hf_model_name_or_path", type=str, default="cyberagent/open-calm-small") # required=True) | |
| parser.add_argument("--env", type=str, default="dev") | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--make_public", action='store_true') | |
| args = parser.parse_args() | |
| def load_model( | |
| model_dir, | |
| ): | |
| model = AutoModelForCausalLM.from_pretrained(args.hf_model_name_or_path, device_map="auto", torch_dtype=torch.float32) | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda:0") | |
| return model | |
| logging.info("Loading model") | |
| model = load_model(args.hf_model_name_or_path) | |
| tokenizer = AutoTokenizer.from_pretrained(args.hf_model_name_or_path) | |
| logging.info("Finished loading model") | |
| class Streamer(BaseStreamer): | |
| def __init__(self, tokenizer): | |
| self.tokenizer = tokenizer | |
| self.num_invoked = 0 | |
| self.prompt = "" | |
| self.generated_text = "" | |
| self.ended = False | |
| def put(self, t: torch.Tensor): | |
| d = t.dim() | |
| if d == 1: | |
| pass | |
| elif d == 2: | |
| t = t[0] | |
| else: | |
| raise NotImplementedError | |
| t = [int(x) for x in t.numpy()] | |
| text = self.tokenizer.decode(t, skip_special_tokens=True) | |
| if self.num_invoked == 0: | |
| self.prompt = text | |
| self.num_invoked += 1 | |
| return | |
| self.generated_text += text | |
| logging.debug(f"[streamer]: {self.generated_text}") | |
| def end(self): | |
| self.ended = True | |
| def generate( | |
| prompt, | |
| max_new_tokens, | |
| temperature, | |
| repetition_penalty, | |
| do_sample, | |
| no_repeat_ngram_size, | |
| ): | |
| log = dict(locals()) | |
| logging.debug(log) | |
| print(log) | |
| input_ids = tokenizer(prompt, return_tensors="pt")['input_ids'].to(model.device) | |
| max_possilbe_new_tokens = model.config.max_position_embeddings - len(input_ids.squeeze(0)) | |
| max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens) | |
| streamer = Streamer(tokenizer=tokenizer) | |
| thr = threading.Thread(target=model.generate, args=(), kwargs=dict( | |
| input_ids=input_ids, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| max_new_tokens=max_possilbe_new_tokens, | |
| streamer=streamer, | |
| # max_length=4096, | |
| # top_k=100, | |
| # top_p=0.9, | |
| # num_return_sequences=2, | |
| # num_beams=2, | |
| )) | |
| thr.start() | |
| gen_tokens = model.generate( | |
| input_ids=input_ids, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| max_new_tokens=max_possilbe_new_tokens, | |
| ) | |
| gen = tokenizer.decode(gen_tokens[0], skip_special_tokens=True) | |
| while not streamer.ended: | |
| time.sleep(0.05) | |
| yield streamer.generated_text | |
| # TODO: optimize for final few tokens | |
| gen = streamer.generated_text | |
| log.update(dict( | |
| generation=gen, | |
| version=VERSION, | |
| time=str(arrow.now("+09:00")))) | |
| logging.info(log) | |
| yield gen | |
| def process_feedback( | |
| rating, | |
| prompt, | |
| generation, | |
| max_new_tokens, | |
| temperature, | |
| repetition_penalty, | |
| do_sample, | |
| no_repeat_ngram_size, | |
| ): | |
| log = dict(locals()) | |
| log.update(dict( | |
| time=str(arrow.now("+09:00")), | |
| version=VERSION, | |
| )) | |
| logging.info(log) | |
| if gr_interface: | |
| gr_interface.close(verbose=False) | |
| with gr.Blocks() as gr_interface: | |
| with gr.Row(): | |
| gr.Markdown(f"# open-calm-small Playground ({VERSION})") | |
| with gr.Row(): | |
| gr.Markdown("open-calm-small Playground") | |
| with gr.Row(): | |
| # left panel | |
| with gr.Column(scale=1): | |
| # generation params | |
| with gr.Box(): | |
| gr.Markdown("hyper parameters") | |
| # hidden default params | |
| do_sample = gr.Checkbox(True, label="Do Sample", visible=True) | |
| no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size", visible=False) | |
| # visible params | |
| max_new_tokens = gr.Slider( | |
| 128, | |
| min(512, model.config.max_position_embeddings), | |
| value=128, | |
| step=128, | |
| label="max tokens", | |
| ) | |
| temperature = gr.Slider( | |
| 0, 1, value=0.7, step=0.05, label="temperature", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| 1, 1.5, value=1.2, step=0.05, label="frequency penalty", | |
| ) | |
| # grouping params for easier reference | |
| gr_params = [ | |
| max_new_tokens, | |
| temperature, | |
| repetition_penalty, | |
| do_sample, | |
| no_repeat_ngram_size, | |
| ] | |
| # right panel | |
| with gr.Column(scale=2): | |
| # user input block | |
| with gr.Box(): | |
| textbox_prompt = gr.Textbox( | |
| label="入力", | |
| placeholder="AIによって私達の暮らしは、", | |
| interactive=True, | |
| lines=5, | |
| value="AIによって私達の暮らしは、" | |
| ) | |
| with gr.Box(): | |
| with gr.Row(): | |
| btn_stop = gr.Button(value="キャンセル", variant="secondary") | |
| btn_submit = gr.Button(value="実行", variant="primary") | |
| # model output block | |
| with gr.Box(): | |
| textbox_generation = gr.Textbox( | |
| label="応答", | |
| lines=5, | |
| value="" | |
| ) | |
| # event handling | |
| inputs = [textbox_prompt] + gr_params | |
| click_event = btn_submit.click(generate, inputs, textbox_generation, queue=True) | |
| btn_stop.click(None, None, None, cancels=click_event, queue=False) | |
| for btn_rating in btn_ratings: | |
| btn_rating.click(process_feedback, [btn_rating, textbox_prompt, textbox_generation] + gr_params, queue=False) | |
| gr_interface.queue(max_size=32, concurrency_count=2) | |
| gr_interface.launch(server_port=args.port, share=args.make_public) | |