Callmebowoo-22 commited on
Commit
2a17517
·
verified ·
1 Parent(s): afcb330

Update utils/model.py

Browse files
Files changed (1) hide show
  1. utils/model.py +25 -45
utils/model.py CHANGED
@@ -1,47 +1,35 @@
1
- from transformers import pipeline, AutoModelForTimeSeriesPrediction
2
  import torch
3
  import numpy as np
4
- import pandas as pd
5
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
8
  def predict_umkm(data):
9
  try:
10
- # ===== 1. Validasi Input =====
11
- if not isinstance(data, pd.DataFrame):
12
- data = pd.DataFrame(data)
13
-
14
- required_cols = ['demand', 'supply', 'harga']
15
- for col in required_cols:
16
- if col not in data.columns:
17
- raise ValueError(f"Kolom {col} tidak ditemukan")
18
 
19
  # ===== 2. GRANITE-TTM Forecasting =====
20
- # Konversi data ke tensor dengan format khusus
21
- values = torch.tensor(
22
- data['demand'].astype(float).values,
23
- dtype=torch.float32
24
- ).unsqueeze(0).unsqueeze(-1).to(device) # Shape: [1, seq_len, 1]
25
-
26
- # Load model (pastikan nama model benar)
27
- ttm_model = AutoModelForTimeSeriesPrediction.from_pretrained(
28
  "ibm/granite-timeseries-ttm-r2",
29
  trust_remote_code=True
30
  ).to(device)
31
 
32
- # Generate predictions
33
  with torch.no_grad():
34
- predictions = ttm_model.generate(
35
- inputs=values,
36
- max_length=min(7, len(data)+3) # Prediksi maks 7 hari
37
- )
38
-
39
- # Post-processing output
40
- demand_pred = predictions.cpu().numpy().squeeze().tolist()
41
- if isinstance(demand_pred, float):
42
- demand_pred = [demand_pred] # Convert single value to list
43
 
44
- # ===== 3. Chronos-T5 Decision =====
45
  chronos = pipeline(
46
  "text-generation",
47
  model="amazon/chronos-t5-small",
@@ -50,26 +38,18 @@ def predict_umkm(data):
50
 
51
  prompt = f"""
52
  [INSTRUCTION]
53
- Berikan rekomendasi inventory untuk UMKM berdasarkan:
54
- - Prediksi demand 7 hari: {demand_pred[:7]}
55
- - Stok saat ini: {data['supply'].iloc[-1]:.0f}
56
- - Trend: {'naik' if demand_pred[-1] > demand_pred[0] else 'turun'}
57
 
58
  [FORMAT]
59
- - Gunakan bahasa Indonesia
60
- - Maksimal 1 kalimat
61
- - Sertakan angka konkret
62
  [/FORMAT]
63
  """
64
 
65
- recommendation = chronos(
66
- prompt,
67
- max_new_tokens=50,
68
- do_sample=True,
69
- temperature=0.7
70
- )[0]['generated_text']
71
-
72
- return recommendation.split("[/FORMAT]")[-1].strip()
73
 
74
  except Exception as e:
75
- return f"🚨 Error: {str(e)}"
 
1
+ from transformers import AutoModelForTimeSeriesPrediction, pipeline
2
  import torch
3
  import numpy as np
 
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
6
 
7
  def predict_umkm(data):
8
  try:
9
+ # ===== 1. Validasi Data =====
10
+ demand_values = data['demand'].values.astype(float)
11
+ if len(demand_values) < 3:
12
+ raise ValueError("Data historis terlalu pendek (min 3 titik)")
 
 
 
 
13
 
14
  # ===== 2. GRANITE-TTM Forecasting =====
15
+ # Format input khusus untuk model time series IBM
16
+ inputs = {
17
+ "past_values": torch.tensor(demand_values, dtype=torch.float32).unsqueeze(0).to(device),
18
+ "static_categorical_features": torch.zeros(1, 1, dtype=torch.long).to(device)
19
+ }
20
+
21
+ # Load model dengan config yang benar
22
+ model = AutoModelForTimeSeriesPrediction.from_pretrained(
23
  "ibm/granite-timeseries-ttm-r2",
24
  trust_remote_code=True
25
  ).to(device)
26
 
27
+ # Generate prediksi
28
  with torch.no_grad():
29
+ outputs = model(**inputs)
30
+ predictions = outputs.last_hidden_state.mean(dim=1).squeeze()
 
 
 
 
 
 
 
31
 
32
+ # ===== 3. Format untuk Chronos-T5 =====
33
  chronos = pipeline(
34
  "text-generation",
35
  model="amazon/chronos-t5-small",
 
38
 
39
  prompt = f"""
40
  [INSTRUCTION]
41
+ Berikan rekomendasi stok untuk 7 hari ke depan berdasarkan:
42
+ - Prediksi demand: {predictions.cpu().numpy().tolist()[:7]}
43
+ - Stok saat ini: {data['supply'].iloc[-1]}
44
+ - Tren: {'' if predictions[-1] > predictions[0] else ''}
45
 
46
  [FORMAT]
47
+ 1 kalimat dengan angka spesifik
 
 
48
  [/FORMAT]
49
  """
50
 
51
+ result = chronos(prompt, max_new_tokens=50)[0]['generated_text']
52
+ return result.split("[/FORMAT]")[-1].strip()
 
 
 
 
 
 
53
 
54
  except Exception as e:
55
+ return f"⚠️ Kesalahan sistem: {str(e)}"