lenawilli commited on
Commit
aeb1acc
Β·
verified Β·
1 Parent(s): 606351b

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +8 -255
src/streamlit_app.py CHANGED
@@ -7,16 +7,6 @@ from transformers import AutoTokenizer, AutoModel
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  import torch
9
  import re
10
- from typing import List, Dict, Any
11
- from openai import OpenAI
12
- from dotenv import load_dotenv
13
- import os
14
- from sentence_transformers import SentenceTransformer
15
- from rdflib import Graph, Namespace, URIRef, Literal, RDF, RDFS, XSD
16
- import os
17
- import networkx as nx
18
- from pyvis.network import Network
19
- import streamlit.components.v1 as components
20
 
21
  # ---------------------------
22
  # LegalBERT-based compliance checker
@@ -60,7 +50,6 @@ class GDPRComplianceChecker:
60
  full_text = f"Article {number}: {title}. {body}"
61
  gdpr_map[number] = {"title": title, "text": full_text}
62
  texts.append(full_text)
63
-
64
  embeddings = self.get_embeddings(texts)
65
  return gdpr_map, embeddings
66
 
@@ -115,36 +104,7 @@ def chunk_policy_text(text, chunk_size=500):
115
  chunks.append(current.strip())
116
  return [chunk for chunk in chunks if len(chunk) > 50]
117
 
118
- def prepare_article_text(article: Dict[str, Any]) -> str:
119
- body = " ".join(
120
- " ".join(sec.values()) if isinstance(sec, dict) else str(sec)
121
- for sec in article.get("sections", [])
122
- )
123
- return f"Art. {article['article_number']} – {article['article_title']} {body}"
124
-
125
- def get_embedding(text: str) -> List[float]:
126
- # If input is a list of strings, clean each string
127
- if isinstance(text, list):
128
- cleaned_text = [t.replace("\n", " ") for t in text]
129
- else: # single string
130
- cleaned_text = text.replace("\n", " ")
131
- resp = client.embeddings.create(model=EMBED_MODEL, input=cleaned_text)
132
- if isinstance(cleaned_text, list):
133
- return [item.embedding for item in resp.data]
134
- else:
135
- return resp.data[0].embedding
136
-
137
- def rdflib_to_networkx(rdflib_graph):
138
- nx_graph = nx.MultiDiGraph()
139
- for s, p, o in rdflib_graph:
140
- nx_graph.add_edge(str(s), str(o), label=str(p))
141
- return nx_graph
142
 
143
- def draw_pyvis_graph(nx_graph):
144
- net = Network(height="600px", width="100%", directed=True, notebook=False)
145
- net.from_nx(nx_graph)
146
- net.repulsion(node_distance=200, central_gravity=0.33, spring_length=100, spring_strength=0.10, damping=0.95)
147
- return net
148
  # ---------------------------
149
  # Streamlit interface
150
  # ---------------------------
@@ -159,7 +119,7 @@ with st.sidebar:
159
  if gdpr_file and policy_file:
160
  model_choice = st.selectbox(
161
  "Choose the model to use:",
162
- ["Logistic Regression", "MultinomialNB", "LegalBERT (Eurlex)", "SentenceTransformer", "LLM Model", "Knowledge Graphs"]
163
  )
164
 
165
  gdpr_data = json.load(gdpr_file)
@@ -216,194 +176,9 @@ if gdpr_file and policy_file:
216
  "article_scores": dict(article_scores)
217
  }
218
 
219
- elif model_choice == "SentenceTransformer":
220
- model = joblib.load("sentence_transformer_model.joblib")
221
- gdpr_texts = []
222
- gdpr_map = {}
223
- for article in gdpr_data:
224
- number, title = article["article_number"], article["article_title"]
225
- body = " ".join([f"{k} {v}" for sec in article["sections"] for k, v in sec.items()])
226
- full_text = f"Article {number}: {title}. {body}"
227
- gdpr_map[number] = {
228
- "title": title,
229
- "text": full_text
230
- }
231
- gdpr_texts.append(full_text)
232
-
233
- gdpr_embeddings = model.encode(gdpr_texts, convert_to_numpy=True)
234
-
235
- chunks = chunk_policy_text(policy_text)
236
- chunk_embeddings = model.encode(chunks, convert_to_numpy=True)
237
-
238
- sim_matrix = cosine_similarity(gdpr_embeddings, chunk_embeddings)
239
-
240
- article_scores = {}
241
- presence_threshold = 0.35
242
- total_score, counted_articles = 0, 0
243
-
244
- for i, (art_num, art_data) in enumerate(gdpr_map.items()):
245
- max_sim = np.max(sim_matrix[i])
246
- best_idx = np.argmax(sim_matrix[i])
247
-
248
- if max_sim < presence_threshold:
249
- continue
250
-
251
- score_pct = min(100, max(0, (max_sim - presence_threshold) / (1 - presence_threshold) * 100))
252
- article_scores[art_num] = {
253
- "article_title": art_data["title"],
254
- "compliance_percentage": round(score_pct, 2),
255
- "similarity_score": round(max_sim, 4),
256
- "matched_text_snippet": chunks[best_idx][:300] + "..."
257
- }
258
- total_score += score_pct
259
- counted_articles += 1
260
-
261
- overall = round(total_score / counted_articles, 2) if counted_articles else 0
262
- result = {
263
- "overall_compliance_percentage": overall,
264
- "relevant_articles_analyzed": counted_articles,
265
- "total_policy_chunks": len(chunks),
266
- "article_scores": article_scores
267
- }
268
-
269
- elif model_choice == "LLM Model":
270
- load_dotenv()
271
- api_key = os.getenv("OPENAI_API_KEY")
272
- client = OpenAI(api_key=api_key)
273
- EMBED_MODEL = "text-embedding-3-small"
274
- gdpr_embeddings = {}
275
- gdpr_map = {}
276
- for art in gdpr_data:
277
- number, title = art["article_number"], art["article_title"]
278
- art_text = prepare_article_text(art)
279
- gdpr_embeddings[art["article_number"]] = {
280
- "embedding": get_embedding(art_text),
281
- "title": art["article_title"]
282
- }
283
- gdpr_map[number] = {"title": title, "text": art_text}
284
- chunks = chunk_policy_text(policy_text)
285
- chunk_embeddings = get_embedding(chunks)
286
- gdpr_embedding_vectors = [v["embedding"] for v in gdpr_embeddings.values()]
287
- sim_matrix = cosine_similarity(gdpr_embedding_vectors, chunk_embeddings)
288
-
289
- article_scores = {}
290
- presence_threshold = 0.35
291
- total_score, counted_articles = 0, 0
292
-
293
- for i, (art_num, art_data) in enumerate(gdpr_map.items()):
294
- max_sim = np.max(sim_matrix[i])
295
- best_idx = np.argmax(sim_matrix[i])
296
-
297
- if max_sim < presence_threshold:
298
- continue
299
-
300
- score_pct = min(100, max(0, (max_sim - presence_threshold) / (1 - presence_threshold) * 100))
301
- article_scores[art_num] = {
302
- "article_title": art_data["title"],
303
- "compliance_percentage": round(score_pct, 2),
304
- "similarity_score": round(max_sim, 4),
305
- "matched_text_snippet": chunks[best_idx][:300] + "..."
306
- }
307
- total_score += score_pct
308
- counted_articles += 1
309
-
310
- overall = round(total_score / counted_articles, 2) if counted_articles else 0
311
- result = {
312
- "overall_compliance_percentage": overall,
313
- "relevant_articles_analyzed": counted_articles,
314
- "total_policy_chunks": len(chunks),
315
- "article_scores": article_scores
316
- }
317
  elif model_choice == "Knowledge Graphs":
318
- EMBED_MODEL = "all-MiniLM-L6-v2"
319
- model = SentenceTransformer(EMBED_MODEL)
320
- TOP_N = 1
321
- BASE_URI = "http://example.org/gdpr#"
322
- gdpr_embeddings = {}
323
- gdpr_map = {}
324
- for art in gdpr_data:
325
- number, title = art["article_number"], art["article_title"]
326
- art_text = prepare_article_text(art)
327
- gdpr_embeddings[art["article_number"]] = {
328
- "embedding": model.encode(art_text),
329
- "title": art["article_title"],
330
- "uri": URIRef(f"{BASE_URI}Article{art['article_number']}")
331
- }
332
- gdpr_map[number] = {"title": title, "text": art_text}
333
- g = Graph()
334
- EX = Namespace(BASE_URI)
335
- g.bind("ex", EX)
336
-
337
- # Add article nodes
338
- for num, info in gdpr_embeddings.items():
339
- g.add((info["uri"], RDF.type, EX.Article))
340
- g.add((info["uri"], RDFS.label, Literal(f"Article {num}: {info['title']}")))
341
- # Extract GDPR article vectors
342
- article_nums = list(gdpr_embeddings.keys())
343
- article_vectors = np.array([gdpr_embeddings[num]["embedding"] for num in article_nums])
344
-
345
- # Score tracking
346
- total_score = 0
347
- counted_sections = 0
348
- chunks = chunk_policy_text(policy_text)
349
- report = []
350
- presence_threshold = 0.35
351
-
352
- # Process each policy chunk
353
- for idx, text in enumerate(chunks, start=1):
354
- if not text.strip():
355
- continue
356
-
357
- # RDF section node
358
- sec_uri = URIRef(f"{BASE_URI}PolicySection{idx}")
359
- g.add((sec_uri, RDF.type, EX.PolicySection))
360
- g.add((sec_uri, RDFS.label, Literal(f"Section {idx}")))
361
-
362
- # Embed section
363
- sec_emb = model.encode(text)
364
-
365
- # Similarities to all articles
366
- sims = []
367
- for i, art_num in enumerate(article_nums):
368
- art_emb = article_vectors[i]
369
- sim = cosine_similarity([sec_emb], [art_emb])[0][0]
370
- sims.append({
371
- "article": art_num,
372
- "title": gdpr_embeddings[art_num]["title"],
373
- "similarity": round(sim, 4),
374
- "uri": gdpr_embeddings[art_num]["uri"],
375
- "text": gdpr_map[art_num]["text"]
376
- })
377
-
378
- # Sort and pick best match
379
- sims.sort(key=lambda x: x["similarity"], reverse=True)
380
- top_match = sims[0]
381
-
382
- # Threshold filtering
383
- if top_match["similarity"] < presence_threshold:
384
- continue
385
-
386
- # Compliance score
387
- score_pct = min(100, max(0, (top_match["similarity"] - presence_threshold) / (1 - presence_threshold) * 100))
388
-
389
- # Add RDF triples
390
- g.add((sec_uri, EX.relatesTo, top_match["uri"]))
391
- g.add((sec_uri, EX.similarityScore, Literal(top_match["similarity"], datatype=XSD.float)))
392
-
393
-
394
- g.serialize(destination="gdpr_policy_graph.ttl", format="turtle")
395
-
396
- total_score += score_pct
397
- counted_sections += 1
398
-
399
- # Final summary
400
- overall = round(total_score / counted_sections, 2) if counted_sections else 0
401
- result = {
402
- "overall_compliance_percentage": overall,
403
- "relevant_sections_analyzed": counted_sections,
404
- "total_policy_sections": len(chunks),
405
- "ttl": True
406
- }
407
 
408
  else:
409
  result = {}
@@ -412,31 +187,9 @@ if gdpr_file and policy_file:
412
  st.subheader(f"βœ… Overall Compliance Score: {result['overall_compliance_percentage']}%")
413
  st.markdown("---")
414
  st.subheader("πŸ“‹ Detailed Article Breakdown")
415
- ttl_file_path = "gdpr_policy_graph.ttl"
416
- if result.get('article_scores'):
417
- for art_num, data in sorted(result['article_scores'].items(), key=lambda x: -x[1]['compliance_percentage']):
418
- with st.expander(f"Article {art_num} - {data['article_title']} ({data['compliance_percentage']}%)"):
419
- st.write(f"**Similarity Score**: {data['similarity_score']}")
420
- st.write(f"**Matched Text**:\n\n{data['matched_text_snippet']}")
421
- elif result.get("ttl") and os.path.exists(ttl_file_path):
422
- st.markdown("---")
423
- st.subheader("🧠 Interactive RDF Graph Visualization")
424
-
425
- g = Graph()
426
- g.parse(ttl_file_path, format="ttl")
427
-
428
- nx_graph = rdflib_to_networkx(g)
429
- net = draw_pyvis_graph(nx_graph)
430
-
431
- # Save the interactive graph temporarily
432
- net.save_graph("rdf_graph.html")
433
- HtmlFile = open("rdf_graph.html", "r", encoding="utf-8").read()
434
-
435
- # Display interactive graph inside Streamlit
436
- components.html(HtmlFile, height=650, scrolling=True)
437
-
438
- else:
439
- st.info("No article scores or RDF graph to display.")
440
-
441
  else:
442
- st.info("Please upload both a GDPR JSON file and a company policy text file to begin.")
 
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  import torch
9
  import re
 
 
 
 
 
 
 
 
 
 
10
 
11
  # ---------------------------
12
  # LegalBERT-based compliance checker
 
50
  full_text = f"Article {number}: {title}. {body}"
51
  gdpr_map[number] = {"title": title, "text": full_text}
52
  texts.append(full_text)
 
53
  embeddings = self.get_embeddings(texts)
54
  return gdpr_map, embeddings
55
 
 
104
  chunks.append(current.strip())
105
  return [chunk for chunk in chunks if len(chunk) > 50]
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
 
 
 
 
 
108
  # ---------------------------
109
  # Streamlit interface
110
  # ---------------------------
 
119
  if gdpr_file and policy_file:
120
  model_choice = st.selectbox(
121
  "Choose the model to use:",
122
+ ["Logistic Regression", "MultinomialNB", "LegalBERT (Eurlex)", "Knowledge Graphs"]
123
  )
124
 
125
  gdpr_data = json.load(gdpr_file)
 
176
  "article_scores": dict(article_scores)
177
  }
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  elif model_choice == "Knowledge Graphs":
180
+ st.warning("Knowledge Graphs model is not implemented yet.")
181
+ result = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  else:
184
  result = {}
 
187
  st.subheader(f"βœ… Overall Compliance Score: {result['overall_compliance_percentage']}%")
188
  st.markdown("---")
189
  st.subheader("πŸ“‹ Detailed Article Breakdown")
190
+ for art_num, data in sorted(result['article_scores'].items(), key=lambda x: -x[1]['compliance_percentage']):
191
+ with st.expander(f"Article {art_num} - {data['article_title']} ({data['compliance_percentage']}%)"):
192
+ st.write(f"**Similarity Score**: {data['similarity_score']}")
193
+ st.write(f"**Matched Text**:\n\n{data['matched_text_snippet']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  else:
195
+ st.info("Please upload both a GDPR JSON file and a company policy text file to begin.")