[+] cve cli/db fixes: sync returns upsert_result, signal handling
1. cve cli sync: backend.sync() now returns upsert_result directly,
remove redundant db.upsert_entries call;
2. add signal import for SIGINT/SIGTERM handling;
This commit is contained in:
parent
9b7046d6f0
commit
fc52280b43
@ -6,6 +6,7 @@ import datetime
|
||||
import enum
|
||||
import logging
|
||||
import pathlib
|
||||
import signal
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@ -98,18 +99,8 @@ class cve_cli_t:
|
||||
def on_progress(done: int, total: int) -> None:
|
||||
logger.info(dict(msg='sync progress', source=src_val, done=done, total=total))
|
||||
|
||||
entries = await backend.sync(since=since, months=months, on_progress=on_progress)
|
||||
upsert = await backend.sync(db=db, since=since, months=months, on_progress=on_progress)
|
||||
|
||||
upsert = db.upsert_entries(entries)
|
||||
now = datetime.datetime.now(datetime.timezone.utc).isoformat()
|
||||
db.update_sync_meta(src, last_sync=now, entry_count=db.count_entries(src))
|
||||
|
||||
logger.info(dict(
|
||||
msg='ingested',
|
||||
source=src.value,
|
||||
received=upsert.received,
|
||||
in_db=upsert.inserted,
|
||||
))
|
||||
result.sources.append(sync_source_result_t(
|
||||
source=src.value, synced=upsert.received, total_in_db=upsert.inserted,
|
||||
))
|
||||
@ -125,6 +116,49 @@ class cve_cli_t:
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def _async_run_sync(
|
||||
cache_dir: pathlib.Path,
|
||||
source: str,
|
||||
since: Optional[str],
|
||||
months: Optional[int],
|
||||
nvd_api_key: Optional[str],
|
||||
timeout: Optional[int],
|
||||
) -> sync_result_t:
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown = asyncio.Event()
|
||||
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, shutdown.set)
|
||||
|
||||
sync_task = asyncio.create_task(cve_cli_t._async_sync(
|
||||
cache_dir=cache_dir,
|
||||
source=source,
|
||||
since=since,
|
||||
months=months,
|
||||
nvd_api_key=nvd_api_key,
|
||||
))
|
||||
|
||||
elapsed = 0.0
|
||||
while not sync_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(shutdown.wait(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
if shutdown.is_set():
|
||||
sync_task.cancel()
|
||||
logger.warning(dict(msg='sync interrupted'))
|
||||
return sync_result_t(error='interrupted')
|
||||
|
||||
elapsed += 1.0
|
||||
if timeout is not None and elapsed >= timeout:
|
||||
sync_task.cancel()
|
||||
logger.warning(dict(msg='sync timeout', timeout=timeout))
|
||||
return sync_result_t(error='timeout after %ds' % timeout)
|
||||
|
||||
return sync_task.result()
|
||||
|
||||
@staticmethod
|
||||
def sync(
|
||||
cache_dir: pathlib.Path,
|
||||
@ -135,26 +169,14 @@ class cve_cli_t:
|
||||
nvd_api_key: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> int:
|
||||
coro = cve_cli_t._async_sync(
|
||||
result = asyncio.run(cve_cli_t._async_run_sync(
|
||||
cache_dir=cache_dir,
|
||||
source=source,
|
||||
since=since,
|
||||
months=months,
|
||||
nvd_api_key=nvd_api_key,
|
||||
)
|
||||
|
||||
if timeout is not None:
|
||||
async def _with_timeout() -> sync_result_t:
|
||||
try:
|
||||
return await asyncio.wait_for(coro, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(dict(msg='sync timeout', timeout=timeout))
|
||||
return sync_result_t(error='timeout after %ds' % timeout)
|
||||
|
||||
result = asyncio.run(_with_timeout())
|
||||
else:
|
||||
result = asyncio.run(coro)
|
||||
|
||||
timeout=timeout,
|
||||
))
|
||||
render(result, fmt)
|
||||
return 2 if result.error is not None else 0
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class cve_db_t(orm_module_t):
|
||||
|
||||
@classmethod
|
||||
def schema_version(cls) -> int:
|
||||
return 2
|
||||
return 4
|
||||
|
||||
@classmethod
|
||||
def migrate(cls, conn: sqlite3.Connection, from_version: int, to_version: int) -> None:
|
||||
@ -92,6 +92,28 @@ class cve_db_t(orm_module_t):
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
if from_version < 3:
|
||||
conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS cve_sync_days (
|
||||
source TEXT NOT NULL,
|
||||
day TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'complete',
|
||||
PRIMARY KEY (source, day)
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
if from_version < 4:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO cve_sync_days (source, day, status)
|
||||
SELECT DISTINCT 'nvd', SUBSTR(date_modified, 1, 10), 'complete'
|
||||
FROM cve_entries
|
||||
WHERE source = 'nvd' AND date_modified != '' AND LENGTH(date_modified) >= 10
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def __init__(self, db_path_or_conn: 'pathlib.Path | sqlite3.Connection') -> None:
|
||||
if isinstance(db_path_or_conn, sqlite3.Connection):
|
||||
super().__init__(db_path_or_conn)
|
||||
@ -394,5 +416,58 @@ class cve_db_t(orm_module_t):
|
||||
row = self._conn.execute('SELECT COUNT(*) FROM cve_osv_ecosystems').fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
def mark_days_complete(self, source: cve_source_t, days: list[datetime.date]) -> None:
|
||||
cur = self._conn.cursor()
|
||||
for d in days:
|
||||
cur.execute(
|
||||
'''
|
||||
INSERT INTO cve_sync_days (source, day, status)
|
||||
VALUES (?, ?, 'complete')
|
||||
ON CONFLICT(source, day) DO UPDATE SET status = 'complete'
|
||||
''',
|
||||
(source.value, d.isoformat()),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_complete_days(self, source: cve_source_t) -> set[datetime.date]:
|
||||
rows = self._conn.execute(
|
||||
"SELECT day FROM cve_sync_days WHERE source = ? AND status = 'complete'",
|
||||
(source.value,),
|
||||
).fetchall()
|
||||
result: set[datetime.date] = set()
|
||||
for r in rows:
|
||||
try:
|
||||
result.add(datetime.date.fromisoformat(str(r[0])))
|
||||
except ValueError:
|
||||
pass
|
||||
return result
|
||||
|
||||
def compute_missing_ranges(
|
||||
self,
|
||||
source: cve_source_t,
|
||||
start: datetime.date,
|
||||
end: datetime.date,
|
||||
) -> list[tuple[datetime.date, datetime.date]]:
|
||||
"""Given a target [start, end], return list of contiguous (start, end) gaps not yet synced."""
|
||||
complete = self.get_complete_days(source)
|
||||
ranges: list[tuple[datetime.date, datetime.date]] = []
|
||||
gap_start: Optional[datetime.date] = None
|
||||
day = start
|
||||
|
||||
while day <= end:
|
||||
if day not in complete:
|
||||
if gap_start is None:
|
||||
gap_start = day
|
||||
else:
|
||||
if gap_start is not None:
|
||||
ranges.append((gap_start, day - datetime.timedelta(days=1)))
|
||||
gap_start = None
|
||||
day += datetime.timedelta(days=1)
|
||||
|
||||
if gap_start is not None:
|
||||
ranges.append((gap_start, end))
|
||||
|
||||
return ranges
|
||||
|
||||
|
||||
orm_registry_t.register(cve_db_t)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user