Skip to content
This repository has been archived by the owner on Dec 15, 2020. It is now read-only.

execute_futures: add max_in_memory_pages parameter #77

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions aiocassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class _Paginator:

def __init__(self, request, *, executor, loop):
def __init__(self, request, *, executor, loop, max_in_memory_pages=None):
self.cassandra_fut = None

self._request = request
Expand All @@ -30,10 +30,35 @@ def __init__(self, request, *, executor, loop):
self._deque = deque()
self._exc = None
self._drain_event = asyncio.Event(loop=loop)
self._no_fetching_page = asyncio.Event(loop=loop)
self._finish_event = asyncio.Event(loop=loop)
self._exit_event = Event()

self.__pages = set()
self._max_in_memory_pages = max_in_memory_pages
self._page_size = None

def _start_fetching_next_page(self):
self._no_fetching_page.clear()
_fn = self.cassandra_fut.start_fetching_next_page
fut = self._loop.run_in_executor(self._executor, _fn)
self.__pages.add(fut)
fut.add_done_callback(self.__pages.remove)

def _maybe_start_prefetch_next_page(self):
if self._finish_event.is_set() or not self._no_fetching_page.is_set():
return

if not self.cassandra_fut.has_more_pages:
self._finish_event.set()
return

if self._max_in_memory_pages is None:
pass
elif len(self._deque) > self._page_size * (self._max_in_memory_pages - 1):
return

self._start_fetching_next_page()

def _handle_page(self, rows):
if self._exit_event.is_set():
Expand All @@ -42,19 +67,15 @@ def _handle_page(self, rows):
'Paginator is closed, skipping new %i records', _len)
return

if self._page_size is None:
self._page_size = len(rows)

for row in rows:
self._deque.append(row)

self._loop.call_soon_threadsafe(self._no_fetching_page.set)
self._loop.call_soon_threadsafe(self._drain_event.set)

if self.cassandra_fut.has_more_pages:
_fn = self.cassandra_fut.start_fetching_next_page
fut = self._loop.run_in_executor(self._executor, _fn)
self.__pages.add(fut)
fut.add_done_callback(self.__pages.remove)
return

self._loop.call_soon_threadsafe(self._finish_event.set)
self._loop.call_soon_threadsafe(self._maybe_start_prefetch_next_page)

def _handle_err(self, exc):
self._exc = exc
Expand Down Expand Up @@ -102,8 +123,11 @@ async def _paginator(self):
if self._exc is not None:
raise self._exc

self._maybe_start_prefetch_next_page()

while self._deque:
await yield_(self._deque.popleft())
self._maybe_start_prefetch_next_page()

await asyncio.wait(
(
Expand Down Expand Up @@ -153,12 +177,13 @@ async def execute_future(self, *args, **kwargs):
return await asyncio_fut


def execute_futures(self, *args, **kwargs):
def execute_futures(self, *args, max_in_memory_pages=None, **kwargs):
_request = partial(self.execute_async, *args, **kwargs)
return _Paginator(
_request,
executor=self._asyncio_executor,
loop=self._asyncio_loop
loop=self._asyncio_loop,
max_in_memory_pages=max_in_memory_pages
)


Expand Down
19 changes: 19 additions & 0 deletions tests/test_aiocassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,25 @@ async def test_execute_futures_simple_statement(cassandra):
assert len(ret) != 0


@pytest.mark.asyncio
async def test_execute_futures_simple_statement_limit_pages(cassandra):
cql = 'SELECT * FROM system.size_estimates LIMIT 50;'
statement = SimpleStatement(cql, fetch_size=10)

ret = []

async with cassandra.execute_futures(statement, max_in_memory_pages=3) as paginator:
await asyncio.sleep(0.5) # wait for fetching pages
assert len(paginator._deque) == 30
async for row in paginator:
await asyncio.sleep(0.2) # slow down consumer
assert isinstance(row, tuple)
assert len(paginator._deque) <= 30
ret.append(row)

assert len(ret) == 50


@pytest.mark.asyncio
async def test_execute_futures_break(cassandra):
cql = 'SELECT * FROM system.size_estimates;'
Expand Down