Hjgugugjhuhjggg commited on
Commit
4864d22
verified
1 Parent(s): c340762

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -0
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import asyncio
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.responses import StreamingResponse, JSONResponse
7
+ from pydantic import BaseModel
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ GenerationConfig
13
+ )
14
+ import uvicorn
15
+ from duckduckgo_search import ddg # pip install duckduckgo-search
16
+
17
+ # Nombre del modelo fijo
18
+ MODEL_NAME = "lilmeaty/my_xdd"
19
+
20
+ # Variables globales para almacenar el modelo y el tokenizador
21
+ global_model = None
22
+ global_tokenizer = None
23
+
24
+ # Funci贸n as铆ncrona para limpiar la memoria (RAM y cach茅 CUDA)
25
+ async def cleanup_memory(device: str):
26
+ gc.collect()
27
+ if device == "cuda":
28
+ torch.cuda.empty_cache()
29
+ await asyncio.sleep(0.01)
30
+
31
+ # Request para la generaci贸n de texto
32
+ class GenerateRequest(BaseModel):
33
+ input_text: str = ""
34
+ max_new_tokens: int = 200 # l铆mite total de tokens generados (puede ser muy alto)
35
+ temperature: float = 1.0
36
+ top_p: float = 1.0
37
+ top_k: int = 50
38
+ repetition_penalty: float = 1.0
39
+ do_sample: bool = True
40
+ stream: bool = False
41
+ # L铆mite de tokens por chunk en modo streaming (si se excede, se emite el chunk acumulado)
42
+ chunk_token_limit: int = 20
43
+ # Secuencias que, si se detectan, hacen que se detenga la generaci贸n
44
+ stop_sequences: list[str] = []
45
+ # Si se desea incluir Duckasgo en la respuesta final
46
+ include_duckasgo: bool = False
47
+
48
+ # Request para b煤squedas independientes con Duckasgo
49
+ class DuckasgoRequest(BaseModel):
50
+ query: str
51
+
52
+ # Inicializar la aplicaci贸n FastAPI
53
+ app = FastAPI()
54
+
55
+ # Evento de startup: cargar el modelo y tokenizador globalmente
56
+ @app.on_event("startup")
57
+ async def load_global_model():
58
+ global global_model, global_tokenizer
59
+ try:
60
+ config = AutoConfig.from_pretrained(MODEL_NAME)
61
+ global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
62
+ # Se usa torch_dtype=torch.float16 para reducir la huella de memoria (si es posible)
63
+ global_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, torch_dtype=torch.float16)
64
+ # Configurar token de padding si es necesario
65
+ if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
66
+ global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+ global_model.to(device)
69
+ print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
70
+ except Exception as e:
71
+ print("Error al cargar el modelo:", e)
72
+
73
+ # Funci贸n para realizar b煤squeda con Duckasgo de forma as铆ncrona
74
+ async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
75
+ results = await asyncio.to_thread(ddg, query, max_results=max_results)
76
+ if not results:
77
+ return "No se encontraron resultados en Duckasgo."
78
+ summary = "\nResultados de b煤squeda (Duckasgo):\n"
79
+ for idx, res in enumerate(results, start=1):
80
+ title = res.get("title", "Sin t铆tulo")
81
+ url = res.get("href", "Sin URL")
82
+ snippet = res.get("body", "")
83
+ summary += f"{idx}. {title}\n URL: {url}\n {snippet}\n"
84
+ return summary
85
+
86
+ # Funci贸n para generar texto en modo streaming, dividiendo en chunks ilimitados
87
+ # y deteniendo la generaci贸n si se detectan secuencias de parada.
88
+ async def stream_text(request: GenerateRequest, device: str):
89
+ global global_model, global_tokenizer
90
+
91
+ # Codificar la entrada y obtener la longitud inicial
92
+ encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
93
+ initial_input_len = encoded_input.input_ids.shape[-1]
94
+ input_ids = encoded_input.input_ids
95
+
96
+ # Variables para acumular texto
97
+ accumulated_text = "" # todo el texto generado
98
+ current_chunk = "" # chunk actual
99
+ chunk_token_count = 0
100
+
101
+ # Configurar la generaci贸n seg煤n los par谩metros recibidos
102
+ gen_config = GenerationConfig(
103
+ temperature=request.temperature,
104
+ max_new_tokens=request.max_new_tokens,
105
+ top_p=request.top_p,
106
+ top_k=request.top_k,
107
+ repetition_penalty=request.repetition_penalty,
108
+ do_sample=request.do_sample,
109
+ )
110
+
111
+ past_key_values = None
112
+
113
+ # Usamos un bucle while para permitir generaci贸n ilimitada en chunks
114
+ for _ in range(request.max_new_tokens):
115
+ with torch.no_grad():
116
+ outputs = global_model(
117
+ input_ids,
118
+ past_key_values=past_key_values,
119
+ use_cache=True,
120
+ return_dict=True
121
+ )
122
+ logits = outputs.logits[:, -1, :]
123
+ past_key_values = outputs.past_key_values
124
+
125
+ if gen_config.do_sample:
126
+ logits = logits / gen_config.temperature
127
+ if gen_config.top_k and gen_config.top_k > 0:
128
+ topk_values, _ = torch.topk(logits, k=gen_config.top_k)
129
+ logits[logits < topk_values[:, [-1]]] = -float('Inf')
130
+ probs = torch.nn.functional.softmax(logits, dim=-1)
131
+ next_token = torch.multinomial(probs, num_samples=1)
132
+ else:
133
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
134
+
135
+ token_id = next_token.item()
136
+ token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
137
+
138
+ accumulated_text += token_text
139
+ current_chunk += token_text
140
+ chunk_token_count += 1
141
+
142
+ # Verificar si se alcanz贸 alguna de las secuencias de parada
143
+ if request.stop_sequences:
144
+ for stop_seq in request.stop_sequences:
145
+ if stop_seq in accumulated_text:
146
+ # Si se detecta el stop, emitir el chunk actual y terminar
147
+ yield current_chunk
148
+ await cleanup_memory(device)
149
+ return
150
+
151
+ # Si se supera el l铆mite de tokens por chunk, enviar el chunk acumulado
152
+ if chunk_token_count >= request.chunk_token_limit:
153
+ yield current_chunk
154
+ current_chunk = ""
155
+ chunk_token_count = 0
156
+
157
+ # Permitir que otras tareas se ejecuten
158
+ await asyncio.sleep(0)
159
+
160
+ # Actualizar el input para la siguiente iteraci贸n
161
+ input_ids = next_token
162
+
163
+ # Si se ha generado el token de finalizaci贸n, se emite el chunk y se termina
164
+ if token_id == global_tokenizer.eos_token_id:
165
+ break
166
+
167
+ # Emitir el 煤ltimo chunk (si no est谩 vac铆o)
168
+ if current_chunk:
169
+ yield current_chunk
170
+
171
+ # Si se solicit贸 incluir Duckasgo, realizar la b煤squeda y agregarla como chunk final
172
+ if request.include_duckasgo:
173
+ search_summary = await perform_duckasgo_search(request.input_text)
174
+ yield "\n" + search_summary
175
+
176
+ await cleanup_memory(device)
177
+
178
+ # Endpoint para la generaci贸n de texto (modo streaming o no-streaming)
179
+ @app.post("/generate")
180
+ async def generate_text(request: GenerateRequest):
181
+ global global_model, global_tokenizer
182
+ if global_model is None or global_tokenizer is None:
183
+ raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
184
+
185
+ device = "cuda" if torch.cuda.is_available() else "cpu"
186
+ gen_config = GenerationConfig(
187
+ temperature=request.temperature,
188
+ max_new_tokens=request.max_new_tokens,
189
+ top_p=request.top_p,
190
+ top_k=request.top_k,
191
+ repetition_penalty=request.repetition_penalty,
192
+ do_sample=request.do_sample,
193
+ )
194
+
195
+ try:
196
+ if request.stream:
197
+ # Modo streaming: se env铆an m煤ltiples chunks seg煤n se generan
198
+ generator = stream_text(request, device)
199
+ return StreamingResponse(generator, media_type="text/plain")
200
+ else:
201
+ # Modo no-streaming: generaci贸n completa en una sola respuesta
202
+ encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
203
+ with torch.no_grad():
204
+ output = global_model.generate(
205
+ **encoded_input,
206
+ generation_config=gen_config,
207
+ return_dict_in_generate=True,
208
+ output_scores=True
209
+ )
210
+ input_length = encoded_input["input_ids"].shape[-1]
211
+ generated_text = global_tokenizer.decode(
212
+ output.sequences[0][input_length:],
213
+ skip_special_tokens=True
214
+ )
215
+ # Si se han definido secuencias de parada, se corta la respuesta en el primer match
216
+ if request.stop_sequences:
217
+ for stop_seq in request.stop_sequences:
218
+ if stop_seq in generated_text:
219
+ generated_text = generated_text.split(stop_seq)[0]
220
+ break
221
+ # Si se solicit贸 incluir Duckasgo, se realiza la b煤squeda y se agrega al final
222
+ if request.include_duckasgo:
223
+ search_summary = await perform_duckasgo_search(request.input_text)
224
+ generated_text += "\n" + search_summary
225
+ await cleanup_memory(device)
226
+ return {"generated_text": generated_text}
227
+ except Exception as e:
228
+ raise HTTPException(status_code=500, detail=f"Error durante la generaci贸n: {e}")
229
+
230
+ # Endpoint independiente para b煤squedas con Duckasgo
231
+ @app.post("/duckasgo")
232
+ async def duckasgo_search(request: DuckasgoRequest):
233
+ try:
234
+ results = await asyncio.to_thread(ddg, request.query, max_results=10)
235
+ if results is None:
236
+ results = []
237
+ return JSONResponse(content={"query": request.query, "results": results})
238
+ except Exception as e:
239
+ raise HTTPException(status_code=500, detail=f"Error en la b煤squeda: {e}")
240
+
241
+ if __name__ == "__main__":
242
+ uvicorn.run(app, host="0.0.0.0", port=7860)