File size: 3,308 Bytes
9b74ec6
 
06f0356
9b74ec6
5464450
a3a9074
9b74ec6
 
 
 
 
 
 
a7d6d41
a3a9074
9b74ec6
 
 
 
bd2f642
 
 
 
9b74ec6
 
65a2535
 
9b74ec6
a3a9074
74c6866
a3a9074
 
 
 
65a2535
9b74ec6
 
 
 
 
1ba0543
9b74ec6
 
 
 
 
 
 
 
 
 
 
 
 
1ba0543
 
65a2535
1ba0543
 
4113730
18416fb
1ba0543
65a2535
 
1ba0543
18416fb
1ba0543
 
 
65a2535
1ba0543
 
 
65a2535
9b74ec6
 
 
a3a9074
 
a7d6d41
 
a3a9074
5464450
 
bd2f642
 
5464450
bd2f642
 
5464450
 
 
 
 
9b74ec6
 
 
a3a9074
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
import numpy as np


app = FastAPI()

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)

summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

# Define request models
class ModifyQueryRequest(BaseModel):
    query_string: str

# Define request models
class ModifyQueryRequest_v3(BaseModel):
    query_string_list: [str]

class AnswerQuestionRequest(BaseModel):
    question: str
    context: list
    locations: list

class T5QuestionRequest(BaseModel):
    context: str

class T5Response(BaseModel):
    answer: str

# Define response models (if needed)
class ModifyQueryResponse(BaseModel):
    embeddings: list

class AnswerQuestionResponse(BaseModel):
    answer: str
    locations: list

# Define API endpoints
@app.post("/modify_query", response_model=ModifyQueryResponse)
async def modify_query(request: ModifyQueryRequest):
    try:
        binary_embeddings = model.encode([request.query_string], precision="binary")
        return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/answer_question", response_model=AnswerQuestionResponse)
async def answer_question(request: AnswerQuestionRequest):
    try:
        res_locs = []
        context_string = ''
        corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
        query_embeddings = model.encode(request.question, convert_to_tensor=True)
        hits = util.semantic_search(query_embeddings, corpus_embeddings)
        for hit in hits[0]:
            if hit['score'] > .4:
                loc = hit['corpus_id']
                res_locs.append(request.locations[loc])
                context_string += request.context[loc] + ' '
        if len(res_locs) == 0:
            ans = "Sorry, I couldn't find any results for your query. Please try again!"
        else:
            QA_input = {
                'question': request.question,
                'context': context_string.replace('\n',' ')
            }
            result = nlp(QA_input)
            ans = result['answer']
        return AnswerQuestionResponse(answer=ans, locations = res_locs)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/t5answer", response_model=T5Response)
async def t5answer(request: T5QuestionRequest):
    resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
    return T5Response(answer = resp[0]["summary_text"])


# Define API endpoints
@app.post("/modify_query_v3", response_model=ModifyQueryResponse)
async def modify_query2(request: ModifyQueryRequest_v3):
    try:
        embeddings = model.encode(request.query_string_list)
        return ModifyQueryResponse(embeddings=[emb.tolist() for emb in embeddings])
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))



if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)