Spaces:
Paused
Paused
| import transformers | |
| import re | |
| from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
| from vllm import LLM, SamplingParams | |
| import torch | |
| import gradio as gr | |
| import json | |
| import os | |
| import shutil | |
| import requests | |
| import lancedb | |
| import pandas as pd | |
| # Define the device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Define variables | |
| temperature = 0.7 | |
| max_new_tokens = 3000 | |
| top_p = 0.95 | |
| repetition_penalty = 1.2 | |
| model_name = "PleIAs/Cassandre-RAG" | |
| # Initialize vLLM | |
| llm = LLM(model_name, max_model_len=8128) | |
| # Connect to the LanceDB database | |
| db = lancedb.connect("content/lancedb_data") | |
| table = db.open_table("scientific_documents") | |
| def hybrid_search(text): | |
| results = table.search(text, query_type="hybrid").limit(6).to_pandas() | |
| document = [] | |
| document_html = [] | |
| for _, row in results.iterrows(): | |
| hash_id = str(row['hash']) | |
| title = row['section'] | |
| #content = row['text'][:100] + "..." # Truncate the text for preview | |
| content = row['text'] | |
| document.append(f"**{hash_id}**\n{title}\n{content}") | |
| document_html.append(f'<div class="source" id="{hash_id}"><p><b>{hash_id}</b> : {title}<br>{content}</div>') | |
| document = "\n\n".join(document) | |
| document_html = '<div id="source_listing">' + "".join(document_html) + "</div>" | |
| return document, document_html | |
| class CassandreChatBot: | |
| def __init__(self, system_prompt="Tu es Cassandre, le chatbot de l'Éducation nationale qui donne des réponses sourcées."): | |
| self.system_prompt = system_prompt | |
| def predict(self, user_message): | |
| fiches, fiches_html = hybrid_search(user_message) | |
| sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty, stop=["#END#"]) | |
| detailed_prompt = f"""### Query ###\n{user_message}\n\n### Source ###\n{fiches}\n\n### Answer ###\n""" | |
| prompts = [detailed_prompt] | |
| outputs = llm.generate(prompts, sampling_params, use_tqdm=False) | |
| generated_text = outputs[0].outputs[0].text | |
| generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>" | |
| fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html | |
| return generated_text, fiches_html | |
| def format_references(text): | |
| ref_start_marker = '<ref text="' | |
| ref_end_marker = '</ref>' | |
| parts = [] | |
| current_pos = 0 | |
| ref_number = 1 | |
| while True: | |
| start_pos = text.find(ref_start_marker, current_pos) | |
| if start_pos == -1: | |
| parts.append(text[current_pos:]) | |
| break | |
| parts.append(text[current_pos:start_pos]) | |
| end_pos = text.find('">', start_pos) | |
| if end_pos == -1: | |
| break | |
| ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip() | |
| ref_text_encoded = ref_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
| ref_end_pos = text.find(ref_end_marker, end_pos) | |
| if ref_end_pos == -1: | |
| break | |
| ref_id = text[end_pos + 2:ref_end_pos].strip() | |
| tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[{ref_number}]</a></span>' | |
| parts.append(tooltip_html) | |
| current_pos = ref_end_pos + len(ref_end_marker) | |
| ref_number = ref_number + 1 | |
| return ''.join(parts) | |
| # Initialize the CassandreChatBot | |
| cassandre_bot = CassandreChatBot() | |
| # CSS for styling | |
| css = """ | |
| .generation { | |
| margin-left:2em; | |
| margin-right:2em; | |
| } | |
| :target { | |
| background-color: #CCF3DF; | |
| } | |
| .source { | |
| float:left; | |
| max-width:17%; | |
| margin-left:2%; | |
| } | |
| .tooltip { | |
| position: relative; | |
| cursor: pointer; | |
| font-variant-position: super; | |
| color: #97999b; | |
| } | |
| .tooltip:hover::after { | |
| content: attr(data-text); | |
| position: absolute; | |
| left: 0; | |
| top: 120%; | |
| white-space: pre-wrap; | |
| width: 500px; | |
| max-width: 500px; | |
| z-index: 1; | |
| background-color: #f9f9f9; | |
| color: #000; | |
| border: 1px solid #ddd; | |
| border-radius: 5px; | |
| padding: 5px; | |
| display: block; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
| } | |
| """ | |
| # Gradio interface | |
| def gradio_interface(user_message): | |
| response, sources = cassandre_bot.predict(user_message) | |
| return response, sources | |
| # Create Gradio app | |
| demo = gr.Blocks(css=css) | |
| with demo: | |
| gr.HTML("""<h1 style="text-align:center">Cassandre</h1>""") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3) | |
| text_button = gr.Button("Interroger Cassandre") | |
| with gr.Column(scale=3): | |
| text_output = gr.HTML(label="La réponse de Cassandre") | |
| with gr.Row(): | |
| embedding_output = gr.HTML(label="Les sources utilisées") | |
| text_button.click(gradio_interface, inputs=text_input, outputs=[text_output, embedding_output]) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |