1 of 28

Python AST rewriting:�‘how does PyTest do that’

By Disconnect3d

2 of 28

import pytest��@pytest.mark.parametrize('string', ('a', '1', 'ß'))def test_lower_upper(string):assert string == string.upper().lower()

3 of 28

4 of 28

But what is an Abstract Syntax Tree?

5 of 28

Image source: https://www.codeproject.com/Articles/28299/Generate-AST-for-the-DLR

AST

The code:

r = r * (n-1)

6 of 28

Image source: https://www.codeproject.com/Articles/28299/Generate-AST-for-the-DLR

AST

The code:

r = r * (n-1)

NODES

7 of 28

So the whole idea is to change particular �nodes into different nodes

8 of 28

We can almost do it using ast module

9 of 28

import ast��class AssertionRewriter(ast.NodeVisitor):def visit_Assert(self, assert_):print(ast.dump(assert_))��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()�x.visit(nodes)

10 of 28

Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)

assert 1==2

11 of 28

Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)

assert 1==2

12 of 28

Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)

assert 1==2

13 of 28

Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)

assert 1==2

14 of 28

Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)

assert 1==2

Btw no idea why it is a list

15 of 28

Assert(� test=Compare(� left=Num(n=1),� ops=[Eq()],� comparators=[Num(n=2)]� ),� msg=None�)

assert 1==2

16 of 28

What about rewriting?

17 of 28

We need to convert nodes -> Python [byte]code

18 of 28

Obviously it can be done using builtins

19 of 28

But I have no idea how

20 of 28

So let’s use an external module�astunparse

21 of 28

import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))

22 of 28

import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))

23 of 28

import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))

24 of 28

import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))

25 of 28

import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))

26 of 28

import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))

27 of 28

import ast, astunparse��class AssertionRewriter(ast.NodeTransformer):def visit_Assert(self, assert_):� change_to = "print('Assert was here')"� print_call = ast.parse(change_to).body[0]�� statements = [print_call]�� return statements��nodes = ast.parse('assert 1==2')��x = AssertionRewriter()��transformed = x.visit(nodes)��print(astunparse.unparse(transformed))

"\nprint('Assert was here')\n"

28 of 28

The end.

Hope you enjoyed!