File size: 4,266 Bytes
cf51ebb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import List, Dict, Literal
import json
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from app.core.config import settings
from app.models.quiz import Question
import logging

logger = logging.getLogger(__name__)

class QuizGenerator:
    def __init__(self, provider: Literal["openai", "deepseek"] = "openai"):
        self.embeddings = OpenAIEmbeddings(openai_api_key=settings.OPENAI_API_KEY)
        self.vector_store = FAISS.load_local(
            settings.VECTOR_DB_PATH, 
            self.embeddings,
            allow_dangerous_deserialization=True
        )
        
        # Configuration selon le provider
        if provider == "deepseek":
            self.llm = ChatOpenAI(
                model_name="deepseek-chat",
                temperature=0.7,
                openai_api_key=settings.DEEPSEEK_API_KEY,
                base_url="https://api.deepseek.com"
            )
        else:  # openai
            self.llm = ChatOpenAI(
                model_name=settings.MODEL_NAME,
                temperature=0.7,
                openai_api_key=settings.OPENAI_API_KEY
            )

    def clean_json_response(self, response_text: str) -> str:
        """Nettoie la réponse du LLM pour obtenir un JSON valide."""
        # Supprime les backticks markdown et le mot 'json'
        cleaned = response_text.replace('```json', '').replace('```', '').strip()
        # Supprime les espaces et sauts de ligne superflus au début et à la fin
        return cleaned.strip()


    async def generate_quiz(self, theme: str, num_questions: int = 5) -> List[Question]:
        logger.debug(f"Génération de quiz - Thème: {theme}, Nb questions: {num_questions}")

        try:
            # Récupérer le contexte pertinent
            docs = self.vector_store.similarity_search(theme, k=3)
            context = "\n".join([doc.page_content for doc in docs])

            # Template pour la génération de questions
            prompt = ChatPromptTemplate.from_template("""

            Tu es un générateur de quiz intelligent.



            Si tu as du contexte pertinent, utilise-le : {context}

            Sinon, utilise tes connaissances générales.



            Génère {num_questions} questions de quiz sur le thème: {theme}



            IMPORTANT: Réponds UNIQUEMENT avec un JSON valide sans backticks ni formatage markdown.

            Format exact attendu:

            {{

                "questions": [

                    {{

                        "question": "La question",

                        "options": ["Option A", "Option B", "Option C"],

                        "correct_answer": "La bonne réponse (qui doit être une des options)"

                    }}

                ]

            }}

            """)

            # Générer les questions
            response = await self.llm.agenerate([
                prompt.format_messages(
                    context=context,
                    theme=theme,
                    num_questions=num_questions
                )
            ])

            # Parser le JSON et créer les objets Question
            response_text = response.generations[0][0].text
            cleaned_response = self.clean_json_response(response_text)
            
            try:
                response_json = json.loads(cleaned_response)
            except json.JSONDecodeError as e:
                logger.error(f"Réponse brute: {response_text}")
                logger.error(f"Réponse nettoyée: {cleaned_response}")
                raise Exception(f"Erreur de parsing JSON: {str(e)}")
            
            questions = []
            for q in response_json["questions"]:
                questions.append(Question(
                    question=q["question"],
                    options=q["options"],
                    correct_answer=q["correct_answer"]
                ))
            
            return questions
            
        except Exception as e:
            print(f"Erreur dans generate_quiz: {str(e)}")
            raise e