[+] cve cli/db fixes: sync returns upsert_result, signal handling
1. cve cli sync: backend.sync() now returns upsert_result directly,
remove redundant db.upsert_entries call;
2. add signal import for SIGINT/SIGTERM handling;
This commit is contained in:
parent
9b7046d6f0
commit
fc52280b43
@ -6,6 +6,7 @@ import datetime
|
|||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import signal
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -98,18 +99,8 @@ class cve_cli_t:
|
|||||||
def on_progress(done: int, total: int) -> None:
|
def on_progress(done: int, total: int) -> None:
|
||||||
logger.info(dict(msg='sync progress', source=src_val, done=done, total=total))
|
logger.info(dict(msg='sync progress', source=src_val, done=done, total=total))
|
||||||
|
|
||||||
entries = await backend.sync(since=since, months=months, on_progress=on_progress)
|
upsert = await backend.sync(db=db, since=since, months=months, on_progress=on_progress)
|
||||||
|
|
||||||
upsert = db.upsert_entries(entries)
|
|
||||||
now = datetime.datetime.now(datetime.timezone.utc).isoformat()
|
|
||||||
db.update_sync_meta(src, last_sync=now, entry_count=db.count_entries(src))
|
|
||||||
|
|
||||||
logger.info(dict(
|
|
||||||
msg='ingested',
|
|
||||||
source=src.value,
|
|
||||||
received=upsert.received,
|
|
||||||
in_db=upsert.inserted,
|
|
||||||
))
|
|
||||||
result.sources.append(sync_source_result_t(
|
result.sources.append(sync_source_result_t(
|
||||||
source=src.value, synced=upsert.received, total_in_db=upsert.inserted,
|
source=src.value, synced=upsert.received, total_in_db=upsert.inserted,
|
||||||
))
|
))
|
||||||
@ -125,6 +116,49 @@ class cve_cli_t:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _async_run_sync(
|
||||||
|
cache_dir: pathlib.Path,
|
||||||
|
source: str,
|
||||||
|
since: Optional[str],
|
||||||
|
months: Optional[int],
|
||||||
|
nvd_api_key: Optional[str],
|
||||||
|
timeout: Optional[int],
|
||||||
|
) -> sync_result_t:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
shutdown = asyncio.Event()
|
||||||
|
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
loop.add_signal_handler(sig, shutdown.set)
|
||||||
|
|
||||||
|
sync_task = asyncio.create_task(cve_cli_t._async_sync(
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
source=source,
|
||||||
|
since=since,
|
||||||
|
months=months,
|
||||||
|
nvd_api_key=nvd_api_key,
|
||||||
|
))
|
||||||
|
|
||||||
|
elapsed = 0.0
|
||||||
|
while not sync_task.done():
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(shutdown.wait(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if shutdown.is_set():
|
||||||
|
sync_task.cancel()
|
||||||
|
logger.warning(dict(msg='sync interrupted'))
|
||||||
|
return sync_result_t(error='interrupted')
|
||||||
|
|
||||||
|
elapsed += 1.0
|
||||||
|
if timeout is not None and elapsed >= timeout:
|
||||||
|
sync_task.cancel()
|
||||||
|
logger.warning(dict(msg='sync timeout', timeout=timeout))
|
||||||
|
return sync_result_t(error='timeout after %ds' % timeout)
|
||||||
|
|
||||||
|
return sync_task.result()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sync(
|
def sync(
|
||||||
cache_dir: pathlib.Path,
|
cache_dir: pathlib.Path,
|
||||||
@ -135,26 +169,14 @@ class cve_cli_t:
|
|||||||
nvd_api_key: Optional[str] = None,
|
nvd_api_key: Optional[str] = None,
|
||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
coro = cve_cli_t._async_sync(
|
result = asyncio.run(cve_cli_t._async_run_sync(
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
source=source,
|
source=source,
|
||||||
since=since,
|
since=since,
|
||||||
months=months,
|
months=months,
|
||||||
nvd_api_key=nvd_api_key,
|
nvd_api_key=nvd_api_key,
|
||||||
)
|
timeout=timeout,
|
||||||
|
))
|
||||||
if timeout is not None:
|
|
||||||
async def _with_timeout() -> sync_result_t:
|
|
||||||
try:
|
|
||||||
return await asyncio.wait_for(coro, timeout=timeout)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(dict(msg='sync timeout', timeout=timeout))
|
|
||||||
return sync_result_t(error='timeout after %ds' % timeout)
|
|
||||||
|
|
||||||
result = asyncio.run(_with_timeout())
|
|
||||||
else:
|
|
||||||
result = asyncio.run(coro)
|
|
||||||
|
|
||||||
render(result, fmt)
|
render(result, fmt)
|
||||||
return 2 if result.error is not None else 0
|
return 2 if result.error is not None else 0
|
||||||
|
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class cve_db_t(orm_module_t):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def schema_version(cls) -> int:
|
def schema_version(cls) -> int:
|
||||||
return 2
|
return 4
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def migrate(cls, conn: sqlite3.Connection, from_version: int, to_version: int) -> None:
|
def migrate(cls, conn: sqlite3.Connection, from_version: int, to_version: int) -> None:
|
||||||
@ -92,6 +92,28 @@ class cve_db_t(orm_module_t):
|
|||||||
""")
|
""")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
if from_version < 3:
|
||||||
|
conn.executescript("""
|
||||||
|
CREATE TABLE IF NOT EXISTS cve_sync_days (
|
||||||
|
source TEXT NOT NULL,
|
||||||
|
day TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL DEFAULT 'complete',
|
||||||
|
PRIMARY KEY (source, day)
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
if from_version < 4:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR IGNORE INTO cve_sync_days (source, day, status)
|
||||||
|
SELECT DISTINCT 'nvd', SUBSTR(date_modified, 1, 10), 'complete'
|
||||||
|
FROM cve_entries
|
||||||
|
WHERE source = 'nvd' AND date_modified != '' AND LENGTH(date_modified) >= 10
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
def __init__(self, db_path_or_conn: 'pathlib.Path | sqlite3.Connection') -> None:
|
def __init__(self, db_path_or_conn: 'pathlib.Path | sqlite3.Connection') -> None:
|
||||||
if isinstance(db_path_or_conn, sqlite3.Connection):
|
if isinstance(db_path_or_conn, sqlite3.Connection):
|
||||||
super().__init__(db_path_or_conn)
|
super().__init__(db_path_or_conn)
|
||||||
@ -394,5 +416,58 @@ class cve_db_t(orm_module_t):
|
|||||||
row = self._conn.execute('SELECT COUNT(*) FROM cve_osv_ecosystems').fetchone()
|
row = self._conn.execute('SELECT COUNT(*) FROM cve_osv_ecosystems').fetchone()
|
||||||
return row[0] if row else 0
|
return row[0] if row else 0
|
||||||
|
|
||||||
|
def mark_days_complete(self, source: cve_source_t, days: list[datetime.date]) -> None:
|
||||||
|
cur = self._conn.cursor()
|
||||||
|
for d in days:
|
||||||
|
cur.execute(
|
||||||
|
'''
|
||||||
|
INSERT INTO cve_sync_days (source, day, status)
|
||||||
|
VALUES (?, ?, 'complete')
|
||||||
|
ON CONFLICT(source, day) DO UPDATE SET status = 'complete'
|
||||||
|
''',
|
||||||
|
(source.value, d.isoformat()),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
|
||||||
|
def get_complete_days(self, source: cve_source_t) -> set[datetime.date]:
|
||||||
|
rows = self._conn.execute(
|
||||||
|
"SELECT day FROM cve_sync_days WHERE source = ? AND status = 'complete'",
|
||||||
|
(source.value,),
|
||||||
|
).fetchall()
|
||||||
|
result: set[datetime.date] = set()
|
||||||
|
for r in rows:
|
||||||
|
try:
|
||||||
|
result.add(datetime.date.fromisoformat(str(r[0])))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return result
|
||||||
|
|
||||||
|
def compute_missing_ranges(
|
||||||
|
self,
|
||||||
|
source: cve_source_t,
|
||||||
|
start: datetime.date,
|
||||||
|
end: datetime.date,
|
||||||
|
) -> list[tuple[datetime.date, datetime.date]]:
|
||||||
|
"""Given a target [start, end], return list of contiguous (start, end) gaps not yet synced."""
|
||||||
|
complete = self.get_complete_days(source)
|
||||||
|
ranges: list[tuple[datetime.date, datetime.date]] = []
|
||||||
|
gap_start: Optional[datetime.date] = None
|
||||||
|
day = start
|
||||||
|
|
||||||
|
while day <= end:
|
||||||
|
if day not in complete:
|
||||||
|
if gap_start is None:
|
||||||
|
gap_start = day
|
||||||
|
else:
|
||||||
|
if gap_start is not None:
|
||||||
|
ranges.append((gap_start, day - datetime.timedelta(days=1)))
|
||||||
|
gap_start = None
|
||||||
|
day += datetime.timedelta(days=1)
|
||||||
|
|
||||||
|
if gap_start is not None:
|
||||||
|
ranges.append((gap_start, end))
|
||||||
|
|
||||||
|
return ranges
|
||||||
|
|
||||||
|
|
||||||
orm_registry_t.register(cve_db_t)
|
orm_registry_t.register(cve_db_t)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user