rasulbrur commited on
Commit
cec4377
·
1 Parent(s): 6843abb

changed transformer version

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. rag/sql_db.py +16 -6
Dockerfile CHANGED
@@ -13,7 +13,7 @@ RUN apt-get update && \
13
  ENV HF_HOME=/app/cache
14
 
15
  # Create cache directory and set permissions
16
- RUN mkdir -p /app/cache && chmod -R 777 /app/cache
17
 
18
  # Copy your code into the container
19
  COPY . .
 
13
  ENV HF_HOME=/app/cache
14
 
15
  # Create cache directory and set permissions
16
+ RUN mkdir -p /app/cache /app/db && chmod -R 777 /app/cache /app/db
17
 
18
  # Copy your code into the container
19
  COPY . .
rag/sql_db.py CHANGED
@@ -1,4 +1,3 @@
1
- # rag/retriever.py
2
  import os
3
  import pandas as pd
4
  import faiss
@@ -8,13 +7,20 @@ from .embedder import Embedder
8
  from datetime import datetime
9
 
10
  class SQL_Key_Pair:
11
- def __init__(self, file_path="financial data sp500 companies.csv", model_name="all-MiniLM-L6-v2", db_path="financial_data.db"):
 
 
12
  self.embedder = Embedder(model_name)
13
  self.index = None
14
  self.documents = []
15
  self.data = None
16
  self.embeddings = None
17
- self.db_conn = sqlite3.connect(db_path)
 
 
 
 
 
18
  self.create_db_table()
19
  self.load_data(file_path)
20
 
@@ -36,6 +42,7 @@ class SQL_Key_Pair:
36
  )
37
  """)
38
  self.db_conn.commit()
 
39
 
40
  def load_data(self, file_path):
41
  """
@@ -119,7 +126,6 @@ class SQL_Key_Pair:
119
 
120
  return retrieved_text
121
 
122
-
123
  def query_csv(self, query, k=3):
124
  """
125
  Query the CSV data with a user query.
@@ -140,7 +146,6 @@ class SQL_Key_Pair:
140
 
141
  return "\n".join(responses)
142
 
143
-
144
  def entity_based_query(self, entities):
145
  return self.keyword_match_search(entities)
146
 
@@ -168,4 +173,9 @@ class SQL_Key_Pair:
168
  return f"Error querying database: {str(e)}"
169
 
170
  def __del__(self):
171
- self.db_conn.close()
 
 
 
 
 
 
 
1
  import os
2
  import pandas as pd
3
  import faiss
 
7
  from datetime import datetime
8
 
9
  class SQL_Key_Pair:
10
+ def __init__(self, file_path="financial_data.csv", model_name="all-MiniLM-L6-v2", db_path="/app/db/financial_data.db"):
11
+ # Ensure the database directory exists
12
+ os.makedirs(os.path.dirname(db_path), exist_ok=True)
13
  self.embedder = Embedder(model_name)
14
  self.index = None
15
  self.documents = []
16
  self.data = None
17
  self.embeddings = None
18
+ try:
19
+ self.db_conn = sqlite3.connect(db_path)
20
+ print(f"Connected to SQLite database at {db_path}")
21
+ except sqlite3.OperationalError as e:
22
+ print(f"Failed to connect to database: {e}")
23
+ raise
24
  self.create_db_table()
25
  self.load_data(file_path)
26
 
 
42
  )
43
  """)
44
  self.db_conn.commit()
45
+ print("Created custom_financials table")
46
 
47
  def load_data(self, file_path):
48
  """
 
126
 
127
  return retrieved_text
128
 
 
129
  def query_csv(self, query, k=3):
130
  """
131
  Query the CSV data with a user query.
 
146
 
147
  return "\n".join(responses)
148
 
 
149
  def entity_based_query(self, entities):
150
  return self.keyword_match_search(entities)
151
 
 
173
  return f"Error querying database: {str(e)}"
174
 
175
  def __del__(self):
176
+ try:
177
+ if hasattr(self, 'db_conn') and self.db_conn:
178
+ self.db_conn.close()
179
+ print("Closed SQLite database connection")
180
+ except Exception as e:
181
+ print(f"Error closing database connection: {e}")