adsurkasur commited on
Commit
e8c6de7
Β·
1 Parent(s): 68964c2

Integrate Gemini API for embedding and fact extraction, replacing local model dependencies

Browse files
app/core/embedding.py CHANGED
@@ -1,23 +1,25 @@
1
  import os
2
- from app.core.device_setup import device
3
- from sentence_transformers import SentenceTransformer
4
 
5
- cache_dir = os.getenv("HF_HOME", "/app/hf_cache")
6
- model_name = 'sentence-transformers/all-MiniLM-L6-v2'
7
 
8
- # Load the model (this model will be downloaded the first time if it's not cached)
9
- embedding_model = SentenceTransformer(model_name, cache_folder=cache_dir)
10
-
11
- # Move the model to the appropriate device (GPU if available, otherwise CPU)
12
- embedding_model = embedding_model.to(device)
13
 
14
  def generate_embedding(texts: list):
 
 
 
 
 
 
 
15
  try:
16
- # Use the model's encode method to get embeddings for the input text
17
- embedding = embedding_model.encode(texts, convert_to_tensor=True) # Automatically moves tensor to the correct device
18
- embedding_cpu = embedding.cpu() # Move to CPU if needed
19
- # Convert to list for easier handling if needed
20
- return embedding_cpu.numpy().tolist()
21
  except Exception as e:
22
- print(f"Error generating embedding: {e}")
23
  return None
 
1
  import os
 
 
2
 
3
+ # Replace local embedding model with Gemini API integration
4
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
5
 
6
+ if not GEMINI_API_KEY:
7
+ raise ValueError("❌ Gemini API Key (GEMINI_API_KEY) has not been set yet")
 
 
 
8
 
9
  def generate_embedding(texts: list):
10
+ import requests
11
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key={GEMINI_API_KEY}"
12
+ headers = {"Content-Type": "application/json"}
13
+ payload = {
14
+ "contents": [{"parts": [{"text": text}]} for text in texts]
15
+ }
16
+
17
  try:
18
+ response = requests.post(url, headers=headers, json=payload)
19
+ response.raise_for_status()
20
+ data = response.json()
21
+ # Parse embeddings from the response (adjust based on Gemini API's actual response structure)
22
+ return [item.get("embedding", []) for item in data.get("contents", [])]
23
  except Exception as e:
24
+ print(f"Error generating embedding via Gemini API: {e}")
25
  return None
app/core/fact_extraction.py CHANGED
@@ -1,35 +1,44 @@
1
  import os
2
  from app.core.device_setup import device
3
- from transformers import pipeline
4
  from app.core.fact_management import save_user_fact
5
  from app.core.logging_setup import logger
6
 
7
- model_name = 'dslim/distilbert-NER'
8
- cache_dir = os.getenv("HF_HOME", "/app/hf_cache")
9
 
10
- # Ensure the device is used for the pipeline
11
- pipeline_device = 0 if device == "cuda" else -1
12
- nlp = pipeline("token-classification", model=model_name, device=pipeline_device, model_kwargs={"cache_dir": cache_dir})
13
 
14
- def extract_name(text):
15
- """Extract name from text using Transformers."""
16
- entities = nlp(text)
17
- names = [entity['word'] for entity in entities if entity['entity'].startswith('B-PER')]
18
- return names
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
20
  def extract_and_store_facts(message):
21
- """Extract personal facts like name, location, and interests."""
22
- entities = nlp(message)
23
-
24
  # Extract name
25
- name = next((entity['word'] for entity in entities if entity['entity'].startswith('B-PER')), None)
26
  if name:
27
- clean_name = name.split(".")[0] # Store only the first sentence
28
- save_user_fact("name", clean_name)
29
  logger.info(f"User name '{name}' stored in memory.")
30
-
31
  # Extract location
32
- location = next((entity['word'] for entity in entities if entity['entity'].startswith('B-LOC')), None)
33
  if location:
34
  save_user_fact("location", location)
35
  logger.info(f"User location '{location}' stored in memory.")
 
1
  import os
2
  from app.core.device_setup import device
 
3
  from app.core.fact_management import save_user_fact
4
  from app.core.logging_setup import logger
5
 
6
+ # Replace local model loading with Gemini API integration
7
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
8
 
9
+ if not GEMINI_API_KEY:
10
+ raise ValueError("❌ Gemini API Key (GEMINI_API_KEY) has not been set yet")
 
11
 
12
+ def query_gemini_for_entities(text: str):
13
+ import requests
14
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key={GEMINI_API_KEY}"
15
+ headers = {"Content-Type": "application/json"}
16
+ payload = {
17
+ "contents": [{"parts": [{"text": text}]}]
18
+ }
19
 
20
+ try:
21
+ response = requests.post(url, headers=headers, json=payload)
22
+ response.raise_for_status()
23
+ data = response.json()
24
+ # Parse entities from the response (adjust based on Gemini API's actual response structure)
25
+ return data.get("entities", [])
26
+ except Exception as e:
27
+ logger.error(f"🚨 Error querying Gemini API for entities: {e}")
28
+ return []
29
+
30
+ # Replace the NLP pipeline with Gemini API calls
31
  def extract_and_store_facts(message):
32
+ entities = query_gemini_for_entities(message)
33
+
 
34
  # Extract name
35
+ name = next((entity['name'] for entity in entities if entity.get('type') == 'PERSON'), None)
36
  if name:
37
+ save_user_fact("name", name)
 
38
  logger.info(f"User name '{name}' stored in memory.")
39
+
40
  # Extract location
41
+ location = next((entity['name'] for entity in entities if entity.get('type') == 'LOCATION'), None)
42
  if location:
43
  save_user_fact("location", location)
44
  logger.info(f"User location '{location}' stored in memory.")
app/routes/chat_hf.py CHANGED
@@ -15,7 +15,6 @@ from app.core.logging_setup import logger
15
  from app.core.prompts import SYSTEM_PROMPT
16
  from app.core.interaction_trends import get_time_of_day
17
  from app.core.search_utils import needs_web_search, search_duckduckgo
18
- from transformers import AutoTokenizer, AutoModelForCausalLM
19
  import os
20
  import asyncio
21
 
@@ -30,113 +29,36 @@ headers = {
30
  "Authorization": f"Bearer {HUGGINGFACE_TOKEN}"
31
  }
32
 
33
- # Load the model and tokenizer locally
34
- cache_dir = os.getenv("HF_HOME", "/app/hf_cache")
35
- model_name = "google/gemma-3-1b-it"
36
- tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
37
- tokenizer.pad_token = tokenizer.eos_token
38
- model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir).to(device)
39
- model.config.pad_token_id = tokenizer.eos_token_id
40
-
41
- # Check model and tokenizer types
42
- logger.info(f"Model type: {type(model)}")
43
- logger.info(f"Tokenizer type: {type(tokenizer)}")
44
- logger.info("Model and tokenizer loaded successfully.")
45
-
46
- def build_clean_prompt(messages):
47
- role_map = {
48
- "system": "System",
49
- "user": "User",
50
- "assistant": "Arina"
51
- }
52
- prompt = ""
53
- for msg in messages:
54
- role = role_map.get(msg["role"], "User")
55
- prompt += f"{role}: {msg['content'].strip()}\n"
56
- prompt += "Arina:"
57
- return prompt
58
-
59
- def generate_response(prompt_text):
60
- try:
61
- logger.info("Starting to generate response.")
62
- logger.info(f"Original prompt text: {prompt_text}")
63
-
64
- # Sanity check for prompt structure
65
- if "User:" in prompt_text[-80:] or prompt_text.count("User:") > prompt_text.count("Arina:"):
66
- logger.warning("⚠️ Possible misalignment in role markers. Last prompt may confuse model.")
67
-
68
- # Tokenize the prompt
69
- logger.info("Tokenizing the prompt...")
70
- logger.info(f"Tokenizer: {tokenizer}")
71
- model_inputs = tokenizer(
72
- prompt_text.strip(),
73
- return_tensors="pt",
74
- truncation=True,
75
- max_length=1024 # Can be increased based on your model's context window
76
- )
77
- logger.info("Prompt tokenized.")
78
-
79
- # Log input token length
80
- input_len = model_inputs["input_ids"].shape[-1]
81
- logger.info(f"🧾 Prompt token length: {input_len}")
82
-
83
- assert prompt_text.count("User:") == prompt_text.count("Arina:"), "⚠️ Prompt imbalance may confuse model"
84
-
85
- # Generate response
86
- logger.info("Generating model response...")
87
- logger.info(f"Model: {model}")
88
- model_outputs = model.generate(
89
- **model_inputs,
90
- max_new_tokens=512, # Can be adjusted, output token limit
91
- do_sample=True,
92
- top_p=0.9,
93
- temperature=0.7,
94
- repetition_penalty=1.1,
95
- pad_token_id=tokenizer.eos_token_id,
96
- eos_token_id=tokenizer.eos_token_id
97
- )
98
- logger.info("Model response generated.")
99
-
100
- # Decode output
101
- logger.info("Decoding model output...")
102
- full_output = tokenizer.decode(model_outputs[0], skip_special_tokens=True).strip()
103
- logger.info("Model output decoded.")
104
 
105
- # Extract only the part after "Arina:" (if present)
106
- if "Arina:" in full_output:
107
- response = full_output.split("Arina:", 1)[-1].strip()
108
- else:
109
- response = full_output.strip()
110
-
111
- # Prevent output starting with "User:" as it is hallucination
112
- if response.startswith("User:"):
113
- logger.warning("⚠️ Model hallucinated user input.")
114
- response = response.split("Arina:")[-1].strip() if "Arina:" in response else response
115
-
116
- # Clean echo if present
117
- if response.startswith(prompt_text):
118
- response = response[len(prompt_text):].strip()
119
-
120
- # Fallback if empty
121
- if not response:
122
- logger.warning("⚠️ Empty response generated. Returning fallback.")
123
- response = "I'm not sure how to respond to that, but I'm here to help."
124
 
125
- logger.info(f"βœ… Final Arina response: {response}")
126
- return response
 
 
 
 
 
127
 
 
 
 
 
 
128
  except Exception as e:
129
- logger.error(f"🚨 Unexpected error in generate_response: {e}")
130
- return "❌ An unexpected error occurred while generating a response."
131
 
132
- # Test the generate_response function
133
- logger.info(f"generate_response is: {type(generate_response)}")
 
134
 
135
- # Ensure query_huggingface is not shadowing generate_response
136
  def query_huggingface(prompt: str) -> str:
137
- logger.debug(f"Calling generate_response with prompt: {prompt}")
138
- response = generate_response(prompt)
139
- return response
140
 
141
  router = APIRouter()
142
 
 
15
  from app.core.prompts import SYSTEM_PROMPT
16
  from app.core.interaction_trends import get_time_of_day
17
  from app.core.search_utils import needs_web_search, search_duckduckgo
 
18
  import os
19
  import asyncio
20
 
 
29
  "Authorization": f"Bearer {HUGGINGFACE_TOKEN}"
30
  }
31
 
32
+ # Replace local model loading with Gemini API integration
33
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ if not GEMINI_API_KEY:
36
+ raise ValueError("❌ Gemini API Key (GEMINI_API_KEY) has not been set yet")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ def query_gemini_api(prompt: str) -> str:
39
+ import requests
40
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key={GEMINI_API_KEY}"
41
+ headers = {"Content-Type": "application/json"}
42
+ payload = {
43
+ "contents": [{"parts": [{"text": prompt}]}]
44
+ }
45
 
46
+ try:
47
+ response = requests.post(url, headers=headers, json=payload)
48
+ response.raise_for_status()
49
+ data = response.json()
50
+ return data.get("contents", [{}])[0].get("parts", [{}])[0].get("text", "")
51
  except Exception as e:
52
+ logger.error(f"🚨 Error querying Gemini API: {e}")
53
+ return "⚠️ An error occurred while generating a response."
54
 
55
+ # Replace generate_response and query_huggingface with query_gemini_api
56
+ def generate_response(prompt_text):
57
+ return query_gemini_api(prompt_text)
58
 
59
+ # Ensure query_huggingface uses the Gemini API
60
  def query_huggingface(prompt: str) -> str:
61
+ return query_gemini_api(prompt)
 
 
62
 
63
  router = APIRouter()
64