Akbartus commited on
Commit
a132d67
·
verified ·
1 Parent(s): 9f1432d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +39 -15
main.py CHANGED
@@ -1,23 +1,47 @@
1
- from transformers import pipeline, set_seed
2
- from fastapi import FastAPI
3
  from pydantic import BaseModel
4
- import uvicorn
5
 
 
 
6
 
 
 
 
 
7
 
8
- app = FastAPI()
 
9
 
10
- class Item(BaseModel):
11
- prompt: str
12
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- generator = pipeline('text-generation', model='gpt2')
15
- set_seed(42)
 
 
 
 
 
16
 
17
- def generate(item: Item):
18
- generator(item.prompt, max_length=30, num_return_sequences=5)
19
-
20
 
21
- @app.post("/generate/")
22
- async def generate_text(item: Item):
23
- return {"response": generate(item)}
 
 
 
 
 
 
1
  from pydantic import BaseModel
 
2
 
3
+ from .ConfigEnv import config
4
+ from fastapi.middleware.cors import CORSMiddleware
5
 
6
+ from langchain.llms import Clarifai
7
+ from langchain.chains import LLMChain
8
+ from langchain.prompts import PromptTemplate
9
+ from TextGen import app
10
 
11
+ class Generate(BaseModel):
12
+ text:str
13
 
14
+ def generate_text(prompt: str):
15
+ if prompt == "":
16
+ return {"detail": "Please provide a prompt."}
17
+ else:
18
+ prompt = PromptTemplate(template=prompt, input_variables=['Prompt'])
19
+ llm = Clarifai(
20
+ pat = config.CLARIFAI_PAT,
21
+ user_id = config.USER_ID,
22
+ app_id = config.APP_ID,
23
+ model_id = config.MODEL_ID,
24
+ model_version_id=config.MODEL_VERSION_ID,
25
+ )
26
+ llmchain = LLMChain(
27
+ prompt=prompt,
28
+ llm=llm
29
+ )
30
+ llm_response = llmchain.run({"Prompt": prompt})
31
+ return Generate(text=llm_response)
32
 
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=["*"],
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
 
 
 
 
41
 
42
+ def api_home():
43
+ return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}
44
+
45
+
46
+ def inference(input_prompt: str):
47
+ return generate_text(prompt=input_prompt)