hoshoo21 commited on
Commit
bb9c96b
·
1 Parent(s): ad90168

changing to fast api

Browse files
Files changed (2) hide show
  1. app.py +44 -59
  2. persiststorage.db +0 -0
app.py CHANGED
@@ -1,76 +1,61 @@
1
- from flask import Flask, Response,request, jsonify
2
- from werkzeug.utils import secure_filename
3
- import os
4
  from rag_engine import RagEngine
5
- from flask_cors import CORS, cross_origin
 
 
 
6
 
7
- app = Flask(__name__)
8
- cors = CORS(app)
9
- app.config["CORS_HEADERS"]= 'Content-Type'
 
 
 
 
 
 
 
 
10
 
11
- app.config["UPLOAD_FOLDER"]= "uploads"
12
- os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True)
13
 
14
  rag = RagEngine()
15
 
16
- @app.route("/upload", methods=["POST"])
17
- @cross_origin()
18
- def upload_pdf():
19
- file = request.files.get('file')
20
  if not file.filename.endswith(".pdf"):
21
- return jsonify({"error":"Only pdf files are support"}),400
22
- filename = secure_filename(file.filename)
23
- filepath = os.path.join(app.config["UPLOAD_FOLDER"], filename)
24
- file.save(filepath)
25
-
26
- try :
 
 
 
 
 
27
  rag.index_pdf(filepath)
28
  except ValueError as ve:
29
- return jsonify({"error": str(ve)}), 400
30
-
31
- return jsonify({"message":f"file {filename} uploaded and indexed successfully"})
32
 
 
33
 
34
- @app.route ("/stream", methods=["POST"])
35
- @cross_origin()
36
- def stream_answer():
37
-
38
- question = request.json.get("question", "")
39
- print (question)
40
- if not question.strip():
41
- return jsonify({"error": "Empty question"}), 400
42
 
43
- def generate():
44
- for token in rag.ask_question_stream(question):
45
- yield token
46
 
47
- return Response(generate(), mimetype='text/plain')
48
-
49
-
50
- @app.route("/ask", methods=["POST"])
51
- @cross_origin()
52
- def ask():
53
- question = request.json.get("question", "")
54
  if not question.strip():
55
- return jsonify({"error": "Empty question"}), 400
56
- try :
57
- answer = rag.ask_question(question)
58
- except Exception as e:
59
- return jsonify({"error": str(e)}),500
60
- return jsonify({"message": answer})
61
 
62
- @app.route("/stream_answer",methods=["POST"])
63
- @cross_origin()
64
- def stream_question():
65
- data = request.get_json()
66
- question = data.get("question","")
67
- if not question:
68
- return jsonify({"error": "No question provided"}),400
69
- def event_stream():
70
- for token in rag.stream_answer(question=question):
71
  yield token
72
- return Response(event_stream(), content_type ="text/event-stream")
73
-
74
 
75
- if __name__ == "__main__":
76
- app.run(host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import os
4
  from rag_engine import RagEngine
5
+ from starlette.responses import JSONResponse
6
+ from starlette.status import HTTP_400_BAD_REQUEST
7
+ from fastapi.responses import StreamingResponse
8
+ import asyncio
9
 
10
+ app = FastAPI()
11
+
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"], # or specify your allowed origins
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],)
18
+
19
+ UPLOAD_FOLDER = "uploads"
20
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
21
 
 
 
22
 
23
  rag = RagEngine()
24
 
25
+ @app.post("/upload")
26
+ async def upload_pdf(file: UploadFile = File(...)):
 
 
27
  if not file.filename.endswith(".pdf"):
28
+ raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Only pdf files are supported")
29
+
30
+ filename = os.path.basename(file.filename) # simple sanitization
31
+ filepath = os.path.join(UPLOAD_FOLDER, filename)
32
+
33
+ # Save uploaded file to disk
34
+ with open(filepath, "wb") as buffer:
35
+ content = await file.read()
36
+ buffer.write(content)
37
+
38
+ try:
39
  rag.index_pdf(filepath)
40
  except ValueError as ve:
41
+ raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=str(ve))
 
 
42
 
43
+ return JSONResponse(content={"message": f"file {filename} uploaded and indexed successfully"})
44
 
 
 
 
 
 
 
 
 
45
 
 
 
 
46
 
47
+ @app.post("/stream")
48
+ async def stream_answer(request:Request ):
49
+ data = await request.json()
50
+ question = data.get("question", "")
51
+ print(question)
 
 
52
  if not question.strip():
53
+ raise HTTPException(status_code=400, detail="Empty question")
 
 
 
 
 
54
 
55
+ async def generate():
56
+ # Assuming rag.ask_question_stream is a generator
57
+ for token in rag.stream_answer(question):
 
 
 
 
 
 
58
  yield token
59
+ await asyncio.sleep(0) # yield control to event loop
 
60
 
61
+ return StreamingResponse(generate(), media_type="text/plain")
 
persiststorage.db CHANGED
Binary files a/persiststorage.db and b/persiststorage.db differ