Update app.py
Browse files
app.py
CHANGED
@@ -65,34 +65,15 @@ def generate(req: UserInputRequest):
|
|
65 |
{"role": "user", "content": req.user_input}
|
66 |
]
|
67 |
|
68 |
-
|
69 |
-
chat_template_raw = tokenizer.apply_chat_template(
|
70 |
messages,
|
71 |
add_generation_prompt=True,
|
72 |
-
return_tensors=
|
73 |
-
)
|
74 |
-
|
75 |
-
# Güvenlik: None veya beklenmedik tip gelirse zorla string'e çevir
|
76 |
-
if chat_template_raw is None:
|
77 |
-
chat_template_str = ""
|
78 |
-
elif isinstance(chat_template_raw, str):
|
79 |
-
chat_template_str = chat_template_raw
|
80 |
-
else:
|
81 |
-
chat_template_str = str(chat_template_raw)
|
82 |
-
|
83 |
-
# === Sonra tokenizer() ile input_ids + attention_mask hazırla
|
84 |
-
tokenized_inputs = tokenizer(
|
85 |
-
chat_template_str,
|
86 |
-
return_tensors="pt",
|
87 |
-
padding=True
|
88 |
).to(model.device)
|
89 |
|
90 |
-
|
91 |
-
attention_mask = tokenized_inputs['attention_mask']
|
92 |
-
|
93 |
-
input_len = input_ids.shape[-1]
|
94 |
total_ctx = model.config.max_position_embeddings if hasattr(model.config, 'max_position_embeddings') else 4096
|
95 |
-
max_new_tokens = max(1, total_ctx - input_len)
|
96 |
|
97 |
log(f"ℹ️ Input uzunluğu: {input_len}, max_new_tokens ayarlandı: {max_new_tokens}")
|
98 |
|
@@ -102,8 +83,7 @@ def generate(req: UserInputRequest):
|
|
102 |
]
|
103 |
|
104 |
outputs = model.generate(
|
105 |
-
input_ids=
|
106 |
-
attention_mask=attention_mask,
|
107 |
max_new_tokens=max_new_tokens,
|
108 |
eos_token_id=terminators
|
109 |
)
|
|
|
65 |
{"role": "user", "content": req.user_input}
|
66 |
]
|
67 |
|
68 |
+
chat_input = tokenizer.apply_chat_template(
|
|
|
69 |
messages,
|
70 |
add_generation_prompt=True,
|
71 |
+
return_tensors="pt"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
).to(model.device)
|
73 |
|
74 |
+
input_len = chat_input.shape[-1]
|
|
|
|
|
|
|
75 |
total_ctx = model.config.max_position_embeddings if hasattr(model.config, 'max_position_embeddings') else 4096
|
76 |
+
max_new_tokens = min(512, max(1, total_ctx - input_len))
|
77 |
|
78 |
log(f"ℹ️ Input uzunluğu: {input_len}, max_new_tokens ayarlandı: {max_new_tokens}")
|
79 |
|
|
|
83 |
]
|
84 |
|
85 |
outputs = model.generate(
|
86 |
+
input_ids=chat_input,
|
|
|
87 |
max_new_tokens=max_new_tokens,
|
88 |
eos_token_id=terminators
|
89 |
)
|