File size: 5,436 Bytes
4c1fee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os, torch, zipfile, threading, uvicorn
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import hf_hub_download
from datetime import datetime
import random

# === Sabitler ===
HF_TOKEN = os.environ.get("HF_TOKEN")
MODEL_BASE = "mistralai/Mistral-7B-Instruct-v0.2"
FINE_TUNE_ZIP = "trained_model_000_009.zip"
FINE_TUNE_REPO = "UcsTurkey/trained-zips"
USE_SAMPLING = False
CONFIDENCE_THRESHOLD = -1.5
FALLBACK_ANSWERS = [
    "Bu konuda maalesef bilgim yok.",
    "Ne demek istediğinizi tam anlayamadım.",
    "Bu soruya şu an yanıt veremiyorum."
]

# === Log
def log(message):
    timestamp = datetime.now().strftime("%H:%M:%S")
    print(f"[{timestamp}] {message}")
    os.sys.stdout.flush()

# === FastAPI
app = FastAPI()
chat_history = []
model = None
tokenizer = None

class Message(BaseModel):
    user_input: str

@app.get("/")
def health():
    return {"status": "ok"}

@app.get("/start", response_class=HTMLResponse)
def root():
    return """
    <html>
    <body>
        <h2>Mistral 7B Chat</h2>
        <textarea id=\"input\" rows=\"4\" cols=\"60\" placeholder=\"SORU: ...\"></textarea><br>
        <button onclick=\"send()\">Gönder</button>
        <pre id=\"output\"></pre>
        <script>
        async function send() {
            const input = document.getElementById(\"input\").value;
            const res = await fetch('/chat', {
                method: 'POST',
                headers: { 'Content-Type': 'application/json' },
                body: JSON.stringify({ user_input: input })
            });
            const data = await res.json();
            document.getElementById('output').innerText = data.answer || data.error || 'Hata oluştu.';
        }
        </script>
    </body>
    </html>
    """

@app.post("/chat")
def chat(msg: Message):
    global model, tokenizer
    try:
        if model is None:
            return {"error": "Model yüklenmedi"}
        user_input = msg.user_input.strip()
        if not user_input:
            return {"error": "Boş giriş"}
        prompt = f"SORU: {user_input}\nCEVAP:"
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=128,
                do_sample=USE_SAMPLING,
                temperature=0.7 if USE_SAMPLING else None,
                top_p=0.9 if USE_SAMPLING else None,
                top_k=50 if USE_SAMPLING else None,
                return_dict_in_generate=True,
                output_scores=True,
                suppress_tokens=[tokenizer.pad_token_id]
            )
        decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
        answer = decoded[len(prompt):].strip()

        if output.scores and len(output.scores) > 0:
            first_token_score = output.scores[0][0]
            if torch.isnan(first_token_score).any() or torch.isinf(first_token_score).any():
                log("⚠️ Geçersiz logit (NaN/Inf) tespit edildi.")
                return {"answer": random.choice(FALLBACK_ANSWERS)}
            max_score = torch.max(first_token_score).item()
            log(f"🔍 İlk token skoru: {max_score:.4f}")
            if max_score < CONFIDENCE_THRESHOLD:
                answer = random.choice(FALLBACK_ANSWERS)

        chat_history.append({"user": user_input, "bot": answer})
        log(f"Soru: {user_input} → Cevap: {answer[:60]}...")
        return {"answer": answer, "chat_history": chat_history}
    except Exception as e:
        log(f"❌ /chat hatası: {e}")
        return {"error": str(e)}

def detect_env():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    supports_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
    return device, supports_bf16

def setup_model():
    global model, tokenizer
    try:
        log("📦 Zip indiriliyor...")
        zip_path = hf_hub_download(
            repo_id=FINE_TUNE_REPO,
            filename=FINE_TUNE_ZIP,
            repo_type="model",
            token=HF_TOKEN
        )
        extract_path = "/app/extracted"
        os.makedirs(extract_path, exist_ok=True)
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(extract_path)
        tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_path, "output"))
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        device, supports_bf16 = detect_env()
        dtype = torch.bfloat16 if supports_bf16 else torch.float32
        log(f"🧠 Ortam: {device.upper()}, dtype: {dtype}")
        base = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
        peft = PeftModel.from_pretrained(base, os.path.join(extract_path, "output"))
        model = peft.model.to(device)
        model.eval()
        log("✅ Model yüklendi.")
    except Exception as e:
        log(f"❌ Model setup hatası: {e}")

def run_server():
    log("🌐 Uvicorn başlatılıyor...")
    uvicorn.run(app, host="0.0.0.0", port=7860)

log("🚀 Başlatılıyor...")
threading.Thread(target=setup_model, daemon=True).start()
threading.Thread(target=run_server, daemon=True).start()
while True:
    time.sleep(60)