|
import os, torch, threading, uvicorn, time, traceback, zipfile, random |
|
from fastapi import FastAPI |
|
from fastapi.responses import HTMLResponse, JSONResponse |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from peft import PeftModel |
|
from huggingface_hub import hf_hub_download |
|
from datetime import datetime |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" |
|
os.environ["TORCH_HOME"] = "/app/.torch_cache" |
|
os.makedirs("/app/.torch_cache", exist_ok=True) |
|
|
|
|
|
MODEL_BASE = "TURKCELL/Turkcell-LLM-7b-v1" |
|
USE_FINE_TUNE = False |
|
FINE_TUNE_REPO = "UcsTurkey/trained-zips" |
|
FINE_TUNE_ZIP = "trained_model_000_009.zip" |
|
USE_SAMPLING = False |
|
CONFIDENCE_THRESHOLD = -1.5 |
|
FALLBACK_ANSWERS = [ |
|
"Bu konuda maalesef bilgim yok.", |
|
"Ne demek istediğinizi tam anlayamadım.", |
|
"Bu soruya şu an yanıt veremiyorum." |
|
] |
|
|
|
|
|
def log(message): |
|
timestamp = time.strftime("%H:%M:%S") |
|
print(f"[{timestamp}] {message}", flush=True) |
|
|
|
|
|
app = FastAPI() |
|
chat_history = [] |
|
model = None |
|
tokenizer = None |
|
|
|
class Message(BaseModel): |
|
user_input: str |
|
|
|
@app.get("/") |
|
def health(): |
|
return {"status": "ok"} |
|
|
|
@app.get("/start", response_class=HTMLResponse) |
|
def root(): |
|
return """ |
|
<html> |
|
<body> |
|
<h2>Mistral 7B Instruct Chat</h2> |
|
<textarea id="input" rows="4" cols="60" placeholder="Write your instruction..."></textarea><br> |
|
<button onclick="send()">Gönder</button><br><br> |
|
<label>Model Cevabı:</label><br> |
|
<textarea id="output" rows="10" cols="80" readonly style="white-space: pre-wrap;"></textarea> |
|
<script> |
|
async function send() { |
|
const input = document.getElementById("input").value; |
|
const res = await fetch('/chat', { |
|
method: 'POST', |
|
headers: { 'Content-Type': 'application/json' }, |
|
body: JSON.stringify({ user_input: input }) |
|
}); |
|
const data = await res.json(); |
|
document.getElementById('output').value = data.answer || data.error || 'Hata oluştu.'; |
|
} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
@app.post("/chat") |
|
def chat(msg: Message): |
|
global model, tokenizer |
|
try: |
|
if model is None or tokenizer is None: |
|
return {"error": "Model veya tokenizer henüz yüklenmedi."} |
|
|
|
user_input = msg.user_input.strip() |
|
if not user_input: |
|
return {"error": "Boş giriş"} |
|
|
|
messages = [{"role": "user", "content": user_input}] |
|
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True) |
|
if isinstance(input_ids, torch.Tensor): |
|
input_ids = input_ids.to(model.device) |
|
attention_mask = (input_ids != tokenizer.pad_token_id).long() |
|
inputs = {"input_ids": input_ids, "attention_mask": attention_mask} |
|
else: |
|
inputs = {k: v.to(model.device) for k, v in input_ids.items()} |
|
if "attention_mask" not in inputs: |
|
inputs["attention_mask"] = (inputs["input_ids"] != tokenizer.pad_token_id).long() |
|
|
|
generate_args = { |
|
"max_new_tokens": 128, |
|
"return_dict_in_generate": True, |
|
"output_scores": True, |
|
"do_sample": USE_SAMPLING, |
|
"pad_token_id": tokenizer.pad_token_id, |
|
"eos_token_id": tokenizer.eos_token_id, |
|
"renormalize_logits": True |
|
} |
|
|
|
if USE_SAMPLING: |
|
generate_args.update({ |
|
"temperature": 0.7, |
|
"top_p": 0.9, |
|
"top_k": 50 |
|
}) |
|
|
|
with torch.no_grad(): |
|
output = model.generate(**inputs, **generate_args) |
|
|
|
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True) |
|
input_text = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) |
|
answer = decoded.replace(input_text, "").strip() |
|
|
|
if output.scores and len(output.scores) > 0: |
|
first_token_score = output.scores[0][0] |
|
if torch.isnan(first_token_score).any() or torch.isinf(first_token_score).any(): |
|
log("⚠️ Geçersiz logit (NaN/Inf) tespit edildi.") |
|
return {"answer": random.choice(FALLBACK_ANSWERS)} |
|
max_score = torch.max(first_token_score).item() |
|
log(f"🔍 İlk token skoru: {max_score:.4f}") |
|
if max_score < CONFIDENCE_THRESHOLD: |
|
answer = random.choice(FALLBACK_ANSWERS) |
|
|
|
chat_history.append({"user": user_input, "bot": answer}) |
|
log(f"Soru: {user_input} → Cevap: {answer[:60]}...") |
|
return {"answer": answer, "chat_history": chat_history} |
|
|
|
except Exception as e: |
|
log(f"❌ /chat hatası: {e}") |
|
traceback.print_exc() |
|
return {"error": str(e)} |
|
|
|
def detect_env(): |
|
return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def setup_model(): |
|
global model, tokenizer |
|
try: |
|
device = detect_env() |
|
dtype = torch.float32 |
|
|
|
if USE_FINE_TUNE: |
|
log("📦 Fine-tune zip indiriliyor...") |
|
zip_path = hf_hub_download( |
|
repo_id=FINE_TUNE_REPO, |
|
filename=FINE_TUNE_ZIP, |
|
repo_type="model", |
|
token=HF_TOKEN |
|
) |
|
extract_dir = "/app/extracted" |
|
os.makedirs(extract_dir, exist_ok=True) |
|
with zipfile.ZipFile(zip_path, "r") as zip_ref: |
|
zip_ref.extractall(extract_dir) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"), use_fast=False) |
|
base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device) |
|
model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output")).to(device) |
|
else: |
|
log("🧠 Ana model indiriliyor...") |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False) |
|
model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device) |
|
|
|
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
model.eval() |
|
log("✅ Model başarıyla yüklendi.") |
|
except Exception as e: |
|
log(f"❌ Model yüklenirken hata: {e}") |
|
traceback.print_exc() |
|
|
|
def run_server(): |
|
log("🌐 Uvicorn başlatılıyor...") |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
log("===== Application Startup =====") |
|
threading.Thread(target=setup_model, daemon=True).start() |
|
threading.Thread(target=run_server, daemon=True).start() |
|
|
|
while True: |
|
time.sleep(60) |
|
|