File size: 4,650 Bytes
514204a
13a4ba2
ca8334f
13a4ba2
 
be28415
 
 
 
 
 
 
13a4ba2
 
 
 
 
 
 
 
 
 
 
 
be28415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13a4ba2
 
be28415
 
13a4ba2
 
cdd85c7
be28415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdd85c7
be28415
69217a0
13a4ba2
 
 
 
 
69217a0
13a4ba2
 
69217a0
13a4ba2
 
e900f04
69217a0
13a4ba2
e900f04
 
 
cdd85c7
13a4ba2
e900f04
 
 
 
69217a0
be28415
 
 
 
1a5536b
be28415
13a4ba2
be28415
 
 
 
 
 
 
 
13a4ba2
 
 
 
 
 
 
 
 
 
 
 
ca8334f
cdd85c7
a0bc52e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import streamlit as st
import json
import os
from sentence_transformers import SentenceTransformer, util
import torch
from huggingface_hub import InferenceClient
import asyncio

# Load the Hugging Face token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError("HF_TOKEN environment variable is not set. Please set it before running the application.")

# Load the processed legal code data
@st.cache_resource
def load_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

# Initialize the sentence transformer model
@st.cache_resource
def load_model():
    return SentenceTransformer('distiluse-base-multilingual-cased-v1')

async def generate_keywords(query):
    client = InferenceClient(token=HF_TOKEN)
    
    prompt = f"Na podstawie poniższego pytania, wygeneruj 3-5 słów kluczowych, które najlepiej opisują główne tematy i koncepcje prawne zawarte w pytaniu. Podaj tylko słowa kluczowe, oddzielone przecinkami.\n\nPytanie: {query}\n\nSłowa kluczowe:"
    
    response = await client.text_generation(
        "Qwen/Qwen2.5-72B-Instruct",
        prompt,
        max_new_tokens=50,
        temperature=0.3,
        top_p=0.9
    )
    
    keywords = [keyword.strip() for keyword in response.split(',')]
    return keywords

def search_relevant_chunks(keywords, chunks, model, top_k=3):
    keyword_embedding = model.encode(keywords, convert_to_tensor=True)
    chunk_embeddings = model.encode([chunk['text'] for chunk in chunks], convert_to_tensor=True)
    
    cos_scores = util.pytorch_cos_sim(keyword_embedding, chunk_embeddings)
    top_results = torch.topk(cos_scores.mean(dim=0), k=top_k)
    
    return [chunks[idx] for idx in top_results.indices]

async def generate_ai_response(query, relevant_chunks):
    client = InferenceClient(token=HF_TOKEN)
    
    context = "Kontekst prawny:\n\n"
    for chunk in relevant_chunks:
        context += f"{chunk['metadata']['nazwa']} - Artykuł {chunk['metadata']['article']}:\n"
        context += f"{chunk['text']}\n\n"

    messages = [
        {"role": "system", "content": "Jesteś asystentem prawniczym. Odpowiadaj na pytania na podstawie podanego kontekstu prawnego."},
        {"role": "user", "content": f"Kontekst: {context}\n\nPytanie: {query}"}
    ]

    response = ""
    async for token in client.text_generation(
        "Qwen/Qwen2.5-72B-Instruct",
        messages,
        max_new_tokens=2048,
        temperature=0.5,
        top_p=0.7,
        stream=True
    ):
        response += token
        yield token

def main():
    st.title("Chatbot Prawny z AI")

    # Load data and model
    data_file = "processed_kodeksy.json"
    if not os.path.exists(data_file):
        st.error(f"Plik {data_file} nie istnieje. Najpierw przetwórz dane kodeksów.")
        return

    chunks = load_data(data_file)
    model = load_model()

    # Initialize chat history
    if "messages" not in st.session_state:
        st.session_state.messages = []

    # Display chat history
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    # User input
    if prompt := st.chat_input("Zadaj pytanie dotyczące prawa..."):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        # Generate keywords and search for relevant chunks
        with st.spinner("Analizuję pytanie i szukam odpowiednich informacji..."):
            keywords = asyncio.run(generate_keywords(prompt))
            relevant_chunks = search_relevant_chunks(keywords, chunks, model)

        # Generate AI response
        with st.chat_message("assistant"):
            message_placeholder = st.empty()
            full_response = ""
            for chunk in asyncio.run(generate_ai_response(prompt, relevant_chunks)):
                full_response += chunk
                message_placeholder.markdown(full_response + "▌")
            message_placeholder.markdown(full_response)
        
        st.session_state.messages.append({"role": "assistant", "content": full_response})

    # Sidebar for additional options
    with st.sidebar:
        st.subheader("Opcje")
        if st.button("Wyczyść historię czatu"):
            st.session_state.messages = []
            st.experimental_rerun()

        st.subheader("Informacje o bazie danych")
        st.write(f"Liczba chunków: {len(chunks)}")
        st.write(f"Przykładowy chunk:")
        st.json(chunks[0] if chunks else {})

if __name__ == "__main__":
    main()