File size: 2,131 Bytes
a8c6d33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from text_generation import Client
import os
from dotenv import load_dotenv

load_dotenv()

PAPERSPACE_IP = os.getenv("PAPERSPACE_IP")

client = Client(PAPERSPACE_IP)

def generate_text(input_text, max_new_tokens=20, temperature=1):
    return client.generate(input_text, max_new_tokens=max_new_tokens, temperature=temperature).generated_text

def generate_multi_text(input_text, file_path, max_new_tokens=20, temperature=1, out_path=None, earlystop = None):
    with open(file_path, "r") as file:
        rows = file.readlines()

    if earlystop is not None:
        rows = rows[:earlystop] 
    multi_turns = [formatter(row.strip()) for row in rows]
    print("You are playing " + str(len(multi_turns)) + " turns.")
    generated_text = []

    with open(out_path, "w") as file:
        for i, turn in enumerate(multi_turns):
            single_turn_resp = generate_text(input_text+turn, 
            max_new_tokens=max_new_tokens, temperature=temperature)
            generated_text.append(single_turn_resp)
            file.write(f"Turn {i+1}: {single_turn_resp}\n")
            print(turn)
            print(single_turn_resp)
            print("-----------")
    return generated_text

def read_text_file(file_path):
    with open(file_path, 'r') as file:
        return file.read()

def formatter(user_prompt):
    return f"[User]: {user_prompt.strip()} \n [You]: \n" 

def main():
    cwd = os.getcwd()
    input_text = read_text_file(os.path.join(cwd, 'utils/prompts/prompt_attitude.txt'))
    # user_turn = read_text_file(os.path.join(cwd, '../finetune/data/turns/conversation_nothing.txt'))
    max_new_tokens = 40
    temperature = 0.3
    multi_path = os.path.join(cwd,'inappropriate.txt')
    out_path = os.path.join(cwd, f'utils/user_turns/multi_turns_conversation_t{temperature}_m{max_new_tokens}_promptatt_mistral_inapp.txt')

    generated_text = generate_multi_text(input_text, multi_path, max_new_tokens, temperature, out_path)
    # print(input_text+user_turn)
    # generate_text_resp = generate_text(input_text+user_turn,max_new_tokens )
    # print(generate_text_resp)
if __name__ == "__main__":
    main()