JeetSuthar commited on
Commit
950f514
·
verified ·
1 Parent(s): b15241a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -36
app.py CHANGED
@@ -1,55 +1,33 @@
1
- # from fastapi import FastAPI
2
-
3
- # app = FastAPI()
4
-
5
- # @app.get("/")
6
- # def greet_json():
7
- # return {"Hello": "World!"}
8
-
9
-
10
  from fastapi import FastAPI, HTTPException
11
- from pydantic import BaseModel
12
  from transformers import AutoModelForCausalLM, AutoTokenizer
13
  import sqlite3
14
  import torch
15
 
16
  app = FastAPI()
17
 
18
- # Load the DeepSeek model and tokenizer
19
  MODEL_NAME = "deepseek-ai/deepseek-coder-1.3b-instruct"
20
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to("cpu") # Use "cuda" if available
22
-
23
-
24
- class ChatRequest(BaseModel):
25
- message: str
26
 
27
- def generate_sql_query(user_input: str) -> str:
28
- """
29
- Generate an SQL query from a natural language query using the DeepSeek model.
30
- """
31
- inputs = tokenizer(user_input, return_tensors="pt")
32
- outputs = model.generate(**inputs, max_length=100)
33
- sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
- return sql_query
35
 
 
 
 
 
 
36
 
37
  @app.post("/chat")
38
- def chat(request: ChatRequest):
39
- user_input = request.message
40
-
 
 
41
  sql_query = generate_sql_query(user_input)
42
  print(f"Generated SQL Query: {sql_query}")
43
-
44
  return {"response": sql_query}
45
 
46
  @app.get("/")
47
  def home():
48
  return {"message": "DeepSeek SQL Query API is running"}
49
-
50
-
51
- # Run the API
52
- # if __name__ == "__main__":
53
- # import uvicorn
54
- # uvicorn.run(app, host="0.0.0.0", port=8000)
55
-
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import sqlite3
4
  import torch
5
 
6
  app = FastAPI()
7
 
8
+ # Load Model & Tokenizer
9
  MODEL_NAME = "deepseek-ai/deepseek-coder-1.3b-instruct"
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
11
 
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
 
 
 
 
 
 
14
 
15
+ def generate_sql_query(user_input):
16
+ """ Convert natural language input into an SQL query """
17
+ inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True).to(device)
18
+ outputs = model.generate(**inputs, max_length=100, do_sample=False, num_beams=2)
19
+ return tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
20
 
21
  @app.post("/chat")
22
+ def chat(request: dict):
23
+ user_input = request.get("message", "")
24
+ if not user_input:
25
+ raise HTTPException(status_code=400, detail="Message cannot be empty")
26
+
27
  sql_query = generate_sql_query(user_input)
28
  print(f"Generated SQL Query: {sql_query}")
 
29
  return {"response": sql_query}
30
 
31
  @app.get("/")
32
  def home():
33
  return {"message": "DeepSeek SQL Query API is running"}