File size: 3,841 Bytes
8066b54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import List
import os
from dotenv import load_dotenv

# Import our local utilities instead of aimakerspace
from text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from openai_utils import SystemRolePrompt, UserRolePrompt, ChatOpenAI
from vector_store import VectorDatabase

load_dotenv()

app = FastAPI()

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Prompt templates
system_template = """\
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
system_role_prompt = SystemRolePrompt(system_template)

user_prompt_template = """\
Context:
{context}

Question:
{question}
"""
user_role_prompt = UserRolePrompt(user_prompt_template)

# Initialize components
text_splitter = CharacterTextSplitter()
chat_openai = ChatOpenAI()

# Store vector databases for each session
vector_dbs = {}

class QueryRequest(BaseModel):
    session_id: str
    query: str

@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
    try:
        # Save file temporarily
        file_path = f"temp_{file.filename}"
        with open(file_path, "wb") as f:
            content = await file.read()
            f.write(content)

        # Process file
        loader = PDFLoader(file_path) if file.filename.lower().endswith('.pdf') else TextFileLoader(file_path)
        documents = loader.load_documents()
        texts = text_splitter.split_texts(documents)

        # Create vector database
        vector_db = VectorDatabase()
        vector_db = await vector_db.abuild_from_list(texts)

        # Generate session ID and store vector_db
        import uuid
        session_id = str(uuid.uuid4())
        vector_dbs[session_id] = vector_db

        # Cleanup
        os.remove(file_path)

        return {"session_id": session_id, "message": "File processed successfully"}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/query")
async def query(request: QueryRequest):
    try:
        vector_db = vector_dbs.get(request.session_id)
        if not vector_db:
            raise HTTPException(status_code=404, detail="Session not found")

        # Retrieve context
        context_list = await vector_db.search_by_text(request.query, k=4)
        context_prompt = "\n".join([str(context[0]) for context in context_list])

        # Generate prompts
        formatted_system_prompt = system_role_prompt.create_message()
        formatted_user_prompt = user_role_prompt.create_message(
            question=request.query, 
            context=context_prompt
        )

        # Get response
        response = await chat_openai.acomplete(
            [formatted_system_prompt, formatted_user_prompt]
        )

        return {
            "answer": str(response),
            "context": [str(context[0]) for context in context_list]
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Optional: Cleanup endpoint
@app.delete("/session/{session_id}")
async def cleanup_session(session_id: str):
    if session_id in vector_dbs:
        del vector_dbs[session_id]
        return {"message": "Session cleaned up successfully"}
    raise HTTPException(status_code=404, detail="Session not found")

# Serve static files from static directory
app.mount("/static", StaticFiles(directory="static"), name="static")

@app.get("/")
async def read_root():
    return FileResponse('static/index.html')