Spaces:
Runtime error
Runtime error
Commit
Β·
e8c6de7
1
Parent(s):
68964c2
Integrate Gemini API for embedding and fact extraction, replacing local model dependencies
Browse files- app/core/embedding.py +17 -15
- app/core/fact_extraction.py +28 -19
- app/routes/chat_hf.py +23 -101
app/core/embedding.py
CHANGED
@@ -1,23 +1,25 @@
|
|
1 |
import os
|
2 |
-
from app.core.device_setup import device
|
3 |
-
from sentence_transformers import SentenceTransformer
|
4 |
|
5 |
-
|
6 |
-
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
# Move the model to the appropriate device (GPU if available, otherwise CPU)
|
12 |
-
embedding_model = embedding_model.to(device)
|
13 |
|
14 |
def generate_embedding(texts: list):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
try:
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
#
|
20 |
-
return
|
21 |
except Exception as e:
|
22 |
-
print(f"Error generating embedding: {e}")
|
23 |
return None
|
|
|
1 |
import os
|
|
|
|
|
2 |
|
3 |
+
# Replace local embedding model with Gemini API integration
|
4 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
|
5 |
|
6 |
+
if not GEMINI_API_KEY:
|
7 |
+
raise ValueError("β Gemini API Key (GEMINI_API_KEY) has not been set yet")
|
|
|
|
|
|
|
8 |
|
9 |
def generate_embedding(texts: list):
|
10 |
+
import requests
|
11 |
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key={GEMINI_API_KEY}"
|
12 |
+
headers = {"Content-Type": "application/json"}
|
13 |
+
payload = {
|
14 |
+
"contents": [{"parts": [{"text": text}]} for text in texts]
|
15 |
+
}
|
16 |
+
|
17 |
try:
|
18 |
+
response = requests.post(url, headers=headers, json=payload)
|
19 |
+
response.raise_for_status()
|
20 |
+
data = response.json()
|
21 |
+
# Parse embeddings from the response (adjust based on Gemini API's actual response structure)
|
22 |
+
return [item.get("embedding", []) for item in data.get("contents", [])]
|
23 |
except Exception as e:
|
24 |
+
print(f"Error generating embedding via Gemini API: {e}")
|
25 |
return None
|
app/core/fact_extraction.py
CHANGED
@@ -1,35 +1,44 @@
|
|
1 |
import os
|
2 |
from app.core.device_setup import device
|
3 |
-
from transformers import pipeline
|
4 |
from app.core.fact_management import save_user_fact
|
5 |
from app.core.logging_setup import logger
|
6 |
|
7 |
-
|
8 |
-
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
nlp = pipeline("token-classification", model=model_name, device=pipeline_device, model_kwargs={"cache_dir": cache_dir})
|
13 |
|
14 |
-
def
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def extract_and_store_facts(message):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
# Extract name
|
25 |
-
name = next((entity['
|
26 |
if name:
|
27 |
-
|
28 |
-
save_user_fact("name", clean_name)
|
29 |
logger.info(f"User name '{name}' stored in memory.")
|
30 |
-
|
31 |
# Extract location
|
32 |
-
location = next((entity['
|
33 |
if location:
|
34 |
save_user_fact("location", location)
|
35 |
logger.info(f"User location '{location}' stored in memory.")
|
|
|
1 |
import os
|
2 |
from app.core.device_setup import device
|
|
|
3 |
from app.core.fact_management import save_user_fact
|
4 |
from app.core.logging_setup import logger
|
5 |
|
6 |
+
# Replace local model loading with Gemini API integration
|
7 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
|
8 |
|
9 |
+
if not GEMINI_API_KEY:
|
10 |
+
raise ValueError("β Gemini API Key (GEMINI_API_KEY) has not been set yet")
|
|
|
11 |
|
12 |
+
def query_gemini_for_entities(text: str):
|
13 |
+
import requests
|
14 |
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key={GEMINI_API_KEY}"
|
15 |
+
headers = {"Content-Type": "application/json"}
|
16 |
+
payload = {
|
17 |
+
"contents": [{"parts": [{"text": text}]}]
|
18 |
+
}
|
19 |
|
20 |
+
try:
|
21 |
+
response = requests.post(url, headers=headers, json=payload)
|
22 |
+
response.raise_for_status()
|
23 |
+
data = response.json()
|
24 |
+
# Parse entities from the response (adjust based on Gemini API's actual response structure)
|
25 |
+
return data.get("entities", [])
|
26 |
+
except Exception as e:
|
27 |
+
logger.error(f"π¨ Error querying Gemini API for entities: {e}")
|
28 |
+
return []
|
29 |
+
|
30 |
+
# Replace the NLP pipeline with Gemini API calls
|
31 |
def extract_and_store_facts(message):
|
32 |
+
entities = query_gemini_for_entities(message)
|
33 |
+
|
|
|
34 |
# Extract name
|
35 |
+
name = next((entity['name'] for entity in entities if entity.get('type') == 'PERSON'), None)
|
36 |
if name:
|
37 |
+
save_user_fact("name", name)
|
|
|
38 |
logger.info(f"User name '{name}' stored in memory.")
|
39 |
+
|
40 |
# Extract location
|
41 |
+
location = next((entity['name'] for entity in entities if entity.get('type') == 'LOCATION'), None)
|
42 |
if location:
|
43 |
save_user_fact("location", location)
|
44 |
logger.info(f"User location '{location}' stored in memory.")
|
app/routes/chat_hf.py
CHANGED
@@ -15,7 +15,6 @@ from app.core.logging_setup import logger
|
|
15 |
from app.core.prompts import SYSTEM_PROMPT
|
16 |
from app.core.interaction_trends import get_time_of_day
|
17 |
from app.core.search_utils import needs_web_search, search_duckduckgo
|
18 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
19 |
import os
|
20 |
import asyncio
|
21 |
|
@@ -30,113 +29,36 @@ headers = {
|
|
30 |
"Authorization": f"Bearer {HUGGINGFACE_TOKEN}"
|
31 |
}
|
32 |
|
33 |
-
#
|
34 |
-
|
35 |
-
model_name = "google/gemma-3-1b-it"
|
36 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
37 |
-
tokenizer.pad_token = tokenizer.eos_token
|
38 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir).to(device)
|
39 |
-
model.config.pad_token_id = tokenizer.eos_token_id
|
40 |
-
|
41 |
-
# Check model and tokenizer types
|
42 |
-
logger.info(f"Model type: {type(model)}")
|
43 |
-
logger.info(f"Tokenizer type: {type(tokenizer)}")
|
44 |
-
logger.info("Model and tokenizer loaded successfully.")
|
45 |
-
|
46 |
-
def build_clean_prompt(messages):
|
47 |
-
role_map = {
|
48 |
-
"system": "System",
|
49 |
-
"user": "User",
|
50 |
-
"assistant": "Arina"
|
51 |
-
}
|
52 |
-
prompt = ""
|
53 |
-
for msg in messages:
|
54 |
-
role = role_map.get(msg["role"], "User")
|
55 |
-
prompt += f"{role}: {msg['content'].strip()}\n"
|
56 |
-
prompt += "Arina:"
|
57 |
-
return prompt
|
58 |
-
|
59 |
-
def generate_response(prompt_text):
|
60 |
-
try:
|
61 |
-
logger.info("Starting to generate response.")
|
62 |
-
logger.info(f"Original prompt text: {prompt_text}")
|
63 |
-
|
64 |
-
# Sanity check for prompt structure
|
65 |
-
if "User:" in prompt_text[-80:] or prompt_text.count("User:") > prompt_text.count("Arina:"):
|
66 |
-
logger.warning("β οΈ Possible misalignment in role markers. Last prompt may confuse model.")
|
67 |
-
|
68 |
-
# Tokenize the prompt
|
69 |
-
logger.info("Tokenizing the prompt...")
|
70 |
-
logger.info(f"Tokenizer: {tokenizer}")
|
71 |
-
model_inputs = tokenizer(
|
72 |
-
prompt_text.strip(),
|
73 |
-
return_tensors="pt",
|
74 |
-
truncation=True,
|
75 |
-
max_length=1024 # Can be increased based on your model's context window
|
76 |
-
)
|
77 |
-
logger.info("Prompt tokenized.")
|
78 |
-
|
79 |
-
# Log input token length
|
80 |
-
input_len = model_inputs["input_ids"].shape[-1]
|
81 |
-
logger.info(f"π§Ύ Prompt token length: {input_len}")
|
82 |
-
|
83 |
-
assert prompt_text.count("User:") == prompt_text.count("Arina:"), "β οΈ Prompt imbalance may confuse model"
|
84 |
-
|
85 |
-
# Generate response
|
86 |
-
logger.info("Generating model response...")
|
87 |
-
logger.info(f"Model: {model}")
|
88 |
-
model_outputs = model.generate(
|
89 |
-
**model_inputs,
|
90 |
-
max_new_tokens=512, # Can be adjusted, output token limit
|
91 |
-
do_sample=True,
|
92 |
-
top_p=0.9,
|
93 |
-
temperature=0.7,
|
94 |
-
repetition_penalty=1.1,
|
95 |
-
pad_token_id=tokenizer.eos_token_id,
|
96 |
-
eos_token_id=tokenizer.eos_token_id
|
97 |
-
)
|
98 |
-
logger.info("Model response generated.")
|
99 |
-
|
100 |
-
# Decode output
|
101 |
-
logger.info("Decoding model output...")
|
102 |
-
full_output = tokenizer.decode(model_outputs[0], skip_special_tokens=True).strip()
|
103 |
-
logger.info("Model output decoded.")
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
response = full_output.split("Arina:", 1)[-1].strip()
|
108 |
-
else:
|
109 |
-
response = full_output.strip()
|
110 |
-
|
111 |
-
# Prevent output starting with "User:" as it is hallucination
|
112 |
-
if response.startswith("User:"):
|
113 |
-
logger.warning("β οΈ Model hallucinated user input.")
|
114 |
-
response = response.split("Arina:")[-1].strip() if "Arina:" in response else response
|
115 |
-
|
116 |
-
# Clean echo if present
|
117 |
-
if response.startswith(prompt_text):
|
118 |
-
response = response[len(prompt_text):].strip()
|
119 |
-
|
120 |
-
# Fallback if empty
|
121 |
-
if not response:
|
122 |
-
logger.warning("β οΈ Empty response generated. Returning fallback.")
|
123 |
-
response = "I'm not sure how to respond to that, but I'm here to help."
|
124 |
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
127 |
|
|
|
|
|
|
|
|
|
|
|
128 |
except Exception as e:
|
129 |
-
logger.error(f"π¨
|
130 |
-
return "
|
131 |
|
132 |
-
#
|
133 |
-
|
|
|
134 |
|
135 |
-
# Ensure query_huggingface
|
136 |
def query_huggingface(prompt: str) -> str:
|
137 |
-
|
138 |
-
response = generate_response(prompt)
|
139 |
-
return response
|
140 |
|
141 |
router = APIRouter()
|
142 |
|
|
|
15 |
from app.core.prompts import SYSTEM_PROMPT
|
16 |
from app.core.interaction_trends import get_time_of_day
|
17 |
from app.core.search_utils import needs_web_search, search_duckduckgo
|
|
|
18 |
import os
|
19 |
import asyncio
|
20 |
|
|
|
29 |
"Authorization": f"Bearer {HUGGINGFACE_TOKEN}"
|
30 |
}
|
31 |
|
32 |
+
# Replace local model loading with Gemini API integration
|
33 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
if not GEMINI_API_KEY:
|
36 |
+
raise ValueError("β Gemini API Key (GEMINI_API_KEY) has not been set yet")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
def query_gemini_api(prompt: str) -> str:
|
39 |
+
import requests
|
40 |
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key={GEMINI_API_KEY}"
|
41 |
+
headers = {"Content-Type": "application/json"}
|
42 |
+
payload = {
|
43 |
+
"contents": [{"parts": [{"text": prompt}]}]
|
44 |
+
}
|
45 |
|
46 |
+
try:
|
47 |
+
response = requests.post(url, headers=headers, json=payload)
|
48 |
+
response.raise_for_status()
|
49 |
+
data = response.json()
|
50 |
+
return data.get("contents", [{}])[0].get("parts", [{}])[0].get("text", "")
|
51 |
except Exception as e:
|
52 |
+
logger.error(f"π¨ Error querying Gemini API: {e}")
|
53 |
+
return "β οΈ An error occurred while generating a response."
|
54 |
|
55 |
+
# Replace generate_response and query_huggingface with query_gemini_api
|
56 |
+
def generate_response(prompt_text):
|
57 |
+
return query_gemini_api(prompt_text)
|
58 |
|
59 |
+
# Ensure query_huggingface uses the Gemini API
|
60 |
def query_huggingface(prompt: str) -> str:
|
61 |
+
return query_gemini_api(prompt)
|
|
|
|
|
62 |
|
63 |
router = APIRouter()
|
64 |
|