Tim Luka Horstmann commited on
Commit
588cb6a
·
1 Parent(s): a7ba255

Updated app to use ctransformers and gemma without token

Browse files
Files changed (1) hide show
  1. app.py +70 -40
app.py CHANGED
@@ -1,60 +1,86 @@
1
  import json
2
  import numpy as np
3
  from sentence_transformers import SentenceTransformer
4
- from transformers import pipeline, TextIteratorStreamer
5
- from threading import Thread
6
  import torch
7
  import torch.nn.functional as F
8
  from fastapi import FastAPI, HTTPException
9
  from fastapi.responses import StreamingResponse
10
  from pydantic import BaseModel
 
 
 
 
 
 
 
 
11
 
12
  app = FastAPI()
13
 
14
- # Load precomputed CV embeddings
15
- with open("cv_embeddings.json", "r", encoding="utf-8") as f:
16
- cv_data = json.load(f)
17
- cv_chunks = [item["chunk"] for item in cv_data]
18
- cv_embeddings = np.array([item["embedding"] for item in cv_data])
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- cv_embeddings_tensor = torch.tensor(cv_embeddings)
 
 
 
21
 
22
- embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
 
 
 
 
 
 
 
 
23
 
24
- generator = pipeline(
25
- "text-generation",
26
- model="distilgpt2",
27
- device=-1,
28
- )
29
 
30
  def retrieve_context(query, top_k=3):
31
- query_embedding = embedder.encode(query, convert_to_tensor=True).unsqueeze(0)
32
- similarities = F.cosine_similarity(query_embedding, cv_embeddings_tensor, dim=1)
33
- top_k = min(top_k, len(similarities))
34
- top_indices = torch.topk(similarities, k=top_k).indices.cpu().numpy()
35
- return "\n".join([cv_chunks[i] for i in top_indices])
 
 
 
 
36
 
37
  def stream_response(query):
38
- context = retrieve_context(query)
39
- prompt = (
40
- f"I am Tim Luka Horstmann, a German Computer Scientist. Based on my CV:\n{context}\n\n"
41
- f"Question: {query}\nAnswer:"
42
- )
43
-
44
- streamer = TextIteratorStreamer(generator.tokenizer, skip_prompt=True, skip_special_tokens=True)
45
- generation_kwargs = {
46
- "text_inputs": prompt,
47
- "max_new_tokens": 200,
48
- "do_sample": False,
49
- "streamer": streamer,
50
- }
51
-
52
- thread = Thread(target=generator, kwargs=generation_kwargs)
53
- thread.start()
54
-
55
- for token in streamer:
56
- yield f"data: {token}\n\n"
57
- yield "data: [DONE]\n\n"
58
 
59
  class QueryRequest(BaseModel):
60
  data: list
@@ -64,4 +90,8 @@ async def predict(request: QueryRequest):
64
  if not request.data or not isinstance(request.data, list) or len(request.data) < 1:
65
  raise HTTPException(status_code=400, detail="Invalid input: 'data' must be a non-empty list")
66
  query = request.data[0]
67
- return StreamingResponse(stream_response(query), media_type="text/event-stream")
 
 
 
 
 
1
  import json
2
  import numpy as np
3
  from sentence_transformers import SentenceTransformer
 
 
4
  import torch
5
  import torch.nn.functional as F
6
  from fastapi import FastAPI, HTTPException
7
  from fastapi.responses import StreamingResponse
8
  from pydantic import BaseModel
9
+ from ctransformers import AutoModelForCausalLM
10
+ from huggingface_hub import login
11
+ import logging
12
+ import os
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
 
18
  app = FastAPI()
19
 
20
+ # Authenticate with Hugging Face
21
+ hf_token = os.getenv("HF_TOKEN")
22
+ if not hf_token:
23
+ logger.error("HF_TOKEN environment variable not set. Required for gated models.")
24
+ raise ValueError("HF_TOKEN not set")
25
+ login(token=hf_token) # Set token for huggingface_hub
26
+
27
+ try:
28
+ # Load precomputed CV embeddings
29
+ logger.info("Loading CV embeddings from cv_embeddings.json")
30
+ with open("cv_embeddings.json", "r", encoding="utf-8") as f:
31
+ cv_data = json.load(f)
32
+ cv_chunks = [item["chunk"] for item in cv_data]
33
+ cv_embeddings = np.array([item["embedding"] for item in cv_data])
34
+ cv_embeddings_tensor = torch.tensor(cv_embeddings)
35
+ logger.info("CV embeddings loaded successfully")
36
 
37
+ # Load embedding model
38
+ logger.info("Loading SentenceTransformer model")
39
+ embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
40
+ logger.info("SentenceTransformer model loaded")
41
 
42
+ # Load Gemma 3 model with ctransformers
43
+ logger.info("Loading Gemma 3 model")
44
+ generator = AutoModelForCausalLM.from_pretrained(
45
+ "google/gemma-3-12b-it-qat-q4_0-gguf",
46
+ local_files_only=False,
47
+ model_type="gemma",
48
+ model_file="gemma-3-12b-it-q4_0.gguf",
49
+ )
50
+ logger.info("Gemma 3 model loaded")
51
 
52
+ except Exception as e:
53
+ logger.error(f"Startup error: {str(e)}", exc_info=True)
54
+ raise
 
 
55
 
56
  def retrieve_context(query, top_k=3):
57
+ try:
58
+ query_embedding = embedder.encode(query, convert_to_tensor=True).unsqueeze(0)
59
+ similarities = F.cosine_similarity(query_embedding, cv_embeddings_tensor, dim=1)
60
+ top_k = min(top_k, len(similarities))
61
+ top_indices = torch.topk(similarities, k=top_k).indices.cpu().numpy()
62
+ return "\n".join([cv_chunks[i] for i in top_indices])
63
+ except Exception as e:
64
+ logger.error(f"Error in retrieve_context: {str(e)}")
65
+ raise
66
 
67
  def stream_response(query):
68
+ try:
69
+ logger.info(f"Processing query: {query}")
70
+ context = retrieve_context(query)
71
+ prompt = (
72
+ f"I am Tim Luka Horstmann, a German Computer Scientist. Based on my CV:\n{context}\n\n"
73
+ f"Question: {query}\nAnswer:"
74
+ )
75
+
76
+ # Stream response with ctransformers
77
+ for token in generator(prompt, max_new_tokens=512, stream=True):
78
+ yield f"data: {token}\n\n"
79
+ yield "data: [DONE]\n\n"
80
+ except Exception as e:
81
+ logger.error(f"Error in stream_response: {str(e)}")
82
+ yield f"data: Error: {str(e)}\n\n"
83
+ yield "data: [DONE]\n\n"
 
 
 
 
84
 
85
  class QueryRequest(BaseModel):
86
  data: list
 
90
  if not request.data or not isinstance(request.data, list) or len(request.data) < 1:
91
  raise HTTPException(status_code=400, detail="Invalid input: 'data' must be a non-empty list")
92
  query = request.data[0]
93
+ return StreamingResponse(stream_response(query), media_type="text/event-stream")
94
+
95
+ @app.get("/health")
96
+ async def health_check():
97
+ return {"status": "healthy"}