ankanghosh commited on
Commit
ac3798b
·
verified ·
1 Parent(s): 2311e2d

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +22 -23
rag_engine.py CHANGED
@@ -66,33 +66,32 @@ local_metadata_file = "metadata.jsonl"
66
 
67
  def load_model():
68
  try:
69
- # Check if model is already loaded
70
- if not st.session_state.model_initialized:
71
- with st.spinner("Loading tokenizer and model... This may take a minute."):
72
- print("Loading tokenizer...")
73
- tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-small-v2")
74
-
75
- print("Loading model...")
76
- model = AutoModel.from_pretrained(
77
- "intfloat/e5-small-v2",
78
- torch_dtype=torch.float16, # Use half precision
79
- low_cpu_mem_usage=True,
80
- )
81
-
82
- model.eval()
83
- torch.set_grad_enabled(False)
84
-
85
- # Store in session state
86
- st.session_state.tokenizer = tokenizer
87
- st.session_state.model = model
88
- st.session_state.model_initialized = True
89
-
90
- print("✅ Model loaded successfully")
91
 
92
  return st.session_state.tokenizer, st.session_state.model
93
  except Exception as e:
94
  print(f"❌ Error loading model: {str(e)}")
95
- st.error(f"Error loading model: {str(e)}")
96
  raise
97
 
98
  def download_file_from_gcs(bucket, gcs_path, local_path):
 
66
 
67
  def load_model():
68
  try:
69
+ if st.session_state.model is None:
70
+ # Force model to CPU - more stable than GPU for this use case
71
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
72
+
73
+ print("Loading tokenizer...")
74
+ tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-small-v2")
75
+
76
+ print("Loading model...")
77
+ model = AutoModel.from_pretrained(
78
+ "intfloat/e5-small-v2",
79
+ torch_dtype=torch.float16 # Use half precision
80
+ )
81
+
82
+ # Move model to the designated device
83
+ model = model.to(st.session_state.device)
84
+ model.eval()
85
+ torch.set_grad_enabled(False)
86
+
87
+ st.session_state.tokenizer = tokenizer
88
+ st.session_state.model = model
89
+
90
+ print("✅ Model loaded successfully")
91
 
92
  return st.session_state.tokenizer, st.session_state.model
93
  except Exception as e:
94
  print(f"❌ Error loading model: {str(e)}")
 
95
  raise
96
 
97
  def download_file_from_gcs(bucket, gcs_path, local_path):