Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import sys | |
| import os | |
| import random | |
| import llm_blender | |
| import descriptions | |
| from datasets import load_dataset | |
| from llm_blender.blender.blender_utils import get_topk_candidates_from_ranks | |
| from typing import List | |
| MAX_BASE_LLM_NUM = 20 | |
| MIN_BASE_LLM_NUM = 3 | |
| SOURCE_MAX_LENGTH = 256 | |
| DEFAULT_SOURCE_MAX_LENGTH = 128 | |
| CANDIDATE_MAX_LENGTH = 256 | |
| DEFAULT_CANDIDATE_MAX_LENGTH = 128 | |
| FUSER_MAX_NEW_TOKENS = 512 | |
| DEFAULT_FUSER_MAX_NEW_TOKENS = 256 | |
| # MIX-INSTRUCT | |
| EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation') | |
| SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42) | |
| MIX_INSTRUCT_EXAMPLES = [] | |
| CANDIDATE_MAP = {} | |
| for i, example in enumerate(SHUFFLED_EXAMPLES_DATASET): | |
| MIX_INSTRUCT_EXAMPLES.append([ | |
| example['instruction'], | |
| example['input'], | |
| ]) | |
| CANDIDATE_MAP[example['instruction']+example['input']] = example['candidates'] | |
| if i > 100: | |
| break | |
| # HHH ALIGNMENT | |
| HHH_EXAMPLES = [] | |
| subsets = ['harmless', 'helpful', 'honest', 'other'] | |
| random.seed(42) | |
| for subset in subsets: | |
| dataset = load_dataset("HuggingFaceH4/hhh_alignment", subset) | |
| for example in dataset['test']: | |
| if random.random() < 0.5: | |
| HHH_EXAMPLES.append([ | |
| subset, | |
| example['input'], | |
| example['targets']['choices'][0], | |
| example['targets']['choices'][1], | |
| "Response 1" if example['targets']['labels'][0] == 1 else "Response 2", | |
| ]) | |
| else: | |
| HHH_EXAMPLES.append([ | |
| subset, | |
| example['input'], | |
| example['targets']['choices'][1], | |
| example['targets']['choices'][0], | |
| "Response 2" if example['targets']['labels'][0] == 1 else "Response 1", | |
| ]) | |
| def get_hhh_examples(subset, instruction, response1, response2, dummy_text): | |
| return instruction, response1, response2 | |
| # MT_BENCH_HUMAN_JUDGMENTS | |
| MT_BENCH_HUMAN_JUDGE_EXAMPLES = [] | |
| dataset = load_dataset("lmsys/mt_bench_human_judgments") | |
| for example in dataset['human']: | |
| if example['turn'] != 1: | |
| continue | |
| MT_BENCH_HUMAN_JUDGE_EXAMPLES.append([ | |
| example['model_a'], | |
| example['model_b'], | |
| str(example['conversation_a']), | |
| str(example['conversation_b']), | |
| "Model A" if example['winner'] == 'model_a' else "Model B", | |
| ]) | |
| def get_mt_bench_human_judge_examples(model_a, model_b, conversation_a, conversation_b, dummy_text): | |
| chat_history_a = [] | |
| chat_history_b = [] | |
| conversation_a = eval(conversation_a) | |
| conversation_b = eval(conversation_b) | |
| for i in range(0, len(conversation_a), 2): | |
| chat_history_a.append((conversation_a[i]['content'], conversation_a[i+1]['content'])) | |
| assert conversation_a[i]['role'] == 'user' and conversation_a[i+1]['role'] == 'assistant' | |
| for i in range(0, len(conversation_b), 2): | |
| chat_history_b.append((conversation_b[i]['content'], conversation_b[i+1]['content'])) | |
| assert conversation_b[i]['role'] == 'user' and conversation_b[i+1]['role'] == 'assistant' | |
| return chat_history_a, chat_history_b | |
| blender = llm_blender.Blender() | |
| blender.loadranker("llm-blender/PairRM") | |
| blender.loadfuser("llm-blender/gen_fuser_3b") | |
| def update_base_llms_num(k, llm_outputs): | |
| k = int(k) | |
| return [gr.Dropdown(choices=[f"LLM-{i+1}" for i in range(k)], | |
| value=f"LLM-1" if k >= 1 else "", visible=True), | |
| {f"LLM-{i+1}": llm_outputs.get(f"LLM-{i+1}", "") for i in range(k)}] | |
| def display_llm_output(llm_outputs, selected_base_llm_name): | |
| return gr.Textbox(value=llm_outputs.get(selected_base_llm_name, ""), | |
| label=selected_base_llm_name + " (Click Save to save current content)", | |
| placeholder=f"Enter {selected_base_llm_name} output here", show_label=True) | |
| def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_outputs): | |
| llm_outputs({selected_base_llm_name: selected_base_llm_output}) | |
| return llm_outputs | |
| def get_preprocess_examples(inst, input): | |
| # get the num_of_base_llms | |
| candidates = CANDIDATE_MAP[inst+input] | |
| num_candiates = len(candidates) | |
| dummy_text = inst+input | |
| return inst, input, num_candiates, dummy_text | |
| def update_base_llm_dropdown_along_examples(inst, input): | |
| candidates = CANDIDATE_MAP[inst+input] | |
| ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))} | |
| k = len(candidates) | |
| return ex_llm_outputs, "", "", \ | |
| gr.Dropdown(choices=[f"LLM-{i+1}" for i in range(k)], value=f"LLM-1" if k >= 1 else "", visible=True) | |
| def check_save_ranker_inputs(inst, input, llm_outputs, blender_config): | |
| if not inst and not input: | |
| raise gr.Error("Please enter instruction or input context") | |
| if not all([x for x in llm_outputs.values()]): | |
| empty_llm_names = [llm_name for llm_name, llm_output in llm_outputs.items() if not llm_output] | |
| raise gr.Error("Please enter base LLM outputs for LLMs: {}").format(empty_llm_names) | |
| return { | |
| "inst": inst, | |
| "input": input, | |
| "candidates": list(llm_outputs.values()), | |
| } | |
| def check_fuser_inputs(blender_state, blender_config, ranks): | |
| if "candidates" not in blender_state or len(ranks)==0: | |
| raise gr.Error("Please rank LLM outputs first") | |
| if not (blender_state.get("inst", None) or blender_state.get("input", None)): | |
| raise gr.Error("Please enter instruction or input context") | |
| return | |
| def llms_rank(inst, input, llm_outputs, blender_config): | |
| candidates = list(llm_outputs.values()) | |
| rank_params = { | |
| "source_max_length": blender_config['source_max_length'], | |
| "candidate_max_length": blender_config['candidate_max_length'], | |
| } | |
| ranks = blender.rank(instructions=[inst], inputs=[input], candidates=[candidates])[0] | |
| return [ranks, ", ".join([f"LLM-{i+1}: {rank}" for i, rank in enumerate(ranks)])] | |
| def llms_fuse(blender_state, blender_config, ranks): | |
| inst = blender_state['inst'] | |
| input = blender_state['input'] | |
| candidates = blender_state['candidates'] | |
| top_k_for_fuser = blender_config['top_k_for_fuser'] | |
| fuse_params = blender_config.copy() | |
| fuse_params.pop("top_k_for_fuser") | |
| fuse_params.pop("source_max_length") | |
| fuse_params['no_repeat_ngram_size'] = 3 | |
| top_k_candidates = get_topk_candidates_from_ranks([ranks], [candidates], top_k=top_k_for_fuser)[0] | |
| fuser_outputs = blender.fuse(instructions=[inst], inputs=[input], candidates=[top_k_candidates], **fuse_params, batch_size=1)[0] | |
| return [fuser_outputs, fuser_outputs] | |
| def display_fuser_output(fuser_output): | |
| return fuser_output | |
| with gr.Blocks(theme='ParityError/Anime') as demo: | |
| with gr.Tab("LLM-Blender"): | |
| # llm-blender interface | |
| with gr.Row(): | |
| gr.Markdown(descriptions.LLM_BLENDER_OVERALL_DESC) | |
| gr.Image("https://github.com/yuchenlin/LLM-Blender/blob/main/docs/llm_blender.png?raw=true", height=300) | |
| gr.Markdown("## Input and Base LLMs") | |
| with gr.Row(): | |
| with gr.Column(): | |
| inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True) | |
| input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True) | |
| with gr.Column(): | |
| saved_llm_outputs = gr.State(value={}) | |
| with gr.Group(): | |
| selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM", | |
| choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True) | |
| selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)", | |
| placeholder="Enter LLM-1 output here", show_label=True) | |
| with gr.Row(): | |
| base_llm_outputs_save_button = gr.Button('Save', variant='primary') | |
| base_llm_outputs_clear_single_button = gr.Button('Clear Single', variant='primary') | |
| base_llm_outputs_clear_all_button = gr.Button('Clear All', variant='primary') | |
| base_llms_num = gr.Slider( | |
| label='Number of base llms', | |
| minimum=MIN_BASE_LLM_NUM, | |
| maximum=MAX_BASE_LLM_NUM, | |
| step=1, | |
| value=MIN_BASE_LLM_NUM, | |
| ) | |
| blender_state = gr.State(value={}) | |
| saved_rank_outputs = gr.State(value=[]) | |
| saved_fuse_outputs = gr.State(value=[]) | |
| gr.Markdown("## Blender Outputs") | |
| with gr.Group(): | |
| rank_outputs = gr.Textbox(lines=1, label="Ranking outputs", placeholder="Ranking outputs", show_label=True) | |
| fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True) | |
| with gr.Row(): | |
| rank_button = gr.Button('Rank LLM Outputs', variant='primary') | |
| fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary') | |
| clear_button = gr.Button('Clear Blender Outputs', variant='primary') | |
| blender_config = gr.State(value={ | |
| "source_max_length": DEFAULT_SOURCE_MAX_LENGTH, | |
| "candidate_max_length": DEFAULT_CANDIDATE_MAX_LENGTH, | |
| "top_k_for_fuser": 3, | |
| "max_new_tokens": DEFAULT_FUSER_MAX_NEW_TOKENS, | |
| "temperature": 0.7, | |
| "top_p": 1.0, | |
| }) | |
| with gr.Accordion(label='Advanced options', open=False): | |
| source_max_length = gr.Slider( | |
| label='Max length of Instruction + Input', | |
| minimum=1, | |
| maximum=SOURCE_MAX_LENGTH, | |
| step=1, | |
| value=DEFAULT_SOURCE_MAX_LENGTH, | |
| ) | |
| candidate_max_length = gr.Slider( | |
| label='Max length of LLM-Output Candidate', | |
| minimum=1, | |
| maximum=CANDIDATE_MAX_LENGTH, | |
| step=1, | |
| value=DEFAULT_CANDIDATE_MAX_LENGTH, | |
| ) | |
| top_k_for_fuser = gr.Slider( | |
| label='Top-k ranked candidates to fuse', | |
| minimum=1, | |
| maximum=3, | |
| step=1, | |
| value=3, | |
| ) | |
| max_new_tokens = gr.Slider( | |
| label='Max new tokens fuser can generate', | |
| minimum=1, | |
| maximum=FUSER_MAX_NEW_TOKENS, | |
| step=1, | |
| value=DEFAULT_FUSER_MAX_NEW_TOKENS, | |
| ) | |
| temperature = gr.Slider( | |
| label='Temperature of fuser generation', | |
| minimum=0.1, | |
| maximum=2.0, | |
| step=0.1, | |
| value=0.7, | |
| ) | |
| top_p = gr.Slider( | |
| label='Top-p of fuser generation', | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| value=1.0, | |
| ) | |
| examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False) | |
| batch_examples = gr.Examples( | |
| examples=MIX_INSTRUCT_EXAMPLES, | |
| fn=get_preprocess_examples, | |
| cache_examples=True, | |
| examples_per_page=5, | |
| inputs=[inst_textbox, input_textbox], | |
| outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox], | |
| ) | |
| base_llms_num.input( | |
| fn=update_base_llms_num, | |
| inputs=[base_llms_num, saved_llm_outputs], | |
| outputs=[selected_base_llm_name_dropdown, saved_llm_outputs], | |
| ) | |
| examples_dummy_textbox.change( | |
| fn=update_base_llm_dropdown_along_examples, | |
| inputs=[inst_textbox, input_textbox], | |
| outputs=[saved_llm_outputs, rank_outputs, fuser_outputs, selected_base_llm_name_dropdown], | |
| ).then( | |
| fn=display_llm_output, | |
| inputs=[saved_llm_outputs, selected_base_llm_name_dropdown], | |
| outputs=selected_base_llm_output, | |
| ) | |
| selected_base_llm_name_dropdown.change( | |
| fn=display_llm_output, | |
| inputs=[saved_llm_outputs, selected_base_llm_name_dropdown], | |
| outputs=selected_base_llm_output, | |
| ) | |
| base_llm_outputs_save_button.click( | |
| fn=save_llm_output, | |
| inputs=[selected_base_llm_name_dropdown, selected_base_llm_output, saved_llm_outputs], | |
| outputs=saved_llm_outputs, | |
| ) | |
| base_llm_outputs_clear_all_button.click( | |
| fn=lambda: [{}, ""], | |
| inputs=[], | |
| outputs=[saved_llm_outputs, selected_base_llm_output], | |
| ) | |
| base_llm_outputs_clear_single_button.click( | |
| fn=lambda: "", | |
| inputs=[], | |
| outputs=selected_base_llm_output, | |
| ) | |
| rank_button.click( | |
| fn=check_save_ranker_inputs, | |
| inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config], | |
| outputs=blender_state, | |
| ).success( | |
| fn=llms_rank, | |
| inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config], | |
| outputs=[saved_rank_outputs, rank_outputs], | |
| ) | |
| fuse_button.click( | |
| fn=check_fuser_inputs, | |
| inputs=[blender_state, blender_config, saved_rank_outputs], | |
| outputs=[], | |
| ).success( | |
| fn=llms_fuse, | |
| inputs=[blender_state, blender_config, saved_rank_outputs], | |
| outputs=[saved_fuse_outputs, fuser_outputs], | |
| ) | |
| clear_button.click( | |
| fn=lambda: ["", "", {}, []], | |
| inputs=[], | |
| outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs], | |
| ) | |
| # update blender config | |
| source_max_length.change( | |
| fn=lambda x, y: y.update({"source_max_length": x}) or y, | |
| inputs=[source_max_length, blender_config], | |
| outputs=blender_config, | |
| ) | |
| candidate_max_length.change( | |
| fn=lambda x, y: y.update({"candidate_max_length": x}) or y, | |
| inputs=[candidate_max_length, blender_config], | |
| outputs=blender_config, | |
| ) | |
| top_k_for_fuser.change( | |
| fn=lambda x, y: y.update({"top_k_for_fuser": x}) or y, | |
| inputs=[top_k_for_fuser, blender_config], | |
| outputs=blender_config, | |
| ) | |
| max_new_tokens.change( | |
| fn=lambda x, y: y.update({"max_new_tokens": x}) or y, | |
| inputs=[max_new_tokens, blender_config], | |
| outputs=blender_config, | |
| ) | |
| temperature.change( | |
| fn=lambda x, y: y.update({"temperature": x}) or y, | |
| inputs=[temperature, blender_config], | |
| outputs=blender_config, | |
| ) | |
| top_p.change( | |
| fn=lambda x, y: y.update({"top_p": x}) or y, | |
| inputs=[top_p, blender_config], | |
| outputs=blender_config, | |
| ) | |
| with gr.Tab("PairRM"): | |
| # PairRM interface | |
| with gr.Row(): | |
| gr.Markdown(descriptions.PairRM_OVERALL_DESC) | |
| gr.Image("https://yuchenlin.xyz/LLM-Blender/pairranker.png") | |
| with gr.Tab("Compare two responses"): | |
| instruction = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True) | |
| with gr.Row(): | |
| response1 = gr.Textbox(lines=4, label="Response 1", placeholder="Enter response 1 here", show_label=True) | |
| response2 = gr.Textbox(lines=4, label="Response 2", placeholder="Enter response 2 here", show_label=True) | |
| with gr.Row(): | |
| compare_button = gr.Button('Compare', variant='primary') | |
| clear_button = gr.Button('Clear', variant='primary') | |
| with gr.Row(): | |
| compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True) | |
| compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True) | |
| def compare_fn(inst, response1, response2): | |
| if not inst: | |
| raise gr.Error("Please enter instruction") | |
| if not response1 or not response2: | |
| raise gr.Error("Please enter response 1 and response 2") | |
| comparison_results = blender.compare([inst], [response1], [response2], return_logits=True) | |
| logit = comparison_results[0] | |
| if logit > 0: | |
| result = "Response 1 is better than Response 2" | |
| prob = f"Confidence: {round(logit, 2)}" | |
| elif logit < 0: | |
| result = "Response 2 is better than Response 1" | |
| prob = f"Cofidence: {round(abs(logit), 2)}" | |
| else: | |
| result = "Response 1 and Response 2 are equally good" | |
| prob = f"No confidence for tie" | |
| return [result, prob] | |
| compare_button.click( | |
| fn=compare_fn, | |
| inputs=[instruction, response1, response2], | |
| outputs=[compare_result, compare_result_prob], | |
| ) | |
| clear_button.click( | |
| fn=lambda: ["", ""], | |
| inputs=[], | |
| outputs=[compare_result, compare_result_prob], | |
| ) | |
| hhh_dummy_textbox1 = gr.Textbox(lines=1, label="subset", placeholder="", show_label=False, visible=False) | |
| hhh_dummy_textbox2 = gr.Textbox(lines=1, label="Better Response", placeholder="", show_label=False, visible=False) | |
| gr.Markdown("## Examples from [HuggingFaceH4/hhh_alignment](https://huggingface.co/datasets/HuggingFaceH4/hhh_alignment)") | |
| gr.Examples( | |
| HHH_EXAMPLES, | |
| fn=get_hhh_examples, | |
| cache_examples=True, | |
| examples_per_page=5, | |
| inputs=[hhh_dummy_textbox1, instruction, response1, response2, hhh_dummy_textbox2], | |
| outputs=[instruction, response1, response2], | |
| ) | |
| with gr.Tab("Compare assistant's response in two multi-turn conversations"): | |
| gr.Markdown("NOTE: Comparison of two conversations is based on that the user query in each turn is the same of two conversations.") | |
| def append_message(message, chat_history): | |
| if not message: | |
| return "", chat_history | |
| if len(chat_history) == 0: | |
| chat_history.append((message, "(Please enter your bot response)")) | |
| else: | |
| if chat_history[-1][1] == "(Please enter your bot response)": | |
| chat_history[-1] = (chat_history[-1][0], message) | |
| else: | |
| chat_history.append((message, "(Please enter your bot response)")) | |
| return "", chat_history | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Conversation A") | |
| chatbot1 = gr.Chatbot() | |
| msg1 = gr.Textbox(lines=1, label="Enter Chat history for Conversation A", placeholder="Enter your message here", show_label=True) | |
| clear1 = gr.ClearButton([msg1, chatbot1]) | |
| msg1.submit(append_message, [msg1, chatbot1], [msg1, chatbot1]) | |
| with gr.Column(): | |
| gr.Markdown("### Conversation B") | |
| chatbot2 = gr.Chatbot() | |
| msg2 = gr.Textbox(lines=1, label="Enter Chat history for Conversation B", placeholder="Enter your message here", show_label=True) | |
| clear2 = gr.ClearButton([msg2, chatbot2]) | |
| msg2.submit(append_message, [msg2, chatbot2], [msg2, chatbot2]) | |
| with gr.Row(): | |
| compare_button = gr.Button('Compare', variant='primary') | |
| with gr.Row(): | |
| compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True) | |
| compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True) | |
| def compare_conv_fn(chat_history1, chat_history2): | |
| if len(chat_history1) == 0 or len(chat_history2) == 0: | |
| raise gr.Error("Please enter chat history for both conversations") | |
| assert chat_history1[-1][1] != "(Please enter your bot response)" \ | |
| and chat_history2[-1][1] != "(Please enter your bot response)", \ | |
| "Please complete chat history for both conversations" | |
| chat1_messages = [] | |
| for item in chat_history1: | |
| chat1_messages.append({ | |
| "role": "USER", | |
| "content": item[0], | |
| }) | |
| chat1_messages.append({ | |
| "role": "ASSISTANT", | |
| "content": item[1], | |
| }) | |
| chat2_messages = [] | |
| for item in chat_history2: | |
| chat2_messages.append({ | |
| "role": "USER", | |
| "content": item[0], | |
| }) | |
| chat2_messages.append({ | |
| "role": "ASSISTANT", | |
| "content": item[1], | |
| }) | |
| comparison_results = blender.compare_conversations([chat1_messages], [chat2_messages], return_logits=True) | |
| logit = comparison_results[0] | |
| if logit > 0: | |
| result = "Assistant's response in Conversation A is better than Conversation B" | |
| prob = f"Confidence: {round(logit, 2)}" | |
| elif logit < 0: | |
| result = "Assistant's response in Conversation B is better than Conversation A" | |
| prob = f"Cofidence: {round(abs(logit), 2)}" | |
| else: | |
| result = "Assistant's response in Conversation A and Conversation B are equally good" | |
| prob = f"No confidence for tie" | |
| return [result, prob] | |
| compare_button.click( | |
| fn=compare_conv_fn, | |
| inputs=[chatbot1, chatbot2], | |
| outputs=[compare_result, compare_result_prob], | |
| ) | |
| model_a_dummy_textbox = gr.Textbox(lines=1, label="Model A", placeholder="", show_label=False, visible=False) | |
| model_b_dummy_textbox = gr.Textbox(lines=1, label="Model B", placeholder="", show_label=False, visible=False) | |
| winner_dummy_textbox = gr.Textbox(lines=1, label="Better Model in conversation", placeholder="", show_label=False, visible=False) | |
| chatbot1_dummy_textbox = gr.Textbox(lines=1, label="Conversation A", placeholder="", show_label=False, visible=False) | |
| chatbot2_dummy_textbox = gr.Textbox(lines=1, label="Conversation B", placeholder="", show_label=False, visible=False) | |
| gr.Markdown("## Examples from [lmsys/mt_bench_human_judgments](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments)") | |
| gr.Examples( | |
| MT_BENCH_HUMAN_JUDGE_EXAMPLES, | |
| fn=get_mt_bench_human_judge_examples, | |
| cache_examples=True, | |
| examples_per_page=5, | |
| inputs=[model_a_dummy_textbox, model_b_dummy_textbox, chatbot1_dummy_textbox, chatbot2_dummy_textbox, winner_dummy_textbox], | |
| outputs=[chatbot1, chatbot2], | |
| ) | |
| gr.Markdown(descriptions.CITATION) | |
| demo.queue(max_size=20).launch() |