Spaces:
Paused
Paused
import transformers | |
import re | |
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
import torch | |
import gradio as gr | |
import json | |
import os | |
import shutil | |
import requests | |
import lancedb | |
import pandas as pd | |
# Define the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_name = "PleIAs/Pleias-Rag" | |
# Get Hugging Face token from environment variable | |
hf_token = os.environ.get('HF_TOKEN') | |
if not hf_token: | |
raise ValueError("Please set the HF_TOKEN environment variable") | |
# Initialize model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token) | |
model.to(device) | |
# Set tokenizer configuration | |
tokenizer.eos_token = "<|answer_end|>" | |
eos_token_id=tokenizer.eos_token_id | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.pad_token_id = 1 | |
# Define variables | |
temperature = 0.0 | |
max_new_tokens = 1500 | |
top_p = 0.95 | |
repetition_penalty = 1.0 | |
min_new_tokens = 800 | |
early_stopping = False | |
# Connect to the LanceDB database | |
db = lancedb.connect("content 5/lancedb_data") | |
table = db.open_table("sciencev4") | |
def hybrid_search(text): | |
results = table.search(text, query_type="hybrid").limit(6).to_pandas() | |
document = [] | |
document_html = [] | |
for _, row in results.iterrows(): | |
hash_id = str(row['hash']) | |
title = row['section'] | |
content = row['text'] | |
document.append(f"**{hash_id}**\n{title}\n{content}") | |
document_html.append(f'<div class="source" id="{hash_id}"><p><b>{hash_id}</b> : {title}<br>{content}</div>') | |
document = "\n\n".join(document) | |
document_html = '<div id="source_listing">' + "".join(document_html) + "</div>" | |
return document, document_html | |
class CassandreChatBot: | |
def __init__(self, system_prompt="Tu es Appli, un asistant de recherche qui donne des responses sourcées"): | |
self.system_prompt = system_prompt | |
def predict(self, user_message): | |
fiches, fiches_html = hybrid_search(user_message) | |
detailed_prompt = f"""### Query ###\n{user_message}\n\n### Source ###\n{fiches}\n\n### Analysis ###\n""" | |
# Convert inputs to tensor | |
input_ids = tokenizer.encode(detailed_prompt, return_tensors="pt").to(device) | |
attention_mask = torch.ones_like(input_ids) | |
try: | |
# Add some debug prints | |
print("Input length:", len(input_ids[0])) | |
output = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
do_sample=False, | |
early_stopping=early_stopping, | |
min_new_tokens=min_new_tokens, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
# Add return_dict_in_generate=True to see full output info | |
return_dict_in_generate=True, | |
output_scores=True | |
) | |
# Print debug info about output | |
print("Output sequence length:", len(output.sequences[0])) | |
print("New tokens generated:", len(output.sequences[0]) - len(input_ids[0])) | |
# Try decoding only the new tokens | |
generated_text = tokenizer.decode(output.sequences[0][len(input_ids[0]):]) | |
generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>" | |
fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html | |
return generated_text, fiches_html | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
return None, None | |
def format_references(text): | |
ref_start_marker = '<ref text="' | |
ref_end_marker = '</ref>' | |
parts = [] | |
current_pos = 0 | |
ref_number = 1 | |
while True: | |
start_pos = text.find(ref_start_marker, current_pos) | |
if start_pos == -1: | |
parts.append(text[current_pos:]) | |
break | |
parts.append(text[current_pos:start_pos]) | |
end_pos = text.find('">', start_pos) | |
if end_pos == -1: | |
break | |
ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip() | |
ref_text_encoded = ref_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
ref_end_pos = text.find(ref_end_marker, end_pos) | |
if ref_end_pos == -1: | |
break | |
ref_id = text[end_pos + 2:ref_end_pos].strip() | |
tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[{ref_number}]</a></span>' | |
parts.append(tooltip_html) | |
current_pos = ref_end_pos + len(ref_end_marker) | |
ref_number = ref_number + 1 | |
return ''.join(parts) | |
# Initialize the CassandreChatBot | |
cassandre_bot = CassandreChatBot() | |
# CSS for styling | |
css = """ | |
.generation { | |
margin-left:2em; | |
margin-right:2em; | |
} | |
:target { | |
background-color: #CCF3DF; | |
} | |
.source { | |
float:left; | |
max-width:17%; | |
margin-left:2%; | |
} | |
.tooltip { | |
position: relative; | |
cursor: pointer; | |
font-variant-position: super; | |
color: #97999b; | |
} | |
.tooltip:hover::after { | |
content: attr(data-text); | |
position: absolute; | |
left: 0; | |
top: 120%; | |
white-space: pre-wrap; | |
width: 500px; | |
max-width: 500px; | |
z-index: 1; | |
background-color: #f9f9f9; | |
color: #000; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
padding: 5px; | |
display: block; | |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
} | |
""" | |
# Gradio interface | |
def gradio_interface(user_message): | |
response, sources = cassandre_bot.predict(user_message) | |
return response, sources | |
# Create Gradio app | |
demo = gr.Blocks(css=css) | |
with demo: | |
gr.HTML("""<h1 style="text-align:center">Cassandre</h1>""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3) | |
text_button = gr.Button("Interroger Cassandre") | |
with gr.Column(scale=3): | |
text_output = gr.HTML(label="La réponse de Cassandre") | |
with gr.Row(): | |
embedding_output = gr.HTML(label="Les sources utilisées") | |
text_button.click(gradio_interface, inputs=text_input, outputs=[text_output, embedding_output]) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |