Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,10 @@ import io
|
|
2 |
import os
|
3 |
import re
|
4 |
import time
|
5 |
-
|
|
|
|
|
|
|
6 |
from itertools import islice
|
7 |
from functools import partial
|
8 |
from multiprocessing.pool import ThreadPool
|
@@ -209,6 +212,28 @@ with gr.Blocks(css=css) as demo:
|
|
209 |
gr.Markdown("Save datasets to your account")
|
210 |
gr.LoginButton()
|
211 |
select_namespace_dropdown = gr.Dropdown(choices=[NAMESPACE], value=NAMESPACE, label="Select user or organization", visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
gr.Markdown("Save datasets as public or private datasets")
|
213 |
visibility_radio = gr.Radio(["public", "private"], value="public", container=False, interactive=False)
|
214 |
with gr.Column(visible=False) as dataset_page:
|
@@ -411,35 +436,193 @@ with gr.Blocks(css=css) as demo:
|
|
411 |
]
|
412 |
|
413 |
|
414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
"""
|
416 |
-
|
417 |
-
|
|
|
|
|
|
|
418 |
"""
|
419 |
-
def
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
return value
|
427 |
|
428 |
-
def clean_record(record):
|
|
|
429 |
cleaned = {}
|
|
|
|
|
|
|
430 |
for key, value in record.items():
|
431 |
-
# Normalize key and value
|
432 |
clean_key = key.strip().lower().replace(" ", "_")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
if isinstance(value, list):
|
434 |
-
cleaned[clean_key] = [normalize_value(v) for v in value]
|
435 |
elif isinstance(value, dict):
|
436 |
cleaned[clean_key] = clean_record(value)
|
437 |
else:
|
438 |
-
cleaned[clean_key] = normalize_value(value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
return cleaned
|
440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
return [clean_record(entry) for entry in dataset]
|
442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
def detect_anomalies(record: Dict[str, Any]) -> List[str]:
|
444 |
"""
|
445 |
Detect potential anomalies in a record.
|
@@ -621,20 +804,16 @@ with gr.Blocks(css=css) as demo:
|
|
621 |
return gr.Column(visible=True), gr.Column(visible=False)
|
622 |
|
623 |
|
624 |
-
@generate_full_dataset_button.click(
|
625 |
-
|
|
|
|
|
|
|
626 |
dataset_name, tags = title.strip("# ").split("\ntags:", 1)
|
627 |
dataset_name, tags = dataset_name.strip(), tags.strip()
|
628 |
csv_header, preview_df = parse_preview_df(content)
|
629 |
-
#
|
630 |
-
|
631 |
-
try:
|
632 |
-
if [int(v) for v in values] == list(range(len(preview_df))):
|
633 |
-
preview_df = preview_df.drop(columns=column_name)
|
634 |
-
if [int(v) for v in values] == list(range(1, len(preview_df) + 1)):
|
635 |
-
preview_df = preview_df.drop(columns=column_name)
|
636 |
-
except Exception:
|
637 |
-
pass
|
638 |
columns = list(preview_df)
|
639 |
output: list[Optional[dict]] = [None] * NUM_ROWS
|
640 |
output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
|
@@ -688,7 +867,19 @@ with gr.Blocks(css=css) as demo:
|
|
688 |
visible=True,
|
689 |
)
|
690 |
|
691 |
-
@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
692 |
def load_app(request: gr.Request, oauth_token: Optional[gr.OAuthToken]):
|
693 |
if oauth_token:
|
694 |
user_info = whoami(oauth_token.token)
|
|
|
2 |
import os
|
3 |
import re
|
4 |
import time
|
5 |
+
import requests
|
6 |
+
from typing import Any, Dict, List, Optional, Set, Union
|
7 |
+
from difflib import get_close_matches
|
8 |
+
from pathlib import Path
|
9 |
from itertools import islice
|
10 |
from functools import partial
|
11 |
from multiprocessing.pool import ThreadPool
|
|
|
212 |
gr.Markdown("Save datasets to your account")
|
213 |
gr.LoginButton()
|
214 |
select_namespace_dropdown = gr.Dropdown(choices=[NAMESPACE], value=NAMESPACE, label="Select user or organization", visible=False)
|
215 |
+
|
216 |
+
gr.Markdown("Dataset Refinement Settings")
|
217 |
+
refinement_mode = gr.Radio(
|
218 |
+
["sourceless", "sourced"],
|
219 |
+
value="sourceless",
|
220 |
+
label="Refinement Mode",
|
221 |
+
info="Choose between AI-only refinement or source-based refinement"
|
222 |
+
)
|
223 |
+
|
224 |
+
with gr.Group(visible=False) as source_group:
|
225 |
+
source_type = gr.Dropdown(
|
226 |
+
choices=["csv_url", "xlsx_url", "local_csv", "local_xlsx"],
|
227 |
+
value="csv_url",
|
228 |
+
label="Source Type"
|
229 |
+
)
|
230 |
+
source_path = gr.Textbox(
|
231 |
+
label="Source Path/URL",
|
232 |
+
placeholder="Enter URL or local file path"
|
233 |
+
)
|
234 |
+
load_source_button = gr.Button("Load Source")
|
235 |
+
source_status = gr.Markdown("")
|
236 |
+
|
237 |
gr.Markdown("Save datasets as public or private datasets")
|
238 |
visibility_radio = gr.Radio(["public", "private"], value="public", container=False, interactive=False)
|
239 |
with gr.Column(visible=False) as dataset_page:
|
|
|
436 |
]
|
437 |
|
438 |
|
439 |
+
# Knowledge base storage
|
440 |
+
class KnowledgeBase:
|
441 |
+
def __init__(self):
|
442 |
+
self.materials: Set[str] = {'Metal', 'Wood', 'Plastic', 'Aluminum', 'Bronze', 'Steel', 'Glass', 'Leather', 'Fabric'}
|
443 |
+
self.colors: Set[str] = {'Red', 'Black', 'White', 'Silver', 'Bronze', 'Yellow', 'Blue', 'Green', 'Gray', 'Brown'}
|
444 |
+
self.patterns: Dict[str, List[str]] = {}
|
445 |
+
self.source_data: Dict[str, Any] = {}
|
446 |
+
|
447 |
+
def load_source(self, source_type: str, source_path: str) -> None:
|
448 |
+
"""Load data from various sources into the knowledge base"""
|
449 |
+
try:
|
450 |
+
if source_type == 'csv_url':
|
451 |
+
response = requests.get(source_path)
|
452 |
+
df = pd.read_csv(io.StringIO(response.text))
|
453 |
+
elif source_type == 'xlsx_url':
|
454 |
+
response = requests.get(source_path)
|
455 |
+
df = pd.read_excel(io.BytesIO(response.content))
|
456 |
+
elif source_type == 'local_csv':
|
457 |
+
df = pd.read_csv(source_path)
|
458 |
+
elif source_type == 'local_xlsx':
|
459 |
+
df = pd.read_excel(source_path)
|
460 |
+
else:
|
461 |
+
raise ValueError(f"Unsupported source type: {source_type}")
|
462 |
+
|
463 |
+
# Extract patterns and common values
|
464 |
+
self._extract_knowledge(df)
|
465 |
+
|
466 |
+
# Store source data
|
467 |
+
self.source_data[source_path] = df.to_dict('records')
|
468 |
+
|
469 |
+
except Exception as e:
|
470 |
+
print(f"Error loading source {source_path}: {str(e)}")
|
471 |
+
|
472 |
+
def _extract_knowledge(self, df: pd.DataFrame) -> None:
|
473 |
+
"""Extract patterns and common values from dataframe"""
|
474 |
+
for column in df.columns:
|
475 |
+
if 'material' in column.lower():
|
476 |
+
values = df[column].dropna().unique()
|
477 |
+
self.materials.update(v.title() for v in values if isinstance(v, str))
|
478 |
+
elif 'color' in column.lower():
|
479 |
+
values = df[column].dropna().unique()
|
480 |
+
self.colors.update(v.title() for v in values if isinstance(v, str))
|
481 |
+
|
482 |
+
# Store column patterns
|
483 |
+
if df[column].dtype == 'object':
|
484 |
+
patterns = df[column].dropna().astype(str).tolist()
|
485 |
+
self.patterns[column] = patterns
|
486 |
+
|
487 |
+
def get_closest_match(self, value: str, field_type: str) -> Optional[str]:
|
488 |
+
"""Find closest match from known values"""
|
489 |
+
if field_type == 'material':
|
490 |
+
matches = get_close_matches(value.title(), list(self.materials), n=1, cutoff=0.8)
|
491 |
+
elif field_type == 'color':
|
492 |
+
matches = get_close_matches(value.title(), list(self.colors), n=1, cutoff=0.8)
|
493 |
+
else:
|
494 |
+
return None
|
495 |
+
return matches[0] if matches else None
|
496 |
+
|
497 |
+
# Initialize knowledge base
|
498 |
+
knowledge_base = KnowledgeBase()
|
499 |
+
|
500 |
+
def refine_data_generic(dataset: List[Dict[str, Any]], mode: str = 'sourceless', kb: Optional[KnowledgeBase] = None) -> List[Dict[str, Any]]:
|
501 |
"""
|
502 |
+
Enhanced universal dataset refinement with source-aware and sourceless modes.
|
503 |
+
Args:
|
504 |
+
dataset: List of dictionary records
|
505 |
+
mode: 'sourceless' or 'sourced'
|
506 |
+
knowledge_base: Optional reference data for sourced mode
|
507 |
"""
|
508 |
+
def split_compound_field(field: str) -> List[str]:
|
509 |
+
"""Split compound fields like materialwear into separate values"""
|
510 |
+
parts = re.split(r'[,;\n]+', field)
|
511 |
+
parts = [part.strip().title() for part in parts if part.strip()]
|
512 |
+
return list(set(parts)) # Remove duplicates
|
513 |
+
|
514 |
+
def normalize_value(value: Any, field_name: str) -> Any:
|
515 |
+
"""Smart value normalization with field context"""
|
516 |
+
if not isinstance(value, str):
|
517 |
+
return value
|
518 |
+
|
519 |
+
# Basic cleanup
|
520 |
+
value = re.sub(r'\s+', ' ', value.strip())
|
521 |
+
value = value.replace('_', ' ')
|
522 |
+
|
523 |
+
# Field-specific processing with knowledge base
|
524 |
+
if any(term in field_name.lower() for term in ['material']):
|
525 |
+
parts = split_compound_field(value)
|
526 |
+
if mode == 'sourced' and kb:
|
527 |
+
known = [kb.get_closest_match(p, 'material') or p.title() for p in parts]
|
528 |
+
else:
|
529 |
+
known = [m for m in parts if m in kb.materials] if kb else parts
|
530 |
+
if known:
|
531 |
+
return known[0] if len(known) == 1 else known
|
532 |
+
return value.title()
|
533 |
+
|
534 |
+
if any(term in field_name.lower() for term in ['color']):
|
535 |
+
parts = split_compound_field(value)
|
536 |
+
if mode == 'sourced' and kb:
|
537 |
+
known = [kb.get_closest_match(p, 'color') or p.title() for p in parts]
|
538 |
+
else:
|
539 |
+
known = [c for c in parts if c in kb.colors] if kb else parts
|
540 |
+
if known:
|
541 |
+
return known[0] if len(known) == 1 else known
|
542 |
+
return value.title()
|
543 |
+
|
544 |
+
if any(term in field_name.lower() for term in ['date', 'time']):
|
545 |
+
# Add date normalization logic here
|
546 |
+
return value
|
547 |
+
|
548 |
+
# Default titlecase for descriptive fields
|
549 |
+
if any(term in field_name.lower() for term in ['type', 'status', 'category', 'description']):
|
550 |
+
return value.title()
|
551 |
+
|
552 |
return value
|
553 |
|
554 |
+
def clean_record(record: Dict[str, Any]) -> Dict[str, Any]:
|
555 |
+
"""Enhanced record cleaning with compound field detection"""
|
556 |
cleaned = {}
|
557 |
+
compound_fields = {}
|
558 |
+
|
559 |
+
# First pass: Basic cleaning and compound field detection
|
560 |
for key, value in record.items():
|
|
|
561 |
clean_key = key.strip().lower().replace(" ", "_")
|
562 |
+
|
563 |
+
# Handle compound fields (e.g., materialwear)
|
564 |
+
if isinstance(value, str):
|
565 |
+
for material in COMMON_MATERIALS:
|
566 |
+
if material.lower() in value.lower():
|
567 |
+
compound_fields[clean_key] = value
|
568 |
+
break
|
569 |
+
|
570 |
if isinstance(value, list):
|
571 |
+
cleaned[clean_key] = [normalize_value(v, clean_key) for v in value]
|
572 |
elif isinstance(value, dict):
|
573 |
cleaned[clean_key] = clean_record(value)
|
574 |
else:
|
575 |
+
cleaned[clean_key] = normalize_value(value, clean_key)
|
576 |
+
|
577 |
+
# Second pass: Split compound fields
|
578 |
+
for key, value in compound_fields.items():
|
579 |
+
parts = split_compound_field(value)
|
580 |
+
materials = [p for p in parts if p in COMMON_MATERIALS]
|
581 |
+
if materials:
|
582 |
+
cleaned['material'] = materials[0] if len(materials) == 1 else materials
|
583 |
+
# Store remaining info in wear/condition field
|
584 |
+
remaining = [p for p in parts if p not in materials]
|
585 |
+
if remaining:
|
586 |
+
cleaned['condition'] = ' '.join(remaining)
|
587 |
+
|
588 |
return cleaned
|
589 |
|
590 |
+
# Use knowledge base patterns in sourced mode
|
591 |
+
if mode == 'sourced' and kb and kb.patterns:
|
592 |
+
for record in dataset:
|
593 |
+
for field, patterns in kb.patterns.items():
|
594 |
+
if field in record:
|
595 |
+
value = str(record[field])
|
596 |
+
matches = get_close_matches(value, patterns, n=1, cutoff=0.8)
|
597 |
+
if matches:
|
598 |
+
record[field] = matches[0]
|
599 |
+
|
600 |
+
|
601 |
return [clean_record(entry) for entry in dataset]
|
602 |
|
603 |
+
def refine_preview_data(df: pd.DataFrame, mode: str = 'sourceless') -> pd.DataFrame:
|
604 |
+
"""Refine preview data with the selected mode"""
|
605 |
+
# Remove dummy "id" columns first
|
606 |
+
for column_name, values in df.to_dict(orient="series").items():
|
607 |
+
try:
|
608 |
+
if [int(v) for v in values] == list(range(len(df))):
|
609 |
+
df = df.drop(columns=column_name)
|
610 |
+
if [int(v) for v in values] == list(range(1, len(df) + 1)):
|
611 |
+
df = df.drop(columns=column_name)
|
612 |
+
except Exception:
|
613 |
+
pass
|
614 |
+
|
615 |
+
# Convert to records for refinement
|
616 |
+
records = df.to_dict('records')
|
617 |
+
|
618 |
+
# Apply refinement with current mode and knowledge base
|
619 |
+
refined_records = refine_data_generic(records, mode=mode, kb=knowledge_base)
|
620 |
+
|
621 |
+
# Convert back to DataFrame
|
622 |
+
refined_df = pd.DataFrame(refined_records)
|
623 |
+
|
624 |
+
return refined_df
|
625 |
+
|
626 |
def detect_anomalies(record: Dict[str, Any]) -> List[str]:
|
627 |
"""
|
628 |
Detect potential anomalies in a record.
|
|
|
804 |
return gr.Column(visible=True), gr.Column(visible=False)
|
805 |
|
806 |
|
807 |
+
@generate_full_dataset_button.click(
|
808 |
+
inputs=[dataset_title, dataset_content, search_bar, select_namespace_dropdown, visibility_radio, refinement_mode],
|
809 |
+
outputs=[dataset_dataframe, generate_full_dataset_button, save_dataset_button]
|
810 |
+
)
|
811 |
+
def generate_full_dataset(title, content, search_query, namespace, visibility, mode):
|
812 |
dataset_name, tags = title.strip("# ").split("\ntags:", 1)
|
813 |
dataset_name, tags = dataset_name.strip(), tags.strip()
|
814 |
csv_header, preview_df = parse_preview_df(content)
|
815 |
+
# Clean and refine the preview data
|
816 |
+
preview_df = refine_preview_data(preview_df, mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
817 |
columns = list(preview_df)
|
818 |
output: list[Optional[dict]] = [None] * NUM_ROWS
|
819 |
output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
|
|
|
867 |
visible=True,
|
868 |
)
|
869 |
|
870 |
+
@refinement_mode.change(outputs=[source_group])
|
871 |
+
def toggle_source_group(mode):
|
872 |
+
return gr.Group(visible=(mode == "sourced"))
|
873 |
+
|
874 |
+
@load_source_button.click(inputs=[source_type, source_path], outputs=[source_status])
|
875 |
+
def load_knowledge_source(source_type, source_path):
|
876 |
+
try:
|
877 |
+
knowledge_base.load_source(source_type, source_path)
|
878 |
+
return gr.Markdown("✅ Source loaded successfully", visible=True)
|
879 |
+
except Exception as e:
|
880 |
+
return gr.Markdown(f"❌ Error loading source: {str(e)}", visible=True)
|
881 |
+
|
882 |
+
@demo.load(outputs=show_dataset_outputs + button_groups + buttons + [generated_texts_state] + [select_namespace_dropdown, visibility_radio, source_group])
|
883 |
def load_app(request: gr.Request, oauth_token: Optional[gr.OAuthToken]):
|
884 |
if oauth_token:
|
885 |
user_info = whoami(oauth_token.token)
|