Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Code quality check
run: uv run poe cq-check
- name: Unit tests
run: uv run pytest tests/ -v -m 'not (account or basin or stream or metrics)'
run: uv run pytest tests/ -v -m 'not (account or basin or stream or metrics or correctness)'
- name: Check docs build
working-directory: ./docs
run: |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@ e2e-tests = "uv run pytest tests/ -v -s -m 'account or basin or stream'"
e2e-account-tests = "uv run pytest tests/ -v -s -m account"
e2e-basin-tests = "uv run pytest tests/ -v -s -m basin"
e2e-stream-tests = "uv run pytest tests/ -v -s -m stream"
correctness-tests = "uv run pytest tests/ -v -s -m correctness"
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ markers =
stream: tests for stream operations
metrics: tests for metrics operations
access_tokens: tests for access token operations
correctness: correctness tests
39 changes: 28 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import pytest_asyncio

from s2_sdk import S2, Compression, Endpoints, S2Basin, S2Stream
from s2_sdk import S2, Compression, Endpoints, Retry, S2Basin, S2Stream

pytest_plugins = ["pytest_asyncio"]

Expand Down Expand Up @@ -50,22 +50,35 @@ def endpoints() -> Endpoints | None:
return None


@pytest.fixture(scope="session")
def retry() -> Retry | None:
return None


@pytest_asyncio.fixture(scope="session")
async def s2(
access_token: str, compression: Compression, endpoints: Endpoints | None
access_token: str,
compression: Compression,
endpoints: Endpoints | None,
retry: Retry | None,
) -> AsyncGenerator[S2, None]:
async with S2(access_token, endpoints=endpoints, compression=compression) as s2:
async with S2(
access_token,
endpoints=endpoints,
compression=compression,
retry=retry,
) as s2:
yield s2


@pytest.fixture
def basin_name() -> str:
return _basin_name()
def basin_name(basin_prefix: str) -> str:
return _basin_name(basin_prefix)


@pytest.fixture
def basin_names() -> list[str]:
return [_basin_name() for _ in range(3)]
def basin_names(basin_prefix: str) -> list[str]:
return [_basin_name(basin_prefix) for _ in range(3)]


@pytest.fixture
Expand Down Expand Up @@ -94,8 +107,8 @@ async def basin(s2: S2, basin_name: str) -> AsyncGenerator[S2Basin, None]:


@pytest_asyncio.fixture(scope="class")
async def shared_basin(s2: S2) -> AsyncGenerator[S2Basin, None]:
basin_name = _basin_name()
async def shared_basin(s2: S2, basin_prefix: str) -> AsyncGenerator[S2Basin, None]:
basin_name = _basin_name(basin_prefix)
await s2.create_basin(name=basin_name)

try:
Expand All @@ -117,8 +130,12 @@ async def stream(
await basin.delete_stream(stream_name)


def _basin_name() -> str:
return f"{BASIN_PREFIX}-{uuid.uuid4().hex[:8]}"
def _basin_name(prefix: str) -> str:
suffix = uuid.uuid4().hex[:8]
prefix = prefix.strip("-")[: 48 - len(suffix) - 1].strip("-")
if not prefix:
return suffix
return f"{prefix}-{suffix}"


def _stream_name() -> str:
Expand Down
71 changes: 71 additions & 0 deletions tests/test_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import asyncio
import sys

import pytest

from s2_sdk import Batching, ReadLimit, Record, Retry, S2Stream, SeqNum

TOTAL_RECORDS = 1024


@pytest.fixture(scope="session")
def retry() -> Retry:
return Retry(max_attempts=sys.maxsize)


@pytest.fixture(scope="session")
def basin_prefix() -> str:
return "python-correctness"


@pytest.mark.correctness
@pytest.mark.asyncio
async def test_concurrent_producer_and_consumer_remain_gapless(stream: S2Stream):
async def read_records() -> None:
highest_contiguous_index = -1
last_seq_num: int | None = None
observed_records = 0

async for batch in stream.read_session(
start=SeqNum(0), limit=ReadLimit(count=TOTAL_RECORDS), wait=60
):
for record in batch.records:
assert observed_records < TOTAL_RECORDS

seq_num = record.seq_num
if last_seq_num is None:
assert seq_num == 0
else:
assert seq_num == last_seq_num + 1
last_seq_num = seq_num

body = record.body.decode()
index = int(body)
assert 0 <= index < TOTAL_RECORDS
assert index <= highest_contiguous_index + 1

if index == highest_contiguous_index + 1:
highest_contiguous_index = index
observed_records += 1

assert highest_contiguous_index == TOTAL_RECORDS - 1
assert last_seq_num == TOTAL_RECORDS - 1
assert observed_records == TOTAL_RECORDS

async def append_records() -> None:
async with stream.producer(batching=Batching(max_records=16)) as producer:
tickets = []
for i in range(TOTAL_RECORDS):
ticket = await producer.submit(Record(body=str(i).encode()))
tickets.append(ticket)

for ticket in tickets:
ack = await ticket
assert ack.seq_num >= 0
Comment thread
quettabit marked this conversation as resolved.

async with asyncio.TaskGroup() as task_group:
task_group.create_task(read_records())
task_group.create_task(append_records())

tail = await stream.check_tail()
assert tail.seq_num == TOTAL_RECORDS
Loading