Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Inline constants into instructions
  • Loading branch information
saulshanabrook committed Jul 4, 2022
commit 957770759ca5bf1f867bff677b6f133e9b35627d
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ jobs:
with:
python-version: ${{ matrix.py }}
- uses: actions/checkout@v2
# TODO: #48 Only install test requirements instead of all pinned in CI
- run: pip install -e . -r requirements.test.txt
# TODO: Enable dev mode for Python
# https://docs.python.org/3/library/devmode.html#devmode
- run: pytest -v code_data
- name: Commit and push failing example
if: failure()
Expand Down
65 changes: 47 additions & 18 deletions code_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class CodeData(DataclassHideDefault):
# Mapping of index in the names list to the name
additional_names: dict[int, str] = field(default_factory=dict)

# Additional constants to include, which do not appear in any instructions,
# Mapping of index in the names list to the name
additional_constants: dict[int, ConstantDataType] = field(default_factory=dict)

# number of arguments (not including keyword only arguments, * or ** args)
argcount: int = field(default=0)

Expand All @@ -70,9 +74,6 @@ class CodeData(DataclassHideDefault):
# code flags
flags: FlagsData = field(default_factory=set)

# All code objects are recursively transformed to CodeData objects
consts: Tuple["ConstantDataType", ...] = field(default=(None,))

# tuple of names of arguments and local variables
varnames: Tuple[str, ...] = field(default=tuple())

Expand Down Expand Up @@ -104,22 +105,24 @@ def to_code_data(code: CodeType) -> CodeData:

line_mapping = to_line_mapping(code)
first_line_number_override = line_mapping.set_first_line(code.co_firstlineno)

constants = tuple(map(to_code_constant, code.co_consts))
# retrieve the blocks and pop off used line mapping
blocks, additional_names = bytes_to_blocks(
code.co_code, line_mapping, code.co_names
blocks, additional_names, additional_constants = bytes_to_blocks(
code.co_code, line_mapping, code.co_names, constants
)
return CodeData(
blocks,
line_mapping,
first_line_number_override,
additional_names,
additional_constants,
code.co_argcount,
posonlyargcount,
code.co_kwonlyargcount,
code.co_nlocals,
code.co_stacksize,
to_flags_data(code.co_flags),
tuple(map(to_code_constant, code.co_consts)),
code.co_varnames,
code.co_filename,
code.co_name,
Expand All @@ -134,12 +137,13 @@ def from_code_data(code_data: CodeData) -> CodeType:

:rtype: types.CodeType
"""
consts = tuple(map(from_code_constant, code_data.consts))
flags = from_flags_data(code_data.flags)
code, line_mapping, names = blocks_to_bytes(
code_data.blocks, code_data.additional_names
code, line_mapping, names, constants = blocks_to_bytes(
code_data.blocks, code_data.additional_names, code_data.additional_constants
)

consts = tuple(map(from_code_constant, constants))

line_mapping.update(code_data.additional_line_mapping)
first_line_no = line_mapping.trim_first_line(code_data.first_line_number_override)

Expand Down Expand Up @@ -185,15 +189,12 @@ def from_code_data(code_data: CodeData) -> CodeType:
)


# We need to wrap the data structures in dataclasses to be able to represent
# them with MyPy, since it doesn't support recursive types
# https://github.com/python/mypy/issues/731
ConstantDataType = Union[
int,
"ConstantInt",
str,
float,
"ConstantFloat",
None,
bool,
"ConstantBool",
bytes,
"EllipsisType",
CodeData,
Expand All @@ -206,10 +207,14 @@ def from_code_data(code_data: CodeData) -> CodeType:
def to_code_constant(value: object) -> ConstantDataType:
if isinstance(value, CodeType):
return to_code_data(value)
if isinstance(
value, (int, str, float, type(None), bool, bytes, type(...), complex)
):
if isinstance(value, (str, type(None), bytes, type(...), complex)):
return value
if isinstance(value, bool):
return ConstantBool(value)
if isinstance(value, int):
return ConstantInt(value)
if isinstance(value, float):
return ConstantFloat(value)
if isinstance(value, tuple):
return ConstantTuple(tuple(map(to_code_constant, value)))
if isinstance(value, frozenset):
Expand All @@ -224,9 +229,33 @@ def from_code_constant(value: ConstantDataType) -> object:
return tuple(map(from_code_constant, value.tuple))
if isinstance(value, ConstantSet):
return frozenset(map(from_code_constant, value.frozenset))
if isinstance(value, (ConstantBool, ConstantInt, ConstantFloat)):
return value.value
return value


# Wrap these in types, so that, say, bytecode with constants of 1
# are not equal to bytecodes of constants of True.


@dataclass(frozen=True)
class ConstantBool:
value: bool = field(metadata={"positional": True})


@dataclass(frozen=True)
class ConstantInt:
value: int = field(metadata={"positional": True})


@dataclass(frozen=True)
class ConstantFloat:
value: float = field(metadata={"positional": True})


# We need to wrap the data structures in dataclasses to be able to represent
# them with MyPy, since it doesn't support recursive types
# https://github.com/python/mypy/issues/731
@dataclass(frozen=True)
class ConstantTuple(DataclassHideDefault):
tuple: Tuple[ConstantDataType, ...] = field(metadata={"positional": True})
Expand Down
114 changes: 101 additions & 13 deletions code_data/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,22 @@
import dis
import sys
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union

from code_data.line_mapping import LineMapping

from .dataclass_hide_default import DataclassHideDefault

if TYPE_CHECKING:
from . import ConstantDataType


def bytes_to_blocks(
b: bytes, line_mapping: LineMapping, names: tuple[str, ...]
) -> tuple[Blocks, dict[int, str]]:
b: bytes,
line_mapping: LineMapping,
names: tuple[str, ...],
constants: tuple[ConstantDataType, ...],
) -> tuple[Blocks, dict[int, str], dict[int, ConstantDataType]]:
"""
Parse a sequence of bytes as a sequence of blocks of instructions.
"""
Expand All @@ -33,11 +39,17 @@ def bytes_to_blocks(
# For recording what names we have found to understand the order of the names
found_names: list[str] = []

# For recording what constants we have found to understand the order of the
# constants
found_constants: list[ConstantDataType] = []

for opcode, arg, n_args, offset, next_offset in _parse_bytes(b):

# Compute the jump targets, initially with just the byte offset
# Once we know all the block targets, we will transform to be block offsets
processed_arg = to_arg(opcode, arg, next_offset, names, found_names)
processed_arg = to_arg(
opcode, arg, next_offset, names, found_names, constants, found_constants
)
if isinstance(processed_arg, Jump):
targets_set.add(processed_arg.target)
# Store the number of args if this is a jump instruction
Expand Down Expand Up @@ -81,12 +93,23 @@ def bytes_to_blocks(
additional_names = {
i: name for i, name in enumerate(names) if name not in found_names
}
return {i: block for i, block in enumerate(blocks)}, additional_names
additional_constants = {
i: constant
for i, constant in enumerate(constants)
if constant not in found_constants
}
return (
{i: block for i, block in enumerate(blocks)},
additional_names,
additional_constants,
)


def blocks_to_bytes(
blocks: Blocks, additional_names: dict[int, str]
) -> Tuple[bytes, LineMapping, tuple[str, ...]]:
blocks: Blocks,
additional_names: dict[int, str],
additional_consts: dict[int, ConstantDataType],
) -> Tuple[bytes, LineMapping, tuple[str, ...], tuple[ConstantDataType, ...]]:
# First compute mapping from block to offset
changed_instruction_lengths = True
# So that we know the bytecode offsets for jumps when iterating though instructions
Expand All @@ -100,6 +123,11 @@ def blocks_to_bytes(
# Mapping of name index to final name positions
name_final_positions: dict[int, int] = {}

# List of constants we have collected from the instructions
constants: list[ConstantDataType] = []
# Mapping of constant index to constant name positions
constant_final_positions: dict[int, int] = {}

# Iterate through all blocks and change jump instructions to offsets
while changed_instruction_lengths:

Expand All @@ -111,7 +139,13 @@ def blocks_to_bytes(
if (block_index, instruction_index) in args:
arg_value = args[block_index, instruction_index]
else:
arg_value = from_arg(instruction.arg, names, name_final_positions)
arg_value = from_arg(
instruction.arg,
names,
name_final_positions,
constants,
constant_final_positions,
)
args[block_index, instruction_index] = arg_value
n_instructions = instruction.n_args_override or _instrsize(arg_value)
current_instruction_offset += n_instructions
Expand Down Expand Up @@ -156,6 +190,17 @@ def blocks_to_bytes(
names[i] for _, i in sorted(name_final_positions.items(), key=lambda x: x[0])
]

# Add additional consts to the constants and add final positions
for i, constant in additional_consts.items():
constants.append(constant)
constant_final_positions[i] = len(constants) - 1

# Sort positions by final position
constants = [
constants[i]
for _, i in sorted(constant_final_positions.items(), key=lambda x: x[0])
]

# Finally go assemble the bytes and the line mapping
bytes_: list[int] = []
line_mapping = LineMapping()
Expand All @@ -180,7 +225,7 @@ def blocks_to_bytes(
)
bytes_.append((arg_value >> (8 * i)) & 0xFF)

return bytes(bytes_), line_mapping, tuple(names)
return bytes(bytes_), line_mapping, tuple(names), tuple(constants)


def to_arg(
Expand All @@ -189,6 +234,8 @@ def to_arg(
next_offset: int,
names: tuple[str, ...],
found_names: list[str],
consts: tuple[ConstantDataType, ...],
found_constants: list[ConstantDataType],
) -> Arg:
if opcode in dis.hasjabs:
return Jump((2 if _ATLEAST_310 else 1) * arg, False)
Expand All @@ -201,10 +248,22 @@ def to_arg(
found_names.append(name)
wrong_position = found_names.index(name) != arg
return Name(name, arg if wrong_position else None)
elif opcode in dis.hasconst:
constant = consts[arg]
if constant not in found_constants:
found_constants.append(constant)
wrong_position = found_constants.index(constant) != arg
return Constant(constant, arg if wrong_position else None)
return arg


def from_arg(arg: Arg, names: list[str], name_final_positions: dict[int, int]) -> int:
def from_arg(
arg: Arg,
names: list[str],
name_final_positions: dict[int, int],
constants: list[ConstantDataType],
constants_final_positions: dict[int, int],
) -> int:
# Use 1 as the arg_value, which will be update later
if isinstance(arg, Jump):
return 1
Expand All @@ -215,6 +274,13 @@ def from_arg(arg: Arg, names: list[str], name_final_positions: dict[int, int]) -
final_index = index if arg.index_override is None else arg.index_override
name_final_positions[final_index] = index
return final_index
if isinstance(arg, Constant):
if arg.value not in constants:
constants.append(arg.value)
index = constants.index(arg.value)
final_index = index if arg.index_override is None else arg.index_override
constants_final_positions[final_index] = index
return final_index
return arg


Expand Down Expand Up @@ -277,16 +343,25 @@ class Name(DataclassHideDefault):
index_override: Optional[int] = field(default=None)


@dataclass
class Constant(DataclassHideDefault):
"""
A constant argument.
"""

value: ConstantDataType = field(metadata={"positional": True})
# Optional override for the position if it is not ordered by occurance in the code.
index_override: Optional[int] = field(default=None)


# TODO: Add:
# 1. constant lookup
# 2. a name lookup
# 3. a local lookup
# 5. An unused value
# 6. Comparison lookup
# 7. format value
# 8. Generator kind

Arg = Union[int, Jump, Name]
Arg = Union[int, Jump, Name, Constant]


# dict mapping block offset to list of instructions in the block
Expand Down Expand Up @@ -344,3 +419,16 @@ def _instrsize(arg: int) -> int:
_c_int_upper_limit = (2 ** (_c_int_bit_size - 1)) - 1
# The number of values that can be stored in a signed int
_c_int_length = 2**_c_int_bit_size


# _PyCode_ConstantKey = ctypes.pythonapi._PyCode_ConstantKey
# _PyCode_ConstantKey.restype = ctypes.py_object


# def code_constant_key(value: object) -> object:
# """
# Transforms a value with the _ConstantKey function used in the Code type
# to compare equality of constants. It transforms objects so that the constant `1`
# is not equal to `True` for example, by adding their types.
# """
# return _PyCode_ConstantKey(ctypes.py_object(value))
Loading