File size: 4,545 Bytes
0fd380e
 
d11e1fe
 
bec7021
dbd9820
42cc3fb
1737ef1
0fd380e
42cc3fb
 
037a839
11c5d58
 
42cc3fb
ec7f6a1
 
 
 
 
42cc3fb
d11e1fe
 
42cc3fb
7f617c9
ec7f6a1
 
 
 
 
42cc3fb
ec7f6a1
7f617c9
ec7f6a1
 
42cc3fb
ec7f6a1
 
 
 
42cc3fb
ec7f6a1
 
42cc3fb
d11e1fe
 
0fd380e
 
2583cf2
 
 
 
 
42cc3fb
d11e1fe
 
ec7f6a1
0fd380e
2c6bd00
42cc3fb
ec7f6a1
ecd203a
ec7f6a1
 
 
 
 
 
63ed81d
0fd380e
ecd203a
42cc3fb
ec7f6a1
 
 
037a839
42cc3fb
 
1737ef1
 
d11e1fe
42cc3fb
359c625
63ed81d
359c625
 
 
 
0fd380e
 
42cc3fb
359c625
0fd380e
 
 
 
2c6bd00
42cc3fb
bec7021
 
 
 
 
 
ec7f6a1
42cc3fb
bec7021
5ac0022
d11e1fe
0fd380e
d11e1fe
d00f6f0
42cc3fb
d00f6f0
 
d5e5243
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# ✅ API FastAPI de chunking sémantique intelligent avec fallback automatique

from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional

# ✅ Modules LlamaIndex (version >= 0.10.0)
from llama_index.core import Document
from llama_index.core.settings import Settings
from llama_index.core.node_parser import SemanticSplitterNodeParser
from llama_index.core.text_splitter import RecursiveTextSplitter
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.core.base.llms.base import BaseLLM

# ✅ Embedding local (transformers + torch)
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import os

# ✅ Initialisation de l'app FastAPI
app = FastAPI()

# ✅ Configuration du cache Hugging Face (important pour HF Spaces)
CACHE_DIR = "/app/cache"
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_MODULES_CACHE"] = CACHE_DIR
os.environ["HF_HUB_CACHE"] = CACHE_DIR

# ✅ Choix du modèle d'embedding dense (ex : BGE-small)
MODEL_NAME = "BAAI/bge-small-en-v1.5"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
model = AutoModel.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)

# ✅ Fonction d'embedding normalisé (vectorisation dense)
def get_embedding(text: str):
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0]  # On prend le token [CLS]
        return F.normalize(embeddings, p=2, dim=1).squeeze().tolist()

# ✅ Format des données envoyées à l’API
class ChunkRequest(BaseModel):
    text: str
    max_tokens: Optional[int] = 1000
    overlap: Optional[int] = 350
    source_id: Optional[str] = None
    titre: Optional[str] = None
    source: Optional[str] = None
    type: Optional[str] = None

# ✅ Route de l’API pour le chunking sémantique
@app.post("/chunk")
async def chunk_text(data: ChunkRequest):
    try:
        print(f"\n✅ Texte reçu ({len(data.text)} caractères) : {data.text[:200]}...", flush=True)

        # ✅ Chargement du modèle GGUF distant avec LlamaCPP (CPU friendly)
        llm = LlamaCPP(
            model_url="https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF/resolve/main/codellama-7b-instruct.Q4_K_M.gguf",
            temperature=0.1,
            max_new_tokens=512,
            context_window=2048,
            generate_kwargs={"top_p": 0.95},
            model_kwargs={"n_gpu_layers": 1},
        )

        print("✅ Modèle LLM chargé avec succès !")

        # ✅ Wrapper embedding compatible avec LlamaIndex
        class SimpleEmbedding:
            def get_text_embedding(self, text: str):
                return get_embedding(text)

        # ✅ Configuration globale de LlamaIndex
        assert isinstance(llm, BaseLLM), "❌ L’objet LLM n’est pas compatible avec LlamaIndex"
        Settings.llm = llm
        Settings.embed_model = SimpleEmbedding()

        print("✅ Configuration du LLM et de l'embedding terminée. On initialise le Semantic Splitter...")

        parser = SemanticSplitterNodeParser.from_defaults(llm=llm)
        doc = Document(text=data.text)

        try:
            nodes = parser.get_nodes_from_documents([doc])
            print(f"✅ Semantic Splitter : {len(nodes)} chunks générés")
            if not nodes:
                raise ValueError("Aucun chunk produit par SemanticSplitter")
        except Exception as e:
            print(f"⚠️ Fallback vers RecursiveTextSplitter suite à : {e}")
            splitter = RecursiveTextSplitter(chunk_size=data.max_tokens, chunk_overlap=data.overlap)
            nodes = splitter.get_nodes_from_documents([doc])
            print(f"♻️ Recursive Splitter : {len(nodes)} chunks générés")

        # ✅ Résultat structuré pour n8n ou autre client HTTP
        return {
            "chunks": [node.text for node in nodes],
            "metadatas": [node.metadata for node in nodes],
            "source_id": data.source_id,
            "titre": data.titre,
            "source": data.source,
            "type": data.type,
            "error": None  # ← utilisé par n8n pour signaler "pas d'erreur"
        }

    except Exception as e:
        print(f"❌ Erreur critique : {e}")
        return {"error": str(e)}

# ✅ Lancement local (facultatif pour HF Spaces)
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)