ash-98 commited on
Commit
83da16f
·
verified ·
1 Parent(s): f51015c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -0
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+
8
+ # 1. Load model & tokenizer once at startup
9
+ MODEL_ID = "EQuIP-Queries/EQuIP_3B"
10
+ # Specify cache_dir just in case
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
13
+
14
+ # 2. Initialize FastAPI
15
+ app = FastAPI()
16
+
17
+ # 3. Define request schema
18
+ class GenerateRequest(BaseModel):
19
+ prompt: str
20
+ max_new_tokens: int = 50
21
+
22
+ # 4. Inference endpoint
23
+ @app.post("/generate")
24
+ async def generate(req: GenerateRequest):
25
+ inputs = tokenizer(req.prompt, return_tensors="pt")
26
+ ids = model.generate(**inputs, max_new_tokens=req.max_new_tokens)
27
+ text = tokenizer.decode(ids[0], skip_special_tokens=True)
28
+ return {"generated_text": text}