JustusI commited on
Commit
4e89ebf
·
verified ·
1 Parent(s): a0f157e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -25
app.py CHANGED
@@ -44,26 +44,26 @@ def augment_prompt(query, vectordb):
44
  return augmented_prompt
45
 
46
 
47
- # # Function to handle chat with OpenAI
48
- # def chat_with_openai(query, vectordb, openai_api_key):
49
- # chat = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=openai_api_key)
50
- # augmented_query = augment_prompt(query, vectordb)
51
- # prompt = HumanMessage(content=augmented_query)
52
- # messages = [
53
- # SystemMessage(content="You are a helpful assistant."),
54
- # prompt
55
- # ]
56
- # res = chat(messages)
57
- # return res.content
58
 
59
 
60
- # Function to handle chat with the Google open-source LLM
61
- def chat_with_google_llm(query, vectordb, tokenizer, model):
62
- augmented_query = augment_prompt(query, vectordb)
63
- input_ids = tokenizer(augmented_query, return_tensors="pt") #.to("cuda")
64
- outputs = model.generate(input_ids, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
65
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
- return response
67
 
68
 
69
  # Streamlit UI
@@ -75,12 +75,12 @@ zip_file_path = "chroma_db_compressed_.zip"
75
  extract_path = "./chroma_db_extracted"
76
  vectordb = load_vector_db(zip_file_path, extract_path)
77
 
78
- # Load Google model and tokenizer
79
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
80
- model = AutoModelForCausalLM.from_pretrained(
81
- "google/gemma-2b-it",
82
- torch_dtype=torch.bfloat16
83
- )#.to("cuda")
84
 
85
  # Initialize session state for chat history
86
  if "messages" not in st.session_state:
@@ -99,11 +99,26 @@ if prompt := st.chat_input("Enter your query"):
99
  st.markdown(prompt)
100
 
101
  with st.chat_message("assistant"):
102
- response = chat_with_google_llm(prompt, vectordb, tokenizer, model)
 
103
  st.markdown(response)
104
 
105
  st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
 
 
106
 
 
 
 
 
107
  # # Query input
108
  # query = st.text_input("Enter your query", "")
109
 
 
44
  return augmented_prompt
45
 
46
 
47
+ # Function to handle chat with OpenAI
48
+ def chat_with_openai(query, vectordb, openai_api_key):
49
+ chat = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=openai_api_key)
50
+ augmented_query = augment_prompt(query, vectordb)
51
+ prompt = HumanMessage(content=augmented_query)
52
+ messages = [
53
+ SystemMessage(content="You are a helpful assistant."),
54
+ prompt
55
+ ]
56
+ res = chat(messages)
57
+ return res.content
58
 
59
 
60
+ # # Function to handle chat with the Google open-source LLM
61
+ # def chat_with_google_llm(query, vectordb, tokenizer, model):
62
+ # augmented_query = augment_prompt(query, vectordb)
63
+ # input_ids = tokenizer(augmented_query, return_tensors="pt") #.to("cuda")
64
+ # outputs = model.generate(input_ids, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
65
+ # response = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+ # return response
67
 
68
 
69
  # Streamlit UI
 
75
  extract_path = "./chroma_db_extracted"
76
  vectordb = load_vector_db(zip_file_path, extract_path)
77
 
78
+ # # Load Google model and tokenizer
79
+ # tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
80
+ # model = AutoModelForCausalLM.from_pretrained(
81
+ # "google/gemma-2b-it",
82
+ # torch_dtype=torch.bfloat16
83
+ # )#.to("cuda")
84
 
85
  # Initialize session state for chat history
86
  if "messages" not in st.session_state:
 
99
  st.markdown(prompt)
100
 
101
  with st.chat_message("assistant"):
102
+ openai_api_key = st.secrets["OPENAI_API_KEY"]
103
+ response = chat_with_openai(prompt, vectordb, openai_api_key)
104
  st.markdown(response)
105
 
106
  st.session_state.messages.append({"role": "assistant", "content": response})
107
+
108
+ # User input
109
+ # if prompt := st.chat_input("Enter your query"):
110
+ # st.session_state.messages.append({"role": "user", "content": prompt})
111
+ # with st.chat_message("user"):
112
+ # st.markdown(prompt)
113
+
114
+ # with st.chat_message("assistant"):
115
+ # response = chat_with_google_llm(prompt, vectordb, tokenizer, model)
116
+ # st.markdown(response)
117
 
118
+ # st.session_state.messages.append({"role": "assistant", "content": response})
119
+
120
+
121
+
122
  # # Query input
123
  # query = st.text_input("Enter your query", "")
124