feat: multi-row survivor support in match group review

Replace radio + Merge/Keep Both buttons with per-row checkboxes
and a single Confirm button. Users can now:

- Keep all rows (not duplicates) — check all, confirm
- Merge to one row — uncheck all but one, optionally customize columns
- Split a group — keep some rows, remove others (new capability)

Decision format changed from {action, survivor_idx, overrides} to
{keep_indices, overrides}. apply_review_decisions() updated to handle
all three modes. Batch actions updated accordingly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-28 23:52:45 +00:00
parent debb0cb516
commit 863fe89f2c
2 changed files with 148 additions and 113 deletions

View File

@@ -202,16 +202,14 @@ if uploaded is not None:
def _accept_all(): def _accept_all():
for g in result.match_groups: for g in result.match_groups:
st.session_state["review_decisions"][g.group_id] = { st.session_state["review_decisions"][g.group_id] = {
"action": True, "keep_indices": [g.survivor_index],
"survivor_idx": g.survivor_index,
"overrides": {}, "overrides": {},
} }
def _reject_all(): def _reject_all():
for g in result.match_groups: for g in result.match_groups:
st.session_state["review_decisions"][g.group_id] = { st.session_state["review_decisions"][g.group_id] = {
"action": False, "keep_indices": list(g.row_indices),
"survivor_idx": g.survivor_index,
"overrides": {}, "overrides": {},
} }
@@ -234,27 +232,46 @@ if uploaded is not None:
# Show decision summary # Show decision summary
if decisions: if decisions:
st.divider() st.divider()
accepted = sum( merged = 0
1 for v in decisions.values() customized = 0
if isinstance(v, dict) and v.get("action") is True split = 0
kept_all = 0
for v in decisions.values():
if not isinstance(v, dict):
continue
ki = v.get("keep_indices", [])
# Find the matching group size
gid_for_v = next(
(gid for gid, d in decisions.items() if d is v),
None,
) )
customized = sum( group_size = next(
1 for v in decisions.values() (len(g.row_indices) for g in result.match_groups
if isinstance(v, dict) and v.get("action") is True if g.group_id == gid_for_v),
and v.get("overrides") 0,
) )
rejected = sum( if len(ki) == group_size:
1 for v in decisions.values() kept_all += 1
if isinstance(v, dict) and v.get("action") is False elif len(ki) == 1:
) if v.get("overrides"):
pending = len(result.match_groups) - len(decisions) customized += 1
else:
merged += 1
else:
split += 1
summary_parts = [f"{accepted} merged"] pending = len(result.match_groups) - len(decisions)
parts = []
if merged:
parts.append(f"{merged} merged")
if customized: if customized:
summary_parts.append(f"{customized} customized") parts.append(f"{customized} customized")
summary_parts.append(f"{rejected} kept both") if split:
summary_parts.append(f"{pending} pending") parts.append(f"{split} split")
st.caption("Decisions: " + ", ".join(summary_parts)) if kept_all:
parts.append(f"{kept_all} kept all")
parts.append(f"{pending} pending")
st.caption("Decisions: " + ", ".join(parts))
# Apply decisions and offer download # Apply decisions and offer download
if st.button( if st.button(

View File

@@ -279,11 +279,12 @@ def match_group_card(
) -> None: ) -> None:
"""Render an expandable match group card with side-by-side diff. """Render an expandable match group card with side-by-side diff.
Users can pick which row to keep and cherry-pick column values from Users select which rows to keep via checkboxes. When exactly one row
other rows. Decisions are stored in is kept they can also cherry-pick column values from the other rows.
``st.session_state["review_decisions"]`` as dicts::
{group_id: {"action": bool, "survivor_idx": int, "overrides": {col: val}}} Decision format stored in ``st.session_state["review_decisions"]``::
{group_id: {"keep_indices": [int, ...], "overrides": {col: val}}}
""" """
confidence = group.confidence confidence = group.confidence
matched_on = ", ".join(group.matched_on) matched_on = ", ".join(group.matched_on)
@@ -293,7 +294,7 @@ def match_group_card(
decisions = st.session_state.get("review_decisions", {}) decisions = st.session_state.get("review_decisions", {})
has_decision = gid in decisions has_decision = gid in decisions
decision_dict = decisions.get(gid, {}) decision_dict = decisions.get(gid, {})
action = decision_dict.get("action") if has_decision else None keep_indices = decision_dict.get("keep_indices", []) if has_decision else []
overrides = decision_dict.get("overrides", {}) if has_decision else {} overrides = decision_dict.get("overrides", {}) if has_decision else {}
# Build label — append decision status if already decided # Build label — append decision status if already decided
@@ -302,12 +303,13 @@ def match_group_card(
f"(confidence: {confidence:.0f}%) " f"(confidence: {confidence:.0f}%) "
f"[{matched_on}]" f"[{matched_on}]"
) )
if action is True and overrides: if has_decision:
label += " — Merged (customized)" if len(keep_indices) == n_rows:
elif action is True: label += " — Kept All"
label += " — Merged" elif len(keep_indices) == 1:
elif action is False: label += " — Merged (customized)" if overrides else " — Merged"
label += " — Kept Both" else:
label += f" — Split (kept {len(keep_indices)} of {n_rows})"
# Decided groups collapse; undecided groups stay open # Decided groups collapse; undecided groups stay open
expanded = not has_decision expanded = not has_decision
@@ -346,55 +348,58 @@ def match_group_card(
st.dataframe(styled, use_container_width=True) st.dataframe(styled, use_container_width=True)
if has_decision: if has_decision:
# Show current decision with option to undo # --- Decided state: show summary + undo ---
if action is True: if len(keep_indices) == n_rows:
st.info("Decision: Kept All")
elif len(keep_indices) == 1:
msg = "Decision: Merge" msg = "Decision: Merge"
if overrides: if overrides:
msg += f" ({len(overrides)} column(s) customized)" msg += f" ({len(overrides)} column(s) customized)"
st.success(msg) st.success(msg)
else: else:
st.info("Decision: Keep Both") kept = ", ".join(str(i + 1) for i in sorted(keep_indices))
st.success(
f"Decision: Keep rows {kept} "
f"(removing {n_rows - len(keep_indices)})"
)
def _undo(g=gid, diff=differing_cols): def _undo(g=gid, indices=group.row_indices, diff=differing_cols):
st.session_state["review_decisions"].pop(g, None) st.session_state["review_decisions"].pop(g, None)
st.session_state.pop(f"base_row_{g}", None)
st.session_state.pop(f"customize_{g}", None) st.session_state.pop(f"customize_{g}", None)
for idx in indices:
st.session_state.pop(f"keep_{g}_{idx}", None)
for c in diff: for c in diff:
st.session_state.pop(f"col_{g}_{c}", None) st.session_state.pop(f"col_{g}_{c}", None)
st.button("Undo", key=f"undo_{gid}", on_click=_undo) st.button("Undo", key=f"undo_{gid}", on_click=_undo)
else: else:
# --- Base row selector --- # --- Row selection checkboxes ---
default_base = ( st.caption("Select rows to keep:")
group.row_indices.index(group.survivor_index) chk_cols = st.columns(n_rows)
if group.survivor_index in group.row_indices for i, idx in enumerate(group.row_indices):
else 0 with chk_cols[i]:
st.checkbox(
f"Row {idx + 1}",
value=True,
key=f"keep_{gid}_{idx}",
) )
def _on_base_change(g=gid, diff=differing_cols): # Read current checkbox state
"""Reset column pickers when the base row changes.""" checked = [
for c in diff: idx for idx in group.row_indices
st.session_state.pop(f"col_{g}_{c}", None) if st.session_state.get(f"keep_{gid}_{idx}", True)
]
selected_survivor = st.radio( # --- Customize columns (only when exactly 1 row kept) ---
"Base row (keep)", if len(checked) == 1 and differing_cols:
options=group.row_indices,
index=default_base,
format_func=lambda idx: f"Row {idx + 1}",
key=f"base_row_{gid}",
horizontal=True,
on_change=_on_base_change,
)
# --- Customize columns (progressive disclosure) ---
if differing_cols:
customize = st.checkbox( customize = st.checkbox(
f"Customize columns ({len(differing_cols)} differ)", f"Customize columns ({len(differing_cols)} differ)",
key=f"customize_{gid}", key=f"customize_{gid}",
value=False, value=False,
) )
if customize: if customize:
base_pos = group.row_indices.index(selected_survivor) survivor_idx = checked[0]
base_pos = group.row_indices.index(survivor_idx)
st.caption("Pick which row's value to use for each column:") st.caption("Pick which row's value to use for each column:")
for col in differing_cols: for col in differing_cols:
def _fmt(idx: int, c: str = col) -> str: def _fmt(idx: int, c: str = col) -> str:
@@ -411,42 +416,56 @@ def match_group_card(
key=f"col_{gid}_{col}", key=f"col_{gid}_{col}",
) )
# --- Action buttons --- # --- Status caption ---
def _on_merge( if len(checked) == 0:
st.warning("Select at least one row to keep.")
elif len(checked) == n_rows:
st.caption("Keeping all rows (no duplicates removed from this group)")
elif len(checked) == 1:
st.caption(f"Will merge into Row {checked[0] + 1}, "
f"removing {n_rows - 1} row(s)")
else:
removed = n_rows - len(checked)
st.caption(f"Will keep {len(checked)} rows, "
f"removing {removed}")
# --- Confirm button ---
def _on_confirm(
g=gid, indices=group.row_indices, diff=differing_cols, g=gid, indices=group.row_indices, diff=differing_cols,
): ):
the_df = st.session_state["df"] keep = [
base_idx = st.session_state.get(f"base_row_{g}", indices[0]) idx for idx in indices
if st.session_state.get(f"keep_{g}_{idx}", True)
]
# Safety: never remove all rows
if not keep:
keep = list(indices)
ovr: dict[str, str] = {} ovr: dict[str, str] = {}
# Column overrides only apply for single-survivor merge
if len(keep) == 1:
the_df = st.session_state["df"]
base_idx = keep[0]
for c in diff: for c in diff:
col_key = f"col_{g}_{c}" col_key = f"col_{g}_{c}"
if col_key in st.session_state: if col_key in st.session_state:
source_idx = st.session_state[col_key] source_idx = st.session_state[col_key]
if source_idx != base_idx: if source_idx != base_idx:
ovr[c] = str(the_df.iloc[source_idx].get(c, "")) ovr[c] = str(
the_df.iloc[source_idx].get(c, "")
)
st.session_state["review_decisions"][g] = { st.session_state["review_decisions"][g] = {
"action": True, "keep_indices": keep,
"survivor_idx": base_idx,
"overrides": ovr, "overrides": ovr,
} }
def _on_keep(g=gid):
st.session_state["review_decisions"][g] = {
"action": False,
"survivor_idx": group.survivor_index,
"overrides": {},
}
btn_left, btn_mid, _btn_right = st.columns(3)
with btn_left:
st.button( st.button(
"Merge", key=f"merge_{gid}", "Confirm",
type="primary", on_click=_on_merge, key=f"confirm_{gid}",
) type="primary",
with btn_mid: on_click=_on_confirm,
st.button( disabled=(len(checked) == 0),
"Keep Both", key=f"keep_{gid}",
on_click=_on_keep,
) )
@@ -510,8 +529,12 @@ def apply_review_decisions(
) -> tuple[pd.DataFrame, pd.DataFrame]: ) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Build final DataFrames by applying user review decisions. """Build final DataFrames by applying user review decisions.
Handles per-group survivor selection and column overrides without Supports three modes per group:
re-running the deduplication engine.
- **Merge** (1 row kept): single survivor with optional column overrides.
- **Split** (some rows kept): selected rows survive, others removed.
- **Keep all** (all rows kept): no rows removed.
- **No decision**: engine default (single survivor).
Returns ``(deduplicated_df, removed_df)``. Returns ``(deduplicated_df, removed_df)``.
""" """
@@ -524,33 +547,28 @@ def apply_review_decisions(
# No decision yet — accept with engine defaults # No decision yet — accept with engine defaults
if decision is None: if decision is None:
survivor_idx = group.survivor_index keep = {group.survivor_index}
for idx in group.row_indices: else:
if idx != survivor_idx: keep = set(decision.get("keep_indices", group.row_indices))
remove_indices.add(idx) # Safety: never remove all rows in a group
continue if not keep:
keep = set(group.row_indices)
# Keep both — skip this group entirely
if not decision.get("action", True):
continue
# Merge with user's choices
survivor_idx = decision.get("survivor_idx", group.survivor_index)
ovr = decision.get("overrides", {})
for idx in group.row_indices: for idx in group.row_indices:
if idx != survivor_idx: if idx not in keep:
remove_indices.add(idx) remove_indices.add(idx)
if ovr: # Column overrides (only meaningful for single-survivor merge)
row_overrides[survivor_idx] = ovr ovr = decision.get("overrides", {}) if decision else {}
if ovr and len(keep) == 1:
row_overrides[next(iter(keep))] = ovr
# Build output DataFrames # Build output DataFrames
keep_indices = [i for i in range(len(original_df)) if i not in remove_indices] kept = [i for i in range(len(original_df)) if i not in remove_indices]
if row_overrides: if row_overrides:
rows = [] rows = []
for i in keep_indices: for i in kept:
row = original_df.iloc[i].copy() row = original_df.iloc[i].copy()
if i in row_overrides: if i in row_overrides:
for col, val in row_overrides[i].items(): for col, val in row_overrides[i].items():
@@ -559,7 +577,7 @@ def apply_review_decisions(
rows.append(row) rows.append(row)
deduped = pd.DataFrame(rows).reset_index(drop=True) deduped = pd.DataFrame(rows).reset_index(drop=True)
else: else:
deduped = original_df.iloc[keep_indices].copy().reset_index(drop=True) deduped = original_df.iloc[kept].copy().reset_index(drop=True)
removed = ( removed = (
original_df.iloc[sorted(remove_indices)].copy().reset_index(drop=True) original_df.iloc[sorted(remove_indices)].copy().reset_index(drop=True)