File size: 3,146 Bytes
fe97ada
 
b8cf9e8
 
 
 
 
 
 
b644180
 
b8cf9e8
fe97ada
b8cf9e8
fe97ada
 
 
b8cf9e8
 
 
 
 
 
 
 
 
 
 
b644180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8cf9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe97ada
b8cf9e8
 
 
fe97ada
 
b644180
fe97ada
b8cf9e8
b644180
b8cf9e8
 
b644180
b8cf9e8
 
 
 
 
 
 
fe97ada
 
b8cf9e8
b644180
b8cf9e8
 
 
fe97ada
b8cf9e8
fe97ada
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os

import chainlit as cl  # importing chainlit for our app
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
from peft import PeftModel, PeftConfig

import bitsandbytes as bnb

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


# Prompt Templates
INSTRUCTION_PROMPT_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Please convert the following legal content into a human-readable summary<|eot_id|><|start_header_id|>user<|end_header_id|>
[LEGAL_DOC]
{input}
[END_LEGAL_DOC]<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""

RESPONSE_TEMPLATE = """
{summary}<|eot_id|>
"""

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

config = PeftConfig.from_pretrained("lakshyaag/llama38binstruct_summarize")

base_model = AutoModelForCausalLM.from_pretrained(
    "NousResearch/Meta-Llama-3-8B-Instruct",
    quantization_config=bnb_config,
    device_map="auto",
)

model = PeftModel.from_pretrained(base_model, "lakshyaag/llama38binstruct_summarize")

# Move model to GPU if available
if torch.cuda.is_available():
    model = model.to("cuda")

tokenizer = AutoTokenizer.from_pretrained("lakshyaag/llama38binstruct_summarize")

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"


def create_prompt(sample, include_response=False):
    """
    Parameters:
      - sample: dict representing row of dataset
      - include_response: bool
    Functionality:
      This function should build the Python str `full_prompt`.
      If `include_response` is true, it should include the summary -
      else it should not contain the summary (useful for prompting) and testing
    Returns:
      - full_prompt: str
    """

    full_prompt = INSTRUCTION_PROMPT_TEMPLATE.format(input=sample["original_text"])

    if include_response:
        full_prompt += RESPONSE_TEMPLATE.format(summary=sample["reference_summary"])

    full_prompt += "<|end_of_text|>"

    return full_prompt


@cl.on_message  # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: cl.Message):
    prompt = create_prompt({"original_text": message.content}, include_response=False)

    # convert str input into tokenized input
    encoded_input = tokenizer(prompt, return_tensors="pt")

    # send the tokenized inputs to our GPU
    model_inputs = encoded_input.to("cuda" if torch.cuda.is_available() else "cpu")

    # generate response and set desired generation parameters
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=256,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    # decode output from tokenized output to str output
    decoded_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    # return only the generated response (not the prompt) as output
    response = decoded_output[0].split("<|end_header_id|>")[-1]

    await message.reply(response)


if __name__ == "__main__":
    cl.run()