Coverage for sources/appcore/asyncf.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-26 19:13 +0000

1# vim: set filetype=python fileencoding=utf-8: 

2# -*- coding: utf-8 -*- 

3 

4#============================================================================# 

5# # 

6# Licensed under the Apache License, Version 2.0 (the "License"); # 

7# you may not use this file except in compliance with the License. # 

8# You may obtain a copy of the License at # 

9# # 

10# http://www.apache.org/licenses/LICENSE-2.0 # 

11# # 

12# Unless required by applicable law or agreed to in writing, software # 

13# distributed under the License is distributed on an "AS IS" BASIS, # 

14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # 

15# See the License for the specific language governing permissions and # 

16# limitations under the License. # 

17# # 

18#============================================================================# 

19 

20 

21''' Helper functions for async execution. ''' 

22 

23 

24from . import __ 

25from . import exceptions as _exceptions 

26from . import generics as _generics 

27 

28 

29@__.typx.overload 

30async def gather_async( # pragma: no cover 

31 *operands: __.typx.Any, 

32 return_exceptions: __.typx.Literal[ True ], 

33 error_message: str = 'Failure of async operations.', 

34 ignore_nonawaitables: bool = False, 

35) -> tuple[ _generics.GenericResult, ... ]: ... 

36 

37 

38@__.typx.overload 

39async def gather_async( # noqa: F811 # pragma: no cover 

40 *operands: __.typx.Any, 

41 return_exceptions: __.typx.Literal[ False ] = False, 

42 error_message: str = 'Failure of async operations.', 

43 ignore_nonawaitables: bool = False, 

44) -> tuple[ __.typx.Any, ... ]: ... 

45 

46 

47async def gather_async( # noqa: F811 

48 *operands: __.typx.Any, 

49 return_exceptions: __.typx.Annotated[ 

50 bool, 

51 __.ddoc.Doc( ''' Raw or wrapped results. Wrapped, if true. ''' ) 

52 ] = False, 

53 error_message: str = 'Failure of async operations.', 

54 ignore_nonawaitables: __.typx.Annotated[ 

55 bool, 

56 __.ddoc.Doc( 

57 ''' Ignore or error on non-awaitables. Ignore, if true. ''' ) 

58 ] = False, 

59) -> tuple[ __.typx.Any, ... ]: 

60 ''' Gathers results from invocables concurrently and asynchronously. ''' 

61 from exceptiongroup import ExceptionGroup # TODO: Python 3.11: builtin 

62 if ignore_nonawaitables: 

63 results = await _gather_async_permissive( operands ) 

64 else: 

65 results = await _gather_async_strict( operands ) 

66 if return_exceptions: return tuple( results ) 

67 errors = tuple( result.error for result in results if result.is_error( ) ) 

68 if errors: raise ExceptionGroup( error_message, errors ) 

69 return tuple( result.extract( ) for result in results ) 

70 

71 

72async def intercept_error_async( 

73 awaitable: __.cabc.Awaitable[ __.typx.Any ] 

74) -> _generics.Result[ object, Exception ]: 

75 ''' Converts unwinding exceptions to error results. 

76 

77 Exceptions, which are not instances of :py:exc:`Exception` or one of 

78 its subclasses, are allowed to propagate. In particular, 

79 :py:exc:`KeyboardInterrupt` and :py:exc:`SystemExit` must be allowed 

80 to propagate to be consistent with :py:class:`asyncio.TaskGroup` 

81 behavior. 

82 

83 Helpful when working with :py:func:`asyncio.gather`, for example, 

84 because exceptions can be distinguished from computed values 

85 and collected together into an exception group. 

86 

87 In general, it is a bad idea to swallow exceptions. In this case, 

88 the intent is to add them into an exception group for continued 

89 propagation. 

90 ''' 

91 try: return _generics.Value( await awaitable ) 

92 except Exception as exc: 

93 return _generics.Error( exc ) 

94 

95 

96async def _gather_async_permissive( 

97 operands: __.cabc.Sequence[ __.typx.Any ] 

98) -> __.cabc.Sequence[ __.typx.Any ]: 

99 from asyncio import gather # TODO? Python 3.11: TaskGroup 

100 awaitables: dict[ int, __.cabc.Awaitable[ __.typx.Any ] ] = { } 

101 results: list[ _generics.GenericResult ] = [ ] 

102 for i, operand in enumerate( operands ): 

103 if isinstance( operand, __.cabc.Awaitable ): 

104 awaitables[ i ] = ( 

105 intercept_error_async( __.typx.cast( 

106 __.cabc.Awaitable[ __.typx.Any ], operand ) ) ) 

107 results.append( _generics.Value( None ) ) 

108 else: results.append( _generics.Value( operand ) ) 

109 results_ = await gather( *awaitables.values( ) ) 

110 for i, result in zip( awaitables.keys( ), results_ ): 

111 results[ i ] = result 

112 return results 

113 

114 

115async def _gather_async_strict( 

116 operands: __.cabc.Sequence[ __.typx.Any ] 

117) -> __.cabc.Sequence[ __.typx.Any ]: 

118 from asyncio import gather # TODO? Python 3.11: TaskGroup 

119 from inspect import isawaitable, iscoroutine 

120 awaitables: list[ __.cabc.Awaitable[ __.typx.Any ] ] = [ ] 

121 for operand in operands: # Sanity check. 

122 if isawaitable( operand ): continue 

123 for operand_ in operands: # Cleanup. 

124 if iscoroutine( operand_ ): operand_.close( ) 

125 raise _exceptions.AsyncAssertionFailure( operand ) 

126 for operand in operands: 

127 awaitables.append( intercept_error_async( __.typx.cast( # noqa: PERF401 

128 __.cabc.Awaitable[ __.typx.Any ], operand ) ) ) 

129 return await gather( *awaitables ) 

130 

131 

132if __.typx.TYPE_CHECKING: # pragma: no cover 

133 async def _type_check_canary( ) -> None: 

134 ''' Canary function to verify overload type checking works correctly. 

135 

136 This function is never executed but helps ensure that Pyright 

137 correctly understands our gather_async overloads. 

138 ''' 

139 async def dummy_operation( ) -> str: return "test" 

140 

141 operations = ( dummy_operation( ), dummy_operation( ) ) 

142 results = await gather_async( *operations, return_exceptions = True ) 

143 for result in results: 

144 match result: 

145 case _generics.Value( value ): _ = value 

146 case _generics.Error( error ): _ = error 

147 case _: pass 

148 for result in results: 

149 if _generics.is_error( result ): _ = result.error 

150 else: _ = result.extract( ) 

151 values = await gather_async( *operations, return_exceptions = False ) 

152 for value in values: _ = str( value )