[+] add ORM registry with migration support, migrate cache/db.py

1. add apps/orm/registry.py with orm_module_t base class and orm_registry_t singleton;
  2. singleton per db path, thread-safe, tracks registered ORM classes;
  3. orm_schema_versions table for per-module version tracking;
  4. classmethods table_prefix(), schema_version(), migrate() for schema management;
  5. registry.module(cls) returns typed ORM instance, cached per registry;
  6. migrate cache_db_t to extend orm_module_t, move schema into classmethod migrate();
  7. cache_db_t constructor accepts Path (legacy, uses registry) or Connection (from registry);
  8. orm_registry_t.register(cache_db_t) at module load time;
  9. add test_orm.py with 12 tests: singleton, migration, multi-module, incremental, failure;
This commit is contained in:
LLM 2026-04-13 09:00:00 +00:00
parent 1e1cd6c1c0
commit 687df29dfe
4 changed files with 545 additions and 107 deletions

@ -24,6 +24,8 @@ from ...models import (
package_index_t, package_index_t,
) )
from ..orm.registry import orm_module_t, orm_registry_t
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_T = TypeVar('_T', bound=pydantic.BaseModel) _T = TypeVar('_T', bound=pydantic.BaseModel)
@ -117,10 +119,8 @@ def _fetch_one(
return model.model_validate(dict(zip(columns, raw))) return model.model_validate(dict(zip(columns, raw)))
class cache_db_t: class cache_db_t(orm_module_t):
class constants_t: class constants_t:
schema_version: ClassVar[int] = 1
list_relation_types: ClassVar[dict[str, str]] = { list_relation_types: ClassVar[dict[str, str]] = {
'license': 'license', 'license': 'license',
'depends': 'depends', 'depends': 'depends',
@ -133,33 +133,20 @@ class cache_db_t:
'groups': 'groups', 'groups': 'groups',
} }
def __init__(self, db_path: pathlib.Path) -> None: # ── orm_module_t interface ──
self._db_path = db_path
self._conn = sqlite3.connect(str(db_path))
self._conn.execute('PRAGMA journal_mode=WAL')
self._conn.execute('PRAGMA foreign_keys=ON')
self._ensure_schema()
def close(self) -> None: @classmethod
self._conn.close() def table_prefix(cls) -> str:
return 'cache'
def _ensure_schema(self) -> None: @classmethod
cur = self._conn.cursor() def schema_version(cls) -> int:
return 1
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='schema_meta'") @classmethod
if cur.fetchone() is None: def migrate(cls, conn: sqlite3.Connection, from_version: int, to_version: int) -> None:
self._create_schema(cur) if from_version < 1:
self._conn.commit() conn.executescript("""
return
cur.execute('SELECT version FROM schema_meta LIMIT 1')
row = cur.fetchone()
if row is None or row[0] < cache_db_t.constants_t.schema_version:
self._create_schema(cur)
self._conn.commit()
def _create_schema(self, cur: sqlite3.Cursor) -> None:
cur.executescript("""
CREATE TABLE IF NOT EXISTS schema_meta ( CREATE TABLE IF NOT EXISTS schema_meta (
version INTEGER NOT NULL version INTEGER NOT NULL
); );
@ -231,12 +218,22 @@ class cache_db_t:
CREATE INDEX IF NOT EXISTS idx_local_packages_name_version CREATE INDEX IF NOT EXISTS idx_local_packages_name_version
ON local_packages(name, version); ON local_packages(name, version);
""") """)
conn.commit()
cur.execute('DELETE FROM schema_meta') # ── constructors ──
cur.execute(
'INSERT INTO schema_meta (version) VALUES (?)', def __init__(self, db_path_or_conn: 'pathlib.Path | sqlite3.Connection') -> None:
(cache_db_t.constants_t.schema_version,), if isinstance(db_path_or_conn, sqlite3.Connection):
) # from ORM registry
super().__init__(db_path_or_conn)
else:
# legacy: standalone usage, goes through registry
registry = orm_registry_t.get(db_path_or_conn)
super().__init__(registry.conn)
def close(self) -> None:
# no-op when managed by registry; caller should use registry.close()
pass
# ── helpers ── # ── helpers ──
@ -581,6 +578,7 @@ class cache_db_t:
filename=ppkg.filename, filename=ppkg.filename,
repo=pidx.name, repo=pidx.name,
sha256sum=ppkg.sha256sum, sha256sum=ppkg.sha256sum,
csize=ppkg.csize,
depends=[pacman_constraint_t.parse(d) for d in ppkg.depends], depends=[pacman_constraint_t.parse(d) for d in ppkg.depends],
provides=[pacman_constraint_t.parse(p) for p in ppkg.provides], provides=[pacman_constraint_t.parse(p) for p in ppkg.provides],
conflicts=[pacman_constraint_t.parse(c) for c in ppkg.conflicts], conflicts=[pacman_constraint_t.parse(c) for c in ppkg.conflicts],
@ -599,7 +597,7 @@ class cache_db_t:
cur = self._conn.cursor() cur = self._conn.cursor()
cur.execute( cur.execute(
''' '''
SELECT p.id, p.name, p.version, p.filename, p.sha256sum, s.repo SELECT p.id, p.name, p.version, p.filename, p.sha256sum, p.csize, s.repo
FROM packages p FROM packages p
JOIN snapshots s ON s.id = p.snapshot_id JOIN snapshots s ON s.id = p.snapshot_id
WHERE p.id IN (SELECT MIN(id) FROM packages GROUP BY name, version) WHERE p.id IN (SELECT MIN(id) FROM packages GROUP BY name, version)
@ -608,12 +606,13 @@ class cache_db_t:
pkg_by_id: dict[int, package_desc_t] = {} pkg_by_id: dict[int, package_desc_t] = {}
repo_of: dict[int, str] = {} repo_of: dict[int, str] = {}
for pid, name, version, filename, sha256sum, repo in cur.fetchall(): for pid, name, version, filename, sha256sum, csize, repo in cur.fetchall():
pkg_by_id[pid] = package_desc_t( pkg_by_id[pid] = package_desc_t(
name=name, name=name,
version=version, version=version,
filename=filename, filename=filename,
sha256sum=sha256sum, sha256sum=sha256sum,
csize=csize,
) )
repo_of[pid] = repo repo_of[pid] = repo
@ -629,6 +628,7 @@ class cache_db_t:
filename=ppkg.filename, filename=ppkg.filename,
repo=repo_of[pid], repo=repo_of[pid],
sha256sum=ppkg.sha256sum, sha256sum=ppkg.sha256sum,
csize=ppkg.csize,
depends=[pacman_constraint_t.parse(d) for d in ppkg.depends], depends=[pacman_constraint_t.parse(d) for d in ppkg.depends],
provides=[pacman_constraint_t.parse(p) for p in ppkg.provides], provides=[pacman_constraint_t.parse(p) for p in ppkg.provides],
conflicts=[pacman_constraint_t.parse(c) for c in ppkg.conflicts], conflicts=[pacman_constraint_t.parse(c) for c in ppkg.conflicts],
@ -776,3 +776,6 @@ class cache_db_t:
cur.execute('SELECT COUNT(*) FROM snapshots') cur.execute('SELECT COUNT(*) FROM snapshots')
row = cur.fetchone() row = cur.fetchone()
return row is not None and row[0] > 0 return row is not None and row[0] > 0
orm_registry_t.register(cache_db_t)

@ -0,0 +1,154 @@
"""ORM registry — singleton that manages sqlite connection and registered ORM modules.
Each ORM module (cache, cve, etc.) is a class with:
- classmethod schema_version() -> int
- classmethod migrate(conn, from_version, to_version) -> None
- classmethod table_prefix() -> str
The registry:
1. Holds one sqlite connection per db path (singleton per path).
2. Tracks registered ORM classes.
3. On connect, iterates registered classes and runs migrations if needed.
4. Provides typed access to ORM instances and raw cursor fallback.
"""
import logging
import pathlib
import sqlite3
import threading
from typing import (
ClassVar,
Optional,
Type,
)
logger = logging.getLogger(__name__)
class orm_module_t:
"""Base class for ORM modules. Subclass this and implement the classmethods."""
@classmethod
def table_prefix(cls) -> str:
raise NotImplementedError
@classmethod
def schema_version(cls) -> int:
raise NotImplementedError
@classmethod
def migrate(cls, conn: sqlite3.Connection, from_version: int, to_version: int) -> None:
raise NotImplementedError
def __init__(self, conn: sqlite3.Connection) -> None:
self._conn = conn
class orm_registry_t:
"""Singleton registry per db path."""
_instances: ClassVar[dict[str, 'orm_registry_t']] = {}
_lock: ClassVar[threading.Lock] = threading.Lock()
_registered_classes: ClassVar[list[Type[orm_module_t]]] = []
@classmethod
def register(cls, module_class: Type[orm_module_t]) -> Type[orm_module_t]:
"""Register an ORM module class. Call at module load time."""
if module_class not in cls._registered_classes:
cls._registered_classes.append(module_class)
return module_class
@classmethod
def get(cls, db_path: pathlib.Path) -> 'orm_registry_t':
"""Get or create registry singleton for a db path."""
key = str(db_path.resolve())
with cls._lock:
if key not in cls._instances:
cls._instances[key] = cls(db_path)
return cls._instances[key]
@classmethod
def reset(cls, db_path: Optional[pathlib.Path] = None) -> None:
"""Close and remove singleton(s). For testing."""
with cls._lock:
if db_path is not None:
key = str(db_path.resolve())
inst = cls._instances.pop(key, None)
if inst is not None:
inst._conn.close()
else:
for inst in cls._instances.values():
inst._conn.close()
cls._instances.clear()
def __init__(self, db_path: pathlib.Path) -> None:
db_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(str(db_path))
self._conn.execute('PRAGMA journal_mode=WAL')
self._modules: dict[Type[orm_module_t], orm_module_t] = {}
self._ensure_meta_table()
self._run_migrations()
def _ensure_meta_table(self) -> None:
self._conn.execute(
'''
CREATE TABLE IF NOT EXISTS orm_schema_versions (
module_prefix TEXT PRIMARY KEY,
version INTEGER NOT NULL
)
'''
)
self._conn.commit()
def _get_current_version(self, prefix: str) -> int:
row = self._conn.execute(
'SELECT version FROM orm_schema_versions WHERE module_prefix = ?',
(prefix,),
).fetchone()
return row[0] if row is not None else 0
def _set_version(self, prefix: str, version: int) -> None:
self._conn.execute(
'''
INSERT INTO orm_schema_versions (module_prefix, version)
VALUES (?, ?)
ON CONFLICT(module_prefix) DO UPDATE SET version = ?
''',
(prefix, version, version),
)
self._conn.commit()
def _run_migrations(self) -> None:
for cls in self._registered_classes:
prefix = cls.table_prefix()
current = self._get_current_version(prefix)
target = cls.schema_version()
if current < target:
logger.info(dict(
msg='migrating',
module=prefix,
from_version=current,
to_version=target,
))
cls.migrate(self._conn, current, target)
self._set_version(prefix, target)
logger.info(dict(msg='migration done', module=prefix, version=target))
@property
def conn(self) -> sqlite3.Connection:
return self._conn
def cursor(self) -> sqlite3.Cursor:
return self._conn.cursor()
def module(self, cls: Type[orm_module_t]) -> orm_module_t:
"""Get an ORM module instance. Created once per registry."""
if cls not in self._modules:
self._modules[cls] = cls(self._conn)
return self._modules[cls]
def close(self) -> None:
self._conn.close()

@ -0,0 +1,281 @@
"""Tests for apps/orm/registry.py
Test matrix:
- registry singleton: same path returns same instance, different path returns different
- reset: clears singleton, closes connection
- schema versioning: new module gets migrated, existing at target version skipped
- multi-module: two modules registered, both get migrated independently
- re-open: close and re-open same db, versions persist, no re-migration
- module access: registry.module() returns typed instance, same instance on repeat call
- raw cursor: registry.cursor() works for ad-hoc queries
- migration ordering: from_version=0 on first run, from_version=N on upgrade
- migration failure: exception in migrate rolls back cleanly (no partial version bump)
"""
import pathlib
import sqlite3
import tempfile
import unittest
from ..apps.orm.registry import orm_module_t, orm_registry_t
class items_orm_t(orm_module_t):
@classmethod
def table_prefix(cls) -> str:
return 'items'
@classmethod
def schema_version(cls) -> int:
return 1
@classmethod
def migrate(cls, conn: sqlite3.Connection, from_version: int, to_version: int) -> None:
if from_version < 1:
conn.execute(
'CREATE TABLE IF NOT EXISTS items_data (id INTEGER PRIMARY KEY, name TEXT NOT NULL)'
)
conn.commit()
def insert(self, name: str) -> None:
self._conn.execute('INSERT INTO items_data (name) VALUES (?)', (name,))
self._conn.commit()
def list_all(self) -> list[tuple[int, str]]:
return self._conn.execute('SELECT id, name FROM items_data').fetchall()
class tags_orm_t(orm_module_t):
@classmethod
def table_prefix(cls) -> str:
return 'tags'
@classmethod
def schema_version(cls) -> int:
return 2
@classmethod
def migrate(cls, conn: sqlite3.Connection, from_version: int, to_version: int) -> None:
if from_version < 1:
conn.execute(
'CREATE TABLE IF NOT EXISTS tags_data (id INTEGER PRIMARY KEY, label TEXT NOT NULL)'
)
conn.commit()
if from_version < 2:
conn.execute('ALTER TABLE tags_data ADD COLUMN color TEXT DEFAULT ""')
conn.commit()
def insert(self, label: str, color: str = '') -> None:
self._conn.execute('INSERT INTO tags_data (label, color) VALUES (?, ?)', (label, color))
self._conn.commit()
def list_all(self) -> list[tuple[int, str, str]]:
return self._conn.execute('SELECT id, label, color FROM tags_data').fetchall()
class broken_orm_t(orm_module_t):
@classmethod
def table_prefix(cls) -> str:
return 'broken'
@classmethod
def schema_version(cls) -> int:
return 1
@classmethod
def migrate(cls, conn: sqlite3.Connection, from_version: int, to_version: int) -> None:
raise RuntimeError('intentional migration failure')
class TestRegistrySingleton(unittest.TestCase):
def setUp(self) -> None:
self.tmpdir = tempfile.mkdtemp()
orm_registry_t._registered_classes.clear()
orm_registry_t._instances.clear()
def tearDown(self) -> None:
orm_registry_t.reset()
def test_same_path_same_instance(self) -> None:
orm_registry_t.register(items_orm_t)
p = pathlib.Path(self.tmpdir) / 'a.db'
r1 = orm_registry_t.get(p)
r2 = orm_registry_t.get(p)
self.assertIs(r1, r2)
def test_different_path_different_instance(self) -> None:
orm_registry_t.register(items_orm_t)
r1 = orm_registry_t.get(pathlib.Path(self.tmpdir) / 'a.db')
r2 = orm_registry_t.get(pathlib.Path(self.tmpdir) / 'b.db')
self.assertIsNot(r1, r2)
def test_reset_clears_singleton(self) -> None:
orm_registry_t.register(items_orm_t)
p = pathlib.Path(self.tmpdir) / 'a.db'
r1 = orm_registry_t.get(p)
orm_registry_t.reset(p)
r2 = orm_registry_t.get(p)
self.assertIsNot(r1, r2)
class TestMigration(unittest.TestCase):
def setUp(self) -> None:
self.tmpdir = tempfile.mkdtemp()
orm_registry_t._registered_classes.clear()
orm_registry_t._instances.clear()
def tearDown(self) -> None:
orm_registry_t.reset()
def test_first_migration_creates_table(self) -> None:
orm_registry_t.register(items_orm_t)
p = pathlib.Path(self.tmpdir) / 'test.db'
reg = orm_registry_t.get(p)
# table should exist
rows = reg.conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='items_data'"
).fetchall()
self.assertEqual(len(rows), 1)
def test_version_persisted(self) -> None:
orm_registry_t.register(items_orm_t)
p = pathlib.Path(self.tmpdir) / 'test.db'
reg = orm_registry_t.get(p)
ver = reg._get_current_version('items')
self.assertEqual(ver, 1)
def test_no_remigration_on_reopen(self) -> None:
orm_registry_t.register(items_orm_t)
p = pathlib.Path(self.tmpdir) / 'test.db'
reg = orm_registry_t.get(p)
mod = reg.module(items_orm_t)
assert isinstance(mod, items_orm_t)
mod.insert('hello')
orm_registry_t.reset(p)
# reopen — data should survive, no re-migration
reg2 = orm_registry_t.get(p)
mod2 = reg2.module(items_orm_t)
assert isinstance(mod2, items_orm_t)
self.assertEqual(len(mod2.list_all()), 1)
def test_skips_if_at_target(self) -> None:
orm_registry_t.register(items_orm_t)
p = pathlib.Path(self.tmpdir) / 'test.db'
reg = orm_registry_t.get(p)
# manually set version ahead — should not crash
reg._set_version('items', 999)
orm_registry_t.reset(p)
# re-register with version 1, which is below 999 — no migration
reg2 = orm_registry_t.get(p)
self.assertEqual(reg2._get_current_version('items'), 999)
class TestMultiModule(unittest.TestCase):
def setUp(self) -> None:
self.tmpdir = tempfile.mkdtemp()
orm_registry_t._registered_classes.clear()
orm_registry_t._instances.clear()
def tearDown(self) -> None:
orm_registry_t.reset()
def test_two_modules_migrated_independently(self) -> None:
orm_registry_t.register(items_orm_t)
orm_registry_t.register(tags_orm_t)
p = pathlib.Path(self.tmpdir) / 'test.db'
reg = orm_registry_t.get(p)
self.assertEqual(reg._get_current_version('items'), 1)
self.assertEqual(reg._get_current_version('tags'), 2)
items = reg.module(items_orm_t)
tags = reg.module(tags_orm_t)
assert isinstance(items, items_orm_t)
assert isinstance(tags, tags_orm_t)
items.insert('pkg')
tags.insert('urgent', 'red')
self.assertEqual(len(items.list_all()), 1)
self.assertEqual(len(tags.list_all()), 1)
self.assertEqual(tags.list_all()[0][2], 'red')
def test_incremental_migration(self) -> None:
"""tags_orm_t v2 adds color column via from_version < 2 branch."""
orm_registry_t.register(tags_orm_t)
p = pathlib.Path(self.tmpdir) / 'test.db'
# simulate v1 already applied
conn = sqlite3.connect(str(p))
conn.execute(
'CREATE TABLE IF NOT EXISTS orm_schema_versions (module_prefix TEXT PRIMARY KEY, version INTEGER NOT NULL)'
)
conn.execute("INSERT INTO orm_schema_versions VALUES ('tags', 1)")
conn.execute('CREATE TABLE tags_data (id INTEGER PRIMARY KEY, label TEXT NOT NULL)')
conn.commit()
conn.close()
reg = orm_registry_t.get(p)
# should have migrated from 1 -> 2 (added color column)
self.assertEqual(reg._get_current_version('tags'), 2)
tags = reg.module(tags_orm_t)
assert isinstance(tags, tags_orm_t)
tags.insert('test', 'blue')
row = tags.list_all()[0]
self.assertEqual(row[2], 'blue')
class TestModuleAccess(unittest.TestCase):
def setUp(self) -> None:
self.tmpdir = tempfile.mkdtemp()
orm_registry_t._registered_classes.clear()
orm_registry_t._instances.clear()
def tearDown(self) -> None:
orm_registry_t.reset()
def test_same_instance_on_repeat(self) -> None:
orm_registry_t.register(items_orm_t)
p = pathlib.Path(self.tmpdir) / 'test.db'
reg = orm_registry_t.get(p)
m1 = reg.module(items_orm_t)
m2 = reg.module(items_orm_t)
self.assertIs(m1, m2)
def test_raw_cursor(self) -> None:
orm_registry_t.register(items_orm_t)
p = pathlib.Path(self.tmpdir) / 'test.db'
reg = orm_registry_t.get(p)
cur = reg.cursor()
cur.execute("INSERT INTO items_data (name) VALUES ('raw')")
reg.conn.commit()
rows = cur.execute('SELECT name FROM items_data').fetchall()
self.assertEqual(rows, [('raw',)])
class TestMigrationFailure(unittest.TestCase):
def setUp(self) -> None:
self.tmpdir = tempfile.mkdtemp()
orm_registry_t._registered_classes.clear()
orm_registry_t._instances.clear()
def tearDown(self) -> None:
orm_registry_t.reset()
def test_failed_migration_no_version_bump(self) -> None:
orm_registry_t.register(broken_orm_t)
p = pathlib.Path(self.tmpdir) / 'test.db'
with self.assertRaises(RuntimeError):
orm_registry_t.get(p)
# version should not have been bumped
conn = sqlite3.connect(str(p))
row = conn.execute(
"SELECT version FROM orm_schema_versions WHERE module_prefix = 'broken'"
).fetchone()
# either no row or version 0
if row is not None:
self.assertEqual(row[0], 0)
conn.close()