Spaces:
Runtime error
Runtime error
File size: 2,131 Bytes
a8c6d33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from text_generation import Client
import os
from dotenv import load_dotenv
load_dotenv()
PAPERSPACE_IP = os.getenv("PAPERSPACE_IP")
client = Client(PAPERSPACE_IP)
def generate_text(input_text, max_new_tokens=20, temperature=1):
return client.generate(input_text, max_new_tokens=max_new_tokens, temperature=temperature).generated_text
def generate_multi_text(input_text, file_path, max_new_tokens=20, temperature=1, out_path=None, earlystop = None):
with open(file_path, "r") as file:
rows = file.readlines()
if earlystop is not None:
rows = rows[:earlystop]
multi_turns = [formatter(row.strip()) for row in rows]
print("You are playing " + str(len(multi_turns)) + " turns.")
generated_text = []
with open(out_path, "w") as file:
for i, turn in enumerate(multi_turns):
single_turn_resp = generate_text(input_text+turn,
max_new_tokens=max_new_tokens, temperature=temperature)
generated_text.append(single_turn_resp)
file.write(f"Turn {i+1}: {single_turn_resp}\n")
print(turn)
print(single_turn_resp)
print("-----------")
return generated_text
def read_text_file(file_path):
with open(file_path, 'r') as file:
return file.read()
def formatter(user_prompt):
return f"[User]: {user_prompt.strip()} \n [You]: \n"
def main():
cwd = os.getcwd()
input_text = read_text_file(os.path.join(cwd, 'utils/prompts/prompt_attitude.txt'))
# user_turn = read_text_file(os.path.join(cwd, '../finetune/data/turns/conversation_nothing.txt'))
max_new_tokens = 40
temperature = 0.3
multi_path = os.path.join(cwd,'inappropriate.txt')
out_path = os.path.join(cwd, f'utils/user_turns/multi_turns_conversation_t{temperature}_m{max_new_tokens}_promptatt_mistral_inapp.txt')
generated_text = generate_multi_text(input_text, multi_path, max_new_tokens, temperature, out_path)
# print(input_text+user_turn)
# generate_text_resp = generate_text(input_text+user_turn,max_new_tokens )
# print(generate_text_resp)
if __name__ == "__main__":
main()
|