File size: 7,855 Bytes
293413b
 
cb8303f
293413b
cb8303f
 
 
 
 
aa6b888
 
588cb6a
 
a2d5223
588cb6a
 
 
 
cb8303f
 
 
bdbefdd
588cb6a
 
a29c4ff
588cb6a
58272f8
588cb6a
a29c4ff
 
48a65b5
 
9c89db3
48a65b5
9c89db3
 
 
48a65b5
9c89db3
e54e8f7
 
 
9c89db3
a29c4ff
588cb6a
48a65b5
588cb6a
 
 
 
a2d5223
a29c4ff
 
a2d5223
 
cb8303f
48a65b5
588cb6a
a29c4ff
588cb6a
cb8303f
48a65b5
9c89db3
 
 
 
48a65b5
a2d5223
aa6b888
a2d5223
 
58272f8
aa6b888
588cb6a
aa6b888
 
8583b57
9c89db3
a29c4ff
8583b57
a29c4ff
aa6b888
a2d5223
cb8303f
588cb6a
 
 
cb8303f
293413b
588cb6a
a2d5223
 
 
 
 
588cb6a
 
 
cb8303f
 
293413b
 
8583b57
293413b
 
 
 
 
 
 
 
 
 
588cb6a
293413b
 
 
 
8583b57
 
 
 
 
 
 
 
 
 
293413b
 
 
 
 
 
 
 
8583b57
 
293413b
 
 
 
 
8583b57
59b5835
 
392cd96
 
293413b
 
392cd96
293413b
 
8583b57
 
cb8303f
 
 
 
 
 
 
 
588cb6a
 
 
 
a2d5223
 
a29c4ff
 
 
48a65b5
9c89db3
48a65b5
a29c4ff
 
 
 
 
a2d5223
 
 
9c89db3
a2d5223
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# app.py

import json
import time
import numpy as np
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from llama_cpp import Llama
from huggingface_hub import login, hf_hub_download
import logging
import os
import faiss

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Authenticate with Hugging Fac
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
    logger.error("HF_TOKEN environment variable not set.")
    raise ValueError("HF_TOKEN not set")
login(token=hf_token)

# Models
sentence_transformer_model = "all-MiniLM-L6-v2"
repo_id = "bartowski/deepcogito_cogito-v1-preview-llama-3B-GGUF"
filename = "deepcogito_cogito-v1-preview-llama-3B-Q4_K_M.gguf"  # Updated to Cogito Q4_K_M

# Define FAQs (unchanged)
faqs = [
    {"question": "What is your name?", "answer": "My name is Tim Luka Horstmann."},
    {"question": "Where do you live?", "answer": "I live in Paris, France."},
    {"question": "What is your education?", "answer": "I am currently pursuing a MSc in Data and AI at Institut Polytechnique de Paris. I have an MPhil in Advanced Computer Science from the University of Cambridge, and a BSc in Business Informatics from RheinMain University of Applied Sciences."},
    {"question": "What are your skills?", "answer": "I am proficient in Python, Java, SQL, Cypher, SPARQL, VBA, JavaScript, HTML/CSS, and Ruby. I also use tools like PyTorch, Hugging Face, Scikit-Learn, NumPy, Pandas, Matplotlib, Jupyter, Git, Bash, IoT, Ansible, QuickSight, and Wordpress."},
    {"question": "How are you?", "answer": "I’m doing great, thanks for asking! I’m enjoying life in Paris and working on some exciting AI projects."},
    {"question": "What do you do?", "answer": "I’m a Computer Scientist and AI enthusiast, currently pursuing a MSc in Data and AI at Institut Polytechnique de Paris and interning as a Machine Learning Research Engineer at Hi! PARIS."},
    {"question": "How’s it going?", "answer": "Things are going well, thanks! I’m busy with my studies and research, but I love the challenges and opportunities I get to explore."},
]

try:
    # Load CV embeddings and build FAISS index (unchanged)
    logger.info("Loading CV embeddings from cv_embeddings.json")
    with open("cv_embeddings.json", "r", encoding="utf-8") as f:
        cv_data = json.load(f)
        cv_chunks = [item["chunk"] for item in cv_data]
        cv_embeddings = np.array([item["embedding"] for item in cv_data]).astype('float32')
    faiss.normalize_L2(cv_embeddings)
    faiss_index = faiss.IndexFlatIP(cv_embeddings.shape[1])
    faiss_index.add(cv_embeddings)
    logger.info("FAISS index built successfully")

    # Load embedding model (unchanged)
    logger.info("Loading SentenceTransformer model")
    embedder = SentenceTransformer(sentence_transformer_model, device="cpu")
    logger.info("SentenceTransformer model loaded")

    # Compute FAQ embeddings (unchanged)
    faq_questions = [faq["question"] for faq in faqs]
    faq_embeddings = embedder.encode(faq_questions, convert_to_numpy=True).astype("float32")
    faiss.normalize_L2(faq_embeddings)

    # Load Cogito model
    logger.info(f"Loading {filename} model")
    model_path = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        local_dir="/app/cache" if os.getenv("HF_HOME") else None,
        token=hf_token,
    )
    generator = Llama(
        model_path=model_path,
        n_ctx=1024,
        n_threads=2,
        n_batch=512,
        n_gpu_layers=0,
        verbose=True,
    )
    logger.info(f"{filename} model loaded")

except Exception as e:
    logger.error(f"Startup error: {str(e)}", exc_info=True)
    raise

def retrieve_context(query, top_k=2):
    try:
        query_embedding = embedder.encode(query, convert_to_numpy=True).astype("float32")
        query_embedding = query_embedding.reshape(1, -1)
        faiss.normalize_L2(query_embedding)
        distances, indices = faiss_index.search(query_embedding, top_k)
        return "\n".join([cv_chunks[i] for i in indices[0]])
    except Exception as e:
        logger.error(f"Error in retrieve_context: {str(e)}")
        raise

def stream_response(query):
    logger.info(f"Processing query: {query}")
    start_time = time.time()
    first_token_logged = False
    
    # FAQ check first
    query_embedding = embedder.encode(query, convert_to_numpy=True).astype("float32")
    query_embedding = query_embedding.reshape(1, -1)
    faiss.normalize_L2(query_embedding)
    similarities = np.dot(faq_embeddings, query_embedding.T).flatten()
    max_sim = np.max(similarities)
    if max_sim > 0.9:
        idx = np.argmax(similarities)
        yield f"data: {faqs[idx]['answer']}\n\n"
        yield "data: [DONE]\n\n"
        return

    context = retrieve_context(query, top_k=2)
    messages = [
        {
            "role": "system",
            "content": (
                "You are Tim Luka Horstmann, a Computer Scientist. A user is asking you a question. Respond as yourself, using the first person, in a friendly and concise manner. "
                "For questions about your CV, base your answer *exclusively* on the provided CV information below and do not add any details not explicitly stated. "
                "For casual questions not covered by the CV, respond naturally but limit answers to general truths about yourself (e.g., your current location is Paris, France, or your field is AI) "
                "and say 'I don’t have specific details to share about that' if pressed for specifics beyond the CV or FAQs. Do not invent facts, experiences, or opinions not supported by the CV or FAQs. "
                f"CV: {context}"
            )
        },
        {"role": "user", "content": query}
    ]
    
    buffer = ""
    for chunk in generator.create_chat_completion(
        messages=messages,
        max_tokens=512,
        stream=True,
        temperature=0.3,
        top_p=0.7,
        repeat_penalty=1.2
    ):
        text = chunk['choices'][0]['delta'].get('content', '')
        if text:
            buffer += text
            if not first_token_logged and time.time() - start_time > 0:
                logger.info(f"First token time: {time.time() - start_time:.2f}s")
                first_token_logged = True
            # Yield when buffer contains a word boundary (space, punctuation, or reasonable length)
            if any(buffer.endswith(char) for char in [" ", ".", ",", "!", "?"]) or len(buffer) > 20:
                yield f"data: {buffer}\n\n"
                buffer = ""
    if buffer:  # Flush remaining buffer
        yield f"data: {buffer}\n\n"
    yield "data: [DONE]\n\n"


class QueryRequest(BaseModel):
    data: list

@app.post("/api/predict")
async def predict(request: QueryRequest):
    if not request.data or not isinstance(request.data, list) or len(request.data) < 1:
        raise HTTPException(status_code=400, detail="Invalid input: 'data' must be a non-empty list")
    query = request.data[0]
    return StreamingResponse(stream_response(query), media_type="text/event-stream")

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

@app.get("/model_info")
async def model_info():
    return {
        "model_name": "deepcogito_cogito-v1-preview-llama-3B-GGUF",
        "model_size": "3B",
        "quantization": "Q4_K_M",
        "embedding_model": sentence_transformer_model,
        "faiss_index_size": len(cv_chunks),
        "faiss_index_dim": cv_embeddings.shape[1],
    }

@app.on_event("startup")
async def warm_up_model():
    logger.info("Warming up the model...")
    dummy_query = "Hi"
    for _ in stream_response(dummy_query):
        pass
    logger.info("Model warm-up complete.")