Skip to content

Commit

Permalink
.catch: accept multiple exception types (closes #58) (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal authored Jan 26, 2025
1 parent 8beb852 commit 7a1760a
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 34 deletions.
7 changes: 4 additions & 3 deletions streamable/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,18 @@
def catch(
iterator: Iterator[T],
kind: Type[Exception] = Exception,
when: Callable[[Exception], Any] = bool,
*others: Type[Exception],
when: Optional[Callable[[Exception], Any]] = None,
replacement: T = NO_REPLACEMENT, # type: ignore
finally_raise: bool = False,
) -> Iterator[T]:
validate_iterator(iterator)
return CatchIterator(
iterator,
kind,
(kind, *others),
when,
replacement,
finally_raise=finally_raise,
finally_raise,
)


Expand Down
12 changes: 6 additions & 6 deletions streamable/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ class CatchIterator(Iterator[T]):
def __init__(
self,
iterator: Iterator[T],
kind: Type[Exception],
when: Callable[[Exception], Any],
kinds: Tuple[Type[Exception], ...],
when: Optional[Callable[[Exception], Any]],
replacement: T,
finally_raise: bool,
) -> None:
validate_iterator(iterator)
self.iterator = iterator
self.kind = kind
self.when = wrap_error(when, StopIteration)
self.kinds = kinds
self.when = wrap_error(when, StopIteration) if when else None
self.replacement = replacement
self.finally_raise = finally_raise
self._to_be_finally_raised: Optional[Exception] = None
Expand All @@ -85,8 +85,8 @@ def __next__(self) -> T:
self._to_be_finally_raised = None
raise exception
raise
except Exception as exception:
if isinstance(exception, self.kind) and self.when(exception):
except self.kinds as exception:
if not self.when or self.when(exception):
if self._to_be_finally_raised is None:
self._to_be_finally_raised = exception
if self.replacement is not NO_REPLACEMENT:
Expand Down
23 changes: 17 additions & 6 deletions streamable/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,23 +115,32 @@ def accept(self, visitor: "Visitor[V]") -> V:
def catch(
self,
kind: Type[Exception] = Exception,
when: Callable[[Exception], Any] = bool,
*others: Type[Exception],
when: Optional[Callable[[Exception], Any]] = None,
replacement: T = NO_REPLACEMENT, # type: ignore
finally_raise: bool = False,
) -> "Stream[T]":
"""
Catches the upstream exceptions if they are instances of `kind` and they satisfy the `when` predicate.
Catches the upstream exceptions if they are instances of `kind` (or `others`) and they satisfy the `when` predicate.
Args:
kind (Type[Exception], optional): The type of exceptions to catch. (default: catches base Exception)
when (Callable[[Exception], Any], optional): An additional condition that must be satisfied to catch the exception, i.e. `when(exception)` must be truthy. (default: no additional condition)
kind (Type[Exception], optional): The type of exceptions to catch. (default: catches `Exception`)
*others (Type[Exception], optional): Additional types of exceptions to catch.
when (Optional[Callable[[Exception], Any]], optional): An additional condition that must be satisfied to catch the exception, i.e. `when(exception)` must be truthy. (default: no additional condition)
replacement (T, optional): The value to yield when an exception is catched. (default: do not yield any replacement value)
finally_raise (bool, optional): If True the first catched exception is raised when upstream's iteration ends. (default: iteration ends without raising)
Returns:
Stream[T]: A stream of upstream elements catching the eligible exceptions.
"""
return CatchStream(self, kind, when, replacement, finally_raise)
return CatchStream(
self,
kind,
*others,
when=when,
replacement=replacement,
finally_raise=finally_raise,
)

def count(self) -> int:
"""
Expand Down Expand Up @@ -540,12 +549,14 @@ def __init__(
self,
upstream: Stream[T],
kind: Type[Exception],
when: Callable[[Exception], Any],
*others: Type[Exception],
when: Optional[Callable[[Exception], Any]],
replacement: T,
finally_raise: bool,
) -> None:
super().__init__(upstream)
self._kind = kind
self._others = others
self._when = when
self._replacement = replacement
self._finally_raise = finally_raise
Expand Down
5 changes: 3 additions & 2 deletions streamable/visitors/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> Iterator[T]:
return functions.catch(
stream.upstream.accept(self),
stream._kind,
stream._when,
stream._replacement,
*stream._others,
when=stream._when,
replacement=stream._replacement,
finally_raise=stream._finally_raise,
)

Expand Down
17 changes: 7 additions & 10 deletions streamable/visitors/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def __init__(self) -> None:
def to_string(o: object) -> str: ...

def visit_catch_stream(self, stream: CatchStream[T]) -> str:
replacement = (
f", replacement={self.to_string(stream._replacement)}"
if stream._replacement is not NO_REPLACEMENT
else ""
)
replacement = ""
if stream._replacement is not NO_REPLACEMENT:
replacement = f", replacement={self.to_string(stream._replacement)}"

kinds = ', '.join(map(lambda err: getattr(err, "__name__", self.to_string(err)), (stream._kind, *stream._others)))
self.methods_reprs.append(
f"catch({self.to_string(stream._kind)}, when={self.to_string(stream._when)}{replacement}, finally_raise={self.to_string(stream._finally_raise)})"
f"catch({kinds}, when={self.to_string(stream._when)}{replacement}, finally_raise={self.to_string(stream._finally_raise)})"
)
return stream.upstream.accept(self)

Expand Down Expand Up @@ -142,8 +142,5 @@ def to_string(o: object) -> str:
if isinstance(o, _Star):
return f"star({StrVisitor.to_string(o.func)})"
if repr(o).startswith("<"):
try:
return getattr(o, "__name__")
except AttributeError:
return f"{o.__class__.__name__}(...)"
return getattr(o, "__name__", f"{o.__class__.__name__}(...)")
return repr(o)
36 changes: 29 additions & 7 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ class CustomCallable:
.flatten(concurrency=4)
.throttle(64, interval=datetime.timedelta(seconds=1))
.observe("foos")
.catch(TypeError, finally_raise=True)
.catch(finally_raise=True)
.catch(TypeError, ValueError, ZeroDivisionError)
.catch(TypeError, replacement=None, finally_raise=True)
)

Expand Down Expand Up @@ -279,8 +280,9 @@ class CustomCallable:
.flatten(concurrency=4)
.throttle(per_second=64, per_minute=inf, per_hour=inf, interval=datetime.timedelta(seconds=1))
.observe('foos')
.catch(TypeError, when=bool, finally_raise=True)
.catch(TypeError, when=bool, replacement=None, finally_raise=True)
.catch(Exception, when=None, finally_raise=True)
.catch(TypeError, ValueError, ZeroDivisionError, when=None, finally_raise=False)
.catch(TypeError, when=None, replacement=None, finally_raise=True)
)""",
msg="`repr` should work as expected on a stream with many operation",
)
Expand Down Expand Up @@ -542,7 +544,7 @@ def side_effect(x: int, func: Callable[[int], int]):
throw_for_odd_func_,
]
for raised_exc, catched_exc in [
(TestError, TestError),
(TestError, (TestError,)),
(StopIteration, (WrappedError, RuntimeError)),
]
for concurrency in [1, 2]
Expand All @@ -556,11 +558,11 @@ def side_effect(x: int, func: Callable[[int], int]):
def test_map_or_foreach_with_exception(
self,
raised_exc: Type[Exception],
catched_exc: Type[Exception],
catched_exc: Tuple[Type[Exception], ...],
concurrency: int,
method: Callable[[Stream, Callable[[Any], int], int], Stream],
throw_func: Callable[[Exception], Callable[[Any], int]],
throw_for_odd_func: Callable[[Exception], Callable[[Any], int]],
throw_for_odd_func: Callable[[Type[Exception]], Callable[[Any], int]],
) -> None:
with self.assertRaises(
catched_exc,
Expand All @@ -570,7 +572,9 @@ def test_map_or_foreach_with_exception(

self.assertListEqual(
list(
method(Stream(src), throw_for_odd_func(raised_exc), concurrency).catch(catched_exc) # type: ignore
method(Stream(src), throw_for_odd_func(raised_exc), concurrency).catch(
*catched_exc
)
),
list(even_src),
msg="At any concurrency, `map` and `foreach` and `amap` must not stop after one exception occured.",
Expand Down Expand Up @@ -1464,6 +1468,24 @@ def f(i):
[None, 1, 0.5, 0.25],
msg="`catch` should be able to yield a None replacement",
)
self.assertListEqual(
list(
Stream(
map(
lambda n: 1 / n, # potential ZeroDivisionError
map(
throw_for_odd_func(TestError), # potential TestError
map(
int, # potential ValueError
"01234foo56789",
),
),
)
).catch(ValueError, TestError, ZeroDivisionError)
),
list(map(lambda n: 1 / n, range(2, 10, 2))),
msg="`catch` should accept multiple types",
)

def test_observe(self) -> None:
value_error_rainsing_stream: Stream[List[int]] = (
Expand Down

0 comments on commit 7a1760a

Please sign in to comment.