JustusI commited on
Commit
4b6bb4b
·
verified ·
1 Parent(s): 9f967c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -12
app.py CHANGED
@@ -43,17 +43,26 @@ def augment_prompt(query, vectordb):
43
  return augmented_prompt
44
 
45
 
46
- # Function to handle chat with OpenAI
47
- def chat_with_openai(query, vectordb, openai_api_key):
48
- chat = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=openai_api_key)
 
 
 
 
 
 
 
 
 
 
 
 
49
  augmented_query = augment_prompt(query, vectordb)
50
- prompt = HumanMessage(content=augmented_query)
51
- messages = [
52
- SystemMessage(content="You are a helpful assistant."),
53
- prompt
54
- ]
55
- res = chat(messages)
56
- return res.content
57
 
58
 
59
  # Streamlit UI
@@ -65,6 +74,13 @@ zip_file_path = "chroma_db_compressed_.zip"
65
  extract_path = "./chroma_db_extracted"
66
  vectordb = load_vector_db(zip_file_path, extract_path)
67
 
 
 
 
 
 
 
 
68
  # Initialize session state for chat history
69
  if "messages" not in st.session_state:
70
  st.session_state.messages = []
@@ -82,8 +98,7 @@ if prompt := st.chat_input("Enter your query"):
82
  st.markdown(prompt)
83
 
84
  with st.chat_message("assistant"):
85
- openai_api_key = st.secrets["OPENAI_API_KEY"]
86
- response = chat_with_openai(prompt, vectordb, openai_api_key)
87
  st.markdown(response)
88
 
89
  st.session_state.messages.append({"role": "assistant", "content": response})
 
43
  return augmented_prompt
44
 
45
 
46
+ # # Function to handle chat with OpenAI
47
+ # def chat_with_openai(query, vectordb, openai_api_key):
48
+ # chat = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=openai_api_key)
49
+ # augmented_query = augment_prompt(query, vectordb)
50
+ # prompt = HumanMessage(content=augmented_query)
51
+ # messages = [
52
+ # SystemMessage(content="You are a helpful assistant."),
53
+ # prompt
54
+ # ]
55
+ # res = chat(messages)
56
+ # return res.content
57
+
58
+
59
+ # Function to handle chat with the Google open-source LLM
60
+ def chat_with_google_llm(query, vectordb, tokenizer, model):
61
  augmented_query = augment_prompt(query, vectordb)
62
+ input_ids = tokenizer(augmented_query, return_tensors="pt") #.to("cuda")
63
+ outputs = model.generate(input_ids, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
64
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
+ return response
 
 
 
66
 
67
 
68
  # Streamlit UI
 
74
  extract_path = "./chroma_db_extracted"
75
  vectordb = load_vector_db(zip_file_path, extract_path)
76
 
77
+ # Load Google model and tokenizer
78
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
79
+ model = AutoModelForCausalLM.from_pretrained(
80
+ "google/gemma-2b-it",
81
+ torch_dtype=torch.bfloat16
82
+ )#.to("cuda")
83
+
84
  # Initialize session state for chat history
85
  if "messages" not in st.session_state:
86
  st.session_state.messages = []
 
98
  st.markdown(prompt)
99
 
100
  with st.chat_message("assistant"):
101
+ response = chat_with_google_llm(prompt, vectordb, tokenizer, model)
 
102
  st.markdown(response)
103
 
104
  st.session_state.messages.append({"role": "assistant", "content": response})