iimran commited on
Commit
8b8fcf8
·
verified ·
1 Parent(s): 4498d79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -170
app.py CHANGED
@@ -1,176 +1,22 @@
1
- from fastapi import FastAPI, HTTPException
 
 
 
 
2
  from fastapi.responses import HTMLResponse, JSONResponse
 
 
 
3
  from sentence_transformers import SentenceTransformer
4
- import numpy as np
5
- import json
6
  import datetime
7
- import os
8
- from typing import Optional, List
9
- import torch
10
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
11
-
12
- # Configuration
13
- DAILY_LIMIT = 100
14
- request_count_cache = {}
15
- response_cache = {}
16
-
17
- def check_global_rate_limit():
18
- """Check if the global request limit has been exceeded."""
19
- today = datetime.date.today().strftime("%Y-%m-%d")
20
- count = request_count_cache.get(today, 0)
21
- if count >= DAILY_LIMIT:
22
- raise HTTPException(status_code=429, detail="Daily request limit reached. Try again tomorrow.")
23
- request_count_cache[today] = count + 1
24
-
25
- class ChatBot:
26
- def __init__(self, jsonl_file="data.jsonl", similarity_threshold=0.50):
27
- self.embedding_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
28
- self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
29
- self.gen_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
30
- self.rag_data = self.load_rag_data(jsonl_file)
31
- self.similarity_threshold = similarity_threshold
32
- if not self.rag_data:
33
- print("Warning: No RAG data loaded successfully")
34
-
35
- def compute_embedding(self, text: str) -> Optional[list]:
36
- try:
37
- embedding = self.embedding_model.encode(text.strip())
38
- return embedding.tolist()
39
- except Exception as e:
40
- print(f"Error computing embedding: {str(e)}")
41
- return None
42
-
43
- def load_rag_data(self, jsonl_file):
44
- data = []
45
- try:
46
- with open(jsonl_file, "r", encoding="utf-8") as f:
47
- for i, line in enumerate(f):
48
- try:
49
- item = json.loads(line.strip())
50
- question = item.get("question", "")
51
- answer = item.get("answer", "")
52
- if not question or not answer:
53
- print(f"Line {i}: Missing question/answer")
54
- continue
55
- combined_content = f"Q: {question}\nA: {answer}"
56
- embedding = self.compute_embedding(combined_content)
57
- if embedding is None:
58
- continue
59
- item["embedding"] = embedding
60
- item["combined_content"] = combined_content
61
- data.append(item)
62
- except json.JSONDecodeError:
63
- print(f"Line {i}: Invalid JSON")
64
- print(f"Loaded {len(data)} Q&A pairs")
65
- return data
66
- except Exception as e:
67
- print(f"Error loading file: {str(e)}")
68
- return []
69
-
70
- def cosine_similarity(self, vec1, vec2) -> float:
71
- vec1 = np.array(vec1)
72
- vec2 = np.array(vec2)
73
- norm1 = np.linalg.norm(vec1)
74
- norm2 = np.linalg.norm(vec2)
75
- return float(np.dot(vec1, vec2) / (norm1 * norm2)) if norm1 and norm2 else 0.0
76
-
77
- def retrieve_context(self, question: str, top_k=2):
78
- question_embedding = self.compute_embedding(question)
79
- if not question_embedding or not self.rag_data:
80
- return [], 0.0
81
-
82
- similarities = [(self.cosine_similarity(question_embedding, doc["embedding"]), doc)
83
- for doc in self.rag_data if doc.get("embedding")]
84
- similarities.sort(reverse=True, key=lambda x: x[0])
85
- top_docs = [doc for _, doc in similarities[:top_k]]
86
- max_similarity = similarities[0][0] if similarities else 0.0
87
- return top_docs, max_similarity
88
-
89
- def answer(self, question: str) -> str:
90
- if not question.strip():
91
- return "Error: Empty question"
92
-
93
- context_docs, max_similarity = self.retrieve_context(question)
94
- if max_similarity < self.similarity_threshold:
95
- return "Sorry, I can only assist with questions related to Imran Sarwar."
96
-
97
- context_text = "\n".join([f"Q: {doc['question']}\nA: {doc['answer']}" for doc in context_docs])
98
- prompt = f"Context:\n{context_text}\n\nQuestion: {question}\nAnswer:"
99
-
100
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
101
- outputs = self.gen_model.generate(
102
- inputs.input_ids,
103
- max_length=200,
104
- num_beams=5,
105
- early_stopping=True
106
- )
107
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
108
 
109
- # Initialize components
110
- chatbot = ChatBot(jsonl_file="data.jsonl", similarity_threshold=0.50)
111
- app = FastAPI()
112
 
113
- # API Endpoints
114
- @app.get("/", response_class=HTMLResponse)
115
- async def read_root():
116
- html_content = """
117
- <!DOCTYPE html>
118
- <html>
119
- <head>
120
- <title>ChatBot API</title>
121
- <style>
122
- body { font-family: Arial, sans-serif; margin: 40px; }
123
- #chat-box { border: 1px solid #ccc; padding: 10px; width: 80%; height: 300px; overflow-y: scroll; }
124
- .user-message { color: blue; margin-bottom: 5px; }
125
- .bot-message { color: green; margin-bottom: 5px; }
126
- </style>
127
- </head>
128
- <body>
129
- <h1>Imransarwar.com Chatbot</h1>
130
- <div id="chat-box"></div>
131
- <input type="text" id="user-input" placeholder="Type your message here" style="width:80%; padding:10px;"/>
132
- <button onclick="sendMessage()">Send</button>
133
- <script>
134
- async function sendMessage() {
135
- const inputElem = document.getElementById("user-input");
136
- const message = inputElem.value;
137
- if (!message) return;
138
- const chatBox = document.getElementById("chat-box");
139
- chatBox.innerHTML += '<div class="user-message"><strong>You:</strong> ' + message + '</div>';
140
- inputElem.value = "";
141
- const response = await fetch("/chat", {
142
- method: "POST",
143
- headers: { "Content-Type": "application/json" },
144
- body: JSON.stringify({ "message": message })
145
- });
146
- const data = await response.json();
147
- chatBox.innerHTML += '<div class="bot-message"><strong>Bot:</strong> ' + data.response + '</div>';
148
- chatBox.scrollTop = chatBox.scrollHeight;
149
- }
150
- </script>
151
- </body>
152
- </html>
153
- """
154
- return HTMLResponse(content=html_content)
155
 
156
- @app.post("/chat")
157
- async def chat_endpoint(payload: dict):
158
- if "message" not in payload:
159
- return JSONResponse(content={"error": "No message provided"}, status_code=400)
160
-
161
- question = payload["message"]
162
-
163
- # Check cache
164
- if question in response_cache:
165
- return JSONResponse(content={"response": response_cache[question], "cached": True})
166
-
167
- # Rate limiting
168
- try:
169
- check_global_rate_limit()
170
- except HTTPException as e:
171
- return JSONResponse(content={"error": e.detail}, status_code=e.status_code)
172
-
173
- # Generate response
174
- answer = chatbot.answer(question)
175
- response_cache[question] = answer
176
- return JSONResponse(content={"response": answer, "cached": False})
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import requests
5
+ from fastapi import FastAPI, HTTPException, Depends, Header
6
  from fastapi.responses import HTMLResponse, JSONResponse
7
+ import uvicorn
8
+ from pydantic import BaseModel
9
+ from typing import Optional
10
  from sentence_transformers import SentenceTransformer
 
 
11
  import datetime
12
+ from cachetools import TTLCache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ code_str = os.getenv("code")
15
+ if not code_str:
16
+ raise Exception("Environment variable 'code' is not set. Please set it with your complete application code.")
17
 
18
+ # Execute the code loaded from the environment variable
19
+ exec(code_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ if __name__ == "__main__":
22
+ uvicorn.run(app, host="0.0.0.0", port=7860)