250 lines
8.9 KiB
Python
250 lines
8.9 KiB
Python
from __future__ import annotations
|
|
"""
|
|
BaseRepository: generic CRUD operations for all entities.
|
|
Uses raw SQL via SQLAlchemy text() - no ORM models needed.
|
|
Every method automatically filters is_deleted=false unless specified.
|
|
"""
|
|
|
|
import re
|
|
from uuid import UUID
|
|
from datetime import date, datetime, timezone
|
|
from typing import Any
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
_ISO_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
|
|
_ISO_DATETIME_RE = re.compile(r"^\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}")
|
|
|
|
|
|
def _coerce_value(value: Any) -> Any:
|
|
"""Convert ISO date/datetime strings to Python date/datetime objects.
|
|
asyncpg requires native Python types, not strings, for date columns."""
|
|
if not isinstance(value, str):
|
|
return value
|
|
if _ISO_DATE_RE.match(value):
|
|
try:
|
|
return date.fromisoformat(value)
|
|
except ValueError:
|
|
pass
|
|
if _ISO_DATETIME_RE.match(value):
|
|
try:
|
|
return datetime.fromisoformat(value)
|
|
except ValueError:
|
|
pass
|
|
return value
|
|
|
|
|
|
class BaseRepository:
|
|
def __init__(self, table: str, db: AsyncSession):
|
|
self.table = table
|
|
self.db = db
|
|
|
|
async def list(
|
|
self,
|
|
filters: dict | None = None,
|
|
sort: str = "sort_order",
|
|
sort_dir: str = "ASC",
|
|
page: int = 1,
|
|
per_page: int = 50,
|
|
include_deleted: bool = False,
|
|
) -> list[dict]:
|
|
"""List rows with optional filtering, sorting, pagination."""
|
|
where_clauses = []
|
|
params: dict[str, Any] = {}
|
|
|
|
if not include_deleted:
|
|
where_clauses.append("is_deleted = false")
|
|
|
|
if filters:
|
|
for i, (key, value) in enumerate(filters.items()):
|
|
if value is None:
|
|
where_clauses.append(f"{key} IS NULL")
|
|
elif value == "__notnull__":
|
|
where_clauses.append(f"{key} IS NOT NULL")
|
|
else:
|
|
param_name = f"f_{i}"
|
|
where_clauses.append(f"{key} = :{param_name}")
|
|
params[param_name] = value
|
|
|
|
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
|
|
offset = (page - 1) * per_page
|
|
|
|
query = text(f"""
|
|
SELECT * FROM {self.table}
|
|
WHERE {where_sql}
|
|
ORDER BY {sort} {sort_dir}
|
|
LIMIT :limit OFFSET :offset
|
|
""")
|
|
params["limit"] = per_page
|
|
params["offset"] = offset
|
|
|
|
result = await self.db.execute(query, params)
|
|
return [dict(row._mapping) for row in result]
|
|
|
|
async def count(
|
|
self,
|
|
filters: dict | None = None,
|
|
include_deleted: bool = False,
|
|
) -> int:
|
|
"""Count rows matching filters."""
|
|
where_clauses = []
|
|
params: dict[str, Any] = {}
|
|
|
|
if not include_deleted:
|
|
where_clauses.append("is_deleted = false")
|
|
|
|
if filters:
|
|
for i, (key, value) in enumerate(filters.items()):
|
|
if value is None:
|
|
where_clauses.append(f"{key} IS NULL")
|
|
else:
|
|
param_name = f"f_{i}"
|
|
where_clauses.append(f"{key} = :{param_name}")
|
|
params[param_name] = value
|
|
|
|
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
|
|
query = text(f"SELECT count(*) FROM {self.table} WHERE {where_sql}")
|
|
result = await self.db.execute(query, params)
|
|
return result.scalar() or 0
|
|
|
|
async def get(self, id: UUID | str) -> dict | None:
|
|
"""Get a single row by ID."""
|
|
id_str = str(id)
|
|
# Validate UUID format to prevent asyncpg DataError
|
|
try:
|
|
UUID(id_str)
|
|
except (ValueError, AttributeError):
|
|
return None
|
|
query = text(f"SELECT * FROM {self.table} WHERE id = :id")
|
|
result = await self.db.execute(query, {"id": id_str})
|
|
row = result.first()
|
|
return dict(row._mapping) if row else None
|
|
|
|
async def create(self, data: dict) -> dict:
|
|
"""Insert a new row. Auto-sets created_at, updated_at, is_deleted."""
|
|
data = {k: _coerce_value(v) for k, v in data.items() if v is not None or k in ("description", "notes", "body")}
|
|
data.setdefault("is_deleted", False)
|
|
|
|
now = datetime.now(timezone.utc)
|
|
if "created_at" not in data:
|
|
data["created_at"] = now
|
|
if "updated_at" not in data:
|
|
data["updated_at"] = now
|
|
|
|
columns = ", ".join(data.keys())
|
|
placeholders = ", ".join(f":{k}" for k in data.keys())
|
|
|
|
query = text(f"""
|
|
INSERT INTO {self.table} ({columns})
|
|
VALUES ({placeholders})
|
|
RETURNING *
|
|
""")
|
|
result = await self.db.execute(query, data)
|
|
row = result.first()
|
|
return dict(row._mapping) if row else data
|
|
|
|
async def update(self, id: UUID | str, data: dict) -> dict | None:
|
|
"""Update a row by ID. Auto-sets updated_at."""
|
|
data = {k: _coerce_value(v) for k, v in data.items()}
|
|
data["updated_at"] = datetime.now(timezone.utc)
|
|
|
|
# Remove None values except for fields that should be nullable
|
|
nullable_fields = {
|
|
"description", "notes", "body", "area_id", "project_id",
|
|
"parent_id", "parent_item_id", "release_id", "due_date", "deadline", "tags",
|
|
"context", "folder_id", "meeting_id", "completed_at",
|
|
"waiting_for_contact_id", "waiting_since", "color",
|
|
"rationale", "decided_at", "superseded_by_id",
|
|
"start_at", "end_at", "location", "agenda", "transcript", "notes_body",
|
|
"priority", "recurrence", "mime_type",
|
|
"category", "instructions", "expected_output", "estimated_days",
|
|
"contact_id", "started_at",
|
|
"weekly_hours", "effective_from",
|
|
}
|
|
clean_data = {}
|
|
for k, v in data.items():
|
|
if v is not None or k in nullable_fields:
|
|
clean_data[k] = v
|
|
|
|
if not clean_data:
|
|
return await self.get(id)
|
|
|
|
set_clauses = ", ".join(f"{k} = :{k}" for k in clean_data.keys())
|
|
clean_data["id"] = str(id)
|
|
|
|
query = text(f"""
|
|
UPDATE {self.table}
|
|
SET {set_clauses}
|
|
WHERE id = :id
|
|
RETURNING *
|
|
""")
|
|
result = await self.db.execute(query, clean_data)
|
|
row = result.first()
|
|
return dict(row._mapping) if row else None
|
|
|
|
async def soft_delete(self, id: UUID | str) -> bool:
|
|
"""Soft delete: set is_deleted=true, deleted_at=now()."""
|
|
query = text(f"""
|
|
UPDATE {self.table}
|
|
SET is_deleted = true, deleted_at = :now, updated_at = :now
|
|
WHERE id = :id AND is_deleted = false
|
|
RETURNING id
|
|
""")
|
|
now = datetime.now(timezone.utc)
|
|
result = await self.db.execute(query, {"id": str(id), "now": now})
|
|
return result.first() is not None
|
|
|
|
async def restore(self, id: UUID | str) -> bool:
|
|
"""Restore a soft-deleted row."""
|
|
query = text(f"""
|
|
UPDATE {self.table}
|
|
SET is_deleted = false, deleted_at = NULL, updated_at = :now
|
|
WHERE id = :id AND is_deleted = true
|
|
RETURNING id
|
|
""")
|
|
now = datetime.now(timezone.utc)
|
|
result = await self.db.execute(query, {"id": str(id), "now": now})
|
|
return result.first() is not None
|
|
|
|
async def permanent_delete(self, id: UUID | str) -> bool:
|
|
"""Hard delete. Admin only."""
|
|
query = text(f"DELETE FROM {self.table} WHERE id = :id RETURNING id")
|
|
result = await self.db.execute(query, {"id": str(id)})
|
|
return result.first() is not None
|
|
|
|
async def bulk_soft_delete(self, ids: list[str]) -> int:
|
|
"""Soft delete multiple rows."""
|
|
if not ids:
|
|
return 0
|
|
now = datetime.now(timezone.utc)
|
|
placeholders = ", ".join(f":id_{i}" for i in range(len(ids)))
|
|
params = {f"id_{i}": str(id) for i, id in enumerate(ids)}
|
|
params["now"] = now
|
|
|
|
query = text(f"""
|
|
UPDATE {self.table}
|
|
SET is_deleted = true, deleted_at = :now, updated_at = :now
|
|
WHERE id IN ({placeholders}) AND is_deleted = false
|
|
""")
|
|
result = await self.db.execute(query, params)
|
|
return result.rowcount
|
|
|
|
async def list_deleted(self) -> list[dict]:
|
|
"""List all soft-deleted rows. Used by Admin > Trash."""
|
|
query = text(f"""
|
|
SELECT * FROM {self.table}
|
|
WHERE is_deleted = true
|
|
ORDER BY deleted_at DESC
|
|
""")
|
|
result = await self.db.execute(query)
|
|
return [dict(row._mapping) for row in result]
|
|
|
|
async def reorder(self, id_order: list[str]) -> None:
|
|
"""Update sort_order based on position in list."""
|
|
for i, id in enumerate(id_order):
|
|
await self.db.execute(
|
|
text(f"UPDATE {self.table} SET sort_order = :order WHERE id = :id"),
|
|
{"order": (i + 1) * 10, "id": str(id)}
|
|
)
|