File size: 4,302 Bytes
0fd380e
 
d11e1fe
 
bec7021
dbd9820
0fd380e
1737ef1
0fd380e
 
037a839
11c5d58
 
0fd380e
ec7f6a1
 
 
0fd380e
ec7f6a1
 
d11e1fe
 
0fd380e
7f617c9
ec7f6a1
 
 
 
 
0fd380e
ec7f6a1
7f617c9
ec7f6a1
 
 
 
 
 
0fd380e
ec7f6a1
 
0fd380e
d11e1fe
 
0fd380e
 
2583cf2
 
 
 
 
d11e1fe
 
ec7f6a1
0fd380e
2c6bd00
0fd380e
ec7f6a1
ecd203a
ec7f6a1
 
 
 
 
 
63ed81d
0fd380e
ecd203a
0fd380e
ec7f6a1
 
 
037a839
0fd380e
 
1737ef1
 
d11e1fe
0fd380e
359c625
63ed81d
359c625
 
 
 
0fd380e
 
 
63ed81d
359c625
0fd380e
 
 
 
2c6bd00
0fd380e
bec7021
 
 
 
 
 
ec7f6a1
0fd380e
bec7021
5ac0022
d11e1fe
0fd380e
d11e1fe
d00f6f0
 
 
d5e5243
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# ✅ API FastAPI de chunking sémantique intelligent avec fallback automatique

from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional

# LlamaIndex (>= 0.10.0)
from llama_index.core import Document
from llama_index.core.settings import Settings
from llama_index.core.node_parser import SemanticSplitterNodeParser, RecursiveTextSplitter
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.core.base.llms.base import BaseLLM

# Embedding local (transformers + torch)
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

import os

app = FastAPI()

# ✅ Configuration des caches pour Hugging Face dans le container
CACHE_DIR = "/app/cache"
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_MODULES_CACHE"] = CACHE_DIR
os.environ["HF_HUB_CACHE"] = CACHE_DIR

# ✅ Modèle d'embedding local (dense vector)
MODEL_NAME = "BAAI/bge-small-en-v1.5"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
model = AutoModel.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)

def get_embedding(text: str):
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0]  # CLS token
        return F.normalize(embeddings, p=2, dim=1).squeeze().tolist()

# ✅ Format des données entrantes de l'API
class ChunkRequest(BaseModel):
    text: str
    max_tokens: Optional[int] = 1000
    overlap: Optional[int] = 350
    source_id: Optional[str] = None
    titre: Optional[str] = None
    source: Optional[str] = None
    type: Optional[str] = None

@app.post("/chunk")
async def chunk_text(data: ChunkRequest):
    try:
        print(f"\n✅ Texte reçu ({len(data.text)} caractères) : {data.text[:200]}...", flush=True)

        # ✅ Chargement du modèle GGUF distant via llama-cpp
        llm = LlamaCPP(
            model_url="https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF/resolve/main/codellama-7b-instruct.Q4_K_M.gguf",
            temperature=0.1,
            max_new_tokens=512,
            context_window=2048,
            generate_kwargs={"top_p": 0.95},
            model_kwargs={"n_gpu_layers": 1},
        )

        print("✅ Modèle LLM chargé avec succès !")

        # ✅ Wrapper pour l'embedding local compatible avec LlamaIndex
        class SimpleEmbedding:
            def get_text_embedding(self, text: str):
                return get_embedding(text)

        # ✅ Configuration du moteur LLM et de l'embedding dans LlamaIndex
        assert isinstance(llm, BaseLLM), "❌ L'objet LLM n'est pas compatible avec Settings.llm"
        Settings.llm = llm
        Settings.embed_model = SimpleEmbedding()

        print("✅ Configuration du LLM et de l'embedding terminée. On initialise le Semantic Splitter...", flush=True)

        parser = SemanticSplitterNodeParser.from_defaults(llm=llm)
        doc = Document(text=data.text)

        try:
            nodes = parser.get_nodes_from_documents([doc])
            print(f"✅ Semantic Splitter : {len(nodes)} chunks générés")
            if not nodes:
                raise ValueError("Aucun chunk produit par le Semantic Splitter")

        except Exception as e:
            print(f"⚠️ Fallback vers RecursiveTextSplitter suite à : {e}")
            splitter = RecursiveTextSplitter(chunk_size=data.max_tokens, chunk_overlap=data.overlap)
            nodes = splitter.get_nodes_from_documents([doc])
            print(f"♻️ Recursive Splitter : {len(nodes)} chunks générés")

        # ✅ Construction de la réponse JSON pour n8n ou autre client HTTP
        return {
            "chunks": [node.text for node in nodes],
            "metadatas": [node.metadata for node in nodes],
            "source_id": data.source_id,
            "titre": data.titre,
            "source": data.source,
            "type": data.type,
            "error": None  # n8n utilise cette clé pour détecter les erreurs
        }

    except Exception as e:
        print(f"❌ Erreur critique : {e}")
        return {"error": str(e)}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)