File size: 6,046 Bytes
e956bee
 
 
 
 
 
 
a100dd2
 
e956bee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d622f83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f8f8ff
 
05b3239
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import gradio as gr
import threading
import codecs
#from ast import literal_eval
from datetime import datetime

import os
#os.environ['TRANSFORMERS_CACHE'] = '/data/.modelcache/huggingface/hub/'
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:516"

from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import gc

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]

models = {}
output = {}


def gen_thread(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty):
    global output
    n_input_tokens = inputs.shape[1]
    outputs = models[model_name][1].generate(inputs,
                    max_new_tokens=max_new_tokens, 
                    min_length=min_length,
                    do_sample=True, 
                    temperature=temperature,
                    top_p=top_p,
                    repetition_penalty=repetition_penalty
                    )
    output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:])

def to_md(text):
    # return text.replace("\n", "<br />")
    return text.replace("\n", "<br />")

def infer(
        prompt,
        min_length=2,
        max_new_tokens=10,
        temperature=0.1,
        top_p=1.0,
        repetition_penalty = 1.0,
        stop="\n",
        num_completions=1,
        seed=42,
):

    #gc.collect()
    #torch.cuda.empty_cache()

    if not models:
        for model_name in MODEL_NAMES:
            tokenizer = BloomTokenizerFast.from_pretrained(model_name)
            model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
            model = model.to(DEVICE)
            models[model_name] = tokenizer, model

    max_new_tokens = int(max_new_tokens)
    num_completions = int(num_completions)
    temperature = float(temperature)
    top_p = float(top_p)
    stop = stop.split(";")
    repetition_penalty = float(repetition_penalty)
    seed = seed

    assert 1 <= max_new_tokens <= 384
    assert 0 <= min_length <= max_new_tokens
    assert 1 <= num_completions <= 5
    assert 0.0 <= temperature <= 1.0
    assert 0.0 <= top_p <= 1.0
    assert 0.9 <= repetition_penalty <= 3.0

    if temperature == 0.0:
        temperature = 0.01
    if prompt == "":
        prompt = " "
    
    threads = list()
    print(f"START -> ({datetime.now()})\n")
    print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
    for model_name in MODEL_NAMES:
        inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
        x = threading.Thread(target=gen_thread, args=(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty))
        threads.append(x)
        x.start()
        #n_input_tokens = inputs.shape[1]
        # outputs = models[model_name][1].generate(inputs,
        #             max_new_tokens=max_new_tokens, 
        #             min_length=min_length,
        #             do_sample=True, 
        #             temperature=temperature,
        #             top_p=top_p,
        #             repetition_penalty=repetition_penalty
        #             )
        #output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:])

        #output[model_name] = outputs[len(prompt):]
    
    # Join Threads
    for model_name, thread in enumerate(threads):
        print(f"waiting on: {model_name}\n")
        thread.join()
        print(f"{model_name} thread done\n")
    
    
    for model_name in MODEL_NAMES:
        stop = codecs.getdecoder("unicode_escape")(stop[0])[0]
        stop =  [x.strip(' ') for x in stop.split(',')]
        for stop_word in stop:
            if stop_word != '' and stop_word in output[model_name]:
                output[model_name] = output[model_name][:output[model_name].find(stop_word)]

        print(f"--- START: {model_name} --- \n{output[model_name]}\n--- END {model_name} ---\n\n")
    
    print(f"DONE -> ({datetime.now()})\n")
    return output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]


examples = [
    [
        # Question Answering
        '''Please answer the following question:
Question: What is the capital of Germany?
Answer:''', 1, 3, 0.2, 1.0, 1.0, "\\n,</s>"],
    [
        # Natural Language Interface
        '''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
Possible labels: 1. entailment 2. contradiction
Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes.
Label: entailment
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
Label: contradiction
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
Label:''', 1, 2, 0.2, 1.0, 1.0, "\\n,</s>"]
]



iface = gr.Interface(
    fn=infer,
    allow_flagging="never",
    inputs=[
        gr.Textbox(lines=20),  # prompt
        gr.Slider(0, 256, value=1), #min_length
        gr.Slider(1, 384, value=20),  # max_tokens
        gr.Slider(0.0, 1.0, value=0.2),  # temperature
        gr.Slider(0.0, 1.0, value=0.9),  # top_p
        gr.Slider(0.9, 3.0, value=1.0),  # repetition penalty
        gr.Textbox(lines=1, value="\\n,</s>") # stop
    ],
    outputs=[gr.Textbox(lines=7, label="BLOOM OUTPUT:"), gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")],
    
    examples=examples,
    cache_examples=True,
    title="BLOOM vs BLOOMZ",
    description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the [Petals](https://petals.ml/) network. WARNING: Inference may take a long time. Keep the max_tokens low to speed things up.<p>
    <p>Please consider [joining](https://github.com/bigscience-workshop/petals) the Petals network to help speed up inference.</p><p>Big thanks to [RFTCapital](https://www.rftcapital.com) for providing initial compute resources.</p>'''
).launch()