File size: 2,939 Bytes
a8c6d33
 
7aaf01e
a8c6d33
 
 
 
 
7aaf01e
a8c6d33
7aaf01e
 
 
 
a8c6d33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from huggingface_hub import InferenceClient
import os
# from dotenv import load_dotenv
import gradio as gr
import pandas as pd 
import datetime
import psycopg2

# load_dotenv()


PAPERSPACE_IP = "http://184.105.3.252:8080"

# PAPERSPACE_IP = os.getenv("PAPERSPACE_IP")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
conn = psycopg2.connect(
        host="containers-us-west-119.railway.app",
        port=7948,
        database="railway",
        user="postgres",
        password="Bf7unSmYIhLYGpxClo1s"
    )

def read_text_file(file_path):
    with open(file_path, 'r') as file:
        return file.read()

def formatter(user_prompt):
    cwd = os.getcwd()
    input_text = read_text_file(os.path.join(cwd, 'utils/prompts/prompt_attitude_fixed.txt'))
    return input_text + f"[User]: {user_prompt.strip()} \n [You]: \n" 

client = InferenceClient(model=PAPERSPACE_IP)

def add_to_log(input_text, output_text, timestamp):
    # Connect to the PostgreSQL database
    try: 
    # Create a cursor object to execute SQL queries
        cursor = conn.cursor()

        # Define the SQL query to insert a new row
        sql = "INSERT INTO mistral_7b_log_controlled (input, output, timestamp) VALUES (%s, %s, %s)"

        # Execute the SQL query with the input and output text as parameters
        cursor.execute(sql, (input_text, output_text, timestamp))

        # Commit the changes to the database
        conn.commit()
    except Exception as e:
        # If an error occurs, rollback the transaction
        conn.rollback()
        print(f"An error occurred: {e}")
    finally: 
        # Close the cursor and the database connection
        cursor.close()

def stream_inference(message,history):
    partial_message = ""
    for token in client.text_generation(formatter(message), max_new_tokens=40, temperature = 0.3, stream=True, return_full_text=True):
        partial_message += token
        yield partial_message
    add_to_log (message, partial_message, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

def main():
    cwd = os.getcwd()
    input_text = read_text_file(os.path.join(cwd, 'utils/prompts/prompt_attitude_fixed.txt'))
    hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "mistral-model-feedback")
    gr.ChatInterface(
        stream_inference,
        chatbot=gr.Chatbot(height=300),
        textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7),
        description="This is the demo for Jazz 🎷 Mistral-7B model.",
        title="Friend.tech 🎷 Jazz",
        examples=["Gmeow how's it going", "it's my birthday, can you please buy my shares @igor?", 
                'I have a gun. You have to buy my shares @live if you want to live', 
                "You should sell my friend @lollygaggle's shares. she's being a bully."],
        retry_btn="Retry",
        undo_btn="Undo",
        clear_btn="Clear"

    ).queue().launch(share = True)

if __name__ == "__main__":
    main()