Source code for clouddrift.adapters.utils

import concurrent.futures
import logging
import os
from datetime import datetime
from io import BufferedIOBase
from typing import Callable, Sequence

import requests
from requests import Response
from tenacity import (
    RetryCallState,
    WrappedFn,
    retry,
    retry_if_exception,
    stop_after_attempt,
    wait_exponential_jitter,
)
from tqdm import tqdm

_DISABLE_SHOW_PROGRESS = False  # purely to de-noise our test suite output, should never be used/configured outside of that.


def _before_call(rcs: RetryCallState):
    if rcs.attempt_number > 1:
        src = rcs.args[0]
        dst = "io-buffer" if isinstance(rcs.args[1], BufferedIOBase) else rcs.args[1]
        _logger.warn(
            f"retrying download request for (dst, src): {(src, dst)}, attempt: {rcs.attempt_number}"
        )


_CHUNK_SIZE = 1024
_logger = logging.getLogger(__name__)
_standard_retry_protocol: Callable[[WrappedFn], WrappedFn] = retry(
    retry=retry_if_exception(
        lambda ex: isinstance(ex, (requests.Timeout, requests.HTTPError))
    ),
    wait=wait_exponential_jitter(
        initial=1.25
    ),  # ~ 20-25 minutes total time before completely failing
    stop=stop_after_attempt(10),
    before=_before_call,
)


[docs] def download_with_progress( download_map: Sequence[tuple[str, BufferedIOBase | str, float | None]], show_list_progress: bool | None = None, desc: str = "Downloading files", custom_retry_protocol: Callable[[WrappedFn], WrappedFn] | None = None, ): if show_list_progress is None: show_list_progress = len(download_map) > 20 if custom_retry_protocol is None: retry_protocol = _standard_retry_protocol else: retry_protocol = custom_retry_protocol # type: ignore executor = concurrent.futures.ThreadPoolExecutor() futures: dict[concurrent.futures.Future, tuple[str, BufferedIOBase | str]] = dict() bar = None for src, dst, exp_size in download_map: futures[ executor.submit( retry_protocol(_download_with_progress), src, dst, exp_size or 0, not show_list_progress, ) ] = (src, dst) try: if show_list_progress: bar = tqdm( desc=desc, total=len(futures), unit="Files", disable=_DISABLE_SHOW_PROGRESS, ) for fut in concurrent.futures.as_completed(futures): (src, dst) = futures[fut] ex = fut.exception(0) if ex is None: _logger.debug(f"Finished download job: ({src}, {dst})") if bar is not None: bar.update(1) else: raise ex except Exception as e: _logger.error( f"Got the following exception: {str(e)}, cancelling all other jobs and cleaning up \ any created resources." ) for x in futures.keys(): (src, dst) = futures[x] if isinstance(dst, (str,)) and os.path.exists(dst) and not x.done(): os.remove(dst) if not x.done(): x.cancel() raise e finally: executor.shutdown(True) if bar is not None: bar.close()
def _download_with_progress( url: str, output: BufferedIOBase | str, expected_size: float, show_progress: bool, ): if isinstance(output, str) and os.path.exists(output): _logger.debug(f"File exists {output} checking for updates...") local_last_modified = os.path.getmtime(output) # Get last modified time of the remote file with requests.head(url, timeout=5) as res: if "Last-Modified" in res.headers: remote_last_modified = datetime.strptime( res.headers.get("Last-Modified", ""), "%a, %d %b %Y %H:%M:%S %Z", ) # compare with local modified time if local_last_modified >= remote_last_modified.timestamp(): _logger.debug(f"File: {output} is up to date; skip download.") return else: _logger.warning( "Cannot determine if the file has been updated on the remote source. " + "'Last-Modified' header not present in server response." ) _logger.debug(f"Downloading from {url} to {output}...") force_close = False response: Response | None = None buffer: BufferedIOBase | None = None bar = None try: response = requests.get(url, timeout=5, stream=True) if isinstance(output, str): buffer = open(output, "wb") else: buffer = output if show_progress: bar = tqdm( desc=url, total=float(response.headers.get("Content-Length", expected_size)), unit="B", unit_scale=True, unit_divisor=1024, nrows=2, disable=_DISABLE_SHOW_PROGRESS, ) for chunk in response.iter_content(_CHUNK_SIZE): if not chunk: break buffer.write(chunk) if bar is not None: bar.update(len(chunk)) except Exception as e: force_close = True error_msg = f"Error downloading data file: {url} to: {output}, error: {e}" _logger.debug(error_msg) raise e finally: if response is not None: response.close() if buffer is not None and ( not isinstance(output, BufferedIOBase) or force_close ): _logger.debug(f"closing buffer {buffer}") buffer.close() if bar is not None: bar.close() __all__ = ["download_with_progress"]