File size: 9,724 Bytes
4864d22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import os
import gc
import torch
import asyncio
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig
)
import uvicorn
from duckduckgo_search import ddg  # pip install duckduckgo-search

# Nombre del modelo fijo
MODEL_NAME = "lilmeaty/my_xdd"

# Variables globales para almacenar el modelo y el tokenizador
global_model = None
global_tokenizer = None

# Funci贸n as铆ncrona para limpiar la memoria (RAM y cach茅 CUDA)
async def cleanup_memory(device: str):
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()
    await asyncio.sleep(0.01)

# Request para la generaci贸n de texto
class GenerateRequest(BaseModel):
    input_text: str = ""
    max_new_tokens: int = 200  # l铆mite total de tokens generados (puede ser muy alto)
    temperature: float = 1.0
    top_p: float = 1.0
    top_k: int = 50
    repetition_penalty: float = 1.0
    do_sample: bool = True
    stream: bool = False
    # L铆mite de tokens por chunk en modo streaming (si se excede, se emite el chunk acumulado)
    chunk_token_limit: int = 20
    # Secuencias que, si se detectan, hacen que se detenga la generaci贸n
    stop_sequences: list[str] = []
    # Si se desea incluir Duckasgo en la respuesta final
    include_duckasgo: bool = False

# Request para b煤squedas independientes con Duckasgo
class DuckasgoRequest(BaseModel):
    query: str

# Inicializar la aplicaci贸n FastAPI
app = FastAPI()

# Evento de startup: cargar el modelo y tokenizador globalmente
@app.on_event("startup")
async def load_global_model():
    global global_model, global_tokenizer
    try:
        config = AutoConfig.from_pretrained(MODEL_NAME)
        global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
        # Se usa torch_dtype=torch.float16 para reducir la huella de memoria (si es posible)
        global_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, torch_dtype=torch.float16)
        # Configurar token de padding si es necesario
        if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
            global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id
        device = "cuda" if torch.cuda.is_available() else "cpu"
        global_model.to(device)
        print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
    except Exception as e:
        print("Error al cargar el modelo:", e)

# Funci贸n para realizar b煤squeda con Duckasgo de forma as铆ncrona
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
    results = await asyncio.to_thread(ddg, query, max_results=max_results)
    if not results:
        return "No se encontraron resultados en Duckasgo."
    summary = "\nResultados de b煤squeda (Duckasgo):\n"
    for idx, res in enumerate(results, start=1):
        title = res.get("title", "Sin t铆tulo")
        url = res.get("href", "Sin URL")
        snippet = res.get("body", "")
        summary += f"{idx}. {title}\n   URL: {url}\n   {snippet}\n"
    return summary

# Funci贸n para generar texto en modo streaming, dividiendo en chunks ilimitados
# y deteniendo la generaci贸n si se detectan secuencias de parada.
async def stream_text(request: GenerateRequest, device: str):
    global global_model, global_tokenizer

    # Codificar la entrada y obtener la longitud inicial
    encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
    initial_input_len = encoded_input.input_ids.shape[-1]
    input_ids = encoded_input.input_ids

    # Variables para acumular texto
    accumulated_text = ""  # todo el texto generado
    current_chunk = ""     # chunk actual
    chunk_token_count = 0

    # Configurar la generaci贸n seg煤n los par谩metros recibidos
    gen_config = GenerationConfig(
        temperature=request.temperature,
        max_new_tokens=request.max_new_tokens,
        top_p=request.top_p,
        top_k=request.top_k,
        repetition_penalty=request.repetition_penalty,
        do_sample=request.do_sample,
    )

    past_key_values = None

    # Usamos un bucle while para permitir generaci贸n ilimitada en chunks
    for _ in range(request.max_new_tokens):
        with torch.no_grad():
            outputs = global_model(
                input_ids,
                past_key_values=past_key_values,
                use_cache=True,
                return_dict=True
            )
        logits = outputs.logits[:, -1, :]
        past_key_values = outputs.past_key_values

        if gen_config.do_sample:
            logits = logits / gen_config.temperature
            if gen_config.top_k and gen_config.top_k > 0:
                topk_values, _ = torch.topk(logits, k=gen_config.top_k)
                logits[logits < topk_values[:, [-1]]] = -float('Inf')
            probs = torch.nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
        else:
            next_token = torch.argmax(logits, dim=-1, keepdim=True)

        token_id = next_token.item()
        token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)

        accumulated_text += token_text
        current_chunk += token_text
        chunk_token_count += 1

        # Verificar si se alcanz贸 alguna de las secuencias de parada
        if request.stop_sequences:
            for stop_seq in request.stop_sequences:
                if stop_seq in accumulated_text:
                    # Si se detecta el stop, emitir el chunk actual y terminar
                    yield current_chunk
                    await cleanup_memory(device)
                    return

        # Si se supera el l铆mite de tokens por chunk, enviar el chunk acumulado
        if chunk_token_count >= request.chunk_token_limit:
            yield current_chunk
            current_chunk = ""
            chunk_token_count = 0

        # Permitir que otras tareas se ejecuten
        await asyncio.sleep(0)

        # Actualizar el input para la siguiente iteraci贸n
        input_ids = next_token

        # Si se ha generado el token de finalizaci贸n, se emite el chunk y se termina
        if token_id == global_tokenizer.eos_token_id:
            break

    # Emitir el 煤ltimo chunk (si no est谩 vac铆o)
    if current_chunk:
        yield current_chunk

    # Si se solicit贸 incluir Duckasgo, realizar la b煤squeda y agregarla como chunk final
    if request.include_duckasgo:
        search_summary = await perform_duckasgo_search(request.input_text)
        yield "\n" + search_summary

    await cleanup_memory(device)

# Endpoint para la generaci贸n de texto (modo streaming o no-streaming)
@app.post("/generate")
async def generate_text(request: GenerateRequest):
    global global_model, global_tokenizer
    if global_model is None or global_tokenizer is None:
        raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    gen_config = GenerationConfig(
        temperature=request.temperature,
        max_new_tokens=request.max_new_tokens,
        top_p=request.top_p,
        top_k=request.top_k,
        repetition_penalty=request.repetition_penalty,
        do_sample=request.do_sample,
    )

    try:
        if request.stream:
            # Modo streaming: se env铆an m煤ltiples chunks seg煤n se generan
            generator = stream_text(request, device)
            return StreamingResponse(generator, media_type="text/plain")
        else:
            # Modo no-streaming: generaci贸n completa en una sola respuesta
            encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
            with torch.no_grad():
                output = global_model.generate(
                    **encoded_input,
                    generation_config=gen_config,
                    return_dict_in_generate=True,
                    output_scores=True
                )
            input_length = encoded_input["input_ids"].shape[-1]
            generated_text = global_tokenizer.decode(
                output.sequences[0][input_length:],
                skip_special_tokens=True
            )
            # Si se han definido secuencias de parada, se corta la respuesta en el primer match
            if request.stop_sequences:
                for stop_seq in request.stop_sequences:
                    if stop_seq in generated_text:
                        generated_text = generated_text.split(stop_seq)[0]
                        break
            # Si se solicit贸 incluir Duckasgo, se realiza la b煤squeda y se agrega al final
            if request.include_duckasgo:
                search_summary = await perform_duckasgo_search(request.input_text)
                generated_text += "\n" + search_summary
            await cleanup_memory(device)
            return {"generated_text": generated_text}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error durante la generaci贸n: {e}")

# Endpoint independiente para b煤squedas con Duckasgo
@app.post("/duckasgo")
async def duckasgo_search(request: DuckasgoRequest):
    try:
        results = await asyncio.to_thread(ddg, request.query, max_results=10)
        if results is None:
            results = []
        return JSONResponse(content={"query": request.query, "results": results})
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error en la b煤squeda: {e}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)