Coots commited on
Commit
072b9fb
·
verified ·
1 Parent(s): 1b75a9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -12
app.py CHANGED
@@ -1,19 +1,20 @@
 
 
1
  from fastapi import FastAPI, Request
2
- from fastapi.responses import HTMLResponse, JSONResponse
3
  from fastapi.staticfiles import StaticFiles
4
- from pathlib import Path
5
- import os
6
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
7
 
8
- # Set Hugging Face cache dir
9
  cache_dir = os.getenv("TRANSFORMERS_CACHE", "/cache")
10
  os.makedirs(cache_dir, exist_ok=True)
11
 
12
- # Token for private models
13
  hf_token = os.getenv("HF_TOKEN")
14
 
15
- # Load model
16
  model_id = "mistralai/Mistral-7B-Instruct-v0.2"
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir)
18
  model = AutoModelForCausalLM.from_pretrained(
19
  model_id,
@@ -23,17 +24,26 @@ model = AutoModelForCausalLM.from_pretrained(
23
  torch_dtype="auto"
24
  )
25
 
26
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
 
 
 
 
 
 
 
 
 
27
 
28
- # Create FastAPI app
29
  app = FastAPI()
30
 
31
- # Serve static files (JS, CSS, etc.)
32
  app.mount("/static", StaticFiles(directory="."), name="static")
33
 
34
  # Route: Serve index.html at root
35
  @app.get("/", response_class=HTMLResponse)
36
- async def root():
37
  html_path = Path("index.html")
38
  return HTMLResponse(content=html_path.read_text(), status_code=200)
39
 
@@ -47,5 +57,13 @@ async def ask_ai(request: Request):
47
  return JSONResponse(content={"answer": "❗ Please enter a valid question."})
48
 
49
  prompt = f"[INST] {question} [/INST]"
50
- output = pipe(prompt)[0]["generated_text"]
51
- return JSONResponse(content={"answer": output.strip()})
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
  from fastapi import FastAPI, Request
4
+ from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
5
  from fastapi.staticfiles import StaticFiles
 
 
6
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
7
 
8
+ # Setup Hugging Face cache directory
9
  cache_dir = os.getenv("TRANSFORMERS_CACHE", "/cache")
10
  os.makedirs(cache_dir, exist_ok=True)
11
 
12
+ # Optional token (for private models)
13
  hf_token = os.getenv("HF_TOKEN")
14
 
15
+ # Load model and tokenizer
16
  model_id = "mistralai/Mistral-7B-Instruct-v0.2"
17
+
18
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir)
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
 
24
  torch_dtype="auto"
25
  )
26
 
27
+ # Build generation pipeline
28
+ pipe = pipeline(
29
+ "text-generation",
30
+ model=model,
31
+ tokenizer=tokenizer,
32
+ max_new_tokens=256,
33
+ temperature=0.7,
34
+ top_p=0.9,
35
+ repetition_penalty=1.1
36
+ )
37
 
38
+ # Initialize FastAPI app
39
  app = FastAPI()
40
 
41
+ # Serve static files like script.js
42
  app.mount("/static", StaticFiles(directory="."), name="static")
43
 
44
  # Route: Serve index.html at root
45
  @app.get("/", response_class=HTMLResponse)
46
+ async def serve_home():
47
  html_path = Path("index.html")
48
  return HTMLResponse(content=html_path.read_text(), status_code=200)
49
 
 
57
  return JSONResponse(content={"answer": "❗ Please enter a valid question."})
58
 
59
  prompt = f"[INST] {question} [/INST]"
60
+ try:
61
+ output = pipe(prompt)[0]["generated_text"]
62
+ return JSONResponse(content={"answer": output.strip()})
63
+ except Exception as e:
64
+ return JSONResponse(content={"answer": f"⚠️ Error: {str(e)}"})
65
+
66
+ # Optional: Serve script.js if not using /static path in HTML
67
+ @app.get("/script.js")
68
+ async def serve_script():
69
+ return FileResponse("script.js", media_type="application/javascript")