File size: 3,308 Bytes
9b74ec6 06f0356 9b74ec6 5464450 a3a9074 9b74ec6 a7d6d41 a3a9074 9b74ec6 bd2f642 9b74ec6 65a2535 9b74ec6 a3a9074 74c6866 a3a9074 65a2535 9b74ec6 1ba0543 9b74ec6 1ba0543 65a2535 1ba0543 4113730 18416fb 1ba0543 65a2535 1ba0543 18416fb 1ba0543 65a2535 1ba0543 65a2535 9b74ec6 a3a9074 a7d6d41 a3a9074 5464450 bd2f642 5464450 bd2f642 5464450 9b74ec6 a3a9074 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
import numpy as np
app = FastAPI()
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
# Define request models
class ModifyQueryRequest(BaseModel):
query_string: str
# Define request models
class ModifyQueryRequest_v3(BaseModel):
query_string_list: [str]
class AnswerQuestionRequest(BaseModel):
question: str
context: list
locations: list
class T5QuestionRequest(BaseModel):
context: str
class T5Response(BaseModel):
answer: str
# Define response models (if needed)
class ModifyQueryResponse(BaseModel):
embeddings: list
class AnswerQuestionResponse(BaseModel):
answer: str
locations: list
# Define API endpoints
@app.post("/modify_query", response_model=ModifyQueryResponse)
async def modify_query(request: ModifyQueryRequest):
try:
binary_embeddings = model.encode([request.query_string], precision="binary")
return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/answer_question", response_model=AnswerQuestionResponse)
async def answer_question(request: AnswerQuestionRequest):
try:
res_locs = []
context_string = ''
corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
query_embeddings = model.encode(request.question, convert_to_tensor=True)
hits = util.semantic_search(query_embeddings, corpus_embeddings)
for hit in hits[0]:
if hit['score'] > .4:
loc = hit['corpus_id']
res_locs.append(request.locations[loc])
context_string += request.context[loc] + ' '
if len(res_locs) == 0:
ans = "Sorry, I couldn't find any results for your query. Please try again!"
else:
QA_input = {
'question': request.question,
'context': context_string.replace('\n',' ')
}
result = nlp(QA_input)
ans = result['answer']
return AnswerQuestionResponse(answer=ans, locations = res_locs)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/t5answer", response_model=T5Response)
async def t5answer(request: T5QuestionRequest):
resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
return T5Response(answer = resp[0]["summary_text"])
# Define API endpoints
@app.post("/modify_query_v3", response_model=ModifyQueryResponse)
async def modify_query2(request: ModifyQueryRequest_v3):
try:
embeddings = model.encode(request.query_string_list)
return ModifyQueryResponse(embeddings=[emb.tolist() for emb in embeddings])
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|