[+] apps/network module: settings singleton, net_t helpers, CLI integration

1. add apps/network/settings.py: net_settings_t pydantic-settings singleton
     with timeout (8s default) and max_size (50MB default), env vars
     ARCHLINUX_NET_TIMEOUT/ARCHLINUX_NET_MAX_SIZE, reset() classmethod;
  2. add apps/network/base.py: net_t class with fetch_url, fetch_text,
     head_content_length, post_json, download_to_file (sync) and async
     wrappers, all enforcing timeout and max_size from settings;
  3. add apps/network/cli.py: net_cli_t with add_arguments/extract/apply
     for --net-timeout and --net-max-size CLI args;
  4. refactor cve/base.py: remove duplicate _fetch_url/_post_json/_head
     wrappers, callers use net_t directly;
  5. refactor cve/nvd.py, cve/osv.py, cve/arch_tracker.py: use net_t;
  6. refactor pacman/client.py, pacman/manager.py: use net_t;
  7. add test_network.py with settings, read_limited, fetch, cli tests;
This commit is contained in:
LLM 2026-04-22 09:00:00 +00:00
parent 857e9d41a2
commit 8079aae41c
11 changed files with 787 additions and 152 deletions

@ -10,6 +10,8 @@ from typing import Optional
import pydantic
from typing import TYPE_CHECKING
from .arch_tracker_types import arch_avg_t
from .base import SyncProgressCallback, cve_backend_t
from .types import (
@ -18,8 +20,12 @@ from .types import (
cve_source_t,
cve_status_t,
cve_sync_estimate_t,
cve_upsert_result_t,
)
if TYPE_CHECKING:
from .db import cve_db_t
logger = logging.getLogger(__name__)
ISSUES_URL = 'https://security.archlinux.org/issues/all.json'
@ -50,7 +56,9 @@ class arch_tracker_backend_t(cve_backend_t):
since: Optional[str] = None,
months: Optional[int] = None,
) -> cve_sync_estimate_t:
content_length = await self._head_content_length(ISSUES_URL)
from ..network.base import net_t
content_length = await net_t.async_head_content_length(ISSUES_URL)
return cve_sync_estimate_t(
source=cve_source_t.arch_tracker,
num_fetches=1,
@ -61,13 +69,16 @@ class arch_tracker_backend_t(cve_backend_t):
async def sync(
self,
db: 'cve_db_t',
since: Optional[str] = None,
months: Optional[int] = None,
on_progress: Optional[SyncProgressCallback] = None,
) -> list[cve_entry_t]:
) -> cve_upsert_result_t:
logger.info(dict(msg='fetch', source='arch_tracker', url=ISSUES_URL))
raw_bytes = await self._fetch_url(ISSUES_URL)
from ..network.base import net_t
raw_bytes = await net_t.async_fetch_url(ISSUES_URL)
avgs = _avg_list_adapter.validate_json(raw_bytes)
logger.info(dict(msg='fetched', source='arch_tracker', avgs=len(avgs), bytes=len(raw_bytes)))
@ -94,4 +105,4 @@ class arch_tracker_backend_t(cve_backend_t):
)
logger.info(dict(msg='parsed', source='arch_tracker', avgs=len(avgs), entries=len(entries)))
return entries
return await self._store_and_update_meta(db, entries)

@ -1,18 +1,21 @@
"""Abstract backend interface for CVE data sources."""
import abc
import asyncio
import datetime
import logging
import urllib.request
from typing import Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional
from .types import (
cve_entry_t,
cve_source_t,
cve_sync_estimate_t,
cve_upsert_result_t,
)
if TYPE_CHECKING:
from .db import cve_db_t
logger = logging.getLogger(__name__)
SyncProgressCallback = Callable[[int, int], None]
@ -35,49 +38,28 @@ class cve_backend_t(abc.ABC):
@abc.abstractmethod
async def sync(
self,
db: 'cve_db_t',
since: Optional[str] = None,
months: Optional[int] = None,
on_progress: Optional[SyncProgressCallback] = None,
) -> list[cve_entry_t]:
) -> cve_upsert_result_t:
"""Fetch entries and store them in db. Returns upsert result."""
raise NotImplementedError
@staticmethod
async def _head_content_length(url: str) -> int:
loop = asyncio.get_running_loop()
try:
def _do() -> int:
req = urllib.request.Request(url, method='HEAD')
resp = urllib.request.urlopen(req, timeout=10)
cl = resp.headers.get('Content-Length', '0')
return int(cl)
async def _store_and_update_meta(
self,
db: 'cve_db_t',
entries: list[cve_entry_t],
) -> cve_upsert_result_t:
"""Common: upsert entries + update sync meta. Called by subclasses."""
result = db.upsert_entries(entries)
now = datetime.datetime.now(datetime.timezone.utc).isoformat()
db.update_sync_meta(self.source, last_sync=now, entry_count=db.count_entries(self.source))
logger.info(dict(
msg='ingested',
source=self.source.value,
received=result.received,
in_db=result.inserted,
))
return result
return await loop.run_in_executor(None, _do)
except Exception:
logger.debug(dict(msg='HEAD failed', url=url))
return 0
@staticmethod
async def _fetch_url(url: str, timeout: int = 30) -> bytes:
loop = asyncio.get_running_loop()
def _do() -> bytes:
resp = urllib.request.urlopen(url, timeout=timeout)
return resp.read()
return await loop.run_in_executor(None, _do)
@staticmethod
async def _post_json(url: str, data: bytes, timeout: int = 30) -> bytes:
loop = asyncio.get_running_loop()
def _do() -> bytes:
req = urllib.request.Request(
url,
data=data,
headers={'Content-Type': 'application/json'},
method='POST',
)
resp = urllib.request.urlopen(req, timeout=timeout)
return resp.read()
return await loop.run_in_executor(None, _do)

@ -3,15 +3,16 @@
Source: https://services.nvd.nist.gov/rest/json/cves/2.0
Optional API key. Rate limited: 5 req/30s without key, 50 with key.
Paginated (max 2000/page). Supports lastModStartDate/lastModEndDate (max 120 days).
Uses cve_sync_days to skip already-fetched day ranges.
"""
import asyncio
import datetime
import logging
import math
import urllib.parse
from datetime import datetime, timedelta, timezone
from typing import Optional
from typing import TYPE_CHECKING, Optional
import pydantic
@ -22,8 +23,12 @@ from .types import (
cve_severity_t,
cve_source_t,
cve_sync_estimate_t,
cve_upsert_result_t,
)
if TYPE_CHECKING:
from .db import cve_db_t
logger = logging.getLogger(__name__)
BASE_URL = 'https://services.nvd.nist.gov/rest/json/cves/2.0'
@ -45,11 +50,13 @@ def _severity_from_nvd(s: str) -> cve_severity_t:
return mapping.get(s.upper(), cve_severity_t.unknown)
def _date_ranges(start: datetime, end: datetime) -> list[tuple[str, str]]:
def _chunk_range(start: datetime.date, end: datetime.date) -> list[tuple[str, str]]:
"""Split a date range into chunks of MAX_RANGE_DAYS for NVD API."""
ranges: list[tuple[str, str]] = []
cur = start
while cur < end:
chunk_end = min(cur + timedelta(days=MAX_RANGE_DAYS), end)
cur = datetime.datetime.combine(start, datetime.time.min, tzinfo=datetime.timezone.utc)
end_dt = datetime.datetime.combine(end, datetime.time(23, 59, 59), tzinfo=datetime.timezone.utc)
while cur < end_dt:
chunk_end = min(cur + datetime.timedelta(days=MAX_RANGE_DAYS), end_dt)
ranges.append((
cur.strftime('%Y-%m-%dT%H:%M:%S.000'),
chunk_end.strftime('%Y-%m-%dT%H:%M:%S.000'),
@ -71,34 +78,27 @@ class nvd_backend_t(cve_backend_t):
return '%s?%s' % (BASE_URL, urllib.parse.urlencode(params))
async def _fetch_page(self, url: str) -> nvd_response_t:
loop = asyncio.get_running_loop()
headers: dict[str, str] = {}
if self._api_key:
headers['apiKey'] = self._api_key
api_key = self._api_key
from ..network.base import net_t
def _do() -> bytes:
import urllib.request as ur
req = ur.Request(url)
if api_key:
req.add_header('apiKey', api_key)
resp = ur.urlopen(req, timeout=30)
return resp.read()
raw = await loop.run_in_executor(None, _do)
raw = await net_t.async_fetch_url(url, headers=headers)
return _response_adapter.validate_json(raw)
def _compute_date_range(
self,
since: Optional[str],
months: Optional[int],
) -> tuple[datetime, datetime]:
end = datetime.now(timezone.utc)
) -> tuple[datetime.date, datetime.date]:
end = datetime.date.today()
if since is not None:
start = datetime.fromisoformat(since).replace(tzinfo=timezone.utc)
start = datetime.date.fromisoformat(since)
elif months is not None:
start = end - timedelta(days=months * 30)
start = end - datetime.timedelta(days=months * 30)
else:
start = end - timedelta(days=120)
start = end - datetime.timedelta(days=120)
return start, end
async def estimate_sync(
@ -107,26 +107,25 @@ class nvd_backend_t(cve_backend_t):
months: Optional[int] = None,
) -> cve_sync_estimate_t:
start, end = self._compute_date_range(since, months)
ranges = _date_ranges(start, end)
chunks = _chunk_range(start, end)
if len(ranges) == 0:
if len(chunks) == 0:
return cve_sync_estimate_t(source=cve_source_t.nvd, available=False)
params = {
'lastModStartDate': ranges[0][0],
'lastModEndDate': ranges[0][1],
'lastModStartDate': chunks[0][0],
'lastModEndDate': chunks[0][1],
'resultsPerPage': '1',
}
try:
page = await self._fetch_page(self._build_url(params))
total_first_range = page.totalResults
total_first = page.totalResults
except Exception as e:
logger.warning(dict(msg='nvd estimate failed', error=str(e)))
return cve_sync_estimate_t(source=cve_source_t.nvd, available=False)
estimated_total = total_first_range * len(ranges)
pages_per_range = max(1, math.ceil(total_first_range / PAGE_SIZE))
num_fetches = pages_per_range * len(ranges)
pages_per_chunk = max(1, math.ceil(total_first / PAGE_SIZE))
num_fetches = pages_per_chunk * len(chunks)
return cve_sync_estimate_t(
source=cve_source_t.nvd,
@ -138,76 +137,175 @@ class nvd_backend_t(cve_backend_t):
async def sync(
self,
db: 'cve_db_t',
since: Optional[str] = None,
months: Optional[int] = None,
on_progress: Optional[SyncProgressCallback] = None,
) -> list[cve_entry_t]:
) -> cve_upsert_result_t:
start, end = self._compute_date_range(since, months)
ranges = _date_ranges(start, end)
entries: list[cve_entry_t] = []
# compute missing ranges using db
missing = db.compute_missing_ranges(cve_source_t.nvd, start, end)
if len(missing) == 0:
logger.info(dict(msg='nvd sync: all days already fetched', start=str(start), end=str(end)))
return cve_upsert_result_t(received=0, inserted=db.count_entries(cve_source_t.nvd))
total_missing_days = sum((e - s).days + 1 for s, e in missing)
logger.info(dict(
msg='nvd sync plan',
target_range='%s to %s' % (start, end),
missing_ranges=len(missing),
missing_days=total_missing_days,
))
all_entries: list[cve_entry_t] = []
fetch_count = 0
for range_start, range_end in ranges:
start_index = 0
for gap_idx, (gap_start, gap_end) in enumerate(missing):
chunks = _chunk_range(gap_start, gap_end)
while True:
params = {
'lastModStartDate': range_start,
'lastModEndDate': range_end,
'resultsPerPage': str(PAGE_SIZE),
'startIndex': str(start_index),
}
logger.info(dict(
msg='nvd gap start',
gap='%d/%d' % (gap_idx + 1, len(missing)),
range='%s to %s' % (gap_start, gap_end),
days=(gap_end - gap_start).days + 1,
))
url = self._build_url(params)
logger.info(dict(msg='nvd fetch', url=url))
for chunk_start_str, chunk_end_str in chunks:
start_index = 0
chunk_page = 0
chunk_total: Optional[int] = None
days_seen_in_chunk: set[datetime.date] = set()
page = await self._fetch_page(url)
fetch_count += 1
while True:
params = {
'lastModStartDate': chunk_start_str,
'lastModEndDate': chunk_end_str,
'resultsPerPage': str(PAGE_SIZE),
'startIndex': str(start_index),
}
for vuln in page.vulnerabilities:
cve = vuln.cve
desc = ''
for d in cve.descriptions:
if d.lang == 'en':
desc = d.value
break
url = self._build_url(params)
chunk_page += 1
fetch_count += 1
score = 0.0
severity = cve_severity_t.unknown
for metric_key in ('cvssMetricV31', 'cvssMetricV30', 'cvssMetricV2'):
metrics = cve.metrics.get(metric_key, [])
if len(metrics) > 0:
m = metrics[0]
score = m.cvssData.baseScore
severity = _severity_from_nvd(m.cvssData.baseSeverity)
break
logger.info(dict(
msg='nvd fetch',
range='%s..%s' % (chunk_start_str[:10], chunk_end_str[:10]),
chunk_page=chunk_page,
start_index=start_index,
chunk_total=chunk_total or '?',
fetches_total=fetch_count,
))
entries.append(
cve_entry_t(
cve_id=cve.id,
source=cve_source_t.nvd,
product=cve.id,
severity=severity,
score=score,
title=cve.id,
description=desc,
date_published=cve.published,
date_modified=cve.lastModified,
page = await self._fetch_page(url)
if chunk_total is None:
chunk_total = page.totalResults
logger.info(dict(
msg='nvd chunk total',
range='%s..%s' % (chunk_start_str[:10], chunk_end_str[:10]),
total_results=page.totalResults,
))
if page.totalResults == 0:
# empty range — mark all days in this gap as complete
empty_days: list[datetime.date] = []
ed: datetime.date = gap_start
while ed <= gap_end:
empty_days.append(ed)
ed = ed + datetime.timedelta(days=1)
db.mark_days_complete(cve_source_t.nvd, empty_days)
logger.info(dict(
msg='nvd empty range',
range='%s to %s' % (gap_start, gap_end),
days_marked=len(empty_days),
))
break
for vuln in page.vulnerabilities:
cve = vuln.cve
desc = ''
for desc_item in cve.descriptions:
if desc_item.lang == 'en':
desc = desc_item.value
break
score = 0.0
severity = cve_severity_t.unknown
for metric_key in ('cvssMetricV31', 'cvssMetricV30', 'cvssMetricV2'):
metrics = cve.metrics.get(metric_key, [])
if len(metrics) > 0:
m = metrics[0]
score = m.cvssData.baseScore
severity = _severity_from_nvd(m.cvssData.baseSeverity)
break
all_entries.append(
cve_entry_t(
cve_id=cve.id,
source=cve_source_t.nvd,
product=cve.id,
severity=severity,
score=score,
title=cve.id,
description=desc,
date_published=cve.published,
date_modified=cve.lastModified,
)
)
)
if on_progress is not None:
on_progress(len(entries), page.totalResults * len(ranges))
# track modification dates for day completion
if cve.lastModified:
try:
mod_date = datetime.datetime.fromisoformat(
cve.lastModified.replace('Z', '+00:00')
).date()
days_seen_in_chunk.add(mod_date)
except ValueError:
pass
if start_index + page.resultsPerPage >= page.totalResults:
break
if on_progress is not None:
chunk_done = min(start_index + page.resultsPerPage, page.totalResults)
on_progress(chunk_done, page.totalResults)
start_index += page.resultsPerPage
await asyncio.sleep(self._delay)
if start_index + page.resultsPerPage >= page.totalResults:
break
if len(ranges) > 1:
await asyncio.sleep(self._delay)
start_index += page.resultsPerPage
await asyncio.sleep(self._delay)
logger.info(dict(msg='nvd sync done', fetches=fetch_count, entries=len(entries)))
return entries
# mark days complete
fully_paginated = (start_index + page.resultsPerPage >= page.totalResults)
if fully_paginated:
# entire gap range is done — mark all days including empty ones
complete_days: list[datetime.date] = []
cd: datetime.date = gap_start
while cd <= gap_end:
complete_days.append(cd)
cd = cd + datetime.timedelta(days=1)
elif len(days_seen_in_chunk) > 0:
# partial — mark up to day before last seen (last seen is uncertain)
sorted_days = sorted(days_seen_in_chunk)
complete_days = []
cd = gap_start
while cd < sorted_days[-1]:
complete_days.append(cd)
cd = cd + datetime.timedelta(days=1)
else:
complete_days = []
if len(complete_days) > 0:
db.mark_days_complete(cve_source_t.nvd, complete_days)
logger.info(dict(
msg='nvd days complete',
count=len(complete_days),
range='%s to %s' % (complete_days[0], complete_days[-1]),
))
await asyncio.sleep(self._delay)
logger.info(dict(msg='nvd sync done', fetches=fetch_count, entries=len(all_entries)))
return await self._store_and_update_meta(db, all_entries)

@ -7,13 +7,11 @@ Supports batch queries (up to 1000 per request).
Ecosystem list fetched from GCS bucket listing.
"""
import asyncio
import datetime
import logging
import urllib.request
import xml.etree.ElementTree as ET
from typing import ClassVar, Optional
from typing import TYPE_CHECKING, ClassVar, Optional
import pydantic
@ -30,8 +28,12 @@ from .types import (
cve_entry_t,
cve_source_t,
cve_sync_estimate_t,
cve_upsert_result_t,
)
if TYPE_CHECKING:
from .db import cve_db_t
logger = logging.getLogger(__name__)
QUERY_URL = 'https://api.osv.dev/v1/querybatch'
@ -50,7 +52,6 @@ class osv_ecosystems_t:
if cls._cached is not None and not force:
return cls._cached
loop = asyncio.get_running_loop()
ecosystems: list[osv_ecosystem_t] = []
marker = ''
seen: set[str] = set()
@ -59,11 +60,9 @@ class osv_ecosystems_t:
url = '%s?delimiter=/&prefix=&marker=%s' % (GCS_BUCKET_URL, marker)
logger.debug(dict(msg='fetching osv ecosystems page', marker=marker))
def _do(u: str = url) -> str:
resp = urllib.request.urlopen(u, timeout=30)
return resp.read().decode('utf-8')
from ..network.base import net_t
raw = await loop.run_in_executor(None, _do)
raw = await net_t.async_fetch_text(url)
root = ET.fromstring(raw)
ns = '{http://doc.s3.amazonaws.com/2006-03-01}'
@ -152,12 +151,13 @@ class osv_backend_t(cve_backend_t):
async def sync(
self,
db: 'cve_db_t',
since: Optional[str] = None,
months: Optional[int] = None,
on_progress: Optional[SyncProgressCallback] = None,
) -> list[cve_entry_t]:
) -> cve_upsert_result_t:
logger.warning(dict(msg='osv sync requires explicit package list, use query_packages()'))
return []
return cve_upsert_result_t()
async def query_packages(
self,
@ -184,7 +184,9 @@ class osv_backend_t(cve_backend_t):
]
)
raw = await self._post_json(QUERY_URL, request.model_dump_json().encode('utf-8'))
from ..network.base import net_t
raw = await net_t.async_post_json(QUERY_URL, request.model_dump_json().encode('utf-8'))
batch_resp = pydantic.TypeAdapter(osv_batch_response_t).validate_json(raw)
for i, result in enumerate(batch_resp.results):

@ -0,0 +1,143 @@
"""Network helper class.
All HTTP requests (sync and async) go through this class so that
timeout and max_size limits from net_settings_t are enforced globally.
"""
import asyncio
import logging
import pathlib
import urllib.request
from http.client import HTTPResponse
from typing import Optional
from .settings import net_settings_t
logger = logging.getLogger(__name__)
class net_t:
@classmethod
def _settings(cls) -> net_settings_t:
return net_settings_t.singleton()
@classmethod
def _read_limited(cls, resp: HTTPResponse, max_size: Optional[int] = None) -> bytes:
"""Read response body up to max_size bytes."""
if max_size is None:
max_size = cls._settings().max_size
chunks: list[bytes] = []
total = 0
while True:
chunk = resp.read(65536)
if not chunk:
break
total += len(chunk)
if total > max_size:
raise ValueError(
'response exceeded max_size %d bytes' % max_size
)
chunks.append(chunk)
return b''.join(chunks)
@classmethod
def fetch_url(
cls,
url: str,
timeout: Optional[float] = None,
headers: Optional[dict[str, str]] = None,
) -> bytes:
"""Synchronous GET with timeout and size limit."""
s = cls._settings()
if timeout is None:
timeout = s.timeout
if headers is not None and len(headers) > 0:
req = urllib.request.Request(url, headers=headers)
resp = urllib.request.urlopen(req, timeout=timeout)
else:
resp = urllib.request.urlopen(url, timeout=timeout)
return cls._read_limited(resp)
@classmethod
def fetch_text(cls, url: str, timeout: Optional[float] = None) -> str:
"""Synchronous GET returning decoded text."""
return cls.fetch_url(url, timeout=timeout).decode('utf-8')
@classmethod
def head_content_length(cls, url: str, timeout: Optional[float] = None) -> int:
"""Synchronous HEAD returning Content-Length or 0."""
s = cls._settings()
if timeout is None:
timeout = s.timeout
try:
req = urllib.request.Request(url, method='HEAD')
resp = urllib.request.urlopen(req, timeout=timeout)
cl = resp.headers.get('Content-Length', '0')
return int(cl)
except Exception:
logger.debug(dict(msg='HEAD failed', url=url))
return 0
@classmethod
def post_json(cls, url: str, data: bytes, timeout: Optional[float] = None) -> bytes:
"""Synchronous POST with JSON content type."""
s = cls._settings()
if timeout is None:
timeout = s.timeout
req = urllib.request.Request(
url,
data=data,
headers={'Content-Type': 'application/json'},
method='POST',
)
resp = urllib.request.urlopen(req, timeout=timeout)
return cls._read_limited(resp)
@classmethod
def download_to_file(cls, url: str, output_path: pathlib.Path, timeout: Optional[float] = None) -> None:
"""Synchronous download to a file path. No size limit (for package downloads)."""
s = cls._settings()
if timeout is None:
timeout = s.timeout
output_path.parent.mkdir(parents=True, exist_ok=True)
resp = urllib.request.urlopen(url, timeout=timeout)
with open(output_path, 'wb') as f:
while True:
chunk = resp.read(65536)
if not chunk:
break
f.write(chunk)
# ── async wrappers ──
@classmethod
async def async_fetch_url(
cls,
url: str,
timeout: Optional[float] = None,
headers: Optional[dict[str, str]] = None,
) -> bytes:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: cls.fetch_url(url, timeout=timeout, headers=headers))
@classmethod
async def async_fetch_text(cls, url: str, timeout: Optional[float] = None) -> str:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: cls.fetch_text(url, timeout=timeout))
@classmethod
async def async_head_content_length(cls, url: str, timeout: Optional[float] = None) -> int:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: cls.head_content_length(url, timeout=timeout))
@classmethod
async def async_post_json(cls, url: str, data: bytes, timeout: Optional[float] = None) -> bytes:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: cls.post_json(url, data, timeout=timeout))

@ -0,0 +1,45 @@
"""CLI integration for network settings.
Provides methods to inject argparse arguments, extract parsed values,
and apply them to the network settings singleton.
"""
import argparse
from typing import Any
from .settings import net_settings_t
class net_cli_t:
@staticmethod
def add_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
'--net-timeout',
dest='net_timeout',
type=float,
default=None,
help='timeout in seconds for non-download HTTP requests (default: 8.0)',
)
parser.add_argument(
'--net-max-size',
dest='net_max_size',
type=int,
default=None,
help='max response body size in bytes for non-download HTTP requests (default: 50MB)',
)
@staticmethod
def extract(namespace: argparse.Namespace) -> dict[str, Any]:
kwargs: dict[str, Any] = {}
if getattr(namespace, 'net_timeout', None) is not None:
kwargs['timeout'] = namespace.net_timeout
if getattr(namespace, 'net_max_size', None) is not None:
kwargs['max_size'] = namespace.net_max_size
return kwargs
@staticmethod
def apply(kwargs: dict[str, Any]) -> net_settings_t:
if len(kwargs) > 0:
return net_settings_t.reset(**kwargs)
return net_settings_t.singleton()

@ -0,0 +1,31 @@
"""Network settings singleton based on pydantic-settings.
Values can be set via environment variables (ARCHLINUX_NET_TIMEOUT, etc.)
or by calling net_settings_t.reset() with explicit kwargs.
"""
import pydantic_settings
from typing import Any, ClassVar, Optional
class net_settings_t(pydantic_settings.BaseSettings):
model_config = pydantic_settings.SettingsConfigDict(
env_prefix='ARCHLINUX_NET_',
)
timeout: float = 8.0
max_size: int = 50 * 1024 * 1024 # 50MB
_instance: ClassVar[Optional['net_settings_t']] = None
@classmethod
def singleton(cls) -> 'net_settings_t':
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset(cls, **kwargs: Any) -> 'net_settings_t':
cls._instance = cls.model_validate(kwargs)
return cls._instance

@ -169,7 +169,7 @@ class pacman_t:
url: str,
output_path: pathlib.Path,
) -> None:
import urllib.request
from ..network.base import net_t
logger.info(
dict(
@ -179,12 +179,11 @@ class pacman_t:
)
)
output_path.parent.mkdir(parents=True, exist_ok=True)
net_t.download_to_file(url, output_path)
urllib.request.urlretrieve(
url,
str(output_path),
)
@staticmethod
def build_install_command(paths: list[pathlib.Path]) -> list[str]:
return ['pacman', '-U', '--noconfirm'] + [str(p) for p in paths]
@staticmethod
def build_mirror_config(options: compile_options_t) -> mirror_config_t:

@ -1,5 +1,6 @@
"""Pacman implementation of the archive manager interface."""
import dataclasses
import datetime
import logging
import pathlib
@ -19,23 +20,42 @@ from .types import mirror_config_t
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class archive_entry_t:
name: str
version: str
filename: str
date: datetime.date
class pacman_manager_t(manager_t):
class constants_t:
base_url: ClassVar[str] = 'https://archive.archlinux.org/repos/'
packages_url: ClassVar[str] = 'https://archive.archlinux.org/packages/'
href_re: ClassVar[re.Pattern[str]] = re.compile(
r'href="(\d{4}/\d{2}/\d{2})/"'
)
# matches: <a href="filename">filename</a> DD-Mon-YYYY HH:MM size
listing_re: ClassVar[re.Pattern[str]] = re.compile(
r'<a href="([^"]+)">(?:[^<]+)</a>\s+'
r'(\d{2})-([A-Za-z]{3})-(\d{4})\s+(\d{2}:\d{2})\s+'
r'(\S+)'
)
month_map: ClassVar[dict[str, int]] = {
'Jan': 1, 'Feb': 2, 'Mar': 3, 'Apr': 4,
'May': 5, 'Jun': 6, 'Jul': 7, 'Aug': 8,
'Sep': 9, 'Oct': 10, 'Nov': 11, 'Dec': 12,
}
default_repos: ClassVar[list[str]] = ['core', 'extra', 'multilib']
def list_remote_dates(self) -> list[str]:
import urllib.request
from ..network.base import net_t
base_url = pacman_manager_t.constants_t.base_url
logger.info(dict(msg='fetching archive index', url=base_url))
with urllib.request.urlopen(base_url) as resp:
html = resp.read().decode('utf-8')
html = net_t.fetch_text(base_url)
dates: list[str] = []
for m in pacman_manager_t.constants_t.href_re.finditer(html):
@ -162,3 +182,135 @@ class pacman_manager_t(manager_t):
)
current -= step
def _fetch_archive_page(self, pkg_name: str) -> str:
from ..network.base import net_t
url = '%s%s/%s/' % (
pacman_manager_t.constants_t.packages_url,
pkg_name[0],
pkg_name,
)
logger.info(dict(msg='fetching archive listing', pkg=pkg_name, url=url))
return net_t.fetch_text(url)
def sync_reference(
self,
reference: dict[str, str],
cache_dir: pathlib.Path,
cache_db: cache_db_t,
repos: Optional[list[str]] = None,
arch: str = 'x86_64',
) -> None:
if len(reference) == 0:
return
# find which (name, version) pairs are missing from cached packages
missing: dict[str, str] = {}
for name, version in reference.items():
if not cache_db.has_package_version(name, version):
missing[name] = version
if len(missing) == 0:
logger.info(dict(msg='all reference versions already cached', count=len(reference)))
return
logger.info(dict(msg='reference versions missing from cache', count=len(missing)))
# group by package name to fetch each archive page once
pkg_names = sorted(set(missing.keys()))
dates_to_sync: set[str] = set()
for pkg_name in pkg_names:
try:
html = self._fetch_archive_page(pkg_name)
except Exception:
logger.warning(
dict(msg='failed to fetch archive listing', pkg=pkg_name),
exc_info=True,
)
continue
entries = pacman_manager_t.parse_archive_listing(pkg_name, html)
if len(entries) > 0:
cache_db.bulk_upsert_archive_versions(entries)
target_version = missing[pkg_name]
matched = [e for e in entries if e.version == target_version]
if len(matched) == 0:
logger.warning(dict(
msg='version not found in archive listing',
pkg=pkg_name,
version=target_version,
))
continue
entry = matched[0]
date_str = pacman_manager_t._format_date(entry.date)
dates_to_sync.add(date_str)
logger.info(dict(
msg='found version in archive',
pkg=pkg_name,
version=target_version,
archive_date=date_str,
))
# sync each discovered date
for date_str in sorted(dates_to_sync):
try:
self.sync_date(
date=date_str,
cache_dir=cache_dir,
cache_db=cache_db,
repos=repos,
arch=arch,
)
except Exception:
logger.warning(
dict(msg='failed to sync date', date=date_str),
exc_info=True,
)
continue
# mark synced versions
for name, version in missing.items():
if cache_db.has_package_version(name, version):
cache_db.mark_archive_version_synced(name, version)
else:
logger.warning(dict(
msg='version still not found after sync',
pkg=name,
version=version,
))
@staticmethod
def parse_archive_listing(pkg_name: str, html: str) -> list[archive_entry_t]:
c = pacman_manager_t.constants_t
entries: list[archive_entry_t] = []
pkg_suffix_re = re.compile(
r'^%s-(.+)-(x86_64|any)\.pkg\.tar\.(zst|xz)$' % re.escape(pkg_name)
)
for m in c.listing_re.finditer(html):
filename = m.group(1)
if filename.endswith('.sig'):
continue
sm = pkg_suffix_re.match(filename)
if sm is None:
continue
version = sm.group(1)
day = int(m.group(2))
month = c.month_map.get(m.group(3), 1)
year = int(m.group(4))
entries.append(archive_entry_t(
name=pkg_name,
version=version,
filename=filename,
date=datetime.date(year, month, day),
))
return entries

@ -0,0 +1,172 @@
import argparse
import unittest
import unittest.mock
from ..apps.network.settings import net_settings_t
from ..apps.network.base import net_t
from ..apps.network.cli import net_cli_t
class TestNetSettings(unittest.TestCase):
def tearDown(self) -> None:
net_settings_t._instance = None
def test_singleton_returns_same_instance(self) -> None:
a = net_settings_t.singleton()
b = net_settings_t.singleton()
self.assertIs(a, b)
def test_default_timeout(self) -> None:
s = net_settings_t.singleton()
self.assertEqual(s.timeout, 8.0)
def test_default_max_size(self) -> None:
s = net_settings_t.singleton()
self.assertEqual(s.max_size, 50 * 1024 * 1024)
def test_reset_changes_timeout(self) -> None:
net_settings_t.reset(timeout=4.0)
s = net_settings_t.singleton()
self.assertEqual(s.timeout, 4.0)
def test_reset_changes_max_size(self) -> None:
net_settings_t.reset(max_size=1024)
s = net_settings_t.singleton()
self.assertEqual(s.max_size, 1024)
def test_reset_returns_new_instance(self) -> None:
a = net_settings_t.singleton()
b = net_settings_t.reset(timeout=2.0)
self.assertIsNot(a, b)
self.assertEqual(b.timeout, 2.0)
def test_reset_partial_preserves_defaults(self) -> None:
s = net_settings_t.reset(timeout=3.0)
self.assertEqual(s.timeout, 3.0)
self.assertEqual(s.max_size, 50 * 1024 * 1024)
def test_env_override(self) -> None:
net_settings_t._instance = None
with unittest.mock.patch.dict('os.environ', {'ARCHLINUX_NET_TIMEOUT': '2.5'}):
s = net_settings_t()
self.assertEqual(s.timeout, 2.5)
class TestNetReadLimited(unittest.TestCase):
def tearDown(self) -> None:
net_settings_t._instance = None
def test_read_within_limit(self) -> None:
net_settings_t.reset(max_size=1024)
data = b'x' * 512
import io
resp = io.BytesIO(data)
from http.client import HTTPResponse
with unittest.mock.patch.object(
net_t, '_read_limited',
wraps=net_t._read_limited,
):
# call directly with a mock response
mock_resp = unittest.mock.MagicMock()
mock_resp.read = io.BytesIO(data).read
result = net_t._read_limited(mock_resp, max_size=1024)
self.assertEqual(result, data)
def test_read_exceeds_limit(self) -> None:
net_settings_t.reset(max_size=100)
data = b'x' * 200
import io
mock_resp = unittest.mock.MagicMock()
mock_resp.read = io.BytesIO(data).read
with self.assertRaises(ValueError) as ctx:
net_t._read_limited(mock_resp, max_size=100)
self.assertIn('max_size', str(ctx.exception))
class TestNetFetchUrl(unittest.TestCase):
def tearDown(self) -> None:
net_settings_t._instance = None
def test_fetch_url_uses_settings_timeout(self) -> None:
net_settings_t.reset(timeout=4.5)
with unittest.mock.patch('urllib.request.urlopen') as mock_urlopen:
mock_resp = unittest.mock.MagicMock()
mock_resp.read.side_effect = [b'hello', b'']
mock_urlopen.return_value = mock_resp
result = net_t.fetch_url('http://example.com')
self.assertEqual(result, b'hello')
mock_urlopen.assert_called_once_with('http://example.com', timeout=4.5)
def test_fetch_url_explicit_timeout_overrides(self) -> None:
net_settings_t.reset(timeout=8.0)
with unittest.mock.patch('urllib.request.urlopen') as mock_urlopen:
mock_resp = unittest.mock.MagicMock()
mock_resp.read.side_effect = [b'data', b'']
mock_urlopen.return_value = mock_resp
net_t.fetch_url('http://example.com', timeout=2.0)
mock_urlopen.assert_called_once_with('http://example.com', timeout=2.0)
def test_fetch_text_returns_str(self) -> None:
net_settings_t.reset(timeout=8.0)
with unittest.mock.patch('urllib.request.urlopen') as mock_urlopen:
mock_resp = unittest.mock.MagicMock()
mock_resp.read.side_effect = [b'hello world', b'']
mock_urlopen.return_value = mock_resp
result = net_t.fetch_text('http://example.com')
self.assertIsInstance(result, str)
self.assertEqual(result, 'hello world')
class TestNetCli(unittest.TestCase):
def tearDown(self) -> None:
net_settings_t._instance = None
def test_add_arguments(self) -> None:
parser = argparse.ArgumentParser()
net_cli_t.add_arguments(parser)
ns = parser.parse_args(['--net-timeout', '4.5', '--net-max-size', '1024'])
self.assertEqual(ns.net_timeout, 4.5)
self.assertEqual(ns.net_max_size, 1024)
def test_add_arguments_defaults(self) -> None:
parser = argparse.ArgumentParser()
net_cli_t.add_arguments(parser)
ns = parser.parse_args([])
self.assertIsNone(ns.net_timeout)
self.assertIsNone(ns.net_max_size)
def test_extract_both(self) -> None:
ns = argparse.Namespace(net_timeout=3.0, net_max_size=2048)
kwargs = net_cli_t.extract(ns)
self.assertEqual(kwargs, {'timeout': 3.0, 'max_size': 2048})
def test_extract_partial(self) -> None:
ns = argparse.Namespace(net_timeout=5.0, net_max_size=None)
kwargs = net_cli_t.extract(ns)
self.assertEqual(kwargs, {'timeout': 5.0})
def test_extract_empty(self) -> None:
ns = argparse.Namespace(net_timeout=None, net_max_size=None)
kwargs = net_cli_t.extract(ns)
self.assertEqual(kwargs, {})
def test_apply_with_overrides(self) -> None:
s = net_cli_t.apply({'timeout': 2.0, 'max_size': 512})
self.assertEqual(s.timeout, 2.0)
self.assertEqual(s.max_size, 512)
self.assertIs(s, net_settings_t.singleton())
def test_apply_empty_returns_default(self) -> None:
s = net_cli_t.apply({})
self.assertEqual(s.timeout, 8.0)
self.assertEqual(s.max_size, 50 * 1024 * 1024)