Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/s2_sdk/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
self._basin_clients: dict[str, HttpClient] = {}
self._retrier = Retrier(
should_retry_on=http_retry_on,
max_attempts=retry.max_attempts,
max_retries=retry._max_retries(),
min_base_delay=retry.min_base_delay.total_seconds(),
max_base_delay=retry.max_base_delay.total_seconds(),
)
Expand Down Expand Up @@ -568,7 +568,7 @@ def __init__(
self._compression = compression
self._retrier = Retrier(
should_retry_on=http_retry_on,
max_attempts=retry.max_attempts,
max_retries=retry._max_retries(),
min_base_delay=retry.min_base_delay.total_seconds(),
max_base_delay=retry.max_base_delay.total_seconds(),
)
Expand Down Expand Up @@ -806,15 +806,15 @@ def __init__(
self._encryption_key = encryption_key
self._retrier = Retrier(
should_retry_on=http_retry_on,
max_attempts=retry.max_attempts,
max_retries=retry._max_retries(),
min_base_delay=retry.min_base_delay.total_seconds(),
max_base_delay=retry.max_base_delay.total_seconds(),
)
self._append_retrier = Retrier(
should_retry_on=lambda e: is_safe_to_retry_unary(
e, retry.append_retry_policy
),
max_attempts=retry.max_attempts,
max_retries=retry._max_retries(),
min_base_delay=retry.min_base_delay.total_seconds(),
max_base_delay=retry.max_base_delay.total_seconds(),
)
Expand Down
41 changes: 21 additions & 20 deletions src/s2_sdk/_retrier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import math
import random
from dataclasses import dataclass
from typing import Callable
Expand All @@ -15,36 +16,36 @@ class Retrier:
def __init__(
self,
should_retry_on: Callable[[Exception], bool],
max_attempts: int,
max_retries: int,
min_base_delay: float = 0.1,
max_base_delay: float = 1.0,
):
self.should_retry_on = should_retry_on
self.max_attempts = max_attempts
self.max_retries = max_retries
self.min_base_delay = min_base_delay
self.max_base_delay = max_base_delay

async def __call__(self, f: Callable, *args, **kwargs):
backoffs = compute_backoffs(
attempts=max(self.max_attempts - 1, 0),
min_base_delay=self.min_base_delay,
max_base_delay=self.max_base_delay,
)
max_retries = self.max_retries
attempt = 0
while True:
try:
return await f(*args, **kwargs)
except Exception as e:
if attempt < len(backoffs) and self.should_retry_on(e):
delay = backoffs[attempt]
if attempt < max_retries and self.should_retry_on(e):
delay = compute_backoff(
attempt,
min_base_delay=self.min_base_delay,
max_base_delay=self.max_base_delay,
)
retry_after = getattr(e, "_retry_after", None)
if retry_after is not None:
delay = max(delay, retry_after)
logger.debug(
"retrying request: error=%s backoff=%.3fs retries_remaining=%d",
e,
delay,
len(backoffs) - attempt - 1,
max_retries - attempt - 1,
)
await asyncio.sleep(delay)
attempt += 1
Expand All @@ -53,7 +54,7 @@ async def __call__(self, f: Callable, *args, **kwargs):
"not retrying request: error=%s is_retryable=%s retries_exhausted=%s",
e,
self.should_retry_on(e),
attempt >= len(backoffs),
attempt >= max_retries,
)
raise e

Expand All @@ -63,17 +64,17 @@ class Attempt:
value: int


def compute_backoffs(
attempts: int,
def compute_backoff(
attempt: int,
min_base_delay: float = 0.1,
max_base_delay: float = 1.0,
) -> list[float]:
backoffs = []
for n in range(attempts):
base_delay = min(min_base_delay * 2**n, max_base_delay)
jitter = random.uniform(0, base_delay)
backoffs.append(base_delay + jitter)
return backoffs
) -> float:
try:
base_delay = min(math.ldexp(min_base_delay, attempt), max_base_delay)
except OverflowError:
base_delay = max_base_delay
jitter = random.uniform(0, base_delay)
return base_delay + jitter


def is_safe_to_retry_unary(
Expand Down
22 changes: 12 additions & 10 deletions src/s2_sdk/_s2s/_append_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from s2_sdk._exceptions import ReadTimeoutError, S2ClientError
from s2_sdk._frame_signal import FrameSignal
from s2_sdk._mappers import append_ack_from_proto, append_input_to_proto
from s2_sdk._retrier import Attempt, compute_backoffs, is_safe_to_retry_session
from s2_sdk._retrier import Attempt, compute_backoff, is_safe_to_retry_session
from s2_sdk._s2s import _stream_records_path
from s2_sdk._s2s._protocol import (
Message,
Expand Down Expand Up @@ -64,11 +64,9 @@ async def pipe_inputs():

async def retrying_inner():
inflight_inputs: deque[_InflightInput] = deque()
backoffs = compute_backoffs(
retry._max_retries(),
min_base_delay=retry.min_base_delay.total_seconds(),
max_base_delay=retry.max_base_delay.total_seconds(),
)
max_retries = retry._max_retries()
min_base_delay = retry.min_base_delay.total_seconds()
max_base_delay = retry.max_base_delay.total_seconds()
attempt = Attempt(0)
try:
while True:
Expand All @@ -92,26 +90,30 @@ async def retrying_inner():
return
except Exception as e:
has_inflight = len(inflight_inputs) > 0
if attempt.value < len(backoffs) and is_safe_to_retry_session(
if attempt.value < max_retries and is_safe_to_retry_session(
e,
retry.append_retry_policy,
has_inflight,
frame_signal,
):
backoff = backoffs[attempt.value]
backoff = compute_backoff(
attempt.value,
min_base_delay=min_base_delay,
max_base_delay=max_base_delay,
)
logger.debug(
"retrying append session: error=%s backoff=%.3fs retries_remaining=%d",
e,
backoff,
len(backoffs) - attempt.value - 1,
max_retries - attempt.value - 1,
)
await asyncio.sleep(backoff)
attempt.value += 1
else:
logger.debug(
"not retrying append session: error=%s retries_exhausted=%s",
e,
attempt.value >= len(backoffs),
attempt.value >= max_retries,
)
raise
finally:
Expand Down
22 changes: 12 additions & 10 deletions src/s2_sdk/_s2s/_read_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from s2_sdk._client import HttpClient
from s2_sdk._exceptions import ReadTimeoutError
from s2_sdk._mappers import read_batch_from_proto, read_limit_params, read_start_params
from s2_sdk._retrier import Attempt, compute_backoffs, http_retry_on
from s2_sdk._retrier import Attempt, compute_backoff, http_retry_on
from s2_sdk._s2s import _stream_records_path
from s2_sdk._s2s._protocol import parse_error_info, read_messages
from s2_sdk._types import (
Expand Down Expand Up @@ -40,11 +40,9 @@ async def run_read_session(
encryption_key: str | None = None,
) -> AsyncIterable[ReadBatch]:
params = _build_read_params(start, limit, until_timestamp, clamp_to_tail, wait)
backoffs = compute_backoffs(
retry._max_retries(),
min_base_delay=retry.min_base_delay.total_seconds(),
max_base_delay=retry.max_base_delay.total_seconds(),
)
max_retries = retry._max_retries()
min_base_delay = retry.min_base_delay.total_seconds()
max_base_delay = retry.max_base_delay.total_seconds()
attempt = Attempt(0)

remaining_count = limit.count if limit and limit.count is not None else None
Expand Down Expand Up @@ -122,13 +120,17 @@ async def run_read_session(

return
except Exception as e:
if attempt.value < len(backoffs) and http_retry_on(e):
backoff = backoffs[attempt.value]
if attempt.value < max_retries and http_retry_on(e):
backoff = compute_backoff(
attempt.value,
min_base_delay=min_base_delay,
max_base_delay=max_base_delay,
)
logger.debug(
"retrying read session: error=%s backoff=%.3fs retries_remaining=%d",
e,
backoff,
len(backoffs) - attempt.value - 1,
max_retries - attempt.value - 1,
)
await asyncio.sleep(backoff)
attempt.value += 1
Expand All @@ -137,7 +139,7 @@ async def run_read_session(
"not retrying read session: error=%s is_retryable=%s retries_exhausted=%s",
e,
http_retry_on(e),
attempt.value >= len(backoffs),
attempt.value >= max_retries,
)
raise e

Expand Down
26 changes: 26 additions & 0 deletions tests/test_retrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sys

import pytest

from s2_sdk._retrier import compute_backoff


class TestComputeBackoff:
@pytest.mark.parametrize(
("attempt", "expected_min", "expected_max"),
[
(0, 0.1, 0.2),
(1, 0.2, 0.4),
(2, 0.4, 0.8),
(3, 0.8, 1.6),
(4, 1.0, 2.0),
(5, 1.0, 2.0),
],
)
def test_backoff_range(self, attempt, expected_min, expected_max):
backoff = compute_backoff(attempt, min_base_delay=0.1, max_base_delay=1.0)
assert expected_min <= backoff <= expected_max

def test_backoff_caps_for_max_int_attempt(self):
backoff = compute_backoff(sys.maxsize, min_base_delay=0.1, max_base_delay=1.0)
assert 1.0 <= backoff <= 2.0
Comment thread
quettabit marked this conversation as resolved.
Loading
Loading