Spaces:
Paused
Paused
File size: 6,249 Bytes
130c728 aa38253 130c728 aa38253 3db8a39 aa81522 3db8a39 130c728 9705109 130c728 fb20557 130c728 89c250e 130c728 4ebcca2 130c728 aa38253 89c250e 130c728 aa38253 52e369f aa38253 52e369f fcce9ea aa38253 130c728 aa38253 130c728 89c250e 130c728 89c250e 130c728 89c250e 130c728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import transformers
import re
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import gradio as gr
import json
import os
import shutil
import requests
import lancedb
import pandas as pd
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "PleIAs/Pleias-Rag"
# Get Hugging Face token from environment variable
hf_token = os.environ.get('HF_TOKEN')
if not hf_token:
raise ValueError("Please set the HF_TOKEN environment variable")
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token)
model.to(device)
# Set tokenizer configuration
tokenizer.eos_token = "<|answer_end|>"
eos_token_id=tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 1
# Define variables
temperature = 0.0
max_new_tokens = 1500
top_p = 0.95
repetition_penalty = 1.0
min_new_tokens = 800
early_stopping = False
# Connect to the LanceDB database
db = lancedb.connect("content 5/lancedb_data")
table = db.open_table("sciencev4")
def hybrid_search(text):
results = table.search(text, query_type="hybrid").limit(6).to_pandas()
document = []
document_html = []
for _, row in results.iterrows():
hash_id = str(row['hash'])
title = row['section']
content = row['text']
document.append(f"<|source_id_start|>{hash_id}<|source_id_end|>\n{content}")
document_html.append(f'<div class="source" id="{hash_id}"><p><b>{hash_id}</b> : <br>{content}</div>')
document = "\n\n".join(document)
document_html = '<div id="source_listing">' + "".join(document_html) + "</div>"
return document, document_html
class CassandreChatBot:
def __init__(self, system_prompt="Tu es Appli, un asistant de recherche qui donne des responses sourcées"):
self.system_prompt = system_prompt
def predict(self, user_message):
fiches, fiches_html = hybrid_search(user_message)
detailed_prompt = f"""<|query_start|>{user_message}<|query_end|>\n### Source ###\n{fiches}\n\n<|source_analysis_start|>\n"""
# Convert inputs to tensor
input_ids = tokenizer.encode(detailed_prompt, return_tensors="pt").to(device)
attention_mask = torch.ones_like(input_ids)
try:
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
early_stopping=early_stopping,
min_new_tokens=min_new_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Only decode the new tokens by slicing from the input length
generated_text = tokenizer.decode(output[0][len(input_ids[0]):])
generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
return generated_text, fiches_html
except Exception as e:
print(f"Error during generation: {str(e)}")
import traceback
traceback.print_exc()
return None, None
def format_references(text):
ref_start_marker = '<ref text="'
ref_end_marker = '</ref>'
parts = []
current_pos = 0
ref_number = 1
while True:
start_pos = text.find(ref_start_marker, current_pos)
if start_pos == -1:
parts.append(text[current_pos:])
break
parts.append(text[current_pos:start_pos])
end_pos = text.find('">', start_pos)
if end_pos == -1:
break
ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip()
ref_text_encoded = ref_text.replace("&", "&").replace("<", "<").replace(">", ">")
ref_end_pos = text.find(ref_end_marker, end_pos)
if ref_end_pos == -1:
break
ref_id = text[end_pos + 2:ref_end_pos].strip()
tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[{ref_number}]</a></span>'
parts.append(tooltip_html)
current_pos = ref_end_pos + len(ref_end_marker)
ref_number = ref_number + 1
return ''.join(parts)
# Initialize the CassandreChatBot
cassandre_bot = CassandreChatBot()
# CSS for styling
css = """
.generation {
margin-left:2em;
margin-right:2em;
}
:target {
background-color: #CCF3DF;
}
.source {
float:left;
max-width:17%;
margin-left:2%;
}
.tooltip {
position: relative;
cursor: pointer;
font-variant-position: super;
color: #97999b;
}
.tooltip:hover::after {
content: attr(data-text);
position: absolute;
left: 0;
top: 120%;
white-space: pre-wrap;
width: 500px;
max-width: 500px;
z-index: 1;
background-color: #f9f9f9;
color: #000;
border: 1px solid #ddd;
border-radius: 5px;
padding: 5px;
display: block;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
"""
# Gradio interface
def gradio_interface(user_message):
response, sources = cassandre_bot.predict(user_message)
return response, sources
# Create Gradio app
demo = gr.Blocks(css=css)
with demo:
gr.HTML("""<h1 style="text-align:center">pleias-RAG 1.0</h1>""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3)
text_button = gr.Button("Interroger pleias-RAG")
with gr.Column(scale=3):
text_output = gr.HTML(label="La réponse de pleias-RAG")
with gr.Row():
embedding_output = gr.HTML(label="Les sources utilisées")
text_button.click(gradio_interface, inputs=text_input, outputs=[text_output, embedding_output])
# Launch the app
if __name__ == "__main__":
demo.launch() |