File size: 1,167 Bytes
5488167
 
 
 
 
96ec844
 
5488167
 
 
 
 
96ec844
 
5488167
 
 
 
 
96ec844
 
 
 
 
 
 
 
 
 
5488167
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import json
from pathlib import Path
import random


DEFAULT_ROOT_DIR = "examples/default/input_params"
ZH_RAP_LORA_ROOT_DIR = "examples/zh_rap_lora/input_params"

class DataSampler:
    def __init__(self, root_dir=DEFAULT_ROOT_DIR):
        self.root_dir = root_dir
        self.input_params_files = list(Path(self.root_dir).glob("*.json"))
        self.zh_rap_lora_input_params_files = list(Path(ZH_RAP_LORA_ROOT_DIR).glob("*.json"))
        self.zh_rap_lora_input_params_files += list(Path(ZH_RAP_LORA_ROOT_DIR).glob("*.json"))

    def load_json(self, file_path):
        with open(file_path, "r", encoding="utf-8") as f:
            return json.load(f)

    def sample(self, lora_name_or_path=None):
        if lora_name_or_path is None or lora_name_or_path == "none":
            json_path = random.choice(self.input_params_files)
            json_data = self.load_json(json_path)
        else:
            json_path = random.choice(self.zh_rap_lora_input_params_files)
            json_data = self.load_json(json_path)
            # Update the lora_name in the json_data
            json_data["lora_name_or_path"] = lora_name_or_path

        return json_data