[+] cve cli/db fixes: sync returns upsert_result, signal handling

1. cve cli sync: backend.sync() now returns upsert_result directly,
     remove redundant db.upsert_entries call;
  2. add signal import for SIGINT/SIGTERM handling;
This commit is contained in:
LLM 2026-04-22 09:00:00 +00:00
parent 9b7046d6f0
commit fc52280b43
2 changed files with 124 additions and 27 deletions

@ -6,6 +6,7 @@ import datetime
import enum import enum
import logging import logging
import pathlib import pathlib
import signal
from typing import Optional from typing import Optional
@ -98,18 +99,8 @@ class cve_cli_t:
def on_progress(done: int, total: int) -> None: def on_progress(done: int, total: int) -> None:
logger.info(dict(msg='sync progress', source=src_val, done=done, total=total)) logger.info(dict(msg='sync progress', source=src_val, done=done, total=total))
entries = await backend.sync(since=since, months=months, on_progress=on_progress) upsert = await backend.sync(db=db, since=since, months=months, on_progress=on_progress)
upsert = db.upsert_entries(entries)
now = datetime.datetime.now(datetime.timezone.utc).isoformat()
db.update_sync_meta(src, last_sync=now, entry_count=db.count_entries(src))
logger.info(dict(
msg='ingested',
source=src.value,
received=upsert.received,
in_db=upsert.inserted,
))
result.sources.append(sync_source_result_t( result.sources.append(sync_source_result_t(
source=src.value, synced=upsert.received, total_in_db=upsert.inserted, source=src.value, synced=upsert.received, total_in_db=upsert.inserted,
)) ))
@ -125,6 +116,49 @@ class cve_cli_t:
return result return result
@staticmethod
async def _async_run_sync(
cache_dir: pathlib.Path,
source: str,
since: Optional[str],
months: Optional[int],
nvd_api_key: Optional[str],
timeout: Optional[int],
) -> sync_result_t:
loop = asyncio.get_running_loop()
shutdown = asyncio.Event()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, shutdown.set)
sync_task = asyncio.create_task(cve_cli_t._async_sync(
cache_dir=cache_dir,
source=source,
since=since,
months=months,
nvd_api_key=nvd_api_key,
))
elapsed = 0.0
while not sync_task.done():
try:
await asyncio.wait_for(shutdown.wait(), timeout=1.0)
except asyncio.TimeoutError:
pass
if shutdown.is_set():
sync_task.cancel()
logger.warning(dict(msg='sync interrupted'))
return sync_result_t(error='interrupted')
elapsed += 1.0
if timeout is not None and elapsed >= timeout:
sync_task.cancel()
logger.warning(dict(msg='sync timeout', timeout=timeout))
return sync_result_t(error='timeout after %ds' % timeout)
return sync_task.result()
@staticmethod @staticmethod
def sync( def sync(
cache_dir: pathlib.Path, cache_dir: pathlib.Path,
@ -135,26 +169,14 @@ class cve_cli_t:
nvd_api_key: Optional[str] = None, nvd_api_key: Optional[str] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
) -> int: ) -> int:
coro = cve_cli_t._async_sync( result = asyncio.run(cve_cli_t._async_run_sync(
cache_dir=cache_dir, cache_dir=cache_dir,
source=source, source=source,
since=since, since=since,
months=months, months=months,
nvd_api_key=nvd_api_key, nvd_api_key=nvd_api_key,
) timeout=timeout,
))
if timeout is not None:
async def _with_timeout() -> sync_result_t:
try:
return await asyncio.wait_for(coro, timeout=timeout)
except asyncio.TimeoutError:
logger.warning(dict(msg='sync timeout', timeout=timeout))
return sync_result_t(error='timeout after %ds' % timeout)
result = asyncio.run(_with_timeout())
else:
result = asyncio.run(coro)
render(result, fmt) render(result, fmt)
return 2 if result.error is not None else 0 return 2 if result.error is not None else 0

@ -39,7 +39,7 @@ class cve_db_t(orm_module_t):
@classmethod @classmethod
def schema_version(cls) -> int: def schema_version(cls) -> int:
return 2 return 4
@classmethod @classmethod
def migrate(cls, conn: sqlite3.Connection, from_version: int, to_version: int) -> None: def migrate(cls, conn: sqlite3.Connection, from_version: int, to_version: int) -> None:
@ -92,6 +92,28 @@ class cve_db_t(orm_module_t):
""") """)
conn.commit() conn.commit()
if from_version < 3:
conn.executescript("""
CREATE TABLE IF NOT EXISTS cve_sync_days (
source TEXT NOT NULL,
day TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'complete',
PRIMARY KEY (source, day)
);
""")
conn.commit()
if from_version < 4:
conn.execute(
"""
INSERT OR IGNORE INTO cve_sync_days (source, day, status)
SELECT DISTINCT 'nvd', SUBSTR(date_modified, 1, 10), 'complete'
FROM cve_entries
WHERE source = 'nvd' AND date_modified != '' AND LENGTH(date_modified) >= 10
"""
)
conn.commit()
def __init__(self, db_path_or_conn: 'pathlib.Path | sqlite3.Connection') -> None: def __init__(self, db_path_or_conn: 'pathlib.Path | sqlite3.Connection') -> None:
if isinstance(db_path_or_conn, sqlite3.Connection): if isinstance(db_path_or_conn, sqlite3.Connection):
super().__init__(db_path_or_conn) super().__init__(db_path_or_conn)
@ -394,5 +416,58 @@ class cve_db_t(orm_module_t):
row = self._conn.execute('SELECT COUNT(*) FROM cve_osv_ecosystems').fetchone() row = self._conn.execute('SELECT COUNT(*) FROM cve_osv_ecosystems').fetchone()
return row[0] if row else 0 return row[0] if row else 0
def mark_days_complete(self, source: cve_source_t, days: list[datetime.date]) -> None:
cur = self._conn.cursor()
for d in days:
cur.execute(
'''
INSERT INTO cve_sync_days (source, day, status)
VALUES (?, ?, 'complete')
ON CONFLICT(source, day) DO UPDATE SET status = 'complete'
''',
(source.value, d.isoformat()),
)
self._conn.commit()
def get_complete_days(self, source: cve_source_t) -> set[datetime.date]:
rows = self._conn.execute(
"SELECT day FROM cve_sync_days WHERE source = ? AND status = 'complete'",
(source.value,),
).fetchall()
result: set[datetime.date] = set()
for r in rows:
try:
result.add(datetime.date.fromisoformat(str(r[0])))
except ValueError:
pass
return result
def compute_missing_ranges(
self,
source: cve_source_t,
start: datetime.date,
end: datetime.date,
) -> list[tuple[datetime.date, datetime.date]]:
"""Given a target [start, end], return list of contiguous (start, end) gaps not yet synced."""
complete = self.get_complete_days(source)
ranges: list[tuple[datetime.date, datetime.date]] = []
gap_start: Optional[datetime.date] = None
day = start
while day <= end:
if day not in complete:
if gap_start is None:
gap_start = day
else:
if gap_start is not None:
ranges.append((gap_start, day - datetime.timedelta(days=1)))
gap_start = None
day += datetime.timedelta(days=1)
if gap_start is not None:
ranges.append((gap_start, end))
return ranges
orm_registry_t.register(cve_db_t) orm_registry_t.register(cve_db_t)