Callmebowoo-22 commited on
Commit
e406ffb
·
verified ·
1 Parent(s): c219b92

Update utils/model.py

Browse files
Files changed (1) hide show
  1. utils/model.py +48 -19
utils/model.py CHANGED
@@ -1,32 +1,47 @@
1
  from transformers import pipeline, AutoModelForTimeSeriesPrediction
2
  import torch
3
  import numpy as np
 
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
6
 
7
  def predict_umkm(data):
8
  try:
9
- # ===== 1. GRANITE-TTM =====
10
- # Pastikan kolom 'demand' ada dan tidak kosong
11
- if 'demand' not in data or len(data['demand']) == 0:
12
- raise ValueError("Kolom 'demand' tidak ditemukan atau kosong")
 
 
 
 
13
 
14
- # Konversi data ke format yang diterima GRANITE
15
- values = torch.tensor(data['demand'].values, dtype=torch.float32).unsqueeze(0).to(device)
 
 
 
 
16
 
17
- # Load model
18
  ttm_model = AutoModelForTimeSeriesPrediction.from_pretrained(
19
- "ibm/granite-timeseries-ttm-r2"
 
20
  ).to(device)
21
 
22
- # Prediksi (pastikan input shape [batch_size, sequence_length])
23
  with torch.no_grad():
24
- predictions = ttm_model.generate(input_values=values, max_length=7)
 
 
 
25
 
26
- # Konversi output ke list
27
- demand_pred = predictions.cpu().numpy().flatten().tolist()
 
 
28
 
29
- # ===== 2. CHRONOS-T5 =====
30
  chronos = pipeline(
31
  "text-generation",
32
  model="amazon/chronos-t5-small",
@@ -34,13 +49,27 @@ def predict_umkm(data):
34
  )
35
 
36
  prompt = f"""
37
- Data UMKM:
38
- - Prediksi 7 hari: {demand_pred}
39
- - Stok saat ini: {data['supply'].iloc[-1]}
40
- Berikan rekomendasi dalam 1 kalimat:
 
 
 
 
 
 
 
41
  """
42
 
43
- return chronos(prompt, max_length=100)[0]['generated_text']
 
 
 
 
 
 
 
44
 
45
  except Exception as e:
46
- return f"⚠️ Error: {str(e)}"
 
1
  from transformers import pipeline, AutoModelForTimeSeriesPrediction
2
  import torch
3
  import numpy as np
4
+ import pandas as pd
5
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
8
  def predict_umkm(data):
9
  try:
10
+ # ===== 1. Validasi Input =====
11
+ if not isinstance(data, pd.DataFrame):
12
+ data = pd.DataFrame(data)
13
+
14
+ required_cols = ['demand', 'supply']
15
+ for col in required_cols:
16
+ if col not in data.columns:
17
+ raise ValueError(f"Kolom {col} tidak ditemukan")
18
 
19
+ # ===== 2. GRANITE-TTM Forecasting =====
20
+ # Konversi data ke tensor dengan format khusus
21
+ values = torch.tensor(
22
+ data['demand'].astype(float).values,
23
+ dtype=torch.float32
24
+ ).unsqueeze(0).unsqueeze(-1).to(device) # Shape: [1, seq_len, 1]
25
 
26
+ # Load model (pastikan nama model benar)
27
  ttm_model = AutoModelForTimeSeriesPrediction.from_pretrained(
28
+ "ibm/granite-ttm-r2",
29
+ trust_remote_code=True
30
  ).to(device)
31
 
32
+ # Generate predictions
33
  with torch.no_grad():
34
+ predictions = ttm_model.generate(
35
+ inputs=values,
36
+ max_length=min(7, len(data)+3) # Prediksi maks 7 hari
37
+ )
38
 
39
+ # Post-processing output
40
+ demand_pred = predictions.cpu().numpy().squeeze().tolist()
41
+ if isinstance(demand_pred, float):
42
+ demand_pred = [demand_pred] # Convert single value to list
43
 
44
+ # ===== 3. Chronos-T5 Decision =====
45
  chronos = pipeline(
46
  "text-generation",
47
  model="amazon/chronos-t5-small",
 
49
  )
50
 
51
  prompt = f"""
52
+ [INSTRUCTION]
53
+ Berikan rekomendasi inventory untuk UMKM berdasarkan:
54
+ - Prediksi demand 7 hari: {demand_pred[:7]}
55
+ - Stok saat ini: {data['supply'].iloc[-1]:.0f}
56
+ - Trend: {'naik' if demand_pred[-1] > demand_pred[0] else 'turun'}
57
+
58
+ [FORMAT]
59
+ - Gunakan bahasa Indonesia
60
+ - Maksimal 1 kalimat
61
+ - Sertakan angka konkret
62
+ [/FORMAT]
63
  """
64
 
65
+ recommendation = chronos(
66
+ prompt,
67
+ max_new_tokens=50,
68
+ do_sample=True,
69
+ temperature=0.7
70
+ )[0]['generated_text']
71
+
72
+ return recommendation.split("[/FORMAT]")[-1].strip()
73
 
74
  except Exception as e:
75
+ return f"🚨 Error: {str(e)}"