Update utils/model.py
Browse files- utils/model.py +48 -19
utils/model.py
CHANGED
@@ -1,32 +1,47 @@
|
|
1 |
from transformers import pipeline, AutoModelForTimeSeriesPrediction
|
2 |
import torch
|
3 |
import numpy as np
|
|
|
4 |
|
5 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
6 |
|
7 |
def predict_umkm(data):
|
8 |
try:
|
9 |
-
# ===== 1.
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
#
|
15 |
-
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
# Load model
|
18 |
ttm_model = AutoModelForTimeSeriesPrediction.from_pretrained(
|
19 |
-
"ibm/granite-
|
|
|
20 |
).to(device)
|
21 |
|
22 |
-
#
|
23 |
with torch.no_grad():
|
24 |
-
predictions = ttm_model.generate(
|
|
|
|
|
|
|
25 |
|
26 |
-
#
|
27 |
-
demand_pred = predictions.cpu().numpy().
|
|
|
|
|
28 |
|
29 |
-
# =====
|
30 |
chronos = pipeline(
|
31 |
"text-generation",
|
32 |
model="amazon/chronos-t5-small",
|
@@ -34,13 +49,27 @@ def predict_umkm(data):
|
|
34 |
)
|
35 |
|
36 |
prompt = f"""
|
37 |
-
|
38 |
-
|
39 |
-
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
"""
|
42 |
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
except Exception as e:
|
46 |
-
return f"
|
|
|
1 |
from transformers import pipeline, AutoModelForTimeSeriesPrediction
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
|
6 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
7 |
|
8 |
def predict_umkm(data):
|
9 |
try:
|
10 |
+
# ===== 1. Validasi Input =====
|
11 |
+
if not isinstance(data, pd.DataFrame):
|
12 |
+
data = pd.DataFrame(data)
|
13 |
+
|
14 |
+
required_cols = ['demand', 'supply']
|
15 |
+
for col in required_cols:
|
16 |
+
if col not in data.columns:
|
17 |
+
raise ValueError(f"Kolom {col} tidak ditemukan")
|
18 |
|
19 |
+
# ===== 2. GRANITE-TTM Forecasting =====
|
20 |
+
# Konversi data ke tensor dengan format khusus
|
21 |
+
values = torch.tensor(
|
22 |
+
data['demand'].astype(float).values,
|
23 |
+
dtype=torch.float32
|
24 |
+
).unsqueeze(0).unsqueeze(-1).to(device) # Shape: [1, seq_len, 1]
|
25 |
|
26 |
+
# Load model (pastikan nama model benar)
|
27 |
ttm_model = AutoModelForTimeSeriesPrediction.from_pretrained(
|
28 |
+
"ibm/granite-ttm-r2",
|
29 |
+
trust_remote_code=True
|
30 |
).to(device)
|
31 |
|
32 |
+
# Generate predictions
|
33 |
with torch.no_grad():
|
34 |
+
predictions = ttm_model.generate(
|
35 |
+
inputs=values,
|
36 |
+
max_length=min(7, len(data)+3) # Prediksi maks 7 hari
|
37 |
+
)
|
38 |
|
39 |
+
# Post-processing output
|
40 |
+
demand_pred = predictions.cpu().numpy().squeeze().tolist()
|
41 |
+
if isinstance(demand_pred, float):
|
42 |
+
demand_pred = [demand_pred] # Convert single value to list
|
43 |
|
44 |
+
# ===== 3. Chronos-T5 Decision =====
|
45 |
chronos = pipeline(
|
46 |
"text-generation",
|
47 |
model="amazon/chronos-t5-small",
|
|
|
49 |
)
|
50 |
|
51 |
prompt = f"""
|
52 |
+
[INSTRUCTION]
|
53 |
+
Berikan rekomendasi inventory untuk UMKM berdasarkan:
|
54 |
+
- Prediksi demand 7 hari: {demand_pred[:7]}
|
55 |
+
- Stok saat ini: {data['supply'].iloc[-1]:.0f}
|
56 |
+
- Trend: {'naik' if demand_pred[-1] > demand_pred[0] else 'turun'}
|
57 |
+
|
58 |
+
[FORMAT]
|
59 |
+
- Gunakan bahasa Indonesia
|
60 |
+
- Maksimal 1 kalimat
|
61 |
+
- Sertakan angka konkret
|
62 |
+
[/FORMAT]
|
63 |
"""
|
64 |
|
65 |
+
recommendation = chronos(
|
66 |
+
prompt,
|
67 |
+
max_new_tokens=50,
|
68 |
+
do_sample=True,
|
69 |
+
temperature=0.7
|
70 |
+
)[0]['generated_text']
|
71 |
+
|
72 |
+
return recommendation.split("[/FORMAT]")[-1].strip()
|
73 |
|
74 |
except Exception as e:
|
75 |
+
return f"🚨 Error: {str(e)}"
|