diff --git a/src/gui/app.py b/src/gui/app.py index b2b0b1d..b936b0e 100644 --- a/src/gui/app.py +++ b/src/gui/app.py @@ -202,16 +202,14 @@ if uploaded is not None: def _accept_all(): for g in result.match_groups: st.session_state["review_decisions"][g.group_id] = { - "action": True, - "survivor_idx": g.survivor_index, + "keep_indices": [g.survivor_index], "overrides": {}, } def _reject_all(): for g in result.match_groups: st.session_state["review_decisions"][g.group_id] = { - "action": False, - "survivor_idx": g.survivor_index, + "keep_indices": list(g.row_indices), "overrides": {}, } @@ -234,27 +232,46 @@ if uploaded is not None: # Show decision summary if decisions: st.divider() - accepted = sum( - 1 for v in decisions.values() - if isinstance(v, dict) and v.get("action") is True - ) - customized = sum( - 1 for v in decisions.values() - if isinstance(v, dict) and v.get("action") is True - and v.get("overrides") - ) - rejected = sum( - 1 for v in decisions.values() - if isinstance(v, dict) and v.get("action") is False - ) - pending = len(result.match_groups) - len(decisions) + merged = 0 + customized = 0 + split = 0 + kept_all = 0 + for v in decisions.values(): + if not isinstance(v, dict): + continue + ki = v.get("keep_indices", []) + # Find the matching group size + gid_for_v = next( + (gid for gid, d in decisions.items() if d is v), + None, + ) + group_size = next( + (len(g.row_indices) for g in result.match_groups + if g.group_id == gid_for_v), + 0, + ) + if len(ki) == group_size: + kept_all += 1 + elif len(ki) == 1: + if v.get("overrides"): + customized += 1 + else: + merged += 1 + else: + split += 1 - summary_parts = [f"{accepted} merged"] + pending = len(result.match_groups) - len(decisions) + parts = [] + if merged: + parts.append(f"{merged} merged") if customized: - summary_parts.append(f"{customized} customized") - summary_parts.append(f"{rejected} kept both") - summary_parts.append(f"{pending} pending") - st.caption("Decisions: " + ", ".join(summary_parts)) + parts.append(f"{customized} customized") + if split: + parts.append(f"{split} split") + if kept_all: + parts.append(f"{kept_all} kept all") + parts.append(f"{pending} pending") + st.caption("Decisions: " + ", ".join(parts)) # Apply decisions and offer download if st.button( diff --git a/src/gui/components.py b/src/gui/components.py index f962335..d319282 100644 --- a/src/gui/components.py +++ b/src/gui/components.py @@ -279,11 +279,12 @@ def match_group_card( ) -> None: """Render an expandable match group card with side-by-side diff. - Users can pick which row to keep and cherry-pick column values from - other rows. Decisions are stored in - ``st.session_state["review_decisions"]`` as dicts:: + Users select which rows to keep via checkboxes. When exactly one row + is kept they can also cherry-pick column values from the other rows. - {group_id: {"action": bool, "survivor_idx": int, "overrides": {col: val}}} + Decision format stored in ``st.session_state["review_decisions"]``:: + + {group_id: {"keep_indices": [int, ...], "overrides": {col: val}}} """ confidence = group.confidence matched_on = ", ".join(group.matched_on) @@ -293,7 +294,7 @@ def match_group_card( decisions = st.session_state.get("review_decisions", {}) has_decision = gid in decisions decision_dict = decisions.get(gid, {}) - action = decision_dict.get("action") if has_decision else None + keep_indices = decision_dict.get("keep_indices", []) if has_decision else [] overrides = decision_dict.get("overrides", {}) if has_decision else {} # Build label — append decision status if already decided @@ -302,12 +303,13 @@ def match_group_card( f"(confidence: {confidence:.0f}%) " f"[{matched_on}]" ) - if action is True and overrides: - label += " — Merged (customized)" - elif action is True: - label += " — Merged" - elif action is False: - label += " — Kept Both" + if has_decision: + if len(keep_indices) == n_rows: + label += " — Kept All" + elif len(keep_indices) == 1: + label += " — Merged (customized)" if overrides else " — Merged" + else: + label += f" — Split (kept {len(keep_indices)} of {n_rows})" # Decided groups collapse; undecided groups stay open expanded = not has_decision @@ -346,55 +348,58 @@ def match_group_card( st.dataframe(styled, use_container_width=True) if has_decision: - # Show current decision with option to undo - if action is True: + # --- Decided state: show summary + undo --- + if len(keep_indices) == n_rows: + st.info("Decision: Kept All") + elif len(keep_indices) == 1: msg = "Decision: Merge" if overrides: msg += f" ({len(overrides)} column(s) customized)" st.success(msg) else: - st.info("Decision: Keep Both") + kept = ", ".join(str(i + 1) for i in sorted(keep_indices)) + st.success( + f"Decision: Keep rows {kept} " + f"(removing {n_rows - len(keep_indices)})" + ) - def _undo(g=gid, diff=differing_cols): + def _undo(g=gid, indices=group.row_indices, diff=differing_cols): st.session_state["review_decisions"].pop(g, None) - st.session_state.pop(f"base_row_{g}", None) st.session_state.pop(f"customize_{g}", None) + for idx in indices: + st.session_state.pop(f"keep_{g}_{idx}", None) for c in diff: st.session_state.pop(f"col_{g}_{c}", None) st.button("Undo", key=f"undo_{gid}", on_click=_undo) else: - # --- Base row selector --- - default_base = ( - group.row_indices.index(group.survivor_index) - if group.survivor_index in group.row_indices - else 0 - ) + # --- Row selection checkboxes --- + st.caption("Select rows to keep:") + chk_cols = st.columns(n_rows) + for i, idx in enumerate(group.row_indices): + with chk_cols[i]: + st.checkbox( + f"Row {idx + 1}", + value=True, + key=f"keep_{gid}_{idx}", + ) - def _on_base_change(g=gid, diff=differing_cols): - """Reset column pickers when the base row changes.""" - for c in diff: - st.session_state.pop(f"col_{g}_{c}", None) + # Read current checkbox state + checked = [ + idx for idx in group.row_indices + if st.session_state.get(f"keep_{gid}_{idx}", True) + ] - selected_survivor = st.radio( - "Base row (keep)", - options=group.row_indices, - index=default_base, - format_func=lambda idx: f"Row {idx + 1}", - key=f"base_row_{gid}", - horizontal=True, - on_change=_on_base_change, - ) - - # --- Customize columns (progressive disclosure) --- - if differing_cols: + # --- Customize columns (only when exactly 1 row kept) --- + if len(checked) == 1 and differing_cols: customize = st.checkbox( f"Customize columns ({len(differing_cols)} differ)", key=f"customize_{gid}", value=False, ) if customize: - base_pos = group.row_indices.index(selected_survivor) + survivor_idx = checked[0] + base_pos = group.row_indices.index(survivor_idx) st.caption("Pick which row's value to use for each column:") for col in differing_cols: def _fmt(idx: int, c: str = col) -> str: @@ -411,43 +416,57 @@ def match_group_card( key=f"col_{gid}_{col}", ) - # --- Action buttons --- - def _on_merge( + # --- Status caption --- + if len(checked) == 0: + st.warning("Select at least one row to keep.") + elif len(checked) == n_rows: + st.caption("Keeping all rows (no duplicates removed from this group)") + elif len(checked) == 1: + st.caption(f"Will merge into Row {checked[0] + 1}, " + f"removing {n_rows - 1} row(s)") + else: + removed = n_rows - len(checked) + st.caption(f"Will keep {len(checked)} rows, " + f"removing {removed}") + + # --- Confirm button --- + def _on_confirm( g=gid, indices=group.row_indices, diff=differing_cols, ): - the_df = st.session_state["df"] - base_idx = st.session_state.get(f"base_row_{g}", indices[0]) + keep = [ + idx for idx in indices + if st.session_state.get(f"keep_{g}_{idx}", True) + ] + # Safety: never remove all rows + if not keep: + keep = list(indices) + ovr: dict[str, str] = {} - for c in diff: - col_key = f"col_{g}_{c}" - if col_key in st.session_state: - source_idx = st.session_state[col_key] - if source_idx != base_idx: - ovr[c] = str(the_df.iloc[source_idx].get(c, "")) + # Column overrides only apply for single-survivor merge + if len(keep) == 1: + the_df = st.session_state["df"] + base_idx = keep[0] + for c in diff: + col_key = f"col_{g}_{c}" + if col_key in st.session_state: + source_idx = st.session_state[col_key] + if source_idx != base_idx: + ovr[c] = str( + the_df.iloc[source_idx].get(c, "") + ) + st.session_state["review_decisions"][g] = { - "action": True, - "survivor_idx": base_idx, + "keep_indices": keep, "overrides": ovr, } - def _on_keep(g=gid): - st.session_state["review_decisions"][g] = { - "action": False, - "survivor_idx": group.survivor_index, - "overrides": {}, - } - - btn_left, btn_mid, _btn_right = st.columns(3) - with btn_left: - st.button( - "Merge", key=f"merge_{gid}", - type="primary", on_click=_on_merge, - ) - with btn_mid: - st.button( - "Keep Both", key=f"keep_{gid}", - on_click=_on_keep, - ) + st.button( + "Confirm", + key=f"confirm_{gid}", + type="primary", + on_click=_on_confirm, + disabled=(len(checked) == 0), + ) # --------------------------------------------------------------------------- @@ -510,8 +529,12 @@ def apply_review_decisions( ) -> tuple[pd.DataFrame, pd.DataFrame]: """Build final DataFrames by applying user review decisions. - Handles per-group survivor selection and column overrides without - re-running the deduplication engine. + Supports three modes per group: + + - **Merge** (1 row kept): single survivor with optional column overrides. + - **Split** (some rows kept): selected rows survive, others removed. + - **Keep all** (all rows kept): no rows removed. + - **No decision**: engine default (single survivor). Returns ``(deduplicated_df, removed_df)``. """ @@ -524,33 +547,28 @@ def apply_review_decisions( # No decision yet — accept with engine defaults if decision is None: - survivor_idx = group.survivor_index - for idx in group.row_indices: - if idx != survivor_idx: - remove_indices.add(idx) - continue - - # Keep both — skip this group entirely - if not decision.get("action", True): - continue - - # Merge with user's choices - survivor_idx = decision.get("survivor_idx", group.survivor_index) - ovr = decision.get("overrides", {}) + keep = {group.survivor_index} + else: + keep = set(decision.get("keep_indices", group.row_indices)) + # Safety: never remove all rows in a group + if not keep: + keep = set(group.row_indices) for idx in group.row_indices: - if idx != survivor_idx: + if idx not in keep: remove_indices.add(idx) - if ovr: - row_overrides[survivor_idx] = ovr + # Column overrides (only meaningful for single-survivor merge) + ovr = decision.get("overrides", {}) if decision else {} + if ovr and len(keep) == 1: + row_overrides[next(iter(keep))] = ovr # Build output DataFrames - keep_indices = [i for i in range(len(original_df)) if i not in remove_indices] + kept = [i for i in range(len(original_df)) if i not in remove_indices] if row_overrides: rows = [] - for i in keep_indices: + for i in kept: row = original_df.iloc[i].copy() if i in row_overrides: for col, val in row_overrides[i].items(): @@ -559,7 +577,7 @@ def apply_review_decisions( rows.append(row) deduped = pd.DataFrame(rows).reset_index(drop=True) else: - deduped = original_df.iloc[keep_indices].copy().reset_index(drop=True) + deduped = original_df.iloc[kept].copy().reset_index(drop=True) removed = ( original_df.iloc[sorted(remove_indices)].copy().reset_index(drop=True)