Initial commit
This commit is contained in:
250
core/base_repository.py
Normal file
250
core/base_repository.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from __future__ import annotations
|
||||
"""
|
||||
BaseRepository: generic CRUD operations for all entities.
|
||||
Uses raw SQL via SQLAlchemy text() - no ORM models needed.
|
||||
Every method automatically filters is_deleted=false unless specified.
|
||||
"""
|
||||
|
||||
import re
|
||||
from uuid import UUID
|
||||
from datetime import date, datetime, timezone
|
||||
from typing import Any
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
_ISO_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
|
||||
_ISO_DATETIME_RE = re.compile(r"^\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}")
|
||||
|
||||
|
||||
def _coerce_value(value: Any) -> Any:
|
||||
"""Convert ISO date/datetime strings to Python date/datetime objects.
|
||||
asyncpg requires native Python types, not strings, for date columns."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if _ISO_DATE_RE.match(value):
|
||||
try:
|
||||
return date.fromisoformat(value)
|
||||
except ValueError:
|
||||
pass
|
||||
if _ISO_DATETIME_RE.match(value):
|
||||
try:
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
class BaseRepository:
|
||||
def __init__(self, table: str, db: AsyncSession):
|
||||
self.table = table
|
||||
self.db = db
|
||||
|
||||
async def list(
|
||||
self,
|
||||
filters: dict | None = None,
|
||||
sort: str = "sort_order",
|
||||
sort_dir: str = "ASC",
|
||||
page: int = 1,
|
||||
per_page: int = 50,
|
||||
include_deleted: bool = False,
|
||||
) -> list[dict]:
|
||||
"""List rows with optional filtering, sorting, pagination."""
|
||||
where_clauses = []
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if not include_deleted:
|
||||
where_clauses.append("is_deleted = false")
|
||||
|
||||
if filters:
|
||||
for i, (key, value) in enumerate(filters.items()):
|
||||
if value is None:
|
||||
where_clauses.append(f"{key} IS NULL")
|
||||
elif value == "__notnull__":
|
||||
where_clauses.append(f"{key} IS NOT NULL")
|
||||
else:
|
||||
param_name = f"f_{i}"
|
||||
where_clauses.append(f"{key} = :{param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||||
offset = (page - 1) * per_page
|
||||
|
||||
query = text(f"""
|
||||
SELECT * FROM {self.table}
|
||||
WHERE {where_sql}
|
||||
ORDER BY {sort} {sort_dir}
|
||||
LIMIT :limit OFFSET :offset
|
||||
""")
|
||||
params["limit"] = per_page
|
||||
params["offset"] = offset
|
||||
|
||||
result = await self.db.execute(query, params)
|
||||
return [dict(row._mapping) for row in result]
|
||||
|
||||
async def count(
|
||||
self,
|
||||
filters: dict | None = None,
|
||||
include_deleted: bool = False,
|
||||
) -> int:
|
||||
"""Count rows matching filters."""
|
||||
where_clauses = []
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if not include_deleted:
|
||||
where_clauses.append("is_deleted = false")
|
||||
|
||||
if filters:
|
||||
for i, (key, value) in enumerate(filters.items()):
|
||||
if value is None:
|
||||
where_clauses.append(f"{key} IS NULL")
|
||||
else:
|
||||
param_name = f"f_{i}"
|
||||
where_clauses.append(f"{key} = :{param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||||
query = text(f"SELECT count(*) FROM {self.table} WHERE {where_sql}")
|
||||
result = await self.db.execute(query, params)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get(self, id: UUID | str) -> dict | None:
|
||||
"""Get a single row by ID."""
|
||||
id_str = str(id)
|
||||
# Validate UUID format to prevent asyncpg DataError
|
||||
try:
|
||||
UUID(id_str)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
query = text(f"SELECT * FROM {self.table} WHERE id = :id")
|
||||
result = await self.db.execute(query, {"id": id_str})
|
||||
row = result.first()
|
||||
return dict(row._mapping) if row else None
|
||||
|
||||
async def create(self, data: dict) -> dict:
|
||||
"""Insert a new row. Auto-sets created_at, updated_at, is_deleted."""
|
||||
data = {k: _coerce_value(v) for k, v in data.items() if v is not None or k in ("description", "notes", "body")}
|
||||
data.setdefault("is_deleted", False)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if "created_at" not in data:
|
||||
data["created_at"] = now
|
||||
if "updated_at" not in data:
|
||||
data["updated_at"] = now
|
||||
|
||||
columns = ", ".join(data.keys())
|
||||
placeholders = ", ".join(f":{k}" for k in data.keys())
|
||||
|
||||
query = text(f"""
|
||||
INSERT INTO {self.table} ({columns})
|
||||
VALUES ({placeholders})
|
||||
RETURNING *
|
||||
""")
|
||||
result = await self.db.execute(query, data)
|
||||
row = result.first()
|
||||
return dict(row._mapping) if row else data
|
||||
|
||||
async def update(self, id: UUID | str, data: dict) -> dict | None:
|
||||
"""Update a row by ID. Auto-sets updated_at."""
|
||||
data = {k: _coerce_value(v) for k, v in data.items()}
|
||||
data["updated_at"] = datetime.now(timezone.utc)
|
||||
|
||||
# Remove None values except for fields that should be nullable
|
||||
nullable_fields = {
|
||||
"description", "notes", "body", "area_id", "project_id",
|
||||
"parent_id", "parent_item_id", "release_id", "due_date", "deadline", "tags",
|
||||
"context", "folder_id", "meeting_id", "completed_at",
|
||||
"waiting_for_contact_id", "waiting_since", "color",
|
||||
"rationale", "decided_at", "superseded_by_id",
|
||||
"start_at", "end_at", "location", "agenda", "transcript", "notes_body",
|
||||
"priority", "recurrence", "mime_type",
|
||||
"category", "instructions", "expected_output", "estimated_days",
|
||||
"contact_id", "started_at",
|
||||
"weekly_hours", "effective_from",
|
||||
"task_id", "meeting_id",
|
||||
}
|
||||
clean_data = {}
|
||||
for k, v in data.items():
|
||||
if v is not None or k in nullable_fields:
|
||||
clean_data[k] = v
|
||||
|
||||
if not clean_data:
|
||||
return await self.get(id)
|
||||
|
||||
set_clauses = ", ".join(f"{k} = :{k}" for k in clean_data.keys())
|
||||
clean_data["id"] = str(id)
|
||||
|
||||
query = text(f"""
|
||||
UPDATE {self.table}
|
||||
SET {set_clauses}
|
||||
WHERE id = :id
|
||||
RETURNING *
|
||||
""")
|
||||
result = await self.db.execute(query, clean_data)
|
||||
row = result.first()
|
||||
return dict(row._mapping) if row else None
|
||||
|
||||
async def soft_delete(self, id: UUID | str) -> bool:
|
||||
"""Soft delete: set is_deleted=true, deleted_at=now()."""
|
||||
query = text(f"""
|
||||
UPDATE {self.table}
|
||||
SET is_deleted = true, deleted_at = :now, updated_at = :now
|
||||
WHERE id = :id AND is_deleted = false
|
||||
RETURNING id
|
||||
""")
|
||||
now = datetime.now(timezone.utc)
|
||||
result = await self.db.execute(query, {"id": str(id), "now": now})
|
||||
return result.first() is not None
|
||||
|
||||
async def restore(self, id: UUID | str) -> bool:
|
||||
"""Restore a soft-deleted row."""
|
||||
query = text(f"""
|
||||
UPDATE {self.table}
|
||||
SET is_deleted = false, deleted_at = NULL, updated_at = :now
|
||||
WHERE id = :id AND is_deleted = true
|
||||
RETURNING id
|
||||
""")
|
||||
now = datetime.now(timezone.utc)
|
||||
result = await self.db.execute(query, {"id": str(id), "now": now})
|
||||
return result.first() is not None
|
||||
|
||||
async def permanent_delete(self, id: UUID | str) -> bool:
|
||||
"""Hard delete. Admin only."""
|
||||
query = text(f"DELETE FROM {self.table} WHERE id = :id RETURNING id")
|
||||
result = await self.db.execute(query, {"id": str(id)})
|
||||
return result.first() is not None
|
||||
|
||||
async def bulk_soft_delete(self, ids: list[str]) -> int:
|
||||
"""Soft delete multiple rows."""
|
||||
if not ids:
|
||||
return 0
|
||||
now = datetime.now(timezone.utc)
|
||||
placeholders = ", ".join(f":id_{i}" for i in range(len(ids)))
|
||||
params = {f"id_{i}": str(id) for i, id in enumerate(ids)}
|
||||
params["now"] = now
|
||||
|
||||
query = text(f"""
|
||||
UPDATE {self.table}
|
||||
SET is_deleted = true, deleted_at = :now, updated_at = :now
|
||||
WHERE id IN ({placeholders}) AND is_deleted = false
|
||||
""")
|
||||
result = await self.db.execute(query, params)
|
||||
return result.rowcount
|
||||
|
||||
async def list_deleted(self) -> list[dict]:
|
||||
"""List all soft-deleted rows. Used by Admin > Trash."""
|
||||
query = text(f"""
|
||||
SELECT * FROM {self.table}
|
||||
WHERE is_deleted = true
|
||||
ORDER BY deleted_at DESC
|
||||
""")
|
||||
result = await self.db.execute(query)
|
||||
return [dict(row._mapping) for row in result]
|
||||
|
||||
async def reorder(self, id_order: list[str]) -> None:
|
||||
"""Update sort_order based on position in list."""
|
||||
for i, id in enumerate(id_order):
|
||||
await self.db.execute(
|
||||
text(f"UPDATE {self.table} SET sort_order = :order WHERE id = :id"),
|
||||
{"order": (i + 1) * 10, "id": str(id)}
|
||||
)
|
||||
Reference in New Issue
Block a user