Commit 
							
							·
						
						bd0305e
	
1
								Parent(s):
							
							b058a81
								
Create app.py
Browse files
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,99 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList
         | 
| 4 | 
            +
            import time
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from torch.nn import functional as F
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            m = AutoModelForCausalLM.from_pretrained("/mnt/nvme/home/dakota/ckpts/stablelm/7B-sft-combined/checkpoint-8000", torch_dtype=torch.float16).cuda()
         | 
| 10 | 
            +
            tok = AutoTokenizer.from_pretrained("/mnt/nvme/home/dakota/stablelm_tokenizer")
         | 
| 11 | 
            +
            generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            start_message = """<|SYSTEM|># StableAssistant
         | 
| 15 | 
            +
            - StableAssistant is A helpful and harmless Open Source AI Language Model developed by Stability and CarperAI.
         | 
| 16 | 
            +
            - StableAssistant is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
         | 
| 17 | 
            +
            - StableAssistant is more than just an information source, StableAssistant is also able to write poetry, short stories, and make jokes.
         | 
| 18 | 
            +
            - StableAssistant will refuse to participate in anything that could harm a human."""
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class StopOnTokens(StoppingCriteria):
         | 
| 22 | 
            +
                def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
         | 
| 23 | 
            +
                    stop_ids = [50278, 50279, 50277, 1, 0]
         | 
| 24 | 
            +
                    for stop_id in stop_ids:
         | 
| 25 | 
            +
                        if input_ids[0][-1] == stop_id:
         | 
| 26 | 
            +
                            return True
         | 
| 27 | 
            +
                    return False
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def contrastive_generate(text, bad_text):
         | 
| 31 | 
            +
                with torch.no_grad():
         | 
| 32 | 
            +
                    tokens = tok(text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
         | 
| 33 | 
            +
                    bad_tokens = tok(bad_text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
         | 
| 34 | 
            +
                    history = None
         | 
| 35 | 
            +
                    bad_history = None
         | 
| 36 | 
            +
                    curr_output = list()
         | 
| 37 | 
            +
                    for i in range(1024):
         | 
| 38 | 
            +
                        out = m(tokens, past_key_values=history, use_cache=True)
         | 
| 39 | 
            +
                        logits = out.logits
         | 
| 40 | 
            +
                        history = out.past_key_values
         | 
| 41 | 
            +
                        bad_out = m(bad_tokens, past_key_values=bad_history, use_cache=True)
         | 
| 42 | 
            +
                        bad_logits = bad_out.logits
         | 
| 43 | 
            +
                        bad_history = bad_out.past_key_values
         | 
| 44 | 
            +
                        probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
         | 
| 45 | 
            +
                        bad_probs = F.softmax(bad_logits.float(), dim=-1)[0][-1].cpu()
         | 
| 46 | 
            +
                        logits = torch.log(probs)
         | 
| 47 | 
            +
                        bad_logits = torch.log(bad_probs)
         | 
| 48 | 
            +
                        logits[probs > 0.1] = logits[probs > 0.1] - bad_logits[probs > 0.1]
         | 
| 49 | 
            +
                        probs = F.softmax(logits)
         | 
| 50 | 
            +
                        out = int(torch.multinomial(probs, 1))
         | 
| 51 | 
            +
                        if out in [50278, 50279, 50277, 1, 0]:
         | 
| 52 | 
            +
                            break
         | 
| 53 | 
            +
                        else:
         | 
| 54 | 
            +
                            curr_output.append(out)
         | 
| 55 | 
            +
                        out = np.array([out])
         | 
| 56 | 
            +
                        tokens = torch.from_numpy(np.array([out])).to(
         | 
| 57 | 
            +
                            tokens.device)
         | 
| 58 | 
            +
                        bad_tokens = torch.from_numpy(np.array([out])).to(
         | 
| 59 | 
            +
                            tokens.device)
         | 
| 60 | 
            +
                    return tok.decode(curr_output)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            def generate(text, bad_text=None):
         | 
| 63 | 
            +
                stop = StopOnTokens()
         | 
| 64 | 
            +
                result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True, temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
         | 
| 65 | 
            +
                return result[0]["generated_text"].replace(text, "")
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def user(user_message, history):
         | 
| 69 | 
            +
                return "", history + [[user_message, ""]]
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def bot(history, curr_system_message):
         | 
| 73 | 
            +
                messages = curr_system_message + "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]]) for item in history])
         | 
| 74 | 
            +
                output = generate(messages)
         | 
| 75 | 
            +
                history[-1][1] = output
         | 
| 76 | 
            +
                time.sleep(1)
         | 
| 77 | 
            +
                return history
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def system_update(msg):
         | 
| 81 | 
            +
                global curr_system_message
         | 
| 82 | 
            +
                curr_system_message = msg
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            with gr.Blocks() as demo:
         | 
| 86 | 
            +
                with gr.Row():
         | 
| 87 | 
            +
                    with gr.Column():
         | 
| 88 | 
            +
                        chatbot = gr.Chatbot([])
         | 
| 89 | 
            +
                        clear = gr.Button("Clear")
         | 
| 90 | 
            +
                    with gr.Column():
         | 
| 91 | 
            +
                        system_msg = gr.Textbox(start_message, label="System Message", interactive=True)
         | 
| 92 | 
            +
                        msg = gr.Textbox(label="Chat Message")
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
         | 
| 95 | 
            +
                    bot, [chatbot, system_msg], chatbot
         | 
| 96 | 
            +
                )
         | 
| 97 | 
            +
                system_msg.change(system_update, system_msg, None, queue=False)
         | 
| 98 | 
            +
                clear.click(lambda: None, None, chatbot, queue=False)
         | 
| 99 | 
            +
            demo.launch(share=True)
         | 
 
			
