11from __future__ import annotations
22
3- from typing import TYPE_CHECKING
3+ import asyncio
4+ import functools
5+ import inspect
6+ from typing import TYPE_CHECKING , Any , cast
47from typing_extensions import NamedTuple , ParamSpec , TypeVar , overload
58
69if TYPE_CHECKING :
7- from collections .abc import Callable
10+ from collections .abc import AsyncGenerator , Callable , Coroutine
811
912
1013_T = TypeVar ("_T" )
@@ -17,6 +20,61 @@ class Reducer(NamedTuple):
1720 reducer : Callable [[object , object ], object ]
1821
1922
23+ def _check_loop () -> None :
24+ try :
25+ loop = asyncio .get_running_loop ()
26+ except RuntimeError :
27+ return
28+
29+ from duron .loop import EventLoop # noqa: PLC0415
30+
31+ if isinstance (loop , EventLoop ):
32+ msg = (
33+ "Effects cannot be called from within a duron EventLoop. "
34+ "Use 'ctx.run()' to execute effects."
35+ )
36+ raise RuntimeError (msg ) # noqa: TRY004
37+
38+
39+ def _wrap_effect (fn : Callable [_P , Coroutine [Any , Any , _T ] | _T ]) -> Callable [_P , Any ]:
40+ if inspect .iscoroutinefunction (fn ):
41+
42+ @functools .wraps (fn )
43+ async def async_wrapper (* args : _P .args , ** kwargs : _P .kwargs ) -> _T :
44+ _check_loop ()
45+ return cast ("_T" , await fn (* args , ** kwargs ))
46+
47+ return async_wrapper
48+ if inspect .isasyncgenfunction (fn ):
49+
50+ @functools .wraps (fn )
51+ async def async_gen_wrapper (
52+ * args : _P .args , ** kwargs : _P .kwargs
53+ ) -> AsyncGenerator [Any , Any ]:
54+ _check_loop ()
55+ gen = fn (* args , ** kwargs )
56+ try :
57+ value = await anext (gen )
58+ while True :
59+ try :
60+ sent = yield value
61+ value = await gen .asend (sent )
62+ except GeneratorExit : # noqa: PERF203
63+ await gen .aclose ()
64+ raise
65+ except StopAsyncIteration :
66+ return
67+
68+ return async_gen_wrapper
69+
70+ @functools .wraps (fn )
71+ def sync_wrapper (* args : _P .args , ** kwargs : _P .kwargs ) -> _T :
72+ _check_loop ()
73+ return cast ("_T" , fn (* args , ** kwargs ))
74+
75+ return sync_wrapper
76+
77+
2078@overload
2179def effect (fn : Callable [_P , _T ], / ) -> Callable [_P , _T ]: ...
2280@overload
@@ -50,9 +108,5 @@ async def counter(
50108
51109 """
52110 if fn is not None :
53- return fn
54-
55- def decorate (fn : Callable [_P , _T ]) -> Callable [_P , _T ]:
56- return fn
57-
58- return decorate
111+ return _wrap_effect (fn )
112+ return _wrap_effect
0 commit comments