"""Reusable Streamlit widgets for the DataTools GUI.""" from __future__ import annotations import io from typing import Optional import pandas as pd import streamlit as st from src.core.dedup import ( Algorithm, ColumnMatchStrategy, DeduplicationResult, MatchResult, MatchStrategy, SurvivorRule, ) from src.core.config import ( ColumnStrategyConfig, DeduplicationConfig, StrategyConfig, ) from src.core.normalizers import NormalizerType # --------------------------------------------------------------------------- # App chrome — hide Streamlit default UI for app-like feel # --------------------------------------------------------------------------- _HIDE_CHROME_CSS = """ """ def hide_streamlit_chrome() -> None: """Inject CSS to hide Streamlit's default header, menu, and footer.""" st.markdown(_HIDE_CHROME_CSS, unsafe_allow_html=True) # --------------------------------------------------------------------------- # Config panel (advanced options) # --------------------------------------------------------------------------- def config_panel(df: pd.DataFrame) -> dict: """Render the Advanced Options expander. Returns a settings dict. Keys returned: strategies: list[MatchStrategy] | None survivor_rule: SurvivorRule date_column: str | None merge: bool """ columns = list(df.columns) with st.expander("Advanced Options"): col_left, col_right = st.columns(2) with col_left: subset_cols = st.multiselect( "Match on columns", columns, default=[], help="Leave empty to auto-detect based on column names.", ) key_cols = st.multiselect( "Strong keys", columns, default=[], help="Columns that uniquely identify records (e.g., EIN, SKU). Each is an independent exact-match strategy.", ) fuzzy_cols = st.multiselect( "Fuzzy columns", columns, default=[], help="Columns to fuzzy-match. Others use exact matching.", ) with col_right: algorithm = st.selectbox( "Fuzzy algorithm", ["jaro_winkler", "levenshtein", "token_set_ratio"], index=0, help="jaro_winkler: best for names. levenshtein: best for typos. token_set_ratio: best for addresses.", ) threshold = st.slider( "Similarity threshold", min_value=50, max_value=100, value=85, help="Lower = more matches but more false positives.", ) survivor = st.selectbox( "Survivor rule", ["first", "last", "most-complete", "most-recent"], index=0, help="Which row to keep when duplicates are found.", ) # Second row of options col_a, col_b = st.columns(2) with col_a: normalize_options = {c: "auto" for c in columns} normalizer_types = ["auto", "email", "phone", "name", "address", "string", "none"] normalize_map: dict[str, str] = {} if fuzzy_cols or subset_cols: target_cols = fuzzy_cols or subset_cols st.markdown("**Per-column normalizers**") for col_name in target_cols: norm = st.selectbox( f"Normalizer for '{col_name}'", normalizer_types, index=0, key=f"norm_{col_name}", ) if norm not in ("auto", "none"): normalize_map[col_name] = norm with col_b: merge = st.checkbox( "Merge mode", value=False, help="Fill missing fields in the surviving row from removed duplicates.", ) date_column: Optional[str] = None if survivor == "most-recent": date_column = st.selectbox( "Date column", columns, help="Required for most-recent survivor rule.", ) # Config save/load st.divider() cfg_left, cfg_right = st.columns(2) with cfg_left: config_file = st.file_uploader( "Load config profile", type=["json"], help="Load previously saved settings.", key="config_upload", ) if config_file is not None: import json try: data = json.loads(config_file.read()) loaded = DeduplicationConfig.from_dict(data) st.session_state["loaded_config"] = loaded st.success("Config loaded.") except Exception as e: st.error(f"Failed to load config: {e}") with cfg_right: if st.button("Save current settings"): cfg = _build_config( subset_cols, key_cols, fuzzy_cols, algorithm, threshold, normalize_map, survivor, date_column, merge, ) cfg_json = cfg.to_dict() import json st.download_button( "Download config JSON", data=json.dumps(cfg_json, indent=2), file_name="dedup_config.json", mime="application/json", ) # Build strategies from selections strategies = _build_strategies( subset_cols, key_cols, fuzzy_cols, algorithm, threshold, normalize_map, ) # Survivor rule mapping survivor_map = { "first": SurvivorRule.KEEP_FIRST, "last": SurvivorRule.KEEP_LAST, "most-complete": SurvivorRule.KEEP_MOST_COMPLETE, "most-recent": SurvivorRule.KEEP_MOST_RECENT, } return { "strategies": strategies, "survivor_rule": survivor_map[survivor], "date_column": date_column, "merge": merge, } def _build_strategies( subset_cols: list[str], key_cols: list[str], fuzzy_cols: list[str], algorithm: str, threshold: int, normalize_map: dict[str, str], ) -> Optional[list[MatchStrategy]]: """Build MatchStrategy list from GUI selections. Returns None for auto-detect.""" strategies: list[MatchStrategy] = [] # If user selected columns explicitly, build from those if subset_cols or fuzzy_cols: target_cols = subset_cols if subset_cols else fuzzy_cols fuzzy_set = set(fuzzy_cols) col_strats: list[ColumnMatchStrategy] = [] for col in target_cols: norm = None if col in normalize_map: norm = NormalizerType(normalize_map[col]) if col in fuzzy_set: algo = Algorithm(algorithm) thresh = float(threshold) else: algo = Algorithm.EXACT thresh = 100.0 col_strats.append(ColumnMatchStrategy( column=col, algorithm=algo, threshold=thresh, normalizer=norm, )) strategies.append(MatchStrategy(column_strategies=col_strats)) # Add strong key strategies if key_cols: for col in key_cols: strategies.append(MatchStrategy(column_strategies=[ ColumnMatchStrategy(column=col, algorithm=Algorithm.EXACT, threshold=100.0) ])) return strategies if strategies else None def _build_config( subset_cols, key_cols, fuzzy_cols, algorithm, threshold, normalize_map, survivor, date_column, merge, ) -> DeduplicationConfig: """Build a DeduplicationConfig from GUI state.""" cfg = DeduplicationConfig( survivor_rule=survivor.replace("-", "_"), date_column=date_column, merge=merge, subset_columns=subset_cols or None, fuzzy_columns=fuzzy_cols or None, default_algorithm=algorithm, default_threshold=float(threshold), normalize_map=normalize_map or None, ) strategies = _build_strategies( subset_cols, key_cols, fuzzy_cols, algorithm, threshold, normalize_map, ) if strategies: cfg.strategies = [ StrategyConfig(columns=[ ColumnStrategyConfig( column=cs.column, algorithm=cs.algorithm.value, threshold=cs.threshold, normalizer=cs.normalizer.value if cs.normalizer else None, ) for cs in s.column_strategies ]) for s in strategies ] return cfg # --------------------------------------------------------------------------- # Match group review card # --------------------------------------------------------------------------- def _find_differing_cols( group: MatchResult, df: pd.DataFrame, display_cols: list[str], ) -> list[str]: """Return columns where values differ across rows in the group.""" differing = [] for col in display_cols: values = set() for idx in group.row_indices: values.add(str(df.iloc[idx].get(col, "")).strip()) if len(values) > 1: differing.append(col) return differing def match_group_card( group: MatchResult, df: pd.DataFrame, group_num: int, ) -> None: """Render an expandable match group card with side-by-side diff. Users select which rows to keep via checkboxes. When exactly one row is kept they can also cherry-pick column values from the other rows. Decision format stored in ``st.session_state["review_decisions"]``:: {group_id: {"keep_indices": [int, ...], "overrides": {col: val}}} """ confidence = group.confidence matched_on = ", ".join(group.matched_on) n_rows = len(group.row_indices) gid = group.group_id decisions = st.session_state.get("review_decisions", {}) has_decision = gid in decisions decision_dict = decisions.get(gid, {}) keep_indices = decision_dict.get("keep_indices", []) if has_decision else [] overrides = decision_dict.get("overrides", {}) if has_decision else {} # Build label — append decision status if already decided label = ( f"Group {group_num}: {n_rows} rows " f"(confidence: {confidence:.0f}%) " f"[{matched_on}]" ) if has_decision: if len(keep_indices) == n_rows: label += " — Kept All" elif len(keep_indices) == 1: label += " — Merged (customized)" if overrides else " — Merged" else: label += f" — Split (kept {len(keep_indices)} of {n_rows})" # Decided groups collapse; undecided groups stay open expanded = not has_decision display_cols = [c for c in df.columns if not str(c).startswith("_norm_")] differing_cols = _find_differing_cols(group, df, display_cols) with st.expander(label, expanded=expanded): if has_decision: # --- Decided state: read-only table with diff highlighting --- rows_data = [] for idx in group.row_indices: row = {"Row": idx + 1} for col in display_cols: row[col] = df.iloc[idx].get(col, "") rows_data.append(row) compare_df = pd.DataFrame(rows_data).set_index("Row") def _highlight_diffs(s: pd.Series) -> list[str]: styles = [] first_val = str(s.iloc[0]).strip() if len(s) > 0 else "" for val in s: val_str = str(val).strip() if val_str != first_val and val_str and first_val: styles.append( "background-color: rgba(245, 166, 35, 0.2)" ) elif not val_str and first_val: styles.append( "background-color: rgba(240, 82, 82, 0.1)" ) else: styles.append("") return styles styled = compare_df.style.apply(_highlight_diffs, axis=0) st.dataframe(styled, use_container_width=True) if len(keep_indices) == n_rows: st.info("Decision: Kept All") elif len(keep_indices) == 1: msg = "Decision: Merge" if overrides: msg += f" ({len(overrides)} column(s) customized)" st.success(msg) else: kept = ", ".join(str(i + 1) for i in sorted(keep_indices)) st.success( f"Decision: Keep rows {kept} " f"(removing {n_rows - len(keep_indices)})" ) def _undo(g=gid): st.session_state["review_decisions"].pop(g, None) st.session_state.pop(f"editor_{g}", None) st.button("Undo", key=f"undo_{gid}", on_click=_undo) else: # --- Undecided: interactive editor with inline checkboxes & dropdowns --- editor_rows = [] for idx in group.row_indices: row_data = {"Keep": idx == group.survivor_index, "Row": idx + 1} for col in display_cols: row_data[col] = str(df.iloc[idx].get(col, "")) editor_rows.append(row_data) editor_df = pd.DataFrame(editor_rows) col_config = { "Keep": st.column_config.CheckboxColumn( "Keep", default=True, width="small", ), "Row": st.column_config.NumberColumn("Row", width="small"), } for col in differing_cols: vals = [] for idx in group.row_indices: v = str(df.iloc[idx].get(col, "")).strip() if v not in vals: vals.append(v) if "" not in vals: vals.append("") col_config[col] = st.column_config.SelectboxColumn( col, options=vals, required=False, ) disabled_cols = ["Row"] + [ c for c in display_cols if c not in differing_cols ] edited = st.data_editor( editor_df, column_config=col_config, disabled=disabled_cols, use_container_width=True, hide_index=True, key=f"editor_{gid}", ) # Read which rows are checked checked = [ idx for i, idx in enumerate(group.row_indices) if edited.iloc[i]["Keep"] ] if differing_cols: st.caption( f"Columns with differences (editable): " f"{', '.join(differing_cols)}" ) # Status + surviving rows preview if len(checked) == 0: st.warning("Select at least one row to keep.") else: if len(checked) == n_rows: st.caption("Keeping all rows (no duplicates removed)") elif len(checked) == 1: st.caption( f"Merging into Row {checked[0] + 1}, " f"removing {n_rows - 1} row(s)" ) else: st.caption( f"Keeping {len(checked)} rows, " f"removing {n_rows - len(checked)}" ) # Build preview of surviving rows with edits applied checked_positions = [ i for i, idx in enumerate(group.row_indices) if idx in checked ] preview = edited.iloc[checked_positions].drop( columns=["Keep"], ).reset_index(drop=True) st.markdown("**Surviving rows preview:**") st.dataframe(preview, use_container_width=True, hide_index=True) # Confirm def _on_confirm( g=gid, indices=list(group.row_indices), diff=differing_cols, surv=group.survivor_index, ): editor_state = st.session_state.get(f"editor_{g}", {}) ed_rows = editor_state.get("edited_rows", {}) # Determine which rows to keep keep = [] for i, idx in enumerate(indices): changes = ed_rows.get(i, {}) default_keep = idx == surv if changes.get("Keep", default_keep): keep.append(idx) if not keep: keep = list(indices) # Column overrides (single-survivor merge only) ovr: dict[str, str] = {} if len(keep) == 1: surv_idx = keep[0] surv_pos = indices.index(surv_idx) surv_changes = ed_rows.get(surv_pos, {}) the_df = st.session_state["df"] for c in diff: if c in surv_changes: new_val = ( str(surv_changes[c]) if surv_changes[c] is not None else "" ) orig = str( the_df.iloc[surv_idx].get(c, "") ).strip() if new_val.strip() != orig: ovr[c] = new_val st.session_state["review_decisions"][g] = { "keep_indices": keep, "overrides": ovr, } st.button( "Confirm", key=f"confirm_{gid}", type="primary", on_click=_on_confirm, disabled=(len(checked) == 0), ) # --------------------------------------------------------------------------- # Results summary + downloads # --------------------------------------------------------------------------- def results_summary( result: DeduplicationResult, original_df: pd.DataFrame, ) -> None: """Render summary stats and download buttons.""" removed = result.original_row_count - len(result.deduplicated_df) # Summary metrics col1, col2, col3, col4 = st.columns(4) col1.metric("Rows In", result.original_row_count) col2.metric("Rows Out", len(result.deduplicated_df)) col3.metric("Removed", removed) col4.metric("Groups", len(result.match_groups)) st.divider() # Download buttons dl_left, dl_mid, dl_right = st.columns(3) with dl_left: csv_bytes = result.deduplicated_df.to_csv(index=False).encode("utf-8-sig") st.download_button( "Download Deduplicated CSV", data=csv_bytes, file_name="deduplicated.csv", mime="text/csv", ) with dl_mid: if not result.removed_df.empty: removed_bytes = result.removed_df.to_csv(index=False).encode("utf-8-sig") st.download_button( "Download Removed Rows", data=removed_bytes, file_name="removed_rows.csv", mime="text/csv", ) with dl_right: if result.match_groups: groups_data = _build_match_groups_csv(result, original_df) st.download_button( "Download Match Groups Report", data=groups_data, file_name="match_groups.csv", mime="text/csv", ) def apply_review_decisions( original_df: pd.DataFrame, match_groups: list[MatchResult], decisions: dict, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Build final DataFrames by applying user review decisions. Supports three modes per group: - **Merge** (1 row kept): single survivor with optional column overrides. - **Split** (some rows kept): selected rows survive, others removed. - **Keep all** (all rows kept): no rows removed. - **No decision**: engine default (single survivor). Returns ``(deduplicated_df, removed_df)``. """ remove_indices: set[int] = set() row_overrides: dict[int, dict[str, str]] = {} for group in match_groups: gid = group.group_id decision = decisions.get(gid) # No decision yet — accept with engine defaults if decision is None: keep = {group.survivor_index} else: keep = set(decision.get("keep_indices", group.row_indices)) # Safety: never remove all rows in a group if not keep: keep = set(group.row_indices) for idx in group.row_indices: if idx not in keep: remove_indices.add(idx) # Column overrides (only meaningful for single-survivor merge) ovr = decision.get("overrides", {}) if decision else {} if ovr and len(keep) == 1: row_overrides[next(iter(keep))] = ovr # Build output DataFrames kept = [i for i in range(len(original_df)) if i not in remove_indices] if row_overrides: rows = [] for i in kept: row = original_df.iloc[i].copy() if i in row_overrides: for col, val in row_overrides[i].items(): if col in row.index: row[col] = val rows.append(row) deduped = pd.DataFrame(rows).reset_index(drop=True) else: deduped = original_df.iloc[kept].copy().reset_index(drop=True) removed = ( original_df.iloc[sorted(remove_indices)].copy().reset_index(drop=True) if remove_indices else pd.DataFrame() ) return deduped, removed def _build_match_groups_csv( result: DeduplicationResult, original_df: pd.DataFrame, ) -> bytes: """Build the match groups audit CSV as bytes.""" rows = [] for g in result.match_groups: for idx in g.row_indices: row_data = { "_group_id": g.group_id + 1, "_is_survivor": idx == g.survivor_index, "_confidence": g.confidence, "_matched_on": ", ".join(g.matched_on), "_original_row": idx + 1, } for col in original_df.columns: if not str(col).startswith("_norm_"): row_data[col] = original_df.iloc[idx].get(col, "") if idx < len(original_df) else "" rows.append(row_data) groups_df = pd.DataFrame(rows) return groups_df.to_csv(index=False).encode("utf-8-sig")