Python AST rewriting:�‘how does PyTest do that’
By Disconnect3d
import pytest��@pytest.mark.parametrize('string', ('a', '1', 'ß'))�def test_lower_upper(string):� assert string == string.upper().lower()
But what is an Abstract Syntax Tree?
Image source: https://www.codeproject.com/Articles/28299/Generate-AST-for-the-DLR
AST
The code:
r = r * (n-1)
Image source: https://www.codeproject.com/Articles/28299/Generate-AST-for-the-DLR
AST
The code:
r = r * (n-1)
NODES
So the whole idea is to change particular �nodes into different nodes
We can almost do it using ast module
import ast��class AssertionRewriter(ast.NodeVisitor):� def visit_Assert(self, assert_):� print(ast.dump(assert_))��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()�x.visit(nodes) |
Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)
assert 1==2
Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)
assert 1==2
Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)
assert 1==2
Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)
assert 1==2
Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)
assert 1==2
Btw no idea why it is a list
Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)
assert 1==2
What about rewriting?
We need to convert nodes -> Python [byte]code
Obviously it can be done using builtins
But I have no idea how
So let’s use an external module�astunparse
import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):� def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))� |
import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):� def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))� |
import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):� def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))� |
import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):� def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))� |
import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):� def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))� |
import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):� def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))� |
import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):� def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))� |
"\nprint('Assert was here')\n"
The end.
Hope you enjoyed!