|
import os |
|
import asyncio |
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
from openai import OpenAI |
|
openai = OpenAI( |
|
api_key=DEEPINFRA_API_KEY, |
|
base_url="https://api.deepinfra.com/v1/openai", |
|
) |
|
|
|
|
|
import weaviate |
|
from weaviate.classes.init import Auth |
|
from contextlib import contextmanager |
|
|
|
@contextmanager |
|
def weaviate_client(): |
|
""" |
|
Context manager that yields a Weaviate client and |
|
guarantees client.close() on exit. |
|
""" |
|
client = weaviate.connect_to_weaviate_cloud( |
|
cluster_url=WEAVIATE_URL, |
|
auth_credentials=Auth.api_key(WEAVIATE_API_KEY), |
|
) |
|
try: |
|
yield client |
|
finally: |
|
client.close() |
|
|
|
def embed_texts(texts: list[str], batch_size: int = 50) -> list[list[float]]: |
|
"""Embed texts in batches to avoid API limits.""" |
|
all_embeddings: list[list[float]] = [] |
|
for i in range(0, len(texts), batch_size): |
|
batch = texts[i : i + batch_size] |
|
try: |
|
resp = openai.embeddings.create( |
|
model="Qwen/Qwen3-Embedding-8B", |
|
input=batch, |
|
encoding_format="float" |
|
) |
|
batch_embs = [item.embedding for item in resp.data] |
|
all_embeddings.extend(batch_embs) |
|
except Exception as e: |
|
print(f"Embedding batch error (items {i}–{i+len(batch)-1}): {e}") |
|
all_embeddings.extend([[] for _ in batch]) |
|
return all_embeddings |
|
|
|
def encode_query(query: str) -> list[float] | None: |
|
"""Generate a single embedding vector for a query string.""" |
|
embs = embed_texts([query], batch_size=1) |
|
if embs and embs[0]: |
|
print("Query embedding (first 5 dims):", embs[0][:5]) |
|
return embs[0] |
|
print("Failed to generate query embedding.") |
|
return None |
|
|
|
async def rag_autism(query: str, top_k: int = 3) -> dict: |
|
""" |
|
Run a RAG retrieval on the 'UserSpecificDocument' collection in Weaviate. |
|
Returns up to `top_k` matching text chunks. |
|
""" |
|
qe = encode_query(query) |
|
if not qe: |
|
return {"answer": []} |
|
|
|
try: |
|
with weaviate_client() as client: |
|
coll = client.collections.get("UserSpecificDocument") |
|
res = coll.query.near_vector( |
|
near_vector=qe, |
|
limit=top_k, |
|
return_properties=["text"] |
|
) |
|
if not getattr(res, "objects", None): |
|
return {"answer": []} |
|
return { |
|
"answer": [ |
|
obj.properties.get("text", "[No Text]") |
|
for obj in res.objects |
|
] |
|
} |
|
except Exception as e: |
|
print("RAG Error:", e) |
|
return {"answer": []} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|