File size: 3,206 Bytes
fe97ada
 
b8cf9e8
 
 
 
 
 
 
 
 
fe97ada
b8cf9e8
fe97ada
 
 
b8cf9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe97ada
b8cf9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe97ada
 
 
 
b8cf9e8
 
fe97ada
b8cf9e8
 
 
 
 
 
 
 
 
 
 
 
fe97ada
 
b8cf9e8
 
 
 
 
fe97ada
b8cf9e8
fe97ada
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os

import chainlit as cl  # importing chainlit for our app
import torch
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
import bitsandbytes as bnb

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


# Prompt Templates
INSTRUCTION_PROMPT_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Please convert the following legal content into a human-readable summary<|eot_id|><|start_header_id|>user<|end_header_id|>

[LEGAL_DOC]
{input}
[END_LEGAL_DOC]<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""

RESPONSE_TEMPLATE = """
{summary}<|eot_id|>
"""


def create_prompt(sample, include_response=False):
    """
    Parameters:
      - sample: dict representing row of dataset
      - include_response: bool

    Functionality:
      This function should build the Python str `full_prompt`.

      If `include_response` is true, it should include the summary -
      else it should not contain the summary (useful for prompting) and testing

    Returns:
      - full_prompt: str
    """

    full_prompt = INSTRUCTION_PROMPT_TEMPLATE.format(input=sample["original_text"])

    if include_response:
        full_prompt += RESPONSE_TEMPLATE.format(summary=sample["reference_summary"])

    full_prompt += "<|end_of_text|>"

    return full_prompt


@cl.on_chat_start
async def start_chat():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16,
    )

    model_id = "lakshyaag/llama38binstruct_summarize"

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
        cache_dir=os.path.join(os.getcwd(), ".cache"),
    )

    # Move model to GPU if available
    if torch.cuda.is_available():
        model = model.to("cuda")

    tokenizer = AutoTokenizer.from_pretrained(
        model_id, cache_dir=os.path.join(os.getcwd(), ".cache")
    )

    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    cl.user_session.set("model", model)
    cl.user_session.set("tokenizer", tokenizer)


@cl.on_message  # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: cl.Message):
    model = cl.user_session.get("model")
    tokenizer = cl.user_session.get("tokenizer")

    # convert str input into tokenized input
    encoded_input = tokenizer(message, return_tensors="pt")

    # send the tokenized inputs to our GPU
    model_inputs = encoded_input.to("cuda")

    # generate response and set desired generation parameters
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=256,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    # decode output from tokenized output to str output
    decoded_output = tokenizer.batch_decode(generated_ids)

    # return only the generated response (not the prompt) as output
    response = decoded_output[0].split("<|end_header_id|>")[-1]

    await message.reply(response)


if __name__ == "__main__":
    cl.run()