File size: 3,261 Bytes
366198a
 
 
 
 
 
 
 
7f8b883
366198a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import asyncio
from dotenv import load_dotenv


# Initialize DeepInfra-compatible OpenAI client
from openai import OpenAI
openai = OpenAI(
    api_key=DEEPINFRA_API_KEY,
    base_url="https://api.deepinfra.com/v1/openai",
)

# Weaviate imports
import weaviate
from weaviate.classes.init import Auth
from contextlib import contextmanager

@contextmanager
def weaviate_client():
    """
    Context manager that yields a Weaviate client and
    guarantees client.close() on exit.
    """
    client = weaviate.connect_to_weaviate_cloud(
        cluster_url=WEAVIATE_URL,
        auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
    )
    try:
        yield client
    finally:
        client.close()

def embed_texts(texts: list[str], batch_size: int = 50) -> list[list[float]]:
    """Embed texts in batches to avoid API limits."""
    all_embeddings: list[list[float]] = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        try:
            resp = openai.embeddings.create(
                model="Qwen/Qwen3-Embedding-8B",
                input=batch,
                encoding_format="float"
            )
            batch_embs = [item.embedding for item in resp.data]
            all_embeddings.extend(batch_embs)
        except Exception as e:
            print(f"Embedding batch error (items {i}{i+len(batch)-1}): {e}")
            all_embeddings.extend([[] for _ in batch])
    return all_embeddings

def encode_query(query: str) -> list[float] | None:
    """Generate a single embedding vector for a query string."""
    embs = embed_texts([query], batch_size=1)
    if embs and embs[0]:
        print("Query embedding (first 5 dims):", embs[0][:5])
        return embs[0]
    print("Failed to generate query embedding.")
    return None

async def rag_autism(query: str, top_k: int = 3) -> dict:
    """
    Run a RAG retrieval on the 'UserSpecificDocument' collection in Weaviate.
    Returns up to `top_k` matching text chunks.
    """
    qe = encode_query(query)
    if not qe:
        return {"answer": []}

    try:
        with weaviate_client() as client:
            coll = client.collections.get("UserSpecificDocument")
            res  = coll.query.near_vector(
                near_vector=qe,
                limit=top_k,
                return_properties=["text"]
            )
        if not getattr(res, "objects", None):
            return {"answer": []}
        return {
            "answer": [
                obj.properties.get("text", "[No Text]")
                for obj in res.objects
            ]
        }
    except Exception as e:
        print("RAG Error:", e)
        return {"answer": []}

# Example test harness
# if __name__ == "__main__":
#     test_queries = [
#         "What are the common early signs of autism in young children?",
#         "What diagnostic criteria are used for autism spectrum disorder?",
#         "What support strategies help improve communication skills in autistic individuals?"
#     ]
#     for q in test_queries:
#         print(f"\nQuery: {q}")
#         out = asyncio.run(rag_autism(q, top_k=3))
#         print("Retrieved contexts:")
#         for idx, ctx in enumerate(out["answer"], 1):
#             print(f"{idx}. {ctx}")