acecalisto3 commited on
Commit
303a80b
·
verified ·
1 Parent(s): af05e7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -27
app.py CHANGED
@@ -2,7 +2,10 @@ import io
2
  import os
3
  import re
4
  import time
5
- from typing import Any, Dict, List
 
 
 
6
  from itertools import islice
7
  from functools import partial
8
  from multiprocessing.pool import ThreadPool
@@ -209,6 +212,28 @@ with gr.Blocks(css=css) as demo:
209
  gr.Markdown("Save datasets to your account")
210
  gr.LoginButton()
211
  select_namespace_dropdown = gr.Dropdown(choices=[NAMESPACE], value=NAMESPACE, label="Select user or organization", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  gr.Markdown("Save datasets as public or private datasets")
213
  visibility_radio = gr.Radio(["public", "private"], value="public", container=False, interactive=False)
214
  with gr.Column(visible=False) as dataset_page:
@@ -411,35 +436,193 @@ with gr.Blocks(css=css) as demo:
411
  ]
412
 
413
 
414
- def refine_data_generic(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  """
416
- Universally refine any dataset.
417
- Works on list of dicts. Detects field types and applies general cleanup.
 
 
 
418
  """
419
- def normalize_value(value):
420
- if isinstance(value, str):
421
- # Trim, title-case common descriptors, remove duplicate whitespace
422
- value = re.sub(r'\s+', ' ', value.strip())
423
- value = value.replace('_', ' ')
424
- if any(k in value.lower() for k in ['color', 'material', 'type', 'status']):
425
- value = value.title()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  return value
427
 
428
- def clean_record(record):
 
429
  cleaned = {}
 
 
 
430
  for key, value in record.items():
431
- # Normalize key and value
432
  clean_key = key.strip().lower().replace(" ", "_")
 
 
 
 
 
 
 
 
433
  if isinstance(value, list):
434
- cleaned[clean_key] = [normalize_value(v) for v in value]
435
  elif isinstance(value, dict):
436
  cleaned[clean_key] = clean_record(value)
437
  else:
438
- cleaned[clean_key] = normalize_value(value)
 
 
 
 
 
 
 
 
 
 
 
 
439
  return cleaned
440
 
 
 
 
 
 
 
 
 
 
 
 
441
  return [clean_record(entry) for entry in dataset]
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  def detect_anomalies(record: Dict[str, Any]) -> List[str]:
444
  """
445
  Detect potential anomalies in a record.
@@ -621,20 +804,16 @@ with gr.Blocks(css=css) as demo:
621
  return gr.Column(visible=True), gr.Column(visible=False)
622
 
623
 
624
- @generate_full_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, select_namespace_dropdown, visibility_radio], outputs=[dataset_dataframe, generate_full_dataset_button, save_dataset_button])
625
- def generate_full_dataset(title, content, search_query, namespace, visability):
 
 
 
626
  dataset_name, tags = title.strip("# ").split("\ntags:", 1)
627
  dataset_name, tags = dataset_name.strip(), tags.strip()
628
  csv_header, preview_df = parse_preview_df(content)
629
- # Remove dummy "id" columns
630
- for column_name, values in preview_df.to_dict(orient="series").items():
631
- try:
632
- if [int(v) for v in values] == list(range(len(preview_df))):
633
- preview_df = preview_df.drop(columns=column_name)
634
- if [int(v) for v in values] == list(range(1, len(preview_df) + 1)):
635
- preview_df = preview_df.drop(columns=column_name)
636
- except Exception:
637
- pass
638
  columns = list(preview_df)
639
  output: list[Optional[dict]] = [None] * NUM_ROWS
640
  output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
@@ -688,7 +867,19 @@ with gr.Blocks(css=css) as demo:
688
  visible=True,
689
  )
690
 
691
- @demo.load(outputs=show_dataset_outputs + button_groups + buttons + [generated_texts_state] + [select_namespace_dropdown, visibility_radio])
 
 
 
 
 
 
 
 
 
 
 
 
692
  def load_app(request: gr.Request, oauth_token: Optional[gr.OAuthToken]):
693
  if oauth_token:
694
  user_info = whoami(oauth_token.token)
 
2
  import os
3
  import re
4
  import time
5
+ import requests
6
+ from typing import Any, Dict, List, Optional, Set, Union
7
+ from difflib import get_close_matches
8
+ from pathlib import Path
9
  from itertools import islice
10
  from functools import partial
11
  from multiprocessing.pool import ThreadPool
 
212
  gr.Markdown("Save datasets to your account")
213
  gr.LoginButton()
214
  select_namespace_dropdown = gr.Dropdown(choices=[NAMESPACE], value=NAMESPACE, label="Select user or organization", visible=False)
215
+
216
+ gr.Markdown("Dataset Refinement Settings")
217
+ refinement_mode = gr.Radio(
218
+ ["sourceless", "sourced"],
219
+ value="sourceless",
220
+ label="Refinement Mode",
221
+ info="Choose between AI-only refinement or source-based refinement"
222
+ )
223
+
224
+ with gr.Group(visible=False) as source_group:
225
+ source_type = gr.Dropdown(
226
+ choices=["csv_url", "xlsx_url", "local_csv", "local_xlsx"],
227
+ value="csv_url",
228
+ label="Source Type"
229
+ )
230
+ source_path = gr.Textbox(
231
+ label="Source Path/URL",
232
+ placeholder="Enter URL or local file path"
233
+ )
234
+ load_source_button = gr.Button("Load Source")
235
+ source_status = gr.Markdown("")
236
+
237
  gr.Markdown("Save datasets as public or private datasets")
238
  visibility_radio = gr.Radio(["public", "private"], value="public", container=False, interactive=False)
239
  with gr.Column(visible=False) as dataset_page:
 
436
  ]
437
 
438
 
439
+ # Knowledge base storage
440
+ class KnowledgeBase:
441
+ def __init__(self):
442
+ self.materials: Set[str] = {'Metal', 'Wood', 'Plastic', 'Aluminum', 'Bronze', 'Steel', 'Glass', 'Leather', 'Fabric'}
443
+ self.colors: Set[str] = {'Red', 'Black', 'White', 'Silver', 'Bronze', 'Yellow', 'Blue', 'Green', 'Gray', 'Brown'}
444
+ self.patterns: Dict[str, List[str]] = {}
445
+ self.source_data: Dict[str, Any] = {}
446
+
447
+ def load_source(self, source_type: str, source_path: str) -> None:
448
+ """Load data from various sources into the knowledge base"""
449
+ try:
450
+ if source_type == 'csv_url':
451
+ response = requests.get(source_path)
452
+ df = pd.read_csv(io.StringIO(response.text))
453
+ elif source_type == 'xlsx_url':
454
+ response = requests.get(source_path)
455
+ df = pd.read_excel(io.BytesIO(response.content))
456
+ elif source_type == 'local_csv':
457
+ df = pd.read_csv(source_path)
458
+ elif source_type == 'local_xlsx':
459
+ df = pd.read_excel(source_path)
460
+ else:
461
+ raise ValueError(f"Unsupported source type: {source_type}")
462
+
463
+ # Extract patterns and common values
464
+ self._extract_knowledge(df)
465
+
466
+ # Store source data
467
+ self.source_data[source_path] = df.to_dict('records')
468
+
469
+ except Exception as e:
470
+ print(f"Error loading source {source_path}: {str(e)}")
471
+
472
+ def _extract_knowledge(self, df: pd.DataFrame) -> None:
473
+ """Extract patterns and common values from dataframe"""
474
+ for column in df.columns:
475
+ if 'material' in column.lower():
476
+ values = df[column].dropna().unique()
477
+ self.materials.update(v.title() for v in values if isinstance(v, str))
478
+ elif 'color' in column.lower():
479
+ values = df[column].dropna().unique()
480
+ self.colors.update(v.title() for v in values if isinstance(v, str))
481
+
482
+ # Store column patterns
483
+ if df[column].dtype == 'object':
484
+ patterns = df[column].dropna().astype(str).tolist()
485
+ self.patterns[column] = patterns
486
+
487
+ def get_closest_match(self, value: str, field_type: str) -> Optional[str]:
488
+ """Find closest match from known values"""
489
+ if field_type == 'material':
490
+ matches = get_close_matches(value.title(), list(self.materials), n=1, cutoff=0.8)
491
+ elif field_type == 'color':
492
+ matches = get_close_matches(value.title(), list(self.colors), n=1, cutoff=0.8)
493
+ else:
494
+ return None
495
+ return matches[0] if matches else None
496
+
497
+ # Initialize knowledge base
498
+ knowledge_base = KnowledgeBase()
499
+
500
+ def refine_data_generic(dataset: List[Dict[str, Any]], mode: str = 'sourceless', kb: Optional[KnowledgeBase] = None) -> List[Dict[str, Any]]:
501
  """
502
+ Enhanced universal dataset refinement with source-aware and sourceless modes.
503
+ Args:
504
+ dataset: List of dictionary records
505
+ mode: 'sourceless' or 'sourced'
506
+ knowledge_base: Optional reference data for sourced mode
507
  """
508
+ def split_compound_field(field: str) -> List[str]:
509
+ """Split compound fields like materialwear into separate values"""
510
+ parts = re.split(r'[,;\n]+', field)
511
+ parts = [part.strip().title() for part in parts if part.strip()]
512
+ return list(set(parts)) # Remove duplicates
513
+
514
+ def normalize_value(value: Any, field_name: str) -> Any:
515
+ """Smart value normalization with field context"""
516
+ if not isinstance(value, str):
517
+ return value
518
+
519
+ # Basic cleanup
520
+ value = re.sub(r'\s+', ' ', value.strip())
521
+ value = value.replace('_', ' ')
522
+
523
+ # Field-specific processing with knowledge base
524
+ if any(term in field_name.lower() for term in ['material']):
525
+ parts = split_compound_field(value)
526
+ if mode == 'sourced' and kb:
527
+ known = [kb.get_closest_match(p, 'material') or p.title() for p in parts]
528
+ else:
529
+ known = [m for m in parts if m in kb.materials] if kb else parts
530
+ if known:
531
+ return known[0] if len(known) == 1 else known
532
+ return value.title()
533
+
534
+ if any(term in field_name.lower() for term in ['color']):
535
+ parts = split_compound_field(value)
536
+ if mode == 'sourced' and kb:
537
+ known = [kb.get_closest_match(p, 'color') or p.title() for p in parts]
538
+ else:
539
+ known = [c for c in parts if c in kb.colors] if kb else parts
540
+ if known:
541
+ return known[0] if len(known) == 1 else known
542
+ return value.title()
543
+
544
+ if any(term in field_name.lower() for term in ['date', 'time']):
545
+ # Add date normalization logic here
546
+ return value
547
+
548
+ # Default titlecase for descriptive fields
549
+ if any(term in field_name.lower() for term in ['type', 'status', 'category', 'description']):
550
+ return value.title()
551
+
552
  return value
553
 
554
+ def clean_record(record: Dict[str, Any]) -> Dict[str, Any]:
555
+ """Enhanced record cleaning with compound field detection"""
556
  cleaned = {}
557
+ compound_fields = {}
558
+
559
+ # First pass: Basic cleaning and compound field detection
560
  for key, value in record.items():
 
561
  clean_key = key.strip().lower().replace(" ", "_")
562
+
563
+ # Handle compound fields (e.g., materialwear)
564
+ if isinstance(value, str):
565
+ for material in COMMON_MATERIALS:
566
+ if material.lower() in value.lower():
567
+ compound_fields[clean_key] = value
568
+ break
569
+
570
  if isinstance(value, list):
571
+ cleaned[clean_key] = [normalize_value(v, clean_key) for v in value]
572
  elif isinstance(value, dict):
573
  cleaned[clean_key] = clean_record(value)
574
  else:
575
+ cleaned[clean_key] = normalize_value(value, clean_key)
576
+
577
+ # Second pass: Split compound fields
578
+ for key, value in compound_fields.items():
579
+ parts = split_compound_field(value)
580
+ materials = [p for p in parts if p in COMMON_MATERIALS]
581
+ if materials:
582
+ cleaned['material'] = materials[0] if len(materials) == 1 else materials
583
+ # Store remaining info in wear/condition field
584
+ remaining = [p for p in parts if p not in materials]
585
+ if remaining:
586
+ cleaned['condition'] = ' '.join(remaining)
587
+
588
  return cleaned
589
 
590
+ # Use knowledge base patterns in sourced mode
591
+ if mode == 'sourced' and kb and kb.patterns:
592
+ for record in dataset:
593
+ for field, patterns in kb.patterns.items():
594
+ if field in record:
595
+ value = str(record[field])
596
+ matches = get_close_matches(value, patterns, n=1, cutoff=0.8)
597
+ if matches:
598
+ record[field] = matches[0]
599
+
600
+
601
  return [clean_record(entry) for entry in dataset]
602
 
603
+ def refine_preview_data(df: pd.DataFrame, mode: str = 'sourceless') -> pd.DataFrame:
604
+ """Refine preview data with the selected mode"""
605
+ # Remove dummy "id" columns first
606
+ for column_name, values in df.to_dict(orient="series").items():
607
+ try:
608
+ if [int(v) for v in values] == list(range(len(df))):
609
+ df = df.drop(columns=column_name)
610
+ if [int(v) for v in values] == list(range(1, len(df) + 1)):
611
+ df = df.drop(columns=column_name)
612
+ except Exception:
613
+ pass
614
+
615
+ # Convert to records for refinement
616
+ records = df.to_dict('records')
617
+
618
+ # Apply refinement with current mode and knowledge base
619
+ refined_records = refine_data_generic(records, mode=mode, kb=knowledge_base)
620
+
621
+ # Convert back to DataFrame
622
+ refined_df = pd.DataFrame(refined_records)
623
+
624
+ return refined_df
625
+
626
  def detect_anomalies(record: Dict[str, Any]) -> List[str]:
627
  """
628
  Detect potential anomalies in a record.
 
804
  return gr.Column(visible=True), gr.Column(visible=False)
805
 
806
 
807
+ @generate_full_dataset_button.click(
808
+ inputs=[dataset_title, dataset_content, search_bar, select_namespace_dropdown, visibility_radio, refinement_mode],
809
+ outputs=[dataset_dataframe, generate_full_dataset_button, save_dataset_button]
810
+ )
811
+ def generate_full_dataset(title, content, search_query, namespace, visibility, mode):
812
  dataset_name, tags = title.strip("# ").split("\ntags:", 1)
813
  dataset_name, tags = dataset_name.strip(), tags.strip()
814
  csv_header, preview_df = parse_preview_df(content)
815
+ # Clean and refine the preview data
816
+ preview_df = refine_preview_data(preview_df, mode)
 
 
 
 
 
 
 
817
  columns = list(preview_df)
818
  output: list[Optional[dict]] = [None] * NUM_ROWS
819
  output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
 
867
  visible=True,
868
  )
869
 
870
+ @refinement_mode.change(outputs=[source_group])
871
+ def toggle_source_group(mode):
872
+ return gr.Group(visible=(mode == "sourced"))
873
+
874
+ @load_source_button.click(inputs=[source_type, source_path], outputs=[source_status])
875
+ def load_knowledge_source(source_type, source_path):
876
+ try:
877
+ knowledge_base.load_source(source_type, source_path)
878
+ return gr.Markdown("✅ Source loaded successfully", visible=True)
879
+ except Exception as e:
880
+ return gr.Markdown(f"❌ Error loading source: {str(e)}", visible=True)
881
+
882
+ @demo.load(outputs=show_dataset_outputs + button_groups + buttons + [generated_texts_state] + [select_namespace_dropdown, visibility_radio, source_group])
883
  def load_app(request: gr.Request, oauth_token: Optional[gr.OAuthToken]):
884
  if oauth_token:
885
  user_info = whoami(oauth_token.token)