File size: 3,574 Bytes
03561be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json

from mmgpt.datasets.dolly_dataset import DollyDataset


TEMPLATE = {
    "description": "Template used by Alpaca-LoRA.",
    "prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n",
    "prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Response:\n",
    "prompt_dial": "\n\n### Instruction:\n{question}\n\n### Response:\n",
    "response_split": "### Response:",
}

class LangDialPrompter:
    def __call__(self, question, options=None):
        if options:
            options = ", ".join(options)
            res = TEMPLATE["prompt_choice"].format(image="<image>", question=question, options=options)
        else:
            res = TEMPLATE["prompt_dial"].format(question=question)
        return res

    def get_response(self, output: str) -> str:
        return output.split(TEMPLATE["response_split"])[-1].strip()

class BaiZeDataset(DollyDataset):
    """

    ```json

    [

        {

            "instruction": "Identify the odd one out.",

            "input": "Twitter, Instagram, Telegram",

            "output": "The odd one out is Telegram. Twitter and Instagram are social media platforms mainly for sharing information, images and videos while Telegram is a cloud-based instant messaging and voice-over-IP service."

        },

    ]

    """
    def __init__(self, *args, **kwargs):
        super(BaiZeDataset, self).__init__(*args, **kwargs)
        self.prompter = LangDialPrompter()

    def load_annotation(self, ann_path):
        self.annotation = json.load(open(ann_path, "r"))

    def process_text(self, anns):
        # TODO remove this
        begin_string = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
        convs = anns['input'].split("[|Human|] ")
        conv_list = []
        for conv_id, one_conv in enumerate(convs[1:-1]):
            question, answer = one_conv.split("[|AI|] ")
            question = question.replace("\n", "")
            answer = answer.replace("\n", "")
            instruction = self.prompter(question)
            if conv_id == 0:
                single_conv = dict(instruction=begin_string + instruction, answer=answer)
            else:
                single_conv = dict(instruction=instruction, answer=answer)
            conv_list.append(single_conv)
        return conv_list
    
    def __getitem__(self, index):
        ann = self.annotation[index]
        text_list = self.process_text(ann)
        res_list = []
        for text in text_list:
            single_res = self.tokenize(text)
            single_res["instruction"] = text["instruction"]
            single_res["answer"] = text["answer"]
            res_list.append(single_res)

        input_ids = []
        attention_mask = []
        labels = []
        instruction = []
        answer = []
        for res in res_list:
            input_ids.extend(res["input_ids"])
            attention_mask.extend(res["attention_mask"])
            labels.extend(res["labels"])
            instruction.append(res["instruction"])
            answer.append(res["answer"])

        res = dict(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels, instruction=instruction, answer=answer
        )
        return res