Skip to content

Commit

Permalink
Convert asyncio.run(fn()) to await fn() (#1175)
Browse files Browse the repository at this point in the history
* Convert `asyncio.run(fn())` to `await fn()`

* Format
  • Loading branch information
whitphx authored Oct 18, 2024
1 parent a602160 commit 48783f9
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
9 changes: 9 additions & 0 deletions packages/kernel/py/stlite-lib/stlite_lib/codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ class TransformRuleAction(Enum):
AWAIT_CALL = 1
TIME_SLEEP = 2
STREAMLIT_NAVIGATION_RUN = 3
ASYNCIO_RUN = 4


class TransformHandler:
Expand Down Expand Up @@ -801,6 +802,12 @@ def _handle_target_call(
keywords=[],
),
)
elif action == TransformRuleAction.ASYNCIO_RUN:
# asyncio.run(fn()) -> await fn()
self._set_await_added()
return ast.Await(
value=node.args[0],
)

def on_exit_code_block(self, node: CodeBlockNode) -> CodeBlockNode:
_insert_import_statement(node, self._get_required_imports_in_code_block())
Expand Down Expand Up @@ -905,6 +912,7 @@ def patch(code: str | ast.Module, script_path: str) -> ast.Module:
[
WildcardImportTarget(module="time", attr="sleep"),
WildcardImportTarget(module="streamlit", attr="write_stream"),
WildcardImportTarget(module="asyncio", attr="run"),
]
)

Expand All @@ -921,6 +929,7 @@ def patch(code: str | ast.Module, script_path: str) -> ast.Module:
obj=ReturnValue(called_function="streamlit.navigation"),
attr="run",
): TransformRuleAction.STREAMLIT_NAVIGATION_RUN,
FunctionCall(name="asyncio.run"): TransformRuleAction.ASYNCIO_RUN,
}
)
func_call_transformer = CodeBlockTransformer(
Expand Down
92 changes: 92 additions & 0 deletions packages/kernel/py/stlite-lib/stlite_lib_tests/codemod_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,3 +909,95 @@ def test_not_convert_page_run(test_input):
assert ast.dump(tree, indent=4) == ast.dump(
ast.parse(test_input, "test.py", "exec"), indent=4
)


@pytest.mark.parametrize(
"test_input,expected",
[
pytest.param(
"""
import asyncio
async def main():
return 42
asyncio.run(main())
""",
"""
import asyncio
async def main():
return 42
await main()
""",
id="asyncio_run_basic",
),
pytest.param(
"""
import asyncio
async def main():
return 42
awaitable = main()
asyncio.run(awaitable)
""",
"""
import asyncio
async def main():
return 42
awaitable = main()
await awaitable
""",
id="asyncio_run_non_function_call",
),
pytest.param(
"""
from asyncio import run
async def main():
return 42
run(main())
""",
"""
from asyncio import run
async def main():
return 42
await main()
""",
id="asyncio_run_from_import",
),
pytest.param(
"""
from asyncio import *
async def main():
return 42
run(main())
""",
"""
from asyncio import *
async def main():
return 42
await main()
""",
id="asyncio_run_wildcard_import",
),
],
)
def test_convert_asyncio_run(test_input, expected):
tree = patch(test_input, "test.py")
assert ast.dump(tree, indent=4) == ast.dump(
ast.parse(expected, "test.py", "exec"), indent=4
)

0 comments on commit 48783f9

Please sign in to comment.