[+] rewrite download command: async, pydantic entry model, terminal-aware progress

1. download_entry_t pydantic model (frozen=True) replaces tuples;
  2. download_t class with async run(), _download_one(), _run_parallel(), _run_aria2c_batch();
  3. asyncio.Semaphore for -j concurrency, run_in_executor for blocking I/O;
  4. .part file pattern: download to dest.pkg.part, rename on success;
  5. curl -C - and aria2c --continue=true for resume support;
  6. curl/aria2c stdout/stderr redirected to devnull;
  7. --dry-run, --verify-checksum BooleanOptionalAction flags;
  8. --progress-mode plain|interactive using terminal.py renderer;
  9. progress_t with byte-based ETA (when --size= available) and pkg-rate fallback;
  10. parse --size=BYTES and --hash=sha256: from compiled requirements;
  11. update test_cli.py downloader tests for .part rename and **kwargs;
  12. update test_download_cli.py for pydantic model and new progress_t constructor;
This commit is contained in:
LLM 2026-04-13 09:00:00 +00:00
parent b98173511e
commit b7f7d3d291
3 changed files with 478 additions and 173 deletions

@ -1,10 +1,10 @@
"""Download compiled packages.""" """Download compiled packages."""
import argparse import argparse
import concurrent.futures import asyncio
import enum import enum
import hashlib
import logging import logging
import os
import pathlib import pathlib
import re import re
import subprocess import subprocess
@ -12,13 +12,28 @@ import time
import urllib.request import urllib.request
from typing import ( from typing import (
TYPE_CHECKING,
ClassVar, ClassVar,
Optional, Optional,
) )
import pydantic
if TYPE_CHECKING:
from online.fxreader.pr34.commands_typed.terminal import field_t
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class download_entry_t(pydantic.BaseModel):
model_config = pydantic.ConfigDict(frozen=True)
url: str
filename: str
sha256: str = ''
csize: int = 0
class parse_rate_t: class parse_rate_t:
class constants_t: class constants_t:
rate_re: ClassVar[re.Pattern[str]] = re.compile( rate_re: ClassVar[re.Pattern[str]] = re.compile(
@ -63,122 +78,246 @@ class downloader_t:
limit_rate: int, limit_rate: int,
) -> None: ) -> None:
dest.parent.mkdir(parents=True, exist_ok=True) dest.parent.mkdir(parents=True, exist_ok=True)
part = dest.with_suffix(dest.suffix + '.part')
devnull = subprocess.DEVNULL
if backend is downloader_t.constants_t.backend_t.urllib: if backend is downloader_t.constants_t.backend_t.urllib:
urllib.request.urlretrieve(url, str(dest)) urllib.request.urlretrieve(url, str(part))
elif backend is downloader_t.constants_t.backend_t.curl: elif backend is downloader_t.constants_t.backend_t.curl:
cmd = [ subprocess.check_call(
'curl', '-fSL', [
'--limit-rate', '%d' % limit_rate, 'curl', '-fSL',
'-o', str(dest), '-C', '-',
url, '--limit-rate', '%d' % limit_rate,
] '-o', str(part),
subprocess.check_call(cmd) url,
],
stdout=devnull,
stderr=devnull,
)
elif backend is downloader_t.constants_t.backend_t.aria2c: elif backend is downloader_t.constants_t.backend_t.aria2c:
cmd = [ subprocess.check_call(
'aria2c', [
'--max-download-limit=%d' % limit_rate, 'aria2c',
'-d', str(dest.parent), '--continue=true',
'-o', dest.name, '--max-download-limit=%d' % limit_rate,
url, '-d', str(part.parent),
] '-o', part.name,
subprocess.check_call(cmd) url,
],
stdout=devnull,
stderr=devnull,
)
else: else:
raise NotImplementedError raise NotImplementedError
part.rename(dest)
@staticmethod @staticmethod
def download_batch_aria2c( def download_batch_aria2c(
entries: list[tuple[str, pathlib.Path]], entries: list[tuple[str, pathlib.Path]],
limit_rate: int, limit_rate: int,
jobs: int, jobs: int,
) -> None: ) -> None:
"""Download multiple files using a single aria2c process with -j."""
if len(entries) == 0: if len(entries) == 0:
return return
dest_dir = entries[0][1].parent dest_dir = entries[0][1].parent
dest_dir.mkdir(parents=True, exist_ok=True) dest_dir.mkdir(parents=True, exist_ok=True)
# write input file for aria2c
input_lines: list[str] = [] input_lines: list[str] = []
for url, dest in entries: for url, dest in entries:
input_lines.append(url) input_lines.append(url)
input_lines.append(' dir=%s' % str(dest.parent)) input_lines.append(' dir=%s' % str(dest.parent))
input_lines.append(' out=%s' % dest.name) input_lines.append(' out=%s' % dest.name)
input_txt = '\n'.join(input_lines) + '\n'
input_path = dest_dir / '.aria2c-input.txt' input_path = dest_dir / '.aria2c-input.txt'
input_path.write_text(input_txt) input_path.write_text('\n'.join(input_lines) + '\n')
cmd = [
'aria2c',
'--max-download-limit=%d' % limit_rate,
'-j', '%d' % jobs,
'-i', str(input_path),
]
try: try:
subprocess.check_call(cmd) subprocess.check_call(
[
'aria2c',
'--continue=true',
'--max-download-limit=%d' % limit_rate,
'-j', '%d' % jobs,
'-i', str(input_path),
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
finally: finally:
input_path.unlink(missing_ok=True) input_path.unlink(missing_ok=True)
@staticmethod
def verify_sha256(path: pathlib.Path, expected: str) -> bool:
h = hashlib.sha256()
with open(path, 'rb') as f:
while True:
chunk = f.read(64 * 1024)
if not chunk:
break
h.update(chunk)
return h.hexdigest() == expected
class progress_t: class progress_t:
class constants_t: class constants_t:
class mode_t(enum.Enum): class mode_t(enum.Enum):
plain = 'plain' plain = 'plain'
interactive = 'interactive'
def __init__( def __init__(
self, self,
total: int, total_pkgs: int,
total_bytes: int, already_pkgs: int,
already_done: int, to_download_pkgs: int,
already_bytes: int, already_bytes: int,
total_expected_bytes: int,
to_download_bytes: int,
has_sizes: bool,
mode: 'progress_t.constants_t.mode_t' = constants_t.mode_t.plain,
) -> None: ) -> None:
self.total = total self.total_pkgs = total_pkgs
self.total_bytes = total_bytes self.already_pkgs = already_pkgs
self.already_done = already_done self.to_download_pkgs = to_download_pkgs
self.already_bytes = already_bytes self.already_bytes = already_bytes
self.downloaded_this_run = 0 self.total_expected_bytes = total_expected_bytes
self.downloaded_bytes_this_run = 0 self.to_download_bytes = to_download_bytes
self.has_sizes = has_sizes
self.downloaded_pkgs = 0
self.downloaded_bytes = 0
self.start_time = time.monotonic() self.start_time = time.monotonic()
def update(self, file_bytes: int) -> None: from online.fxreader.pr34.commands_typed.terminal import (
self.downloaded_this_run += 1 render_mode_t,
self.downloaded_bytes_this_run += file_bytes renderer_t,
)
def format_plain(self) -> str: render_mode = (
done = self.already_done + self.downloaded_this_run render_mode_t.interactive
done_mb = (self.already_bytes + self.downloaded_bytes_this_run) / (1024 * 1024) if mode is progress_t.constants_t.mode_t.interactive
total_mb = self.total_bytes / (1024 * 1024) else render_mode_t.plain
)
self._renderer = renderer_t(mode=render_mode)
def update(self, file_bytes: int) -> None:
self.downloaded_pkgs += 1
self.downloaded_bytes += file_bytes
@staticmethod
def _fmt_bytes(b: int) -> str:
if b >= 1024 * 1024 * 1024:
return '%.1fG' % (b / (1024 * 1024 * 1024))
return '%.1fM' % (b / (1024 * 1024))
def _build_fields(self) -> 'list[field_t]':
from online.fxreader.pr34.commands_typed.terminal import (
field_t,
priority_t,
)
done_pkgs = self.already_pkgs + self.downloaded_pkgs
elapsed = time.monotonic() - self.start_time elapsed = time.monotonic() - self.start_time
if self.downloaded_this_run > 0 and elapsed > 0: # speed
rate = self.downloaded_this_run / elapsed if elapsed > 0 and self.downloaded_bytes > 0:
remaining = self.total - done speed = '%s/s' % self._fmt_bytes(int(self.downloaded_bytes / elapsed))
if rate > 0:
eta_s = remaining / rate
eta = '%dm%02ds' % (int(eta_s) // 60, int(eta_s) % 60)
else:
eta = '?'
if rate >= 1:
rate_str = '%.1f pkg/s' % rate
else:
rate_str = '%.1f s/pkg' % (1.0 / rate) if rate > 0 else '?'
else: else:
eta = '?' speed = '-'
rate_str = '?'
return ( # pkg rate
'[%d/%d] this_run=%d %.1f/%.1f MiB ETA=%s %s' if self.downloaded_pkgs > 0 and elapsed > 0:
% (done, self.total, self.downloaded_this_run, done_mb, total_mb, eta, rate_str) pkg_rate = self.downloaded_pkgs / elapsed
) if pkg_rate >= 1:
rate_str = '%.1f pkg/s' % pkg_rate
else:
rate_str = '%.1f s/pkg' % (1.0 / pkg_rate) if pkg_rate > 0 else '-'
else:
rate_str = '-'
# ETA
if self.has_sizes and elapsed > 0 and self.downloaded_bytes > 0:
remaining_bytes = self.to_download_bytes - self.downloaded_bytes
byte_rate = self.downloaded_bytes / elapsed
eta_s = max(0, remaining_bytes / byte_rate) if byte_rate > 0 else 0
eta = '%dm%02ds' % (int(eta_s) // 60, int(eta_s) % 60)
elif self.downloaded_pkgs > 0 and elapsed > 0:
remaining_pkgs = self.to_download_pkgs - self.downloaded_pkgs
pkg_rate_v = self.downloaded_pkgs / elapsed
eta_s = remaining_pkgs / pkg_rate_v if pkg_rate_v > 0 else 0
eta = '~%dm%02ds' % (int(eta_s) // 60, int(eta_s) % 60)
else:
eta = '-'
# total bytes
total_str = self._fmt_bytes(self.total_expected_bytes) if self.has_sizes else '?'
fields = [
field_t(
name='',
value='[%d/%d]' % (done_pkgs, self.total_pkgs),
priority=priority_t.critical,
),
field_t(
name='new',
value='%d/%d' % (self.downloaded_pkgs, self.to_download_pkgs),
priority=priority_t.critical,
),
field_t(
name='cached',
value='%d/%s' % (self.already_pkgs, self._fmt_bytes(self.already_bytes)),
priority=priority_t.normal,
),
field_t(
name='dl',
value='%s/%s' % (
self._fmt_bytes(self.downloaded_bytes),
self._fmt_bytes(self.to_download_bytes) if self.has_sizes else '?',
),
priority=priority_t.high,
),
field_t(
name='total',
value=total_str,
priority=priority_t.low,
),
field_t(
name='',
value=speed,
priority=priority_t.high,
),
field_t(
name='',
value=rate_str,
priority=priority_t.normal,
),
field_t(
name='ETA',
value=eta,
priority=priority_t.critical,
),
]
return fields
def emit(self) -> None:
self._renderer.emit(self._build_fields())
def finish(self) -> None:
self._renderer.finish()
def format_plain(self) -> str:
"""For tests and non-renderer usage."""
from online.fxreader.pr34.commands_typed.terminal import line_formatter_t
return line_formatter_t.format(self._build_fields(), 200)
class download_requirements_t: class download_requirements_t:
@staticmethod @staticmethod
def parse_requirements(txt: str) -> list[tuple[str, str]]: def parse_requirements(txt: str) -> list[download_entry_t]:
entries: list[tuple[str, str]] = [] """Parse compiled requirements into download entries."""
entries: list[download_entry_t] = []
url: Optional[str] = None url: Optional[str] = None
for line in txt.splitlines(): for line in txt.splitlines():
@ -187,14 +326,12 @@ class download_requirements_t:
continue continue
if line.startswith('#'): if line.startswith('#'):
candidate = line[1:].strip() candidate = line[1:].strip()
# strip trailing annotation like "URL # pinned"
if ' #' in candidate: if ' #' in candidate:
candidate = candidate.split(' #', 1)[0].strip() candidate = candidate.split(' #', 1)[0].strip()
if '/' in candidate and '://' in candidate: if '/' in candidate and '://' in candidate:
url = candidate url = candidate
continue continue
# strip trailing inline comment (e.g. "pkg==1.0 # pinned")
if ' #' in line: if ' #' in line:
line = line.split(' #', 1)[0].strip() line = line.split(' #', 1)[0].strip()
@ -202,16 +339,219 @@ class download_requirements_t:
if len(parts) == 0: if len(parts) == 0:
continue continue
pkg_spec = parts[0] sha256 = ''
csize = 0
for p in parts[1:]:
if p.startswith('--hash=sha256:'):
sha256 = p[len('--hash=sha256:'):]
elif p.startswith('--size='):
try:
csize = int(p[len('--size='):])
except ValueError:
pass
if url is not None: if url is not None:
filename = url.rsplit('/', 1)[-1] if '/' in url else pkg_spec filename = url.rsplit('/', 1)[-1] if '/' in url else parts[0]
entries.append((url, filename)) entries.append(download_entry_t(url=url, filename=filename, sha256=sha256, csize=csize))
url = None url = None
return entries return entries
class download_t:
def __init__(
self,
dest_dir: pathlib.Path,
backend: downloader_t.constants_t.backend_t,
limit_rate: int,
jobs: int,
verify: bool,
show_progress: bool,
progress_mode: progress_t.constants_t.mode_t = progress_t.constants_t.mode_t.plain,
) -> None:
self.dest_dir = dest_dir
self.backend = backend
self.limit_rate = limit_rate
self.jobs = jobs
self.verify = verify
self.show_progress = show_progress
self.progress_mode = progress_mode
def _classify(
self, entries: list[download_entry_t],
) -> tuple[list[download_entry_t], int, int]:
"""Returns (to_download, already_count, already_bytes)."""
to_download: list[download_entry_t] = []
already_count = 0
already_bytes = 0
for e in entries:
dest_path = self.dest_dir / e.filename
if dest_path.exists():
if self.verify and e.sha256 != '':
if not downloader_t.verify_sha256(dest_path, e.sha256):
logger.warning(dict(msg='checksum mismatch, re-downloading', file=e.filename))
to_download.append(e)
continue
already_count += 1
already_bytes += dest_path.stat().st_size
else:
to_download.append(e)
return to_download, already_count, already_bytes
async def _download_one(self, e: download_entry_t) -> int:
dest_path = self.dest_dir / e.filename
logger.info(dict(msg='downloading', file=e.filename))
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
downloader_t.download,
e.url,
dest_path,
self.backend,
self.limit_rate,
)
sz = dest_path.stat().st_size if dest_path.exists() else 0
if self.verify and e.sha256 != '' and dest_path.exists():
ok = await loop.run_in_executor(None, downloader_t.verify_sha256, dest_path, e.sha256)
if not ok:
logger.error(dict(msg='checksum mismatch after download', file=e.filename))
logger.info(dict(msg='downloaded', file=e.filename, size=sz))
return sz
def _make_progress(
self,
entries_total: int,
already_count: int,
to_download: list[download_entry_t],
already_bytes: int,
mode: progress_t.constants_t.mode_t = progress_t.constants_t.mode_t.plain,
) -> Optional[progress_t]:
if not self.show_progress:
return None
to_download_bytes = sum(e.csize for e in to_download)
has_sizes = all(e.csize > 0 for e in to_download) and len(to_download) > 0
total_expected_bytes = already_bytes + to_download_bytes
return progress_t(
total_pkgs=entries_total,
already_pkgs=already_count,
to_download_pkgs=len(to_download),
already_bytes=already_bytes,
total_expected_bytes=total_expected_bytes,
to_download_bytes=to_download_bytes,
has_sizes=has_sizes,
mode=mode,
)
async def run(self, entries: list[download_entry_t], dry_run: bool = False) -> int:
to_download, already_count, already_bytes = self._classify(entries)
logger.info(dict(
msg='download plan',
total=len(entries),
already=already_count,
to_download=len(to_download),
))
if dry_run:
for e in to_download:
print('%s -> %s' % (e.url, e.filename))
print('total: %d to download, %d already present' % (len(to_download), already_count))
return 0
if len(to_download) == 0:
logger.info(dict(msg='nothing to download'))
if self.show_progress:
print('[%d/%d] nothing to download' % (already_count, len(entries)))
return 0
progress = self._make_progress(
len(entries), already_count, to_download, already_bytes,
mode=self.progress_mode,
)
# print initial status before first download
if progress is not None:
progress.emit()
if self.backend is downloader_t.constants_t.backend_t.aria2c and self.jobs > 1:
return await self._run_aria2c_batch(to_download, progress)
return await self._run_parallel(to_download, progress)
async def _run_aria2c_batch(
self,
to_download: list[download_entry_t],
progress: Optional[progress_t],
) -> int:
batch = [(e.url, self.dest_dir / e.filename) for e in to_download]
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
downloader_t.download_batch_aria2c,
batch,
self.limit_rate,
self.jobs,
)
for e in to_download:
dest_path = self.dest_dir / e.filename
sz = dest_path.stat().st_size if dest_path.exists() else 0
if self.verify and e.sha256 != '' and dest_path.exists():
ok = await loop.run_in_executor(None, downloader_t.verify_sha256, dest_path, e.sha256)
if not ok:
logger.error(dict(msg='checksum mismatch after download', file=e.filename))
if progress is not None:
progress.update(sz)
logger.info(dict(msg='downloaded', file=e.filename, size=sz))
if progress is not None:
progress.emit()
progress.finish()
return 0
async def _run_parallel(
self,
to_download: list[download_entry_t],
progress: Optional[progress_t],
) -> int:
sem = asyncio.Semaphore(self.jobs)
async def _bounded(e: download_entry_t) -> int:
async with sem:
return await self._download_one(e)
tasks = [asyncio.create_task(_bounded(e)) for e in to_download]
for coro in asyncio.as_completed(tasks):
try:
sz = await coro
except asyncio.CancelledError:
break
if progress is not None:
progress.update(sz)
progress.emit()
if progress is not None:
progress.finish()
downloaded = progress.downloaded_pkgs if progress else len(to_download)
logger.info(dict(
msg='download complete',
downloaded=downloaded,
total=progress.total_pkgs if progress else len(to_download),
))
return 0
def main(args: list[str]) -> int: def main(args: list[str]) -> int:
download_parser = argparse.ArgumentParser( download_parser = argparse.ArgumentParser(
prog='online-fxreader-pr34-archlinux download', prog='online-fxreader-pr34-archlinux download',
@ -259,94 +599,33 @@ def main(args: list[str]) -> int:
default=1, default=1,
help='parallel downloads (default: 1). For aria2c, passed as -j to aria2c directly.', help='parallel downloads (default: 1). For aria2c, passed as -j to aria2c directly.',
) )
download_parser.add_argument(
'--dry-run',
default=False,
action=argparse.BooleanOptionalAction,
help='print what would be downloaded without downloading',
)
download_parser.add_argument(
'--verify-checksum',
default=False,
action=argparse.BooleanOptionalAction,
help='verify sha256 checksum of existing and downloaded files',
)
download_options = download_parser.parse_args(args) opts = download_parser.parse_args(args)
dest_dir = pathlib.Path(download_options.dest_dir) dl = download_t(
dest_dir.mkdir(parents=True, exist_ok=True) dest_dir=pathlib.Path(opts.dest_dir),
backend=downloader_t.constants_t.backend_t(opts.downloader),
limit_rate=parse_rate_t.parse(opts.limit_rate),
jobs=opts.jobs,
verify=opts.verify_checksum,
show_progress=opts.progress,
progress_mode=progress_t.constants_t.mode_t(opts.progress_mode),
)
backend = downloader_t.constants_t.backend_t(download_options.downloader) entries = download_requirements_t.parse_requirements(
limit_rate = parse_rate_t.parse(download_options.limit_rate) pathlib.Path(opts.requirements).read_text()
jobs: int = download_options.jobs )
requirements_txt = pathlib.Path(download_options.requirements).read_text() return asyncio.run(dl.run(entries, dry_run=opts.dry_run))
entries = download_requirements_t.parse_requirements(requirements_txt)
# split into already-done vs to-download
to_download: list[tuple[str, str]] = []
already_count = 0
already_bytes = 0
total_bytes = 0
for url, filename in entries:
dest_path = dest_dir / filename
if dest_path.exists():
already_count += 1
sz = dest_path.stat().st_size
already_bytes += sz
total_bytes += sz
else:
to_download.append((url, filename))
# estimate total bytes (already + to_download as average of already)
avg_size = already_bytes // already_count if already_count > 0 else 10 * 1024 * 1024
total_bytes += avg_size * len(to_download)
progress: Optional[progress_t] = None
if download_options.progress:
progress = progress_t(
total=len(entries),
total_bytes=total_bytes,
already_done=already_count,
already_bytes=already_bytes,
)
if len(to_download) == 0:
print(progress.format_plain())
# aria2c with -j: batch all into single process
if backend is downloader_t.constants_t.backend_t.aria2c and jobs > 1 and len(to_download) > 0:
batch = [(url, dest_dir / filename) for url, filename in to_download]
downloader_t.download_batch_aria2c(batch, limit_rate, jobs)
if progress is not None:
for url, filename in to_download:
dest_path = dest_dir / filename
sz = dest_path.stat().st_size if dest_path.exists() else avg_size
progress.update(sz)
total_bytes = total_bytes - avg_size + sz
progress.total_bytes = total_bytes
print(progress.format_plain())
logger.info(dict(msg='download complete', count=len(entries)))
return 0
def _download_one(url: str, filename: str) -> int:
dest_path = dest_dir / filename
logger.debug(dict(msg='downloading', url=url, dest=str(dest_path)))
downloader_t.download(
url=url,
dest=dest_path,
backend=backend,
limit_rate=limit_rate,
)
return dest_path.stat().st_size if dest_path.exists() else 0
if jobs > 1 and backend is not downloader_t.constants_t.backend_t.aria2c:
with concurrent.futures.ThreadPoolExecutor(max_workers=jobs) as executor:
futures = {
executor.submit(_download_one, url, filename): (url, filename)
for url, filename in to_download
}
for future in concurrent.futures.as_completed(futures):
sz = future.result()
if progress is not None:
progress.update(sz)
print(progress.format_plain())
else:
for url, filename in to_download:
sz = _download_one(url, filename)
if progress is not None:
progress.update(sz)
print(progress.format_plain())
logger.info(dict(msg='download complete', count=len(entries)))
return 0

@ -5,7 +5,7 @@ import tempfile
import unittest import unittest
import unittest.mock import unittest.mock
from typing import Optional from typing import Any, Optional
from ..cli.download import ( from ..cli.download import (
parse_rate_t, parse_rate_t,
@ -85,8 +85,8 @@ class TestDownloadRequirementsParse(unittest.TestCase):
txt = '# https://example.com/core/bash-5.2-1-x86_64.pkg.tar.zst\nbash==5.2-1 --hash=sha256:abc123\n' txt = '# https://example.com/core/bash-5.2-1-x86_64.pkg.tar.zst\nbash==5.2-1 --hash=sha256:abc123\n'
entries = download_requirements_t.parse_requirements(txt) entries = download_requirements_t.parse_requirements(txt)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0][0], 'https://example.com/core/bash-5.2-1-x86_64.pkg.tar.zst') self.assertEqual(entries[0].url, 'https://example.com/core/bash-5.2-1-x86_64.pkg.tar.zst')
self.assertEqual(entries[0][1], 'bash-5.2-1-x86_64.pkg.tar.zst') self.assertEqual(entries[0].filename, 'bash-5.2-1-x86_64.pkg.tar.zst')
def test_multiple(self) -> None: def test_multiple(self) -> None:
txt = ( txt = (
@ -94,8 +94,8 @@ class TestDownloadRequirementsParse(unittest.TestCase):
) )
entries = download_requirements_t.parse_requirements(txt) entries = download_requirements_t.parse_requirements(txt)
self.assertEqual(len(entries), 2) self.assertEqual(len(entries), 2)
self.assertEqual(entries[0][1], 'bash-5.2-1-x86_64.pkg.tar.zst') self.assertEqual(entries[0].filename, 'bash-5.2-1-x86_64.pkg.tar.zst')
self.assertEqual(entries[1][1], 'glibc-2.38-1-x86_64.pkg.tar.zst') self.assertEqual(entries[1].filename, 'glibc-2.38-1-x86_64.pkg.tar.zst')
def test_no_url_skipped(self) -> None: def test_no_url_skipped(self) -> None:
txt = 'bash==5.2-1\n' txt = 'bash==5.2-1\n'
@ -122,18 +122,31 @@ class TestDownloader(unittest.TestCase):
def test_urllib_backend(self, mock_urlretrieve: unittest.mock.MagicMock) -> None: def test_urllib_backend(self, mock_urlretrieve: unittest.mock.MagicMock) -> None:
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
dest = pathlib.Path(tmpdir) / 'test.pkg' dest = pathlib.Path(tmpdir) / 'test.pkg'
part = dest.with_suffix(dest.suffix + '.part')
def fake_retrieve(url: str, path: str) -> None:
pathlib.Path(path).write_bytes(b'\x00')
mock_urlretrieve.side_effect = fake_retrieve
downloader_t.download( downloader_t.download(
url='https://example.com/test.pkg', url='https://example.com/test.pkg',
dest=dest, dest=dest,
backend=downloader_t.constants_t.backend_t.urllib, backend=downloader_t.constants_t.backend_t.urllib,
limit_rate=128 * 1024, limit_rate=128 * 1024,
) )
mock_urlretrieve.assert_called_once_with('https://example.com/test.pkg', str(dest)) mock_urlretrieve.assert_called_once_with('https://example.com/test.pkg', str(part))
self.assertTrue(dest.exists())
@unittest.mock.patch('subprocess.check_call') @unittest.mock.patch('subprocess.check_call')
def test_curl_backend(self, mock_check_call: unittest.mock.MagicMock) -> None: def test_curl_backend(self, mock_check_call: unittest.mock.MagicMock) -> None:
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
dest = pathlib.Path(tmpdir) / 'test.pkg' dest = pathlib.Path(tmpdir) / 'test.pkg'
part = dest.with_suffix(dest.suffix + '.part')
def fake_call(cmd: list[str], **kwargs: Any) -> None:
pathlib.Path(cmd[cmd.index('-o') + 1]).write_bytes(b'\x00')
mock_check_call.side_effect = fake_call
downloader_t.download( downloader_t.download(
url='https://example.com/test.pkg', url='https://example.com/test.pkg',
dest=dest, dest=dest,
@ -143,12 +156,19 @@ class TestDownloader(unittest.TestCase):
cmd = mock_check_call.call_args[0][0] cmd = mock_check_call.call_args[0][0]
self.assertEqual(cmd[0], 'curl') self.assertEqual(cmd[0], 'curl')
self.assertIn('--limit-rate', cmd) self.assertIn('--limit-rate', cmd)
self.assertIn(str(128 * 1024), cmd) self.assertTrue(dest.exists())
@unittest.mock.patch('subprocess.check_call') @unittest.mock.patch('subprocess.check_call')
def test_aria2c_backend(self, mock_check_call: unittest.mock.MagicMock) -> None: def test_aria2c_backend(self, mock_check_call: unittest.mock.MagicMock) -> None:
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
dest = pathlib.Path(tmpdir) / 'test.pkg' dest = pathlib.Path(tmpdir) / 'test.pkg'
def fake_call(cmd: list[str], **kwargs: Any) -> None:
d = cmd[cmd.index('-d') + 1]
o = cmd[cmd.index('-o') + 1]
pathlib.Path(d, o).write_bytes(b'\x00')
mock_check_call.side_effect = fake_call
downloader_t.download( downloader_t.download(
url='https://example.com/test.pkg', url='https://example.com/test.pkg',
dest=dest, dest=dest,
@ -158,6 +178,7 @@ class TestDownloader(unittest.TestCase):
cmd = mock_check_call.call_args[0][0] cmd = mock_check_call.call_args[0][0]
self.assertEqual(cmd[0], 'aria2c') self.assertEqual(cmd[0], 'aria2c')
self.assertIn('--max-download-limit=%d' % (1024 * 1024), cmd) self.assertIn('--max-download-limit=%d' % (1024 * 1024), cmd)
self.assertTrue(dest.exists())
class TestGroupExpansion(unittest.TestCase): class TestGroupExpansion(unittest.TestCase):

@ -48,7 +48,12 @@ def _prefill(dest: pathlib.Path, filenames: list[str]) -> None:
(dest / f).write_bytes(b'\x00' * 200) (dest / f).write_bytes(b'\x00' * 200)
def _fake_download(url: str, dest: pathlib.Path, **kwargs: object) -> None: def _fake_download(
url: str,
dest: pathlib.Path,
backend: object = None,
limit_rate: int = 0,
) -> None:
dest.parent.mkdir(parents=True, exist_ok=True) dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_bytes(b'\x00' * 200) dest.write_bytes(b'\x00' * 200)
@ -78,22 +83,22 @@ def _run(extra_args: list[str], tmpdir: str) -> tuple[int, pathlib.Path, pathlib
class TestProgressFormat(unittest.TestCase): class TestProgressFormat(unittest.TestCase):
def test_initial(self) -> None: def test_initial(self) -> None:
p = progress_t(total=10, total_bytes=100 * 1024 * 1024, already_done=3, already_bytes=30 * 1024 * 1024) p = progress_t(total_pkgs=10, already_pkgs=3, to_download_pkgs=7, already_bytes=30 * 1024 * 1024, total_expected_bytes=100 * 1024 * 1024, to_download_bytes=70 * 1024 * 1024, has_sizes=True)
txt = p.format_plain() txt = p.format_plain()
self.assertIn('[3/10]', txt) self.assertIn('[3/10]', txt)
self.assertIn('this_run=0', txt) self.assertIn('new=0/7', txt)
def test_after_updates(self) -> None: def test_after_updates(self) -> None:
p = progress_t(total=10, total_bytes=100 * 1024 * 1024, already_done=0, already_bytes=0) p = progress_t(total_pkgs=10, already_pkgs=0, to_download_pkgs=10, already_bytes=0, total_expected_bytes=100 * 1024 * 1024, to_download_bytes=100 * 1024 * 1024, has_sizes=True)
p.update(5 * 1024 * 1024) p.update(5 * 1024 * 1024)
p.update(5 * 1024 * 1024) p.update(5 * 1024 * 1024)
txt = p.format_plain() txt = p.format_plain()
self.assertIn('[2/10]', txt) self.assertIn('[2/10]', txt)
self.assertIn('this_run=2', txt) self.assertIn('new=2/10', txt)
def test_eta_and_rate(self) -> None: def test_eta_and_rate(self) -> None:
p = progress_t(total=100, total_bytes=1000 * 1024 * 1024, already_done=0, already_bytes=0) p = progress_t(total_pkgs=100, already_pkgs=0, to_download_pkgs=100, already_bytes=0, total_expected_bytes=1000 * 1024 * 1024, to_download_bytes=1000 * 1024 * 1024, has_sizes=True)
p.start_time -= 5.0 # simulate 5s elapsed for 10 pkgs → 2 pkg/s p.start_time -= 5.0
for _ in range(10): for _ in range(10):
p.update(10 * 1024 * 1024) p.update(10 * 1024 * 1024)
txt = p.format_plain() txt = p.format_plain()
@ -101,8 +106,8 @@ class TestProgressFormat(unittest.TestCase):
self.assertIn('pkg/s', txt) self.assertIn('pkg/s', txt)
def test_slow_rate_shows_s_per_pkg(self) -> None: def test_slow_rate_shows_s_per_pkg(self) -> None:
p = progress_t(total=10, total_bytes=100 * 1024 * 1024, already_done=0, already_bytes=0) p = progress_t(total_pkgs=10, already_pkgs=0, to_download_pkgs=10, already_bytes=0, total_expected_bytes=100 * 1024 * 1024, to_download_bytes=100 * 1024 * 1024, has_sizes=True)
p.start_time -= 30.0 # 30s for 1 package → 30 s/pkg p.start_time -= 30.0
p.update(10 * 1024 * 1024) p.update(10 * 1024 * 1024)
txt = p.format_plain() txt = p.format_plain()
self.assertIn('s/pkg', txt) self.assertIn('s/pkg', txt)