diff --git a/src/s2_sdk/_ops.py b/src/s2_sdk/_ops.py index 86a2f4a..89c79c9 100644 --- a/src/s2_sdk/_ops.py +++ b/src/s2_sdk/_ops.py @@ -125,7 +125,7 @@ def __init__( self._basin_clients: dict[str, HttpClient] = {} self._retrier = Retrier( should_retry_on=http_retry_on, - max_attempts=retry.max_attempts, + max_retries=retry._max_retries(), min_base_delay=retry.min_base_delay.total_seconds(), max_base_delay=retry.max_base_delay.total_seconds(), ) @@ -568,7 +568,7 @@ def __init__( self._compression = compression self._retrier = Retrier( should_retry_on=http_retry_on, - max_attempts=retry.max_attempts, + max_retries=retry._max_retries(), min_base_delay=retry.min_base_delay.total_seconds(), max_base_delay=retry.max_base_delay.total_seconds(), ) @@ -806,7 +806,7 @@ def __init__( self._encryption_key = encryption_key self._retrier = Retrier( should_retry_on=http_retry_on, - max_attempts=retry.max_attempts, + max_retries=retry._max_retries(), min_base_delay=retry.min_base_delay.total_seconds(), max_base_delay=retry.max_base_delay.total_seconds(), ) @@ -814,7 +814,7 @@ def __init__( should_retry_on=lambda e: is_safe_to_retry_unary( e, retry.append_retry_policy ), - max_attempts=retry.max_attempts, + max_retries=retry._max_retries(), min_base_delay=retry.min_base_delay.total_seconds(), max_base_delay=retry.max_base_delay.total_seconds(), ) diff --git a/src/s2_sdk/_retrier.py b/src/s2_sdk/_retrier.py index d388bdc..3b98ce1 100644 --- a/src/s2_sdk/_retrier.py +++ b/src/s2_sdk/_retrier.py @@ -1,5 +1,6 @@ import asyncio import logging +import math import random from dataclasses import dataclass from typing import Callable @@ -15,28 +16,28 @@ class Retrier: def __init__( self, should_retry_on: Callable[[Exception], bool], - max_attempts: int, + max_retries: int, min_base_delay: float = 0.1, max_base_delay: float = 1.0, ): self.should_retry_on = should_retry_on - self.max_attempts = max_attempts + self.max_retries = max_retries self.min_base_delay = min_base_delay self.max_base_delay = max_base_delay async def __call__(self, f: Callable, *args, **kwargs): - backoffs = compute_backoffs( - attempts=max(self.max_attempts - 1, 0), - min_base_delay=self.min_base_delay, - max_base_delay=self.max_base_delay, - ) + max_retries = self.max_retries attempt = 0 while True: try: return await f(*args, **kwargs) except Exception as e: - if attempt < len(backoffs) and self.should_retry_on(e): - delay = backoffs[attempt] + if attempt < max_retries and self.should_retry_on(e): + delay = compute_backoff( + attempt, + min_base_delay=self.min_base_delay, + max_base_delay=self.max_base_delay, + ) retry_after = getattr(e, "_retry_after", None) if retry_after is not None: delay = max(delay, retry_after) @@ -44,7 +45,7 @@ async def __call__(self, f: Callable, *args, **kwargs): "retrying request: error=%s backoff=%.3fs retries_remaining=%d", e, delay, - len(backoffs) - attempt - 1, + max_retries - attempt - 1, ) await asyncio.sleep(delay) attempt += 1 @@ -53,7 +54,7 @@ async def __call__(self, f: Callable, *args, **kwargs): "not retrying request: error=%s is_retryable=%s retries_exhausted=%s", e, self.should_retry_on(e), - attempt >= len(backoffs), + attempt >= max_retries, ) raise e @@ -63,17 +64,17 @@ class Attempt: value: int -def compute_backoffs( - attempts: int, +def compute_backoff( + attempt: int, min_base_delay: float = 0.1, max_base_delay: float = 1.0, -) -> list[float]: - backoffs = [] - for n in range(attempts): - base_delay = min(min_base_delay * 2**n, max_base_delay) - jitter = random.uniform(0, base_delay) - backoffs.append(base_delay + jitter) - return backoffs +) -> float: + try: + base_delay = min(math.ldexp(min_base_delay, attempt), max_base_delay) + except OverflowError: + base_delay = max_base_delay + jitter = random.uniform(0, base_delay) + return base_delay + jitter def is_safe_to_retry_unary( diff --git a/src/s2_sdk/_s2s/_append_session.py b/src/s2_sdk/_s2s/_append_session.py index 38fbb5d..e2d6e03 100644 --- a/src/s2_sdk/_s2s/_append_session.py +++ b/src/s2_sdk/_s2s/_append_session.py @@ -9,7 +9,7 @@ from s2_sdk._exceptions import ReadTimeoutError, S2ClientError from s2_sdk._frame_signal import FrameSignal from s2_sdk._mappers import append_ack_from_proto, append_input_to_proto -from s2_sdk._retrier import Attempt, compute_backoffs, is_safe_to_retry_session +from s2_sdk._retrier import Attempt, compute_backoff, is_safe_to_retry_session from s2_sdk._s2s import _stream_records_path from s2_sdk._s2s._protocol import ( Message, @@ -64,11 +64,9 @@ async def pipe_inputs(): async def retrying_inner(): inflight_inputs: deque[_InflightInput] = deque() - backoffs = compute_backoffs( - retry._max_retries(), - min_base_delay=retry.min_base_delay.total_seconds(), - max_base_delay=retry.max_base_delay.total_seconds(), - ) + max_retries = retry._max_retries() + min_base_delay = retry.min_base_delay.total_seconds() + max_base_delay = retry.max_base_delay.total_seconds() attempt = Attempt(0) try: while True: @@ -92,18 +90,22 @@ async def retrying_inner(): return except Exception as e: has_inflight = len(inflight_inputs) > 0 - if attempt.value < len(backoffs) and is_safe_to_retry_session( + if attempt.value < max_retries and is_safe_to_retry_session( e, retry.append_retry_policy, has_inflight, frame_signal, ): - backoff = backoffs[attempt.value] + backoff = compute_backoff( + attempt.value, + min_base_delay=min_base_delay, + max_base_delay=max_base_delay, + ) logger.debug( "retrying append session: error=%s backoff=%.3fs retries_remaining=%d", e, backoff, - len(backoffs) - attempt.value - 1, + max_retries - attempt.value - 1, ) await asyncio.sleep(backoff) attempt.value += 1 @@ -111,7 +113,7 @@ async def retrying_inner(): logger.debug( "not retrying append session: error=%s retries_exhausted=%s", e, - attempt.value >= len(backoffs), + attempt.value >= max_retries, ) raise finally: diff --git a/src/s2_sdk/_s2s/_read_session.py b/src/s2_sdk/_s2s/_read_session.py index 0c70b84..3e01b76 100644 --- a/src/s2_sdk/_s2s/_read_session.py +++ b/src/s2_sdk/_s2s/_read_session.py @@ -8,7 +8,7 @@ from s2_sdk._client import HttpClient from s2_sdk._exceptions import ReadTimeoutError from s2_sdk._mappers import read_batch_from_proto, read_limit_params, read_start_params -from s2_sdk._retrier import Attempt, compute_backoffs, http_retry_on +from s2_sdk._retrier import Attempt, compute_backoff, http_retry_on from s2_sdk._s2s import _stream_records_path from s2_sdk._s2s._protocol import parse_error_info, read_messages from s2_sdk._types import ( @@ -40,11 +40,9 @@ async def run_read_session( encryption_key: str | None = None, ) -> AsyncIterable[ReadBatch]: params = _build_read_params(start, limit, until_timestamp, clamp_to_tail, wait) - backoffs = compute_backoffs( - retry._max_retries(), - min_base_delay=retry.min_base_delay.total_seconds(), - max_base_delay=retry.max_base_delay.total_seconds(), - ) + max_retries = retry._max_retries() + min_base_delay = retry.min_base_delay.total_seconds() + max_base_delay = retry.max_base_delay.total_seconds() attempt = Attempt(0) remaining_count = limit.count if limit and limit.count is not None else None @@ -122,13 +120,17 @@ async def run_read_session( return except Exception as e: - if attempt.value < len(backoffs) and http_retry_on(e): - backoff = backoffs[attempt.value] + if attempt.value < max_retries and http_retry_on(e): + backoff = compute_backoff( + attempt.value, + min_base_delay=min_base_delay, + max_base_delay=max_base_delay, + ) logger.debug( "retrying read session: error=%s backoff=%.3fs retries_remaining=%d", e, backoff, - len(backoffs) - attempt.value - 1, + max_retries - attempt.value - 1, ) await asyncio.sleep(backoff) attempt.value += 1 @@ -137,7 +139,7 @@ async def run_read_session( "not retrying read session: error=%s is_retryable=%s retries_exhausted=%s", e, http_retry_on(e), - attempt.value >= len(backoffs), + attempt.value >= max_retries, ) raise e diff --git a/tests/test_retrier.py b/tests/test_retrier.py new file mode 100644 index 0000000..da1a815 --- /dev/null +++ b/tests/test_retrier.py @@ -0,0 +1,26 @@ +import sys + +import pytest + +from s2_sdk._retrier import compute_backoff + + +class TestComputeBackoff: + @pytest.mark.parametrize( + ("attempt", "expected_min", "expected_max"), + [ + (0, 0.1, 0.2), + (1, 0.2, 0.4), + (2, 0.4, 0.8), + (3, 0.8, 1.6), + (4, 1.0, 2.0), + (5, 1.0, 2.0), + ], + ) + def test_backoff_range(self, attempt, expected_min, expected_max): + backoff = compute_backoff(attempt, min_base_delay=0.1, max_base_delay=1.0) + assert expected_min <= backoff <= expected_max + + def test_backoff_caps_for_max_int_attempt(self): + backoff = compute_backoff(sys.maxsize, min_base_delay=0.1, max_base_delay=1.0) + assert 1.0 <= backoff <= 2.0 diff --git a/tests/test_retry.py b/tests/test_retry.py deleted file mode 100644 index a2d6682..0000000 --- a/tests/test_retry.py +++ /dev/null @@ -1,145 +0,0 @@ -from s2_sdk._exceptions import ConnectError, ReadTimeoutError, S2ServerError -from s2_sdk._frame_signal import FrameSignal -from s2_sdk._retrier import ( - compute_backoffs, - has_no_side_effects, - is_safe_to_retry_session, - is_safe_to_retry_unary, -) -from s2_sdk._types import AppendRetryPolicy - - -class TestComputeBackoffs: - def test_backoffs_count(self): - backoffs = compute_backoffs(5) - assert len(backoffs) == 5 - - def test_backoffs_range(self): - backoffs = compute_backoffs(5, min_base_delay=0.1, max_base_delay=1.0) - for b in backoffs: - # Each backoff is base_delay + jitter where jitter in [0, base_delay] - # so max is 2 * max_base_delay - assert 0 <= b <= 2.0 - - def test_backoffs_empty(self): - backoffs = compute_backoffs(0) - assert backoffs == [] - - -class TestHasNoSideEffects: - def test_rate_limited(self): - e = S2ServerError("rate_limited", "rate limited", 429) - assert has_no_side_effects(e) is True - - def test_hot_server(self): - e = S2ServerError("hot_server", "hot server", 502) - assert has_no_side_effects(e) is True - - def test_other_server_error(self): - e = S2ServerError("internal", "internal", 500) - assert has_no_side_effects(e) is False - - def test_429_wrong_code(self): - e = S2ServerError("throttled", "throttled", 429) - assert has_no_side_effects(e) is False - - def test_connect_error_connection_refused(self): - cause = ConnectionRefusedError("connection refused") - e = ConnectError("connection refused") - e.__cause__ = cause - assert has_no_side_effects(e) is True - - def test_connect_error_without_refused_cause(self): - e = ConnectError("connection timed out") - assert has_no_side_effects(e) is False - - def test_other_transport_error(self): - e = ReadTimeoutError("timeout") - assert has_no_side_effects(e) is False - - def test_generic_exception(self): - e = RuntimeError("something") - assert has_no_side_effects(e) is False - - -class TestSafeToRetryUnary: - def test_no_policy_retries_retryable(self): - e = S2ServerError("internal", "error", 500) - assert is_safe_to_retry_unary(e, None) is True - - def test_all_policy_retries_retryable(self): - e = S2ServerError("internal", "error", 500) - assert is_safe_to_retry_unary(e, AppendRetryPolicy.ALL) is True - - def test_all_policy_skips_non_retryable(self): - e = S2ServerError("bad_request", "bad request", 400) - assert is_safe_to_retry_unary(e, AppendRetryPolicy.ALL) is False - - def test_nse_policy_retries_no_side_effect_error(self): - e = S2ServerError("rate_limited", "rate limited", 429) - assert is_safe_to_retry_unary(e, AppendRetryPolicy.NO_SIDE_EFFECTS) is True - - def test_nse_policy_retries_connect_error(self): - cause = ConnectionRefusedError("connection refused") - e = ConnectError("connection refused") - e.__cause__ = cause - assert is_safe_to_retry_unary(e, AppendRetryPolicy.NO_SIDE_EFFECTS) is True - - def test_nse_policy_skips_ambiguous_error(self): - e = S2ServerError("internal", "error", 500) - assert is_safe_to_retry_unary(e, AppendRetryPolicy.NO_SIDE_EFFECTS) is False - - def test_nse_policy_skips_timeout(self): - e = ReadTimeoutError("timeout") - assert is_safe_to_retry_unary(e, AppendRetryPolicy.NO_SIDE_EFFECTS) is False - - -class TestSafeToRetrySession: - def test_all_policy_always_retries(self): - e = S2ServerError("internal", "error", 500) - assert is_safe_to_retry_session(e, AppendRetryPolicy.ALL, True, None) is True - - def test_all_policy_skips_non_retryable(self): - e = S2ServerError("bad_request", "bad request", 400) - assert is_safe_to_retry_session(e, AppendRetryPolicy.ALL, False, None) is False - - def test_nse_no_inflight_retries(self): - e = S2ServerError("internal", "error", 500) - assert ( - is_safe_to_retry_session(e, AppendRetryPolicy.NO_SIDE_EFFECTS, False, None) - is True - ) - - def test_nse_inflight_signal_not_set_retries(self): - signal = FrameSignal() - e = S2ServerError("internal", "error", 500) - assert ( - is_safe_to_retry_session(e, AppendRetryPolicy.NO_SIDE_EFFECTS, True, signal) - is True - ) - - def test_nse_inflight_signal_set_no_side_effects_retries(self): - signal = FrameSignal() - signal.signal() - e = S2ServerError("rate_limited", "rate limited", 429) - assert ( - is_safe_to_retry_session(e, AppendRetryPolicy.NO_SIDE_EFFECTS, True, signal) - is True - ) - - def test_nse_inflight_signal_set_ambiguous_skips(self): - signal = FrameSignal() - signal.signal() - e = S2ServerError("internal", "error", 500) - assert ( - is_safe_to_retry_session(e, AppendRetryPolicy.NO_SIDE_EFFECTS, True, signal) - is False - ) - - def test_nse_inflight_no_signal_skips(self): - """No FrameSignal at all + inflight + ambiguous error -> not safe.""" - e = S2ServerError("internal", "error", 500) - assert ( - is_safe_to_retry_session(e, AppendRetryPolicy.NO_SIDE_EFFECTS, True, None) - is False - )