Spaces:
Runtime error
Runtime error
File size: 1,167 Bytes
5488167 96ec844 5488167 96ec844 5488167 96ec844 5488167 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import json
from pathlib import Path
import random
DEFAULT_ROOT_DIR = "examples/default/input_params"
ZH_RAP_LORA_ROOT_DIR = "examples/zh_rap_lora/input_params"
class DataSampler:
def __init__(self, root_dir=DEFAULT_ROOT_DIR):
self.root_dir = root_dir
self.input_params_files = list(Path(self.root_dir).glob("*.json"))
self.zh_rap_lora_input_params_files = list(Path(ZH_RAP_LORA_ROOT_DIR).glob("*.json"))
self.zh_rap_lora_input_params_files += list(Path(ZH_RAP_LORA_ROOT_DIR).glob("*.json"))
def load_json(self, file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
def sample(self, lora_name_or_path=None):
if lora_name_or_path is None or lora_name_or_path == "none":
json_path = random.choice(self.input_params_files)
json_data = self.load_json(json_path)
else:
json_path = random.choice(self.zh_rap_lora_input_params_files)
json_data = self.load_json(json_path)
# Update the lora_name in the json_data
json_data["lora_name_or_path"] = lora_name_or_path
return json_data
|