Skip to content

Commit

Permalink
.catch: accept multiple exception types
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Jan 26, 2025
1 parent 8beb852 commit d63305d
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 38 deletions.
21 changes: 12 additions & 9 deletions streamable/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,22 @@

def catch(
iterator: Iterator[T],
kind: Type[Exception] = Exception,
when: Callable[[Exception], Any] = bool,
kind: Optional[Type[Exception]] = Exception,
*others: Type[Exception],
when: Optional[Callable[[Exception], Any]] = None,
replacement: T = NO_REPLACEMENT, # type: ignore
finally_raise: bool = False,
) -> Iterator[T]:
validate_iterator(iterator)
return CatchIterator(
iterator,
kind,
when,
replacement,
finally_raise=finally_raise,
)
if kind or others:
return CatchIterator(
iterator,
(kind, *others) if kind else others,
when,
replacement,
finally_raise=finally_raise,
)
return iterator


def distinct(
Expand Down
12 changes: 6 additions & 6 deletions streamable/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ class CatchIterator(Iterator[T]):
def __init__(
self,
iterator: Iterator[T],
kind: Type[Exception],
when: Callable[[Exception], Any],
kinds: Tuple[Type[Exception], ...],
when: Optional[Callable[[Exception], Any]],
replacement: T,
finally_raise: bool,
) -> None:
validate_iterator(iterator)
self.iterator = iterator
self.kind = kind
self.when = wrap_error(when, StopIteration)
self.kinds = kinds
self.when = wrap_error(when, StopIteration) if when else None
self.replacement = replacement
self.finally_raise = finally_raise
self._to_be_finally_raised: Optional[Exception] = None
Expand All @@ -85,8 +85,8 @@ def __next__(self) -> T:
self._to_be_finally_raised = None
raise exception
raise
except Exception as exception:
if isinstance(exception, self.kind) and self.when(exception):
except self.kinds as exception:
if not self.when or self.when(exception):
if self._to_be_finally_raised is None:
self._to_be_finally_raised = exception
if self.replacement is not NO_REPLACEMENT:
Expand Down
27 changes: 19 additions & 8 deletions streamable/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,24 +114,33 @@ def accept(self, visitor: "Visitor[V]") -> V:

def catch(
self,
kind: Type[Exception] = Exception,
when: Callable[[Exception], Any] = bool,
kind: Optional[Type[Exception]] = Exception,
*others: Type[Exception],
when: Optional[Callable[[Exception], Any]] = None,
replacement: T = NO_REPLACEMENT, # type: ignore
finally_raise: bool = False,
) -> "Stream[T]":
"""
Catches the upstream exceptions if they are instances of `kind` and they satisfy the `when` predicate.
Catches the upstream exceptions if they are instances of `kind` (or `others`) and they satisfy the `when` predicate.
Args:
kind (Type[Exception], optional): The type of exceptions to catch. (default: catches base Exception)
when (Callable[[Exception], Any], optional): An additional condition that must be satisfied to catch the exception, i.e. `when(exception)` must be truthy. (default: no additional condition)
kind (Optional[Type[Exception]], optional): The type of exceptions to catch. (default: catches Exception)
*others (Type[Exception], optional): Additional types of exceptions to catch.
when (Optional[Callable[[Exception], Any]], optional): An additional condition that must be satisfied to catch the exception, i.e. `when(exception)` must be truthy. (default: no additional condition)
replacement (T, optional): The value to yield when an exception is catched. (default: do not yield any replacement value)
finally_raise (bool, optional): If True the first catched exception is raised when upstream's iteration ends. (default: iteration ends without raising)
Returns:
Stream[T]: A stream of upstream elements catching the eligible exceptions.
"""
return CatchStream(self, kind, when, replacement, finally_raise)
return CatchStream(
self,
kind,
*others,
when=when,
replacement=replacement,
finally_raise=finally_raise,
)

def count(self) -> int:
"""
Expand Down Expand Up @@ -539,13 +548,15 @@ class CatchStream(DownStream[T, T]):
def __init__(
self,
upstream: Stream[T],
kind: Type[Exception],
when: Callable[[Exception], Any],
kind: Optional[Type[Exception]],
*others: Type[Exception],
when: Optional[Callable[[Exception], Any]],
replacement: T,
finally_raise: bool,
) -> None:
super().__init__(upstream)
self._kind = kind
self._others = others
self._when = when
self._replacement = replacement
self._finally_raise = finally_raise
Expand Down
5 changes: 3 additions & 2 deletions streamable/visitors/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> Iterator[T]:
return functions.catch(
stream.upstream.accept(self),
stream._kind,
stream._when,
stream._replacement,
*stream._others,
when=stream._when,
replacement=stream._replacement,
finally_raise=stream._finally_raise,
)

Expand Down
11 changes: 5 additions & 6 deletions streamable/visitors/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ def __init__(self) -> None:
def to_string(o: object) -> str: ...

def visit_catch_stream(self, stream: CatchStream[T]) -> str:
replacement = (
f", replacement={self.to_string(stream._replacement)}"
if stream._replacement is not NO_REPLACEMENT
else ""
)
replacement = ""
if stream._replacement is not NO_REPLACEMENT:
replacement = f", replacement={self.to_string(stream._replacement)}"

self.methods_reprs.append(
f"catch({self.to_string(stream._kind)}, when={self.to_string(stream._when)}{replacement}, finally_raise={self.to_string(stream._finally_raise)})"
f"catch({', '.join(map(StrVisitor.to_string, (stream._kind, *stream._others)))}, when={self.to_string(stream._when)}{replacement}, finally_raise={self.to_string(stream._finally_raise)})"
)
return stream.upstream.accept(self)

Expand Down
41 changes: 34 additions & 7 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ class CustomCallable:
.flatten(concurrency=4)
.throttle(64, interval=datetime.timedelta(seconds=1))
.observe("foos")
.catch(TypeError, finally_raise=True)
.catch(finally_raise=True)
.catch(None, finally_raise=True)
.catch(TypeError, ValueError, ZeroDivisionError)
.catch(TypeError, replacement=None, finally_raise=True)
)

Expand Down Expand Up @@ -279,8 +281,10 @@ class CustomCallable:
.flatten(concurrency=4)
.throttle(per_second=64, per_minute=inf, per_hour=inf, interval=datetime.timedelta(seconds=1))
.observe('foos')
.catch(TypeError, when=bool, finally_raise=True)
.catch(TypeError, when=bool, replacement=None, finally_raise=True)
.catch(Exception, when=None, finally_raise=True)
.catch(None, when=None, finally_raise=True)
.catch(TypeError, ValueError, ZeroDivisionError, when=None, finally_raise=False)
.catch(TypeError, when=None, replacement=None, finally_raise=True)
)""",
msg="`repr` should work as expected on a stream with many operation",
)
Expand Down Expand Up @@ -542,7 +546,7 @@ def side_effect(x: int, func: Callable[[int], int]):
throw_for_odd_func_,
]
for raised_exc, catched_exc in [
(TestError, TestError),
(TestError, (TestError,)),
(StopIteration, (WrappedError, RuntimeError)),
]
for concurrency in [1, 2]
Expand All @@ -556,11 +560,11 @@ def side_effect(x: int, func: Callable[[int], int]):
def test_map_or_foreach_with_exception(
self,
raised_exc: Type[Exception],
catched_exc: Type[Exception],
catched_exc: Tuple[Type[Exception], ...],
concurrency: int,
method: Callable[[Stream, Callable[[Any], int], int], Stream],
throw_func: Callable[[Exception], Callable[[Any], int]],
throw_for_odd_func: Callable[[Exception], Callable[[Any], int]],
throw_for_odd_func: Callable[[Type[Exception]], Callable[[Any], int]],
) -> None:
with self.assertRaises(
catched_exc,
Expand All @@ -570,7 +574,7 @@ def test_map_or_foreach_with_exception(

self.assertListEqual(
list(
method(Stream(src), throw_for_odd_func(raised_exc), concurrency).catch(catched_exc) # type: ignore
method(Stream(src), throw_for_odd_func(raised_exc), concurrency).catch(*catched_exc)
),
list(even_src),
msg="At any concurrency, `map` and `foreach` and `amap` must not stop after one exception occured.",
Expand Down Expand Up @@ -1464,6 +1468,29 @@ def f(i):
[None, 1, 0.5, 0.25],
msg="`catch` should be able to yield a None replacement",
)
self.assertListEqual(
list(
Stream(
map(
lambda n: 1 / n, # potential ZeroDivisionError
map(
throw_for_odd_func(TestError), # potential TestError
map(
int, # potential ValueError
"01234foo56789",
),
),
)
).catch(ValueError, TestError, ZeroDivisionError)
),
list(map(lambda n: 1 / n, range(2, 10, 2))),
msg="`catch` should accept multiple types",
)
with self.assertRaises(
ZeroDivisionError,
msg="`catch` must catch nothing when `kind` is None",
):
list(Stream(map(lambda n: 1 / n, src)).catch(None))

def test_observe(self) -> None:
value_error_rainsing_stream: Stream[List[int]] = (
Expand Down

0 comments on commit d63305d

Please sign in to comment.