Coots commited on
Commit
24aaaa1
·
verified ·
1 Parent(s): befba3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -44
app.py CHANGED
@@ -1,52 +1,29 @@
1
  import os
2
- from pathlib import Path
3
  from fastapi import FastAPI, Request
4
- from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
5
  from fastapi.staticfiles import StaticFiles
6
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
7
-
8
- # Setup Hugging Face cache directory
9
- cache_dir = os.getenv("TRANSFORMERS_CACHE", "/cache")
10
- os.makedirs(cache_dir, exist_ok=True)
11
-
12
- # Optional token (for private models)
13
- hf_token = os.getenv("HF_TOKEN")
14
 
15
- # Load model and tokenizer
16
  model_id = "google/flan-t5-large"
17
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir)
18
- model = AutoModelForCausalLM.from_pretrained(
19
- model_id,
20
- token=hf_token,
21
- cache_dir=cache_dir,
22
- device_map="auto",
23
- torch_dtype="auto"
24
- )
25
 
26
- # Build generation pipeline
27
- pipe = pipeline(
28
- "text-generation",
29
- model=model,
30
- tokenizer=tokenizer,
31
- max_new_tokens=256,
32
- temperature=0.7,
33
- top_p=0.9,
34
- repetition_penalty=1.1
35
- )
36
 
37
- # Initialize FastAPI app
38
  app = FastAPI()
39
 
40
- # Serve static files like script.js
41
  app.mount("/static", StaticFiles(directory="."), name="static")
42
 
43
- # Route: Serve index.html at root
44
  @app.get("/", response_class=HTMLResponse)
45
- async def serve_home():
46
- html_path = Path("index.html")
47
- return HTMLResponse(content=html_path.read_text(), status_code=200)
48
 
49
- # Route: Chat API
50
  @app.post("/api")
51
  async def ask_ai(request: Request):
52
  data = await request.json()
@@ -55,14 +32,8 @@ async def ask_ai(request: Request):
55
  if not question:
56
  return JSONResponse(content={"answer": "❗ Please enter a valid question."})
57
 
58
- prompt = f"[INST] {question} [/INST]"
59
  try:
60
- output = pipe(prompt)[0]["generated_text"]
61
- return JSONResponse(content={"answer": output.strip()})
62
  except Exception as e:
63
  return JSONResponse(content={"answer": f"⚠️ Error: {str(e)}"})
64
-
65
- # Optional: Serve script.js if not using /static path in HTML
66
- @app.get("/script.js")
67
- async def serve_script():
68
- return FileResponse("script.js", media_type="application/javascript")
 
1
  import os
 
2
  from fastapi import FastAPI, Request
3
+ from fastapi.responses import JSONResponse, HTMLResponse
4
  from fastapi.staticfiles import StaticFiles
5
+ from pathlib import Path
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
 
 
 
 
 
7
 
8
+ # Load T5 model for CPU-friendly inference
9
  model_id = "google/flan-t5-large"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
 
 
 
 
 
 
12
 
13
+ pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
14
 
15
+ # Init FastAPI app
16
  app = FastAPI()
17
 
18
+ # Mount static directory (for script.js and index.html in root)
19
  app.mount("/static", StaticFiles(directory="."), name="static")
20
 
21
+ # Serve the HTML page
22
  @app.get("/", response_class=HTMLResponse)
23
+ async def serve_page():
24
+ return HTMLResponse(Path("index.html").read_text())
 
25
 
26
+ # Chat API route
27
  @app.post("/api")
28
  async def ask_ai(request: Request):
29
  data = await request.json()
 
32
  if not question:
33
  return JSONResponse(content={"answer": "❗ Please enter a valid question."})
34
 
 
35
  try:
36
+ response = pipe(question, max_new_tokens=256)[0]["generated_text"]
37
+ return JSONResponse(content={"answer": response.strip()})
38
  except Exception as e:
39
  return JSONResponse(content={"answer": f"⚠️ Error: {str(e)}"})