KJ24 commited on
Commit
0fd380e
·
verified ·
1 Parent(s): 11c5d58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -37
app.py CHANGED
@@ -1,34 +1,34 @@
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from typing import Optional
4
 
5
- # ✅ Modules LlamaIndex – version >= 0.10.0+
6
- from llama_index.core.settings import Settings
7
  from llama_index.core import Document
 
 
8
  from llama_index.llms.llama_cpp import LlamaCPP
9
- from llama_index.core.node_parser import SemanticSplitterNodeParser
10
-
11
  from llama_index.core.base.llms.base import BaseLLM
12
 
13
- # Pour l'embedding LOCAL via transformers
14
  from transformers import AutoTokenizer, AutoModel
15
  import torch
16
  import torch.nn.functional as F
 
17
  import os
18
 
19
  app = FastAPI()
20
 
21
- # ✅ Configuration locale du cache HF pour Hugging Face
22
- # ✅ Définir un chemin autorisé pour le cache (à l'intérieur du container Hugging Face)
23
  CACHE_DIR = "/app/cache"
24
  os.environ["HF_HOME"] = CACHE_DIR
25
  os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
26
  os.environ["HF_MODULES_CACHE"] = CACHE_DIR
27
  os.environ["HF_HUB_CACHE"] = CACHE_DIR
28
 
29
- # ✅ Configuration du modèle dembedding local (ex: BGE / Nomic / GTE etc.)
30
  MODEL_NAME = "BAAI/bge-small-en-v1.5"
31
-
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
33
  model = AutoModel.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
34
 
@@ -36,12 +36,14 @@ def get_embedding(text: str):
36
  with torch.no_grad():
37
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
38
  outputs = model(**inputs)
39
- embeddings = outputs.last_hidden_state[:, 0]
40
  return F.normalize(embeddings, p=2, dim=1).squeeze().tolist()
41
 
42
- # ✅ Données entrantes du POST
43
  class ChunkRequest(BaseModel):
44
  text: str
 
 
45
  source_id: Optional[str] = None
46
  titre: Optional[str] = None
47
  source: Optional[str] = None
@@ -50,11 +52,9 @@ class ChunkRequest(BaseModel):
50
  @app.post("/chunk")
51
  async def chunk_text(data: ChunkRequest):
52
  try:
53
- # Vérification du texte reçu
54
- print(f"✅ Texte reçu ({len(data.text)} caractères) : {data.text[:200]}...")
55
- print("✅ ✔️ Reçu – On passe à la configuration du modèle LLM...")
56
 
57
- # ✅ Chargement du modèle LLM depuis Hugging Face (GGUF distant)
58
  llm = LlamaCPP(
59
  model_url="https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF/resolve/main/codellama-7b-instruct.Q4_K_M.gguf",
60
  temperature=0.1,
@@ -64,44 +64,36 @@ async def chunk_text(data: ChunkRequest):
64
  model_kwargs={"n_gpu_layers": 1},
65
  )
66
 
 
67
 
68
-
69
- print("✅✅ Le modèle CodeLlama-7B-Instruct Q4_K_M a été chargé sans erreur...")
70
-
71
-
72
- print("✅ ✔️ Modèle LLM chargé sans erreur on continue...")
73
-
74
- # ✅ Définition d’un wrapper simple pour l’embedding local
75
  class SimpleEmbedding:
76
  def get_text_embedding(self, text: str):
77
  return get_embedding(text)
78
 
79
- assert isinstance(llm, BaseLLM), "❌ Ce LLM n’est pas compatible avec Settings.llm"
80
-
81
- # ✅ Nouvelle configuration (⚠️ ne plus utiliser ServiceContext)
82
  Settings.llm = llm
83
  Settings.embed_model = SimpleEmbedding()
84
 
85
- print("✅ LLM et embedding configurés - prêt pour le split")
86
- print("✅ Début du split sémantique...", flush=True)
87
 
88
- # ✅ Utilisation du Semantic Splitter avec le LLM actuel
89
  parser = SemanticSplitterNodeParser.from_defaults(llm=llm)
90
-
91
  doc = Document(text=data.text)
92
 
93
  try:
94
  nodes = parser.get_nodes_from_documents([doc])
95
- print(f"✅ Nombre de chunks générés : {len(nodes)}")
96
- print(f"🧩 Exemple chunk : {nodes[0].text[:100]}...")
 
97
 
98
  except Exception as e:
99
- import traceback
100
- traceback.print_exc()
101
- print(f"❌ Erreur lors du split sémantique : {e}")
102
- return {"error": str(e)}
103
 
104
- # ✅ Résultat complet pour l’API
105
  return {
106
  "chunks": [node.text for node in nodes],
107
  "metadatas": [node.metadata for node in nodes],
@@ -109,10 +101,11 @@ async def chunk_text(data: ChunkRequest):
109
  "titre": data.titre,
110
  "source": data.source,
111
  "type": data.type,
112
- "error": None # essentiel pour que n8n voie "rien à signaler"
113
  }
114
 
115
  except Exception as e:
 
116
  return {"error": str(e)}
117
 
118
  if __name__ == "__main__":
 
1
+ # ✅ API FastAPI de chunking sémantique intelligent avec fallback automatique
2
+
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from typing import Optional
6
 
7
+ # LlamaIndex (>= 0.10.0)
 
8
  from llama_index.core import Document
9
+ from llama_index.core.settings import Settings
10
+ from llama_index.core.node_parser import SemanticSplitterNodeParser, RecursiveTextSplitter
11
  from llama_index.llms.llama_cpp import LlamaCPP
 
 
12
  from llama_index.core.base.llms.base import BaseLLM
13
 
14
+ # Embedding local (transformers + torch)
15
  from transformers import AutoTokenizer, AutoModel
16
  import torch
17
  import torch.nn.functional as F
18
+
19
  import os
20
 
21
  app = FastAPI()
22
 
23
+ # ✅ Configuration des caches pour Hugging Face dans le container
 
24
  CACHE_DIR = "/app/cache"
25
  os.environ["HF_HOME"] = CACHE_DIR
26
  os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
27
  os.environ["HF_MODULES_CACHE"] = CACHE_DIR
28
  os.environ["HF_HUB_CACHE"] = CACHE_DIR
29
 
30
+ # ✅ Modèle d'embedding local (dense vector)
31
  MODEL_NAME = "BAAI/bge-small-en-v1.5"
 
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
33
  model = AutoModel.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
34
 
 
36
  with torch.no_grad():
37
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
38
  outputs = model(**inputs)
39
+ embeddings = outputs.last_hidden_state[:, 0] # CLS token
40
  return F.normalize(embeddings, p=2, dim=1).squeeze().tolist()
41
 
42
+ # ✅ Format des données entrantes de l'API
43
  class ChunkRequest(BaseModel):
44
  text: str
45
+ max_tokens: Optional[int] = 1000
46
+ overlap: Optional[int] = 350
47
  source_id: Optional[str] = None
48
  titre: Optional[str] = None
49
  source: Optional[str] = None
 
52
  @app.post("/chunk")
53
  async def chunk_text(data: ChunkRequest):
54
  try:
55
+ print(f"\nTexte reçu ({len(data.text)} caractères) : {data.text[:200]}...", flush=True)
 
 
56
 
57
+ # ✅ Chargement du modèle GGUF distant via llama-cpp
58
  llm = LlamaCPP(
59
  model_url="https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF/resolve/main/codellama-7b-instruct.Q4_K_M.gguf",
60
  temperature=0.1,
 
64
  model_kwargs={"n_gpu_layers": 1},
65
  )
66
 
67
+ print("✅ Modèle LLM chargé avec succès !")
68
 
69
+ # ✅ Wrapper pour l'embedding local compatible avec LlamaIndex
 
 
 
 
 
 
70
  class SimpleEmbedding:
71
  def get_text_embedding(self, text: str):
72
  return get_embedding(text)
73
 
74
+ # Configuration du moteur LLM et de l'embedding dans LlamaIndex
75
+ assert isinstance(llm, BaseLLM), "❌ L'objet LLM n'est pas compatible avec Settings.llm"
 
76
  Settings.llm = llm
77
  Settings.embed_model = SimpleEmbedding()
78
 
79
+ print("✅ Configuration du LLM et de l'embedding terminée. On initialise le Semantic Splitter...", flush=True)
 
80
 
 
81
  parser = SemanticSplitterNodeParser.from_defaults(llm=llm)
 
82
  doc = Document(text=data.text)
83
 
84
  try:
85
  nodes = parser.get_nodes_from_documents([doc])
86
+ print(f"✅ Semantic Splitter : {len(nodes)} chunks générés")
87
+ if not nodes:
88
+ raise ValueError("Aucun chunk produit par le Semantic Splitter")
89
 
90
  except Exception as e:
91
+ print(f"⚠️ Fallback vers RecursiveTextSplitter suite à : {e}")
92
+ splitter = RecursiveTextSplitter(chunk_size=data.max_tokens, chunk_overlap=data.overlap)
93
+ nodes = splitter.get_nodes_from_documents([doc])
94
+ print(f"♻️ Recursive Splitter : {len(nodes)} chunks générés")
95
 
96
+ # ✅ Construction de la réponse JSON pour n8n ou autre client HTTP
97
  return {
98
  "chunks": [node.text for node in nodes],
99
  "metadatas": [node.metadata for node in nodes],
 
101
  "titre": data.titre,
102
  "source": data.source,
103
  "type": data.type,
104
+ "error": None # n8n utilise cette clé pour détecter les erreurs
105
  }
106
 
107
  except Exception as e:
108
+ print(f"❌ Erreur critique : {e}")
109
  return {"error": str(e)}
110
 
111
  if __name__ == "__main__":