Akbartus commited on
Commit
9f2c4c2
·
verified ·
1 Parent(s): 5262eea

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -39
main.py CHANGED
@@ -1,47 +1,22 @@
 
 
1
  from pydantic import BaseModel
 
2
 
3
- from .ConfigEnv import config
4
- from fastapi.middleware.cors import CORSMiddleware
5
 
6
- from langchain.llms import Clarifai
7
- from langchain.chains import LLMChain
8
- from langchain.prompts import PromptTemplate
9
- from TextGen import app
10
 
11
- class Generate(BaseModel):
12
- text:str
13
 
14
- def generate_text(prompt: str):
15
- if prompt == "":
16
- return {"detail": "Please provide a prompt."}
17
- else:
18
- prompt = PromptTemplate(template=prompt, input_variables=['Prompt'])
19
- llm = Clarifai(
20
- pat = config.CLARIFAI_PAT,
21
- user_id = config.USER_ID,
22
- app_id = config.APP_ID,
23
- model_id = config.MODEL_ID,
24
- model_version_id=config.MODEL_VERSION_ID,
25
- )
26
- llmchain = LLMChain(
27
- prompt=prompt,
28
- llm=llm
29
- )
30
- llm_response = llmchain.run({"Prompt": prompt})
31
- return Generate(text=llm_response)
32
 
33
- app.add_middleware(
34
- CORSMiddleware,
35
- allow_origins=["*"],
36
- allow_credentials=True,
37
- allow_methods=["*"],
38
- allow_headers=["*"],
39
- )
40
 
41
- @app.get("/", tags=["Home"])
42
- def api_home():
43
- return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}
44
 
45
- @app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate)
46
- def inference(input_prompt: str):
47
- return generate_text(prompt=input_prompt)
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel
4
+ from transformers import pipeline
5
 
6
+ app = FastAPI(debug=True, title="Text classifier")
 
7
 
 
 
 
 
8
 
9
+ class Payload(BaseModel):
10
+ text: str
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ async def classify_text(text):
14
+ pipe = pipeline("text-classification",
15
+ model="SamLowe/roberta-base-go_emotions")
16
+ return pipe(text)
 
 
 
17
 
 
 
 
18
 
19
+ @app.post("/classify")
20
+ async def classify(payload: Payload):
21
+ result = await classify_text(payload.text)
22
+ return JSONResponse({"data": result})