habulaj commited on
Commit
aa6da07
·
verified ·
1 Parent(s): c407c3e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Query, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer
4
+ from peft import AutoPeftModelForCausalLM
5
+ import torch
6
+ import re
7
+ import json
8
+ import os
9
+
10
+ # -------- CONFIG --------
11
+ MODEL_NAME = "habulaj/filter"
12
+ DEVICE = "cpu"
13
+ DTYPE = torch.float32
14
+
15
+ # -------- LOAD MODEL --------
16
+ print("🔁 Loading model and tokenizer...")
17
+ try:
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
+
20
+ # Set pad_token if not exists
21
+ if tokenizer.pad_token is None:
22
+ tokenizer.pad_token = tokenizer.eos_token
23
+
24
+ model = AutoPeftModelForCausalLM.from_pretrained(
25
+ MODEL_NAME,
26
+ device_map=DEVICE,
27
+ torch_dtype=DTYPE,
28
+ trust_remote_code=True,
29
+ )
30
+ model.eval()
31
+ print("✅ Model loaded successfully.")
32
+
33
+ except Exception as e:
34
+ print(f"❌ Error loading model: {e}")
35
+ model = None
36
+ tokenizer = None
37
+
38
+ # -------- FASTAPI --------
39
+ app = FastAPI(title="News Filter JSON API")
40
+
41
+ # -------- ROOT ENDPOINT --------
42
+ @app.get("/")
43
+ def read_root():
44
+ return {
45
+ "message": "News Filter JSON API is running!",
46
+ "model_loaded": model is not None,
47
+ "docs": "/docs",
48
+ "endpoints": ["/filter", "/health"]
49
+ }
50
+
51
+ # -------- INFERENCE FUNCTION --------
52
+ def generate_json_filter(title: str, content: str) -> str:
53
+ if model is None or tokenizer is None:
54
+ raise ValueError("Model not loaded")
55
+
56
+ prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
57
+
58
+ Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
59
+
60
+ Title: "{title}"
61
+ Content: "{content}"
62
+ """
63
+
64
+ try:
65
+ inputs = tokenizer(
66
+ prompt,
67
+ return_tensors="pt",
68
+ padding=True,
69
+ truncation=True,
70
+ max_length=512
71
+ )
72
+ input_ids = inputs["input_ids"].to(DEVICE)
73
+ attention_mask = inputs["attention_mask"].to(DEVICE)
74
+
75
+ with torch.no_grad():
76
+ outputs = model.generate(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ max_new_tokens=128,
80
+ temperature=1.2,
81
+ top_p=0.9,
82
+ do_sample=True,
83
+ eos_token_id=tokenizer.eos_token_id,
84
+ pad_token_id=tokenizer.pad_token_id,
85
+ )
86
+
87
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
88
+ generated = decoded[len(prompt):].strip()
89
+
90
+ # Extrai JSON
91
+ match = re.search(r"\{.*\}", generated, re.DOTALL)
92
+ if match:
93
+ return match.group(0)
94
+
95
+ # Fallback: retorna um JSON simples se não encontrar
96
+ return '{"status": "processed", "title": "' + title[:50] + '", "content_length": ' + str(len(content)) + '}'
97
+
98
+ except Exception as e:
99
+ raise ValueError(f"Error during generation: {str(e)}")
100
+
101
+ # -------- API ROUTE --------
102
+ @app.get("/filter")
103
+ def get_filter(
104
+ title: str = Query(..., description="Title of the news"),
105
+ content: str = Query(..., description="Content of the news")
106
+ ):
107
+ try:
108
+ json_output = generate_json_filter(title, content)
109
+ # Use json.loads instead of eval for safety
110
+ parsed_json = json.loads(json_output)
111
+ return {"filter": parsed_json}
112
+ except json.JSONDecodeError as e:
113
+ return {"error": "Invalid JSON generated", "raw_output": json_output}
114
+ except Exception as e:
115
+ raise HTTPException(status_code=422, detail=str(e))
116
+
117
+ # -------- HEALTH CHECK --------
118
+ @app.get("/health")
119
+ def health_check():
120
+ return {
121
+ "status": "healthy",
122
+ "model_loaded": model is not None,
123
+ "device": DEVICE,
124
+ "torch_version": torch.__version__
125
+ }