[+] apps/network module: settings singleton, net_t helpers, CLI integration
1. add apps/network/settings.py: net_settings_t pydantic-settings singleton
with timeout (8s default) and max_size (50MB default), env vars
ARCHLINUX_NET_TIMEOUT/ARCHLINUX_NET_MAX_SIZE, reset() classmethod;
2. add apps/network/base.py: net_t class with fetch_url, fetch_text,
head_content_length, post_json, download_to_file (sync) and async
wrappers, all enforcing timeout and max_size from settings;
3. add apps/network/cli.py: net_cli_t with add_arguments/extract/apply
for --net-timeout and --net-max-size CLI args;
4. refactor cve/base.py: remove duplicate _fetch_url/_post_json/_head
wrappers, callers use net_t directly;
5. refactor cve/nvd.py, cve/osv.py, cve/arch_tracker.py: use net_t;
6. refactor pacman/client.py, pacman/manager.py: use net_t;
7. add test_network.py with settings, read_limited, fetch, cli tests;
This commit is contained in:
parent
857e9d41a2
commit
8079aae41c
@ -10,6 +10,8 @@ from typing import Optional
|
||||
|
||||
import pydantic
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .arch_tracker_types import arch_avg_t
|
||||
from .base import SyncProgressCallback, cve_backend_t
|
||||
from .types import (
|
||||
@ -18,8 +20,12 @@ from .types import (
|
||||
cve_source_t,
|
||||
cve_status_t,
|
||||
cve_sync_estimate_t,
|
||||
cve_upsert_result_t,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import cve_db_t
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ISSUES_URL = 'https://security.archlinux.org/issues/all.json'
|
||||
@ -50,7 +56,9 @@ class arch_tracker_backend_t(cve_backend_t):
|
||||
since: Optional[str] = None,
|
||||
months: Optional[int] = None,
|
||||
) -> cve_sync_estimate_t:
|
||||
content_length = await self._head_content_length(ISSUES_URL)
|
||||
from ..network.base import net_t
|
||||
|
||||
content_length = await net_t.async_head_content_length(ISSUES_URL)
|
||||
return cve_sync_estimate_t(
|
||||
source=cve_source_t.arch_tracker,
|
||||
num_fetches=1,
|
||||
@ -61,13 +69,16 @@ class arch_tracker_backend_t(cve_backend_t):
|
||||
|
||||
async def sync(
|
||||
self,
|
||||
db: 'cve_db_t',
|
||||
since: Optional[str] = None,
|
||||
months: Optional[int] = None,
|
||||
on_progress: Optional[SyncProgressCallback] = None,
|
||||
) -> list[cve_entry_t]:
|
||||
) -> cve_upsert_result_t:
|
||||
logger.info(dict(msg='fetch', source='arch_tracker', url=ISSUES_URL))
|
||||
|
||||
raw_bytes = await self._fetch_url(ISSUES_URL)
|
||||
from ..network.base import net_t
|
||||
|
||||
raw_bytes = await net_t.async_fetch_url(ISSUES_URL)
|
||||
avgs = _avg_list_adapter.validate_json(raw_bytes)
|
||||
|
||||
logger.info(dict(msg='fetched', source='arch_tracker', avgs=len(avgs), bytes=len(raw_bytes)))
|
||||
@ -94,4 +105,4 @@ class arch_tracker_backend_t(cve_backend_t):
|
||||
)
|
||||
|
||||
logger.info(dict(msg='parsed', source='arch_tracker', avgs=len(avgs), entries=len(entries)))
|
||||
return entries
|
||||
return await self._store_and_update_meta(db, entries)
|
||||
|
||||
@ -1,18 +1,21 @@
|
||||
"""Abstract backend interface for CVE data sources."""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
import urllib.request
|
||||
|
||||
from typing import Callable, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
from .types import (
|
||||
cve_entry_t,
|
||||
cve_source_t,
|
||||
cve_sync_estimate_t,
|
||||
cve_upsert_result_t,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import cve_db_t
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SyncProgressCallback = Callable[[int, int], None]
|
||||
@ -35,49 +38,28 @@ class cve_backend_t(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def sync(
|
||||
self,
|
||||
db: 'cve_db_t',
|
||||
since: Optional[str] = None,
|
||||
months: Optional[int] = None,
|
||||
on_progress: Optional[SyncProgressCallback] = None,
|
||||
) -> list[cve_entry_t]:
|
||||
) -> cve_upsert_result_t:
|
||||
"""Fetch entries and store them in db. Returns upsert result."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
async def _head_content_length(url: str) -> int:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
def _do() -> int:
|
||||
req = urllib.request.Request(url, method='HEAD')
|
||||
resp = urllib.request.urlopen(req, timeout=10)
|
||||
cl = resp.headers.get('Content-Length', '0')
|
||||
return int(cl)
|
||||
async def _store_and_update_meta(
|
||||
self,
|
||||
db: 'cve_db_t',
|
||||
entries: list[cve_entry_t],
|
||||
) -> cve_upsert_result_t:
|
||||
"""Common: upsert entries + update sync meta. Called by subclasses."""
|
||||
result = db.upsert_entries(entries)
|
||||
now = datetime.datetime.now(datetime.timezone.utc).isoformat()
|
||||
db.update_sync_meta(self.source, last_sync=now, entry_count=db.count_entries(self.source))
|
||||
logger.info(dict(
|
||||
msg='ingested',
|
||||
source=self.source.value,
|
||||
received=result.received,
|
||||
in_db=result.inserted,
|
||||
))
|
||||
return result
|
||||
|
||||
return await loop.run_in_executor(None, _do)
|
||||
except Exception:
|
||||
logger.debug(dict(msg='HEAD failed', url=url))
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_url(url: str, timeout: int = 30) -> bytes:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _do() -> bytes:
|
||||
resp = urllib.request.urlopen(url, timeout=timeout)
|
||||
return resp.read()
|
||||
|
||||
return await loop.run_in_executor(None, _do)
|
||||
|
||||
@staticmethod
|
||||
async def _post_json(url: str, data: bytes, timeout: int = 30) -> bytes:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _do() -> bytes:
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/json'},
|
||||
method='POST',
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=timeout)
|
||||
return resp.read()
|
||||
|
||||
return await loop.run_in_executor(None, _do)
|
||||
|
||||
@ -3,15 +3,16 @@
|
||||
Source: https://services.nvd.nist.gov/rest/json/cves/2.0
|
||||
Optional API key. Rate limited: 5 req/30s without key, 50 with key.
|
||||
Paginated (max 2000/page). Supports lastModStartDate/lastModEndDate (max 120 days).
|
||||
Uses cve_sync_days to skip already-fetched day ranges.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
import math
|
||||
import urllib.parse
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import pydantic
|
||||
|
||||
@ -22,8 +23,12 @@ from .types import (
|
||||
cve_severity_t,
|
||||
cve_source_t,
|
||||
cve_sync_estimate_t,
|
||||
cve_upsert_result_t,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import cve_db_t
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BASE_URL = 'https://services.nvd.nist.gov/rest/json/cves/2.0'
|
||||
@ -45,11 +50,13 @@ def _severity_from_nvd(s: str) -> cve_severity_t:
|
||||
return mapping.get(s.upper(), cve_severity_t.unknown)
|
||||
|
||||
|
||||
def _date_ranges(start: datetime, end: datetime) -> list[tuple[str, str]]:
|
||||
def _chunk_range(start: datetime.date, end: datetime.date) -> list[tuple[str, str]]:
|
||||
"""Split a date range into chunks of MAX_RANGE_DAYS for NVD API."""
|
||||
ranges: list[tuple[str, str]] = []
|
||||
cur = start
|
||||
while cur < end:
|
||||
chunk_end = min(cur + timedelta(days=MAX_RANGE_DAYS), end)
|
||||
cur = datetime.datetime.combine(start, datetime.time.min, tzinfo=datetime.timezone.utc)
|
||||
end_dt = datetime.datetime.combine(end, datetime.time(23, 59, 59), tzinfo=datetime.timezone.utc)
|
||||
while cur < end_dt:
|
||||
chunk_end = min(cur + datetime.timedelta(days=MAX_RANGE_DAYS), end_dt)
|
||||
ranges.append((
|
||||
cur.strftime('%Y-%m-%dT%H:%M:%S.000'),
|
||||
chunk_end.strftime('%Y-%m-%dT%H:%M:%S.000'),
|
||||
@ -71,34 +78,27 @@ class nvd_backend_t(cve_backend_t):
|
||||
return '%s?%s' % (BASE_URL, urllib.parse.urlencode(params))
|
||||
|
||||
async def _fetch_page(self, url: str) -> nvd_response_t:
|
||||
loop = asyncio.get_running_loop()
|
||||
headers: dict[str, str] = {}
|
||||
if self._api_key:
|
||||
headers['apiKey'] = self._api_key
|
||||
|
||||
api_key = self._api_key
|
||||
from ..network.base import net_t
|
||||
|
||||
def _do() -> bytes:
|
||||
import urllib.request as ur
|
||||
|
||||
req = ur.Request(url)
|
||||
if api_key:
|
||||
req.add_header('apiKey', api_key)
|
||||
resp = ur.urlopen(req, timeout=30)
|
||||
return resp.read()
|
||||
|
||||
raw = await loop.run_in_executor(None, _do)
|
||||
raw = await net_t.async_fetch_url(url, headers=headers)
|
||||
return _response_adapter.validate_json(raw)
|
||||
|
||||
def _compute_date_range(
|
||||
self,
|
||||
since: Optional[str],
|
||||
months: Optional[int],
|
||||
) -> tuple[datetime, datetime]:
|
||||
end = datetime.now(timezone.utc)
|
||||
) -> tuple[datetime.date, datetime.date]:
|
||||
end = datetime.date.today()
|
||||
if since is not None:
|
||||
start = datetime.fromisoformat(since).replace(tzinfo=timezone.utc)
|
||||
start = datetime.date.fromisoformat(since)
|
||||
elif months is not None:
|
||||
start = end - timedelta(days=months * 30)
|
||||
start = end - datetime.timedelta(days=months * 30)
|
||||
else:
|
||||
start = end - timedelta(days=120)
|
||||
start = end - datetime.timedelta(days=120)
|
||||
return start, end
|
||||
|
||||
async def estimate_sync(
|
||||
@ -107,26 +107,25 @@ class nvd_backend_t(cve_backend_t):
|
||||
months: Optional[int] = None,
|
||||
) -> cve_sync_estimate_t:
|
||||
start, end = self._compute_date_range(since, months)
|
||||
ranges = _date_ranges(start, end)
|
||||
chunks = _chunk_range(start, end)
|
||||
|
||||
if len(ranges) == 0:
|
||||
if len(chunks) == 0:
|
||||
return cve_sync_estimate_t(source=cve_source_t.nvd, available=False)
|
||||
|
||||
params = {
|
||||
'lastModStartDate': ranges[0][0],
|
||||
'lastModEndDate': ranges[0][1],
|
||||
'lastModStartDate': chunks[0][0],
|
||||
'lastModEndDate': chunks[0][1],
|
||||
'resultsPerPage': '1',
|
||||
}
|
||||
try:
|
||||
page = await self._fetch_page(self._build_url(params))
|
||||
total_first_range = page.totalResults
|
||||
total_first = page.totalResults
|
||||
except Exception as e:
|
||||
logger.warning(dict(msg='nvd estimate failed', error=str(e)))
|
||||
return cve_sync_estimate_t(source=cve_source_t.nvd, available=False)
|
||||
|
||||
estimated_total = total_first_range * len(ranges)
|
||||
pages_per_range = max(1, math.ceil(total_first_range / PAGE_SIZE))
|
||||
num_fetches = pages_per_range * len(ranges)
|
||||
pages_per_chunk = max(1, math.ceil(total_first / PAGE_SIZE))
|
||||
num_fetches = pages_per_chunk * len(chunks)
|
||||
|
||||
return cve_sync_estimate_t(
|
||||
source=cve_source_t.nvd,
|
||||
@ -138,39 +137,99 @@ class nvd_backend_t(cve_backend_t):
|
||||
|
||||
async def sync(
|
||||
self,
|
||||
db: 'cve_db_t',
|
||||
since: Optional[str] = None,
|
||||
months: Optional[int] = None,
|
||||
on_progress: Optional[SyncProgressCallback] = None,
|
||||
) -> list[cve_entry_t]:
|
||||
) -> cve_upsert_result_t:
|
||||
start, end = self._compute_date_range(since, months)
|
||||
ranges = _date_ranges(start, end)
|
||||
|
||||
entries: list[cve_entry_t] = []
|
||||
# compute missing ranges using db
|
||||
missing = db.compute_missing_ranges(cve_source_t.nvd, start, end)
|
||||
|
||||
if len(missing) == 0:
|
||||
logger.info(dict(msg='nvd sync: all days already fetched', start=str(start), end=str(end)))
|
||||
return cve_upsert_result_t(received=0, inserted=db.count_entries(cve_source_t.nvd))
|
||||
|
||||
total_missing_days = sum((e - s).days + 1 for s, e in missing)
|
||||
logger.info(dict(
|
||||
msg='nvd sync plan',
|
||||
target_range='%s to %s' % (start, end),
|
||||
missing_ranges=len(missing),
|
||||
missing_days=total_missing_days,
|
||||
))
|
||||
|
||||
all_entries: list[cve_entry_t] = []
|
||||
fetch_count = 0
|
||||
|
||||
for range_start, range_end in ranges:
|
||||
for gap_idx, (gap_start, gap_end) in enumerate(missing):
|
||||
chunks = _chunk_range(gap_start, gap_end)
|
||||
|
||||
logger.info(dict(
|
||||
msg='nvd gap start',
|
||||
gap='%d/%d' % (gap_idx + 1, len(missing)),
|
||||
range='%s to %s' % (gap_start, gap_end),
|
||||
days=(gap_end - gap_start).days + 1,
|
||||
))
|
||||
|
||||
for chunk_start_str, chunk_end_str in chunks:
|
||||
start_index = 0
|
||||
chunk_page = 0
|
||||
chunk_total: Optional[int] = None
|
||||
days_seen_in_chunk: set[datetime.date] = set()
|
||||
|
||||
while True:
|
||||
params = {
|
||||
'lastModStartDate': range_start,
|
||||
'lastModEndDate': range_end,
|
||||
'lastModStartDate': chunk_start_str,
|
||||
'lastModEndDate': chunk_end_str,
|
||||
'resultsPerPage': str(PAGE_SIZE),
|
||||
'startIndex': str(start_index),
|
||||
}
|
||||
|
||||
url = self._build_url(params)
|
||||
logger.info(dict(msg='nvd fetch', url=url))
|
||||
chunk_page += 1
|
||||
fetch_count += 1
|
||||
|
||||
logger.info(dict(
|
||||
msg='nvd fetch',
|
||||
range='%s..%s' % (chunk_start_str[:10], chunk_end_str[:10]),
|
||||
chunk_page=chunk_page,
|
||||
start_index=start_index,
|
||||
chunk_total=chunk_total or '?',
|
||||
fetches_total=fetch_count,
|
||||
))
|
||||
|
||||
page = await self._fetch_page(url)
|
||||
fetch_count += 1
|
||||
|
||||
if chunk_total is None:
|
||||
chunk_total = page.totalResults
|
||||
logger.info(dict(
|
||||
msg='nvd chunk total',
|
||||
range='%s..%s' % (chunk_start_str[:10], chunk_end_str[:10]),
|
||||
total_results=page.totalResults,
|
||||
))
|
||||
|
||||
if page.totalResults == 0:
|
||||
# empty range — mark all days in this gap as complete
|
||||
empty_days: list[datetime.date] = []
|
||||
ed: datetime.date = gap_start
|
||||
while ed <= gap_end:
|
||||
empty_days.append(ed)
|
||||
ed = ed + datetime.timedelta(days=1)
|
||||
db.mark_days_complete(cve_source_t.nvd, empty_days)
|
||||
logger.info(dict(
|
||||
msg='nvd empty range',
|
||||
range='%s to %s' % (gap_start, gap_end),
|
||||
days_marked=len(empty_days),
|
||||
))
|
||||
break
|
||||
|
||||
for vuln in page.vulnerabilities:
|
||||
cve = vuln.cve
|
||||
desc = ''
|
||||
for d in cve.descriptions:
|
||||
if d.lang == 'en':
|
||||
desc = d.value
|
||||
for desc_item in cve.descriptions:
|
||||
if desc_item.lang == 'en':
|
||||
desc = desc_item.value
|
||||
break
|
||||
|
||||
score = 0.0
|
||||
@ -183,7 +242,7 @@ class nvd_backend_t(cve_backend_t):
|
||||
severity = _severity_from_nvd(m.cvssData.baseSeverity)
|
||||
break
|
||||
|
||||
entries.append(
|
||||
all_entries.append(
|
||||
cve_entry_t(
|
||||
cve_id=cve.id,
|
||||
source=cve_source_t.nvd,
|
||||
@ -197,8 +256,19 @@ class nvd_backend_t(cve_backend_t):
|
||||
)
|
||||
)
|
||||
|
||||
# track modification dates for day completion
|
||||
if cve.lastModified:
|
||||
try:
|
||||
mod_date = datetime.datetime.fromisoformat(
|
||||
cve.lastModified.replace('Z', '+00:00')
|
||||
).date()
|
||||
days_seen_in_chunk.add(mod_date)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if on_progress is not None:
|
||||
on_progress(len(entries), page.totalResults * len(ranges))
|
||||
chunk_done = min(start_index + page.resultsPerPage, page.totalResults)
|
||||
on_progress(chunk_done, page.totalResults)
|
||||
|
||||
if start_index + page.resultsPerPage >= page.totalResults:
|
||||
break
|
||||
@ -206,8 +276,36 @@ class nvd_backend_t(cve_backend_t):
|
||||
start_index += page.resultsPerPage
|
||||
await asyncio.sleep(self._delay)
|
||||
|
||||
if len(ranges) > 1:
|
||||
# mark days complete
|
||||
fully_paginated = (start_index + page.resultsPerPage >= page.totalResults)
|
||||
|
||||
if fully_paginated:
|
||||
# entire gap range is done — mark all days including empty ones
|
||||
complete_days: list[datetime.date] = []
|
||||
cd: datetime.date = gap_start
|
||||
while cd <= gap_end:
|
||||
complete_days.append(cd)
|
||||
cd = cd + datetime.timedelta(days=1)
|
||||
elif len(days_seen_in_chunk) > 0:
|
||||
# partial — mark up to day before last seen (last seen is uncertain)
|
||||
sorted_days = sorted(days_seen_in_chunk)
|
||||
complete_days = []
|
||||
cd = gap_start
|
||||
while cd < sorted_days[-1]:
|
||||
complete_days.append(cd)
|
||||
cd = cd + datetime.timedelta(days=1)
|
||||
else:
|
||||
complete_days = []
|
||||
|
||||
if len(complete_days) > 0:
|
||||
db.mark_days_complete(cve_source_t.nvd, complete_days)
|
||||
logger.info(dict(
|
||||
msg='nvd days complete',
|
||||
count=len(complete_days),
|
||||
range='%s to %s' % (complete_days[0], complete_days[-1]),
|
||||
))
|
||||
|
||||
await asyncio.sleep(self._delay)
|
||||
|
||||
logger.info(dict(msg='nvd sync done', fetches=fetch_count, entries=len(entries)))
|
||||
return entries
|
||||
logger.info(dict(msg='nvd sync done', fetches=fetch_count, entries=len(all_entries)))
|
||||
return await self._store_and_update_meta(db, all_entries)
|
||||
|
||||
@ -7,13 +7,11 @@ Supports batch queries (up to 1000 per request).
|
||||
Ecosystem list fetched from GCS bucket listing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
import urllib.request
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from typing import ClassVar, Optional
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
import pydantic
|
||||
|
||||
@ -30,8 +28,12 @@ from .types import (
|
||||
cve_entry_t,
|
||||
cve_source_t,
|
||||
cve_sync_estimate_t,
|
||||
cve_upsert_result_t,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import cve_db_t
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QUERY_URL = 'https://api.osv.dev/v1/querybatch'
|
||||
@ -50,7 +52,6 @@ class osv_ecosystems_t:
|
||||
if cls._cached is not None and not force:
|
||||
return cls._cached
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
ecosystems: list[osv_ecosystem_t] = []
|
||||
marker = ''
|
||||
seen: set[str] = set()
|
||||
@ -59,11 +60,9 @@ class osv_ecosystems_t:
|
||||
url = '%s?delimiter=/&prefix=&marker=%s' % (GCS_BUCKET_URL, marker)
|
||||
logger.debug(dict(msg='fetching osv ecosystems page', marker=marker))
|
||||
|
||||
def _do(u: str = url) -> str:
|
||||
resp = urllib.request.urlopen(u, timeout=30)
|
||||
return resp.read().decode('utf-8')
|
||||
from ..network.base import net_t
|
||||
|
||||
raw = await loop.run_in_executor(None, _do)
|
||||
raw = await net_t.async_fetch_text(url)
|
||||
|
||||
root = ET.fromstring(raw)
|
||||
ns = '{http://doc.s3.amazonaws.com/2006-03-01}'
|
||||
@ -152,12 +151,13 @@ class osv_backend_t(cve_backend_t):
|
||||
|
||||
async def sync(
|
||||
self,
|
||||
db: 'cve_db_t',
|
||||
since: Optional[str] = None,
|
||||
months: Optional[int] = None,
|
||||
on_progress: Optional[SyncProgressCallback] = None,
|
||||
) -> list[cve_entry_t]:
|
||||
) -> cve_upsert_result_t:
|
||||
logger.warning(dict(msg='osv sync requires explicit package list, use query_packages()'))
|
||||
return []
|
||||
return cve_upsert_result_t()
|
||||
|
||||
async def query_packages(
|
||||
self,
|
||||
@ -184,7 +184,9 @@ class osv_backend_t(cve_backend_t):
|
||||
]
|
||||
)
|
||||
|
||||
raw = await self._post_json(QUERY_URL, request.model_dump_json().encode('utf-8'))
|
||||
from ..network.base import net_t
|
||||
|
||||
raw = await net_t.async_post_json(QUERY_URL, request.model_dump_json().encode('utf-8'))
|
||||
batch_resp = pydantic.TypeAdapter(osv_batch_response_t).validate_json(raw)
|
||||
|
||||
for i, result in enumerate(batch_resp.results):
|
||||
|
||||
@ -0,0 +1,143 @@
|
||||
"""Network helper class.
|
||||
|
||||
All HTTP requests (sync and async) go through this class so that
|
||||
timeout and max_size limits from net_settings_t are enforced globally.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import pathlib
|
||||
import urllib.request
|
||||
|
||||
from http.client import HTTPResponse
|
||||
from typing import Optional
|
||||
|
||||
from .settings import net_settings_t
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class net_t:
|
||||
@classmethod
|
||||
def _settings(cls) -> net_settings_t:
|
||||
return net_settings_t.singleton()
|
||||
|
||||
@classmethod
|
||||
def _read_limited(cls, resp: HTTPResponse, max_size: Optional[int] = None) -> bytes:
|
||||
"""Read response body up to max_size bytes."""
|
||||
if max_size is None:
|
||||
max_size = cls._settings().max_size
|
||||
|
||||
chunks: list[bytes] = []
|
||||
total = 0
|
||||
while True:
|
||||
chunk = resp.read(65536)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if total > max_size:
|
||||
raise ValueError(
|
||||
'response exceeded max_size %d bytes' % max_size
|
||||
)
|
||||
chunks.append(chunk)
|
||||
return b''.join(chunks)
|
||||
|
||||
@classmethod
|
||||
def fetch_url(
|
||||
cls,
|
||||
url: str,
|
||||
timeout: Optional[float] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
) -> bytes:
|
||||
"""Synchronous GET with timeout and size limit."""
|
||||
s = cls._settings()
|
||||
if timeout is None:
|
||||
timeout = s.timeout
|
||||
|
||||
if headers is not None and len(headers) > 0:
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
resp = urllib.request.urlopen(req, timeout=timeout)
|
||||
else:
|
||||
resp = urllib.request.urlopen(url, timeout=timeout)
|
||||
return cls._read_limited(resp)
|
||||
|
||||
@classmethod
|
||||
def fetch_text(cls, url: str, timeout: Optional[float] = None) -> str:
|
||||
"""Synchronous GET returning decoded text."""
|
||||
return cls.fetch_url(url, timeout=timeout).decode('utf-8')
|
||||
|
||||
@classmethod
|
||||
def head_content_length(cls, url: str, timeout: Optional[float] = None) -> int:
|
||||
"""Synchronous HEAD returning Content-Length or 0."""
|
||||
s = cls._settings()
|
||||
if timeout is None:
|
||||
timeout = s.timeout
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(url, method='HEAD')
|
||||
resp = urllib.request.urlopen(req, timeout=timeout)
|
||||
cl = resp.headers.get('Content-Length', '0')
|
||||
return int(cl)
|
||||
except Exception:
|
||||
logger.debug(dict(msg='HEAD failed', url=url))
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def post_json(cls, url: str, data: bytes, timeout: Optional[float] = None) -> bytes:
|
||||
"""Synchronous POST with JSON content type."""
|
||||
s = cls._settings()
|
||||
if timeout is None:
|
||||
timeout = s.timeout
|
||||
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/json'},
|
||||
method='POST',
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=timeout)
|
||||
return cls._read_limited(resp)
|
||||
|
||||
@classmethod
|
||||
def download_to_file(cls, url: str, output_path: pathlib.Path, timeout: Optional[float] = None) -> None:
|
||||
"""Synchronous download to a file path. No size limit (for package downloads)."""
|
||||
s = cls._settings()
|
||||
if timeout is None:
|
||||
timeout = s.timeout
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
resp = urllib.request.urlopen(url, timeout=timeout)
|
||||
with open(output_path, 'wb') as f:
|
||||
while True:
|
||||
chunk = resp.read(65536)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
|
||||
# ── async wrappers ──
|
||||
|
||||
@classmethod
|
||||
async def async_fetch_url(
|
||||
cls,
|
||||
url: str,
|
||||
timeout: Optional[float] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
) -> bytes:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, lambda: cls.fetch_url(url, timeout=timeout, headers=headers))
|
||||
|
||||
@classmethod
|
||||
async def async_fetch_text(cls, url: str, timeout: Optional[float] = None) -> str:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, lambda: cls.fetch_text(url, timeout=timeout))
|
||||
|
||||
@classmethod
|
||||
async def async_head_content_length(cls, url: str, timeout: Optional[float] = None) -> int:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, lambda: cls.head_content_length(url, timeout=timeout))
|
||||
|
||||
@classmethod
|
||||
async def async_post_json(cls, url: str, data: bytes, timeout: Optional[float] = None) -> bytes:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, lambda: cls.post_json(url, data, timeout=timeout))
|
||||
@ -0,0 +1,45 @@
|
||||
"""CLI integration for network settings.
|
||||
|
||||
Provides methods to inject argparse arguments, extract parsed values,
|
||||
and apply them to the network settings singleton.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .settings import net_settings_t
|
||||
|
||||
|
||||
class net_cli_t:
|
||||
@staticmethod
|
||||
def add_arguments(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
'--net-timeout',
|
||||
dest='net_timeout',
|
||||
type=float,
|
||||
default=None,
|
||||
help='timeout in seconds for non-download HTTP requests (default: 8.0)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--net-max-size',
|
||||
dest='net_max_size',
|
||||
type=int,
|
||||
default=None,
|
||||
help='max response body size in bytes for non-download HTTP requests (default: 50MB)',
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract(namespace: argparse.Namespace) -> dict[str, Any]:
|
||||
kwargs: dict[str, Any] = {}
|
||||
if getattr(namespace, 'net_timeout', None) is not None:
|
||||
kwargs['timeout'] = namespace.net_timeout
|
||||
if getattr(namespace, 'net_max_size', None) is not None:
|
||||
kwargs['max_size'] = namespace.net_max_size
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def apply(kwargs: dict[str, Any]) -> net_settings_t:
|
||||
if len(kwargs) > 0:
|
||||
return net_settings_t.reset(**kwargs)
|
||||
return net_settings_t.singleton()
|
||||
@ -0,0 +1,31 @@
|
||||
"""Network settings singleton based on pydantic-settings.
|
||||
|
||||
Values can be set via environment variables (ARCHLINUX_NET_TIMEOUT, etc.)
|
||||
or by calling net_settings_t.reset() with explicit kwargs.
|
||||
"""
|
||||
|
||||
import pydantic_settings
|
||||
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
|
||||
class net_settings_t(pydantic_settings.BaseSettings):
|
||||
model_config = pydantic_settings.SettingsConfigDict(
|
||||
env_prefix='ARCHLINUX_NET_',
|
||||
)
|
||||
|
||||
timeout: float = 8.0
|
||||
max_size: int = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
_instance: ClassVar[Optional['net_settings_t']] = None
|
||||
|
||||
@classmethod
|
||||
def singleton(cls) -> 'net_settings_t':
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset(cls, **kwargs: Any) -> 'net_settings_t':
|
||||
cls._instance = cls.model_validate(kwargs)
|
||||
return cls._instance
|
||||
@ -169,7 +169,7 @@ class pacman_t:
|
||||
url: str,
|
||||
output_path: pathlib.Path,
|
||||
) -> None:
|
||||
import urllib.request
|
||||
from ..network.base import net_t
|
||||
|
||||
logger.info(
|
||||
dict(
|
||||
@ -179,12 +179,11 @@ class pacman_t:
|
||||
)
|
||||
)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
net_t.download_to_file(url, output_path)
|
||||
|
||||
urllib.request.urlretrieve(
|
||||
url,
|
||||
str(output_path),
|
||||
)
|
||||
@staticmethod
|
||||
def build_install_command(paths: list[pathlib.Path]) -> list[str]:
|
||||
return ['pacman', '-U', '--noconfirm'] + [str(p) for p in paths]
|
||||
|
||||
@staticmethod
|
||||
def build_mirror_config(options: compile_options_t) -> mirror_config_t:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Pacman implementation of the archive manager interface."""
|
||||
|
||||
import dataclasses
|
||||
import datetime
|
||||
import logging
|
||||
import pathlib
|
||||
@ -19,23 +20,42 @@ from .types import mirror_config_t
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class archive_entry_t:
|
||||
name: str
|
||||
version: str
|
||||
filename: str
|
||||
date: datetime.date
|
||||
|
||||
|
||||
class pacman_manager_t(manager_t):
|
||||
class constants_t:
|
||||
base_url: ClassVar[str] = 'https://archive.archlinux.org/repos/'
|
||||
packages_url: ClassVar[str] = 'https://archive.archlinux.org/packages/'
|
||||
href_re: ClassVar[re.Pattern[str]] = re.compile(
|
||||
r'href="(\d{4}/\d{2}/\d{2})/"'
|
||||
)
|
||||
# matches: <a href="filename">filename</a> DD-Mon-YYYY HH:MM size
|
||||
listing_re: ClassVar[re.Pattern[str]] = re.compile(
|
||||
r'<a href="([^"]+)">(?:[^<]+)</a>\s+'
|
||||
r'(\d{2})-([A-Za-z]{3})-(\d{4})\s+(\d{2}:\d{2})\s+'
|
||||
r'(\S+)'
|
||||
)
|
||||
month_map: ClassVar[dict[str, int]] = {
|
||||
'Jan': 1, 'Feb': 2, 'Mar': 3, 'Apr': 4,
|
||||
'May': 5, 'Jun': 6, 'Jul': 7, 'Aug': 8,
|
||||
'Sep': 9, 'Oct': 10, 'Nov': 11, 'Dec': 12,
|
||||
}
|
||||
default_repos: ClassVar[list[str]] = ['core', 'extra', 'multilib']
|
||||
|
||||
def list_remote_dates(self) -> list[str]:
|
||||
import urllib.request
|
||||
from ..network.base import net_t
|
||||
|
||||
base_url = pacman_manager_t.constants_t.base_url
|
||||
|
||||
logger.info(dict(msg='fetching archive index', url=base_url))
|
||||
|
||||
with urllib.request.urlopen(base_url) as resp:
|
||||
html = resp.read().decode('utf-8')
|
||||
html = net_t.fetch_text(base_url)
|
||||
|
||||
dates: list[str] = []
|
||||
for m in pacman_manager_t.constants_t.href_re.finditer(html):
|
||||
@ -162,3 +182,135 @@ class pacman_manager_t(manager_t):
|
||||
)
|
||||
|
||||
current -= step
|
||||
|
||||
def _fetch_archive_page(self, pkg_name: str) -> str:
|
||||
from ..network.base import net_t
|
||||
|
||||
url = '%s%s/%s/' % (
|
||||
pacman_manager_t.constants_t.packages_url,
|
||||
pkg_name[0],
|
||||
pkg_name,
|
||||
)
|
||||
logger.info(dict(msg='fetching archive listing', pkg=pkg_name, url=url))
|
||||
return net_t.fetch_text(url)
|
||||
|
||||
def sync_reference(
|
||||
self,
|
||||
reference: dict[str, str],
|
||||
cache_dir: pathlib.Path,
|
||||
cache_db: cache_db_t,
|
||||
repos: Optional[list[str]] = None,
|
||||
arch: str = 'x86_64',
|
||||
) -> None:
|
||||
if len(reference) == 0:
|
||||
return
|
||||
|
||||
# find which (name, version) pairs are missing from cached packages
|
||||
missing: dict[str, str] = {}
|
||||
for name, version in reference.items():
|
||||
if not cache_db.has_package_version(name, version):
|
||||
missing[name] = version
|
||||
|
||||
if len(missing) == 0:
|
||||
logger.info(dict(msg='all reference versions already cached', count=len(reference)))
|
||||
return
|
||||
|
||||
logger.info(dict(msg='reference versions missing from cache', count=len(missing)))
|
||||
|
||||
# group by package name to fetch each archive page once
|
||||
pkg_names = sorted(set(missing.keys()))
|
||||
dates_to_sync: set[str] = set()
|
||||
|
||||
for pkg_name in pkg_names:
|
||||
try:
|
||||
html = self._fetch_archive_page(pkg_name)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
dict(msg='failed to fetch archive listing', pkg=pkg_name),
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
entries = pacman_manager_t.parse_archive_listing(pkg_name, html)
|
||||
if len(entries) > 0:
|
||||
cache_db.bulk_upsert_archive_versions(entries)
|
||||
|
||||
target_version = missing[pkg_name]
|
||||
matched = [e for e in entries if e.version == target_version]
|
||||
if len(matched) == 0:
|
||||
logger.warning(dict(
|
||||
msg='version not found in archive listing',
|
||||
pkg=pkg_name,
|
||||
version=target_version,
|
||||
))
|
||||
continue
|
||||
|
||||
entry = matched[0]
|
||||
date_str = pacman_manager_t._format_date(entry.date)
|
||||
dates_to_sync.add(date_str)
|
||||
|
||||
logger.info(dict(
|
||||
msg='found version in archive',
|
||||
pkg=pkg_name,
|
||||
version=target_version,
|
||||
archive_date=date_str,
|
||||
))
|
||||
|
||||
# sync each discovered date
|
||||
for date_str in sorted(dates_to_sync):
|
||||
try:
|
||||
self.sync_date(
|
||||
date=date_str,
|
||||
cache_dir=cache_dir,
|
||||
cache_db=cache_db,
|
||||
repos=repos,
|
||||
arch=arch,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
dict(msg='failed to sync date', date=date_str),
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
# mark synced versions
|
||||
for name, version in missing.items():
|
||||
if cache_db.has_package_version(name, version):
|
||||
cache_db.mark_archive_version_synced(name, version)
|
||||
else:
|
||||
logger.warning(dict(
|
||||
msg='version still not found after sync',
|
||||
pkg=name,
|
||||
version=version,
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def parse_archive_listing(pkg_name: str, html: str) -> list[archive_entry_t]:
|
||||
c = pacman_manager_t.constants_t
|
||||
entries: list[archive_entry_t] = []
|
||||
pkg_suffix_re = re.compile(
|
||||
r'^%s-(.+)-(x86_64|any)\.pkg\.tar\.(zst|xz)$' % re.escape(pkg_name)
|
||||
)
|
||||
|
||||
for m in c.listing_re.finditer(html):
|
||||
filename = m.group(1)
|
||||
if filename.endswith('.sig'):
|
||||
continue
|
||||
|
||||
sm = pkg_suffix_re.match(filename)
|
||||
if sm is None:
|
||||
continue
|
||||
|
||||
version = sm.group(1)
|
||||
day = int(m.group(2))
|
||||
month = c.month_map.get(m.group(3), 1)
|
||||
year = int(m.group(4))
|
||||
|
||||
entries.append(archive_entry_t(
|
||||
name=pkg_name,
|
||||
version=version,
|
||||
filename=filename,
|
||||
date=datetime.date(year, month, day),
|
||||
))
|
||||
|
||||
return entries
|
||||
|
||||
@ -0,0 +1,172 @@
|
||||
import argparse
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
from ..apps.network.settings import net_settings_t
|
||||
from ..apps.network.base import net_t
|
||||
from ..apps.network.cli import net_cli_t
|
||||
|
||||
|
||||
class TestNetSettings(unittest.TestCase):
|
||||
def tearDown(self) -> None:
|
||||
net_settings_t._instance = None
|
||||
|
||||
def test_singleton_returns_same_instance(self) -> None:
|
||||
a = net_settings_t.singleton()
|
||||
b = net_settings_t.singleton()
|
||||
self.assertIs(a, b)
|
||||
|
||||
def test_default_timeout(self) -> None:
|
||||
s = net_settings_t.singleton()
|
||||
self.assertEqual(s.timeout, 8.0)
|
||||
|
||||
def test_default_max_size(self) -> None:
|
||||
s = net_settings_t.singleton()
|
||||
self.assertEqual(s.max_size, 50 * 1024 * 1024)
|
||||
|
||||
def test_reset_changes_timeout(self) -> None:
|
||||
net_settings_t.reset(timeout=4.0)
|
||||
s = net_settings_t.singleton()
|
||||
self.assertEqual(s.timeout, 4.0)
|
||||
|
||||
def test_reset_changes_max_size(self) -> None:
|
||||
net_settings_t.reset(max_size=1024)
|
||||
s = net_settings_t.singleton()
|
||||
self.assertEqual(s.max_size, 1024)
|
||||
|
||||
def test_reset_returns_new_instance(self) -> None:
|
||||
a = net_settings_t.singleton()
|
||||
b = net_settings_t.reset(timeout=2.0)
|
||||
self.assertIsNot(a, b)
|
||||
self.assertEqual(b.timeout, 2.0)
|
||||
|
||||
def test_reset_partial_preserves_defaults(self) -> None:
|
||||
s = net_settings_t.reset(timeout=3.0)
|
||||
self.assertEqual(s.timeout, 3.0)
|
||||
self.assertEqual(s.max_size, 50 * 1024 * 1024)
|
||||
|
||||
def test_env_override(self) -> None:
|
||||
net_settings_t._instance = None
|
||||
with unittest.mock.patch.dict('os.environ', {'ARCHLINUX_NET_TIMEOUT': '2.5'}):
|
||||
s = net_settings_t()
|
||||
self.assertEqual(s.timeout, 2.5)
|
||||
|
||||
|
||||
class TestNetReadLimited(unittest.TestCase):
|
||||
def tearDown(self) -> None:
|
||||
net_settings_t._instance = None
|
||||
|
||||
def test_read_within_limit(self) -> None:
|
||||
net_settings_t.reset(max_size=1024)
|
||||
data = b'x' * 512
|
||||
|
||||
import io
|
||||
resp = io.BytesIO(data)
|
||||
|
||||
from http.client import HTTPResponse
|
||||
with unittest.mock.patch.object(
|
||||
net_t, '_read_limited',
|
||||
wraps=net_t._read_limited,
|
||||
):
|
||||
# call directly with a mock response
|
||||
mock_resp = unittest.mock.MagicMock()
|
||||
mock_resp.read = io.BytesIO(data).read
|
||||
result = net_t._read_limited(mock_resp, max_size=1024)
|
||||
self.assertEqual(result, data)
|
||||
|
||||
def test_read_exceeds_limit(self) -> None:
|
||||
net_settings_t.reset(max_size=100)
|
||||
data = b'x' * 200
|
||||
|
||||
import io
|
||||
mock_resp = unittest.mock.MagicMock()
|
||||
mock_resp.read = io.BytesIO(data).read
|
||||
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
net_t._read_limited(mock_resp, max_size=100)
|
||||
self.assertIn('max_size', str(ctx.exception))
|
||||
|
||||
|
||||
class TestNetFetchUrl(unittest.TestCase):
|
||||
def tearDown(self) -> None:
|
||||
net_settings_t._instance = None
|
||||
|
||||
def test_fetch_url_uses_settings_timeout(self) -> None:
|
||||
net_settings_t.reset(timeout=4.5)
|
||||
|
||||
with unittest.mock.patch('urllib.request.urlopen') as mock_urlopen:
|
||||
mock_resp = unittest.mock.MagicMock()
|
||||
mock_resp.read.side_effect = [b'hello', b'']
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
result = net_t.fetch_url('http://example.com')
|
||||
self.assertEqual(result, b'hello')
|
||||
mock_urlopen.assert_called_once_with('http://example.com', timeout=4.5)
|
||||
|
||||
def test_fetch_url_explicit_timeout_overrides(self) -> None:
|
||||
net_settings_t.reset(timeout=8.0)
|
||||
|
||||
with unittest.mock.patch('urllib.request.urlopen') as mock_urlopen:
|
||||
mock_resp = unittest.mock.MagicMock()
|
||||
mock_resp.read.side_effect = [b'data', b'']
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
net_t.fetch_url('http://example.com', timeout=2.0)
|
||||
mock_urlopen.assert_called_once_with('http://example.com', timeout=2.0)
|
||||
|
||||
def test_fetch_text_returns_str(self) -> None:
|
||||
net_settings_t.reset(timeout=8.0)
|
||||
|
||||
with unittest.mock.patch('urllib.request.urlopen') as mock_urlopen:
|
||||
mock_resp = unittest.mock.MagicMock()
|
||||
mock_resp.read.side_effect = [b'hello world', b'']
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
result = net_t.fetch_text('http://example.com')
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertEqual(result, 'hello world')
|
||||
|
||||
|
||||
class TestNetCli(unittest.TestCase):
|
||||
def tearDown(self) -> None:
|
||||
net_settings_t._instance = None
|
||||
|
||||
def test_add_arguments(self) -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
net_cli_t.add_arguments(parser)
|
||||
ns = parser.parse_args(['--net-timeout', '4.5', '--net-max-size', '1024'])
|
||||
self.assertEqual(ns.net_timeout, 4.5)
|
||||
self.assertEqual(ns.net_max_size, 1024)
|
||||
|
||||
def test_add_arguments_defaults(self) -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
net_cli_t.add_arguments(parser)
|
||||
ns = parser.parse_args([])
|
||||
self.assertIsNone(ns.net_timeout)
|
||||
self.assertIsNone(ns.net_max_size)
|
||||
|
||||
def test_extract_both(self) -> None:
|
||||
ns = argparse.Namespace(net_timeout=3.0, net_max_size=2048)
|
||||
kwargs = net_cli_t.extract(ns)
|
||||
self.assertEqual(kwargs, {'timeout': 3.0, 'max_size': 2048})
|
||||
|
||||
def test_extract_partial(self) -> None:
|
||||
ns = argparse.Namespace(net_timeout=5.0, net_max_size=None)
|
||||
kwargs = net_cli_t.extract(ns)
|
||||
self.assertEqual(kwargs, {'timeout': 5.0})
|
||||
|
||||
def test_extract_empty(self) -> None:
|
||||
ns = argparse.Namespace(net_timeout=None, net_max_size=None)
|
||||
kwargs = net_cli_t.extract(ns)
|
||||
self.assertEqual(kwargs, {})
|
||||
|
||||
def test_apply_with_overrides(self) -> None:
|
||||
s = net_cli_t.apply({'timeout': 2.0, 'max_size': 512})
|
||||
self.assertEqual(s.timeout, 2.0)
|
||||
self.assertEqual(s.max_size, 512)
|
||||
self.assertIs(s, net_settings_t.singleton())
|
||||
|
||||
def test_apply_empty_returns_default(self) -> None:
|
||||
s = net_cli_t.apply({})
|
||||
self.assertEqual(s.timeout, 8.0)
|
||||
self.assertEqual(s.max_size, 50 * 1024 * 1024)
|
||||
Loading…
Reference in New Issue
Block a user