Spaces:
Sleeping
Sleeping
import streamlit as st | |
import sqlite3 | |
import pandas as pd | |
import openai | |
import os | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
os.environ["OPENAI_API_KEY"] = os.getenv("SECRET_KEY") | |
def init_database(): | |
conn = sqlite3.connect('GPTPromptTemplates.db') | |
cursor = conn.cursor() | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS USERS ( | |
USER_ID INTEGER PRIMARY KEY, | |
User_Name VARCHAR(255) | |
) | |
''') | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS TEMPLATES ( | |
TEMPLATE_ID INTEGER PRIMARY KEY, | |
USER_ID INTEGER, | |
Prompt_Name VARCHAR(255), | |
Prompt_Text TEXT | |
) | |
''') | |
cursor.execute(''' | |
CREATE UNIQUE INDEX IF NOT EXISTS idx_templates_prompt_name ON TEMPLATES (USER_ID, Prompt_Name) | |
''') | |
conn.commit() | |
conn.close() | |
def insert_prompt_template(user_id, prompt_name, prompt_text): | |
conn = sqlite3.connect('GPTPromptTemplates.db') | |
cursor = conn.cursor() | |
cursor.execute('INSERT OR REPLACE INTO TEMPLATES (USER_ID, Prompt_Name, Prompt_Text) VALUES (?, ?, ?)', (user_id, prompt_name, prompt_text)) | |
conn.commit() | |
conn.close() | |
def delete_prompt_template(user_id, prompt_name): | |
conn = sqlite3.connect('GPTPromptTemplates.db') | |
cursor = conn.cursor() | |
cursor.execute('DELETE FROM TEMPLATES WHERE USER_ID = ? AND prompt_name = ?', (user_id, prompt_name)) | |
conn.commit() | |
conn.close() | |
def get_prompt(user_id, prompt_name): | |
conn = sqlite3.connect('GPTPromptTemplates.db') | |
cursor = conn.cursor() | |
cursor.execute('SELECT Prompt_Name, Prompt_Text FROM TEMPLATES WHERE Prompt_Name = ? AND USER_ID = ?', (prompt_name, user_id)) | |
template = cursor.fetchone() | |
conn.close() | |
if template == None: | |
return '','' | |
else: | |
return template[0], template[1] | |
def get_default_prompt(user_id): | |
conn = sqlite3.connect('GPTPromptTemplates.db') | |
cursor = conn.cursor() | |
cursor.execute('SELECT Prompt_Name, Prompt_Text FROM TEMPLATES WHERE USER_ID = ? ORDER BY Prompt_Name ASC LIMIT 1', (user_id, )) | |
template = cursor.fetchone() | |
conn.close() | |
if template == None: | |
return '','' | |
else: | |
return template[0], template[1] | |
def get_prompt_list(user_id): | |
conn = sqlite3.connect('GPTPromptTemplates.db') | |
templates = pd.read_sql_query('SELECT DISTINCT Prompt_Name FROM TEMPLATES WHERE USER_ID = {} ORDER BY Prompt_Name ASC'.format(user_id), conn) | |
conn.commit() | |
conn.close() | |
return templates | |
def template_change_value(): | |
name, prompt = get_prompt(st.session_state.user_id, st.session_state.template_select) | |
st.session_state.name = name | |
st.session_state.prompt = prompt | |
def template_return_value(template_name): | |
st.session_state.template_select = template_name | |
name, prompt = get_prompt(st.session_state.user_id, st.session_state.template_select) | |
st.session_state.name = name | |
st.session_state.prompt = prompt | |
def main(): | |
st.title("Working with Chat GPT with templates") | |
init_database() | |
col1, col2, col3 = st.columns([1,1,1]) | |
user_id = 1 | |
name, prompt = get_default_prompt(user_id) | |
prompt_list = get_prompt_list(user_id) | |
model_names = ['gpt-4','gpt-3.5-turbo','gpt-3.5-turbo-16k'] | |
if not "initialized" in st.session_state: | |
st.session_state.user_id = user_id | |
st.session_state.name = name | |
st.session_state.prompt = prompt | |
st.session_state.prompt_list = prompt_list | |
st.session_state.template_select = name | |
st.session_state.output = '' | |
st.session_state.model_name = 'gpt-4' | |
st.session_state.initialized = True | |
with col1: | |
input_text = st.text_area('Please insert data for transforming', '', key="input_data", height=450) | |
if st.button("Apply"): | |
query = prompt | |
with st.spinner('In progress...'): | |
# st.write("in progress") | |
# text_splitter = CharacterTextSplitter(chunk_size=4096, chunk_overlap=0) | |
# texts = text_splitter.split_text(input_text) | |
# embeddings = OpenAIEmbeddings() | |
# docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever() | |
# docs = docsearch.get_relevant_documents(query) | |
if st.session_state.model_name == 'gpt-4': | |
max_tkns=5500 | |
else : | |
max_tkns=3000 | |
openai.api_key = os.environ["OPENAI_API_KEY"] | |
response = openai.ChatCompletion.create( | |
model=st.session_state.model_name, | |
messages=[ | |
{"role": "system", "content": query}, | |
{"role": "user", "content": input_text}, | |
], | |
temperature = 0.7, | |
max_tokens=max_tkns | |
) | |
st.session_state.output = response["choices"][0]["message"]["content"].replace("\\n", "\n") | |
# chain = load_qa_chain(ChatOpenAI(model = st.session_state.model_name,max_tokens=max_tkns,temperature=0), chain_type="stuff") | |
# st.session_state.output = chain.run(input_documents=docs, question=query) | |
#st.session_state["output"] = output | |
#col3.text_area('Result', value=output, key="output_data", height=450) | |
st.experimental_rerun() | |
st.success("Ready!") | |
with col2: | |
st.session_state.model_name = st.selectbox("GPT model: ",model_names, key="gpt_model") | |
template_return_value(st.selectbox("Template: ",st.session_state.prompt_list, key="prompt_template",)) | |
new_name = st.text_input("Template name:",value=st.session_state.name, key="template_name") | |
input_query = st.text_area("Prompt:",value=st.session_state.prompt, key="template_text", height=200) | |
col4, col5 = st.columns([1,1]) | |
if col4.button("Save"): | |
insert_prompt_template(user_id, new_name, input_query) | |
st.session_state.prompt_list = get_prompt_list(user_id) | |
st.success("Prompt saved!") | |
st.experimental_rerun() | |
if col5.button("Delete"): | |
delete_prompt_template(user_id, new_name) | |
st.session_state.prompt_list = get_prompt_list(user_id) | |
st.success("Prompt deleted!") | |
st.experimental_rerun() | |
with col3: | |
txt_result = st.text_area('Result', value=st.session_state.output, key="output_data", height=450) | |
if __name__ == "__main__": | |
main() |