diff --git a/.clang-format b/.clang-format index b549eb37..4a578de7 100644 --- a/.clang-format +++ b/.clang-format @@ -1,5 +1,4 @@ --- ---- Language: Cpp BasedOnStyle: Microsoft AccessModifierOffset: -4 diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index e09afbfe..3e434b2c 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -13,24 +13,12 @@ jobs: runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: '3.10' - - name: Install Development Dependencies - run: pip install -r requirements-dev.txt - - - name: Installing Component-specific Dependencies - run: pip install -r assembler_tools/hec-assembler-tools/requirements.txt - - - name: Install Apt Dependencies - run: sudo apt install -y clang-format-14 - - - name: Fetch main branch for diff - run: git fetch origin main - - - name: Run pre-commit on changed files only - run: pre-commit run --from-ref origin/main --to-ref HEAD + - name: Run pre-commit + uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da87ed46..4cc0edc9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 default_language_version: - # force all unspecified python hooks to run python3 + # force all unspecified python hooks to run python3.10 python: python3.10 repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -15,6 +15,7 @@ repos: - id: check-yaml args: - --allow-multiple-documents + - id: fix-byte-order-marker - repo: https://github.com/crate-ci/typos rev: v1.33.1 # Updated 2025/06 hooks: @@ -28,75 +29,45 @@ repos: name: insert-license-shell files: \.(sh|py)$ args: - - --license-filepath - # defaults to: LICENSE.txt - - HEADER + - --license-filepath=HEADER + - --use-current-year + - --allow-past-years + - --detect-license-in-X-top-lines=10 + - --fuzzy-ratio-cut-off=50 + - --remove-header - id: insert-license name: insert-license-cpp files: \.(c|cc|cxx|cpp|h|hpp|hxx|inl|h.in)$ args: - - --license-filepath - # defaults to: LICENSE.txt - - HEADER - - --comment-style - - // # defaults to: # + - --license-filepath=HEADER + - --comment-style=// + - --use-current-year + - --allow-past-years + - --detect-license-in-X-top-lines=10 + - --fuzzy-ratio-cut-off=50 + - --remove-header - id: remove-tabs name: remove-tabs files: \.(py)$ args: [--whitespaces-count, '4'] - - repo: https://github.com/psf/black-pre-commit-mirror - rev: 25.1.0 # Updated 2025/06 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.4 hooks: - - id: black - language_version: python3.10 + - id: ruff + args: [--fix] # Automatically fix issues when possible + - id: ruff-format # Replaces black - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.16.0 # Last checked 2025/06 hooks: - id: mypy - language: system - exclude: >- - ^(assembler_tools/hec-assembler-tools/assembler/common/run_config\.py| - *assembler_tools/hec-assembler-tools/assembler/instructions/| - *assembler_tools/hec-assembler-tools/assembler/memory_model/| - *assembler_tools/hec-assembler-tools/assembler/stages/asm_scheduler\.py| - *assembler_tools/hec-assembler-tools/assembler/stages/scheduler\.py| - *assembler_tools/hec-assembler-tools/debug_tools/main\.py| - *assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/| - *assembler_tools/hec-assembler-tools/he_as\.py| - *assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py) - args: ["--follow-imports=skip", "--install-types", "--non-interactive"] - - repo: local + pass_filenames: false + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: "v18.1.2" + hooks: + - id: clang-format + exclude_types: [json] # skip *.json and *.JSON + args: ["--style=file"] + - repo: https://github.com/cpplint/cpplint + rev: "2.0.2" hooks: - - id: pylint - name: pylint - entry: pylint - language: system - types: [python] - exclude: >- - ^(assembler_tools/hec-assembler-tools/assembler/common/run_config\.py| - *assembler_tools/hec-assembler-tools/assembler/instructions/| - *assembler_tools/hec-assembler-tools/assembler/memory_model/| - *assembler_tools/hec-assembler-tools/assembler/stages/asm_scheduler\.py| - *assembler_tools/hec-assembler-tools/assembler/stages/scheduler\.py| - *assembler_tools/hec-assembler-tools/debug_tools/main\.py| - *assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/| - *assembler_tools/hec-assembler-tools/he_as\.py| - *assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py) - args: - - -rn # Only display messages - - -sn # Don't display the score - - --source-roots=p-isa_tools/kerngen,assembler_tools/hec-assembler-tools - - id: clang-format-14 - name: clang-format-14 - entry: clang-format-14 - language: system - files: \.(c|cc|cxx|cpp|h|hpp|hxx|inl)$ - args: ["-i", "--style=file"] - id: cpplint - name: cpplint - entry: cpplint - language: system - files: \.(c|cc|cxx|cpp|h|hpp|hxx)$ - args: - - --recursive - - --filter=-build/c++17 diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index e518bd32..00000000 --- a/.pylintrc +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (C) 2024 Intel Corporation - -# Pylint config containing overrides -[BASIC] -good-names = iN, logN - -[MESSAGES CONTROL] -# W0511 are TODO -disable=W0511 - -# pydantic and pylint don't always play nice -# apparently due to libraries with compiled code -[MASTER] -extension-pkg-allow-list=pydantic - -[CLASSES] -# Minimum number of public methods for a class (see R0903). -min-public-methods = 0 - -[DESIGN] - -# Maximum number of attributes for a class (see R0902). -max-attributes=8 - -[FORMAT] -# `black` takes care of our line lengths, but just in case it gets ridiculous -max-line-length=230 diff --git a/.typos.toml b/.typos.toml index fee970e1..ff6fa075 100644 --- a/.typos.toml +++ b/.typos.toml @@ -8,3 +8,8 @@ # variation of params parms = "parms" bload = "bload" + +[files] +extend-exclude = [ + "requirements-dev.txt", +] diff --git a/HEADER b/HEADER index bd8b8ebc..e78f8972 100644 --- a/HEADER +++ b/HEADER @@ -1,2 +1,2 @@ -Copyright (C) 2025 Intel Corporation +Copyright (C) {year} Intel Corporation SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/assembler/common/__init__.py b/assembler_tools/hec-assembler-tools/assembler/common/__init__.py index 535ed76a..ea923e17 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/__init__.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/__init__.py @@ -1,5 +1,9 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import os + def makeUniquePath(path: str) -> str: """ Returns a unique, normalized, and absolute version of the given file path. diff --git a/assembler_tools/hec-assembler-tools/assembler/common/constants.py b/assembler_tools/hec-assembler-tools/assembler/common/constants.py index 999f84f3..40f10640 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/constants.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/constants.py @@ -1,10 +1,13 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .decorators import classproperty -from .decorators import * class Constants: """ Contains project level and global constants that won't fit logically into any other category. - + Attributes: KILOBYTE (int): Number of bytes in a kilobyte (2^10). MEGABYTE (int): Number of bytes in a megabyte (2^20). @@ -64,8 +67,8 @@ def REPLACEMENT_POLICY_LRU(cls) -> str: @classproperty def REPLACEMENT_POLICIES(cls) -> tuple: - """Tuple containing all replacement policy identifiers.""" - return ( cls.REPLACEMENT_POLICY_FTBU, cls.REPLACEMENT_POLICY_LRU ) + """Tuple containing all replacement policy identifiers.""" + return (cls.REPLACEMENT_POLICY_FTBU, cls.REPLACEMENT_POLICY_LRU) # Misc Constants # -------------- @@ -97,18 +100,31 @@ def TW_GRAMMAR_SEPARATOR(cls) -> str: @classproperty def OPERATIONS(cls) -> list: """List of high-level operations supported by the system.""" - return [ "add", "mul", "ntt", "intt", "relin", "mod_switch", "rotate", - "square", "add_plain", "add_corrected", "mul_plain", "rescale", - "boot_dot_prod", "boot_mod_drop_scale", "boot_mul_const", "boot_galois_plain" ] + return [ + "add", + "mul", + "ntt", + "intt", + "relin", + "mod_switch", + "rotate", + "square", + "add_plain", + "add_corrected", + "mul_plain", + "rescale", + "boot_dot_prod", + "boot_mod_drop_scale", + "boot_mul_const", + "boot_galois_plain", + ] @classmethod def hw_spec_as_dict(cls) -> dict: """ Returns hw configurable attributes as dictionary. """ - dict = {"bytes_per_xinstruction": cls.XINSTRUCTION_SIZE_BYTES, - "max_instructions_per_bundle": cls.MAX_BUNDLE_SIZE} - return dict + return {"bytes_per_xinstruction": cls.XINSTRUCTION_SIZE_BYTES, "max_instructions_per_bundle": cls.MAX_BUNDLE_SIZE} @classmethod def setMaxBundleSize(cls, val: int): @@ -120,7 +136,8 @@ def setXInstructionSizeBytes(cls, val: int): """Updates size of single XInstruction""" cls.__XINSTRUCTION_SIZE_BYTES = val -def convertBytes2Words(bytes: int) -> int: + +def convertBytes2Words(bytes_in: int) -> int: """ Converts a size in bytes to the equivalent number of words. @@ -130,7 +147,8 @@ def convertBytes2Words(bytes: int) -> int: Returns: int: The equivalent size in words. """ - return int(bytes / Constants.WORD_SIZE) + return int(bytes_in / Constants.WORD_SIZE) + def convertWords2Bytes(words: int) -> int: """ @@ -144,6 +162,7 @@ def convertWords2Bytes(words: int) -> int: """ return words * Constants.WORD_SIZE + class MemInfo: """ Constants related to memory information, read from the P-ISA kernel memory file. @@ -160,6 +179,7 @@ class Keyword: These keywords are used to identify different operations and data types within the memory file. """ + @classproperty def KEYGEN(cls): """Keyword for key generation.""" @@ -219,6 +239,7 @@ class MetaFields: """ Names of different metadata fields. """ + @classproperty def FIELD_KEYGEN_SEED(cls): return MemInfo.Keyword.LOAD_KEYGEN_SEED @@ -266,18 +287,21 @@ def FIELD_METADATA(cls): @classproperty def FIELD_METADATA_SUBFIELDS(cls): """Tuple of subfield names for metadata.""" - return ( cls.MetaFields.FIELD_KEYGEN_SEED, - cls.MetaFields.FIELD_TWIDDLE, - cls.MetaFields.FIELD_ONES, - cls.MetaFields.FIELD_NTT_AUX_TABLE, - cls.MetaFields.FIELD_NTT_ROUTING_TABLE, - cls.MetaFields.FIELD_iNTT_AUX_TABLE, - cls.MetaFields.FIELD_iNTT_ROUTING_TABLE ) + return ( + cls.MetaFields.FIELD_KEYGEN_SEED, + cls.MetaFields.FIELD_TWIDDLE, + cls.MetaFields.FIELD_ONES, + cls.MetaFields.FIELD_NTT_AUX_TABLE, + cls.MetaFields.FIELD_NTT_ROUTING_TABLE, + cls.MetaFields.FIELD_iNTT_AUX_TABLE, + cls.MetaFields.FIELD_iNTT_ROUTING_TABLE, + ) class MetaTargets: """ Targets for different metadata. """ + @classproperty def TARGET_ONES(cls): """Special target register for Ones.""" @@ -303,6 +327,7 @@ def TARGET_iNTT_ROUTING_TABLE(cls): """Special target register for rshuffle iNTT routing table.""" return 3 + class MemoryModel: """ Constants related to memory model. @@ -332,30 +357,37 @@ class MemoryModel: def XINST_QUEUE_MAX_CAPACITY(cls): """Maximum capacity of the XINST queue in bytes.""" return cls.__XINST_QUEUE_MAX_CAPACITY + @classproperty def XINST_QUEUE_MAX_CAPACITY_WORDS(cls): """Maximum capacity of the XINST queue in words.""" return convertBytes2Words(cls.__XINST_QUEUE_MAX_CAPACITY) + @classproperty def CINST_QUEUE_MAX_CAPACITY(cls): """Maximum capacity of the CINST queue in bytes.""" return cls.__CINST_QUEUE_MAX_CAPACITY + @classproperty def CINST_QUEUE_MAX_CAPACITY_WORDS(cls): """Maximum capacity of the CINST queue in words.""" return convertBytes2Words(cls.__CINST_QUEUE_MAX_CAPACITY) + @classproperty def MINST_QUEUE_MAX_CAPACITY(cls): """Maximum capacity of the MINST queue in bytes.""" return cls.__MINST_QUEUE_MAX_CAPACITY + @classproperty def MINST_QUEUE_MAX_CAPACITY_WORDS(cls): """Maximum capacity of the MINST queue in words.""" return convertBytes2Words(cls.__MINST_QUEUE_MAX_CAPACITY) + @classproperty def STORE_BUFFER_MAX_CAPACITY(cls): """Maximum capacity of the store buffer in bytes.""" return cls.__STORE_BUFFER_MAX_CAPACITY + @classproperty def STORE_BUFFER_MAX_CAPACITY_WORDS(cls): """Maximum capacity of the store buffer in words.""" @@ -432,20 +464,21 @@ def hw_spec_as_dict(cls) -> dict: """ Returns hw configurable attributes as dictionary. """ - dict = {"max_xinst_queue_size_in_bytes": cls.__XINST_QUEUE_MAX_CAPACITY, - "max_cinst_queue_size_in_bytes": cls.__CINST_QUEUE_MAX_CAPACITY, - "max_minst_queue_size_in_bytes": cls.__MINST_QUEUE_MAX_CAPACITY, - "max_store_buffer_size_in_bytes": cls.__STORE_BUFFER_MAX_CAPACITY, - "num_blocks_per_twid_meta_word": cls.NUM_BLOCKS_PER_TWID_META_WORD, - "num_blocks_per_kgseed_meta_word": cls.NUM_BLOCKS_PER_KGSEED_META_WORD, - "num_routing_table_registers": cls.NUM_ROUTING_TABLE_REGISTERS, - "num_ones_meta_registers": cls.NUM_ONES_META_REGISTERS, - "num_twiddle_meta_registers": cls.NUM_TWIDDLE_META_REGISTERS, - "twiddle_meta_register_size_in_bytes": cls.TWIDDLE_META_REGISTER_SIZE_BYTES, - "max_residuals": cls.MAX_RESIDUALS, - "num_register_banks": cls.NUM_REGISTER_BANKS, - "num_registers_per_bank": cls.NUM_REGISTERS_PER_BANK} - return dict + return { + "max_xinst_queue_size_in_bytes": cls.__XINST_QUEUE_MAX_CAPACITY, + "max_cinst_queue_size_in_bytes": cls.__CINST_QUEUE_MAX_CAPACITY, + "max_minst_queue_size_in_bytes": cls.__MINST_QUEUE_MAX_CAPACITY, + "max_store_buffer_size_in_bytes": cls.__STORE_BUFFER_MAX_CAPACITY, + "num_blocks_per_twid_meta_word": cls.NUM_BLOCKS_PER_TWID_META_WORD, + "num_blocks_per_kgseed_meta_word": cls.NUM_BLOCKS_PER_KGSEED_META_WORD, + "num_routing_table_registers": cls.NUM_ROUTING_TABLE_REGISTERS, + "num_ones_meta_registers": cls.NUM_ONES_META_REGISTERS, + "num_twiddle_meta_registers": cls.NUM_TWIDDLE_META_REGISTERS, + "twiddle_meta_register_size_in_bytes": cls.TWIDDLE_META_REGISTER_SIZE_BYTES, + "max_residuals": cls.MAX_RESIDUALS, + "num_register_banks": cls.NUM_REGISTER_BANKS, + "num_registers_per_bank": cls.NUM_REGISTERS_PER_BANK, + } @classmethod def setMaxXInstQueueCapacity(cls, val: int): @@ -544,6 +577,7 @@ class HBM: This class defines the maximum capacity of HBM in both bytes and words. """ + __MAX_CAPACITY: int @classproperty @@ -561,8 +595,7 @@ def hw_spec_as_dict(cls) -> dict: """ Returns hw configurable attributes as dictionary. """ - dict = {"max_hbm_size_in_bytes": cls.__MAX_CAPACITY} - return dict + return {"max_hbm_size_in_bytes": cls.__MAX_CAPACITY} @classmethod def setMaxCapacity(cls, val: int): @@ -577,6 +610,7 @@ class SPAD: This class defines the maximum capacity of SPAD in both bytes and words. """ + __MAX_CAPACITY: int # Class methods and properties @@ -591,14 +625,13 @@ def MAX_CAPACITY(cls) -> int: def MAX_CAPACITY_WORDS(cls) -> int: """Total capacity of SPAD in Words""" return convertBytes2Words(cls.__MAX_CAPACITY) - + @classmethod def hw_spec_as_dict(cls) -> dict: """ Returns hw configurable attributes as dictionary. """ - dict = {"max_cache_size_in_bytes": cls.__MAX_CAPACITY} - return dict + return {"max_cache_size_in_bytes": cls.__MAX_CAPACITY} @classmethod def setMaxCapacity(cls, val: int): diff --git a/assembler_tools/hec-assembler-tools/assembler/common/counter.py b/assembler_tools/hec-assembler-tools/assembler/common/counter.py index 0f16ec8d..7db53e5c 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/counter.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/counter.py @@ -11,7 +11,7 @@ """ import itertools -from typing import Set, Optional +from typing import Optional class Counter: @@ -80,7 +80,7 @@ def reset(self): """ self.__counter = itertools.count(self.start, self.step) - __counters: Set["Counter.CounterIter"] = set() + __counters: set["Counter.CounterIter"] = set() @classmethod def count(cls, start=0, step=1) -> "Counter.CounterIter": diff --git a/assembler_tools/hec-assembler-tools/assembler/common/cycle_tracking.py b/assembler_tools/hec-assembler-tools/assembler/common/cycle_tracking.py index 4e91213c..acdc2109 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/cycle_tracking.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/cycle_tracking.py @@ -1,6 +1,9 @@ -import numbers +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from typing import NamedTuple + class PrioritizedPlaceholder: """ Base class for priority queue items. @@ -22,9 +25,8 @@ class PrioritizedPlaceholder: _get_priority(): Returns the base priority of the item. _get_priority_delta(): Returns the priority delta of the item. """ - def __init__(self, - priority = (0, 0), - priority_delta = (0, 0)): + + def __init__(self, priority=(0, 0), priority_delta=(0, 0)): """ Initializes a new PrioritizedPlaceholder object. @@ -45,7 +47,7 @@ def priority(self): Returns: tuple: The current priority. """ - return tuple([sum(x) for x in zip(self._get_priority(), self.priority_delta)]) + return tuple([sum(x) for x in zip(self._get_priority(), self.priority_delta, strict=False)]) @property def priority_delta(self): @@ -113,6 +115,7 @@ def __gt__(self, other): """ return self.priority > other.priority + class CycleType(NamedTuple): """ Named tuple to add structure to a cycle type. @@ -158,9 +161,9 @@ def __add__(self, other): elif isinstance(other, tuple): return self.__binaryop_tuple(other, lambda m, n: m + n) else: - raise TypeError('`other`: expected type `int` or `tuple`.') + raise TypeError("`other`: expected type `int` or `tuple`.") - def __sub__(self, other): + def __sub__(self, other): """ Subtracts a tuple or an integer from the `CycleType`. @@ -178,7 +181,7 @@ def __sub__(self, other): elif isinstance(other, tuple): return self.__binaryop_tuple(other, lambda m, n: m - n) else: - raise TypeError('`other`: expected type `int` or `tuple`.') + raise TypeError("`other`: expected type `int` or `tuple`.") def __binaryop_cycles(self, cycles, binaryop_callable): """ @@ -191,7 +194,7 @@ def __binaryop_cycles(self, cycles, binaryop_callable): Returns: CycleType: The resulting `CycleType` after the operation. """ - assert(isinstance(cycles, int)) + assert isinstance(cycles, int) return CycleType(self.bundle, binaryop_callable(self.cycle, cycles)) def __binaryop_tuple(self, other, binaryop_callable): @@ -205,8 +208,11 @@ def __binaryop_tuple(self, other, binaryop_callable): Returns: CycleType: The resulting `CycleType` after the operation. """ - return CycleType(binaryop_callable(self.bundle, int(other[0]) if len(other) > 0 else 0), - binaryop_callable(self.cycle, int(other[1]) if len(other) > 1 else 0)) + return CycleType( + binaryop_callable(self.bundle, int(other[0]) if len(other) > 0 else 0), + binaryop_callable(self.cycle, int(other[1]) if len(other) > 1 else 0), + ) + class CycleTracker: """ @@ -239,9 +245,9 @@ def __init__(self, cycle_ready: CycleType): cycle_ready (CycleType): The initial cycle when the object is ready to be used. Must be a tuple with at least two elements (bundle, cycle). """ - assert(len(cycle_ready) > 1) + assert len(cycle_ready) > 1 self.__cycle_ready = CycleType(*cycle_ready) - self.tag = 0 # User-defined tag + self.tag = 0 # User-defined tag @property def cycle_ready(self): @@ -287,7 +293,7 @@ def _set_cycle_ready(self, value: CycleType): value (CycleType or tuple): New clock cycle when this object will be ready for use. The tuple should be in the form (bundle: int, cycle: int). """ - assert(len(value) > 1) - #if self.cycle_ready < value: + assert len(value) > 1 + # if self.cycle_ready < value: # self.__cycle_ready = CycleType(*value) self.__cycle_ready = CycleType(*value) diff --git a/assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py b/assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py index 26f8bf71..4923d448 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py @@ -8,10 +8,10 @@ and supports operations to update, remove, and retrieve tasks based on their priorities. """ -import heapq import bisect +import heapq import itertools -from typing import List, Dict, Optional, Tuple, Any +from typing import Any class PriorityQueue: @@ -59,10 +59,7 @@ def __next__(self): raise RuntimeError("PriorityQueue changed size during iteration.") # Skip all removed tasks - while ( - self.__current < len(self.__pq) - and self.__pq[self.__current][-1] is self.__removed - ): + while self.__current < len(self.__pq) and self.__pq[self.__current][-1] is self.__removed: self.__current += 1 if self.__current >= len(self.__pq): raise StopIteration @@ -82,12 +79,8 @@ def __init__(self): """ Initializes the priority tracker with empty mappings. """ - self.__priority_dict = ( - {} - ) # dict(int, SortedList(task)): maps priority to unordered set of tasks with same priority - self.__priority_dict_set = ( - {} - ) # dict(int, set(task)): maps priority to unordered set of tasks with same priority + self.__priority_dict = {} # dict(int, SortedList(task)): maps priority to unordered set of tasks with same priority + self.__priority_dict_set = {} # dict(int, set(task)): maps priority to unordered set of tasks with same priority def find(self, priority: int) -> object: """ @@ -99,11 +92,7 @@ def find(self, priority: int) -> object: Returns: object: A task with the specified priority, or None if not found. """ - return ( - next(iter(self.__priority_dict[priority]))[1] - if priority in self.__priority_dict - else None - ) + return next(iter(self.__priority_dict[priority]))[1] if priority in self.__priority_dict else None def push(self, priority: int, tie_breaker: tuple, task: object): """ @@ -151,13 +140,7 @@ def pop(self, priority: int, task=None) -> object: if task: # Find index for task idx = next( - ( - i - for i, (_, contained_task) in enumerate( - self.__priority_dict[priority] - ) - if contained_task == task - ), + (i for i, (_, contained_task) in enumerate(self.__priority_dict[priority]) if contained_task == task), len(self.__priority_dict[priority]), ) if idx >= len(self.__priority_dict[priority]): @@ -178,7 +161,7 @@ def pop(self, priority: int, task=None) -> object: __REMOVED = object() # Placeholder for a removed task - def __init__(self, queue: Optional[List[Tuple[int, Any]]] = None): + def __init__(self, queue: list[tuple[int, Any]] | None = None): """ Creates a new PriorityQueue object. @@ -190,15 +173,9 @@ def __init__(self, queue: Optional[List[Tuple[int, Any]]] = None): ValueError: If any task in the queue is None. """ # entry: [priority: int, nonce: int, task: hashable_object] - self.__pq: List[List[Any]] = ( - [] - ) # list(entry) - List of entries arranged in a heap - self.__entry_finder: Dict[Any, List[Any]] = ( - {} - ) # dictionary(task: Hashable_object, entry) - mapping of tasks to entries - self.__priority_tracker = ( - PriorityQueue.__PriorityTracker() - ) # Tracks tasks by priority + self.__pq: list[list[Any]] = [] # list(entry) - List of entries arranged in a heap + self.__entry_finder: dict[Any, list[Any]] = {} # dictionary(task: Hashable_object, entry) - mapping of tasks to entries + self.__priority_tracker = PriorityQueue.__PriorityTracker() # Tracks tasks by priority self.__counter = itertools.count(1) # Unique sequence count if queue: @@ -260,9 +237,7 @@ def __repr__(self): """ return f"<{type(self).__name__} object at {hex(id(self))}>(len={len(self)}, pq={self.__pq})" - def push( - self, priority: int, task: object, tie_breaker: Optional[Tuple[int, ...]] = None - ): + def push(self, priority: int, task: object, tie_breaker: tuple[int, ...] | None = None): """ Adds a new task or update the priority of an existing task. @@ -296,9 +271,7 @@ def push( if b_add_needed: if len(self.__pq) == 0: - self.__counter = itertools.count( - 1 - ) # restart sequence count when queue is empty + self.__counter = itertools.count(1) # restart sequence count when queue is empty count = next(self.__counter) entry = [priority, (tie_breaker, count), task] self.__entry_finder[task] = entry @@ -318,12 +291,10 @@ def remove(self, task: object): # mark an existing task as PriorityQueue.__REMOVED. entry = self.__entry_finder.pop(task) priority, *_ = entry - self.__priority_tracker.pop( - priority, task - ) # remove it from the priority tracker + self.__priority_tracker.pop(priority, task) # remove it from the priority tracker entry[-1] = PriorityQueue.__REMOVED - def peek(self) -> Optional[Tuple[int, Any]]: + def peek(self) -> tuple[int, Any] | None: """ Returns the task with the lowest priority without removing it from the queue. diff --git a/assembler_tools/hec-assembler-tools/assembler/common/queue_dict.py b/assembler_tools/hec-assembler-tools/assembler/common/queue_dict.py index d2ecefb9..5c39b9e9 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/queue_dict.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/queue_dict.py @@ -1,4 +1,8 @@ -from collections import deque +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections import deque + class QueueDict: """ @@ -12,6 +16,7 @@ class QueueDict: of items occur at the start of the queue structure only. No removals are allowed on any other items of the structure. """ + def __init__(self): """ Initializes a new, empty QueueDict object. @@ -87,7 +92,7 @@ def clear(self): self.__q.clear() self.__lookup = {} - def copy(self) -> object: # QueueDict + def copy(self) -> object: # QueueDict """ Returns a shallow copy of the QueueDict. @@ -116,7 +121,7 @@ def peek(self) -> tuple: def pop(self) -> tuple: """ Removes and returns the (key, value) pair item at the start of the QueueDict. - + Returns: tuple: The (key, value) pair that was removed. """ diff --git a/assembler_tools/hec-assembler-tools/assembler/common/run_config.py b/assembler_tools/hec-assembler-tools/assembler/common/run_config.py index e5ccaef2..2ea4939c 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/run_config.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/run_config.py @@ -26,9 +26,7 @@ class RunConfig: debug_verbose: int __initialized = False # Specifies whether static members have been initialized - __default_config = ( - {} - ) # Dictionary of all configuration items supported and their default values + __default_config = {} # Dictionary of all configuration items supported and their default values def __init__(self, **kwargs): """ @@ -78,23 +76,15 @@ def __init__(self, **kwargs): # Validate inputs if self.repl_policy not in constants.Constants.REPLACEMENT_POLICIES: - raise ValueError( - 'Invalid `repl_policy`. "{}" not in {}'.format( - self.repl_policy, constants.Constants.REPLACEMENT_POLICIES - ) - ) + raise ValueError('Invalid `repl_policy`. "{}" not in {}'.format(self.repl_policy, constants.Constants.REPLACEMENT_POLICIES)) @classproperty def DEFAULT_HBM_SIZE_KB(cls) -> int: - return int( - constants.MemoryModel.HBM.MAX_CAPACITY / constants.Constants.KILOBYTE - ) + return int(constants.MemoryModel.HBM.MAX_CAPACITY / constants.Constants.KILOBYTE) @classproperty def DEFAULT_SPAD_SIZE_KB(cls) -> int: - return int( - constants.MemoryModel.SPAD.MAX_CAPACITY / constants.Constants.KILOBYTE - ) + return int(constants.MemoryModel.SPAD.MAX_CAPACITY / constants.Constants.KILOBYTE) @classproperty def DEFAULT_REPL_POLICY(cls) -> int: @@ -148,7 +138,4 @@ def as_dict(self) -> dict: dict: A dictionary representation of the current configuration settings. """ tmp_self_dict = vars(self) - return { - config_name: tmp_self_dict[config_name] - for config_name in self.__default_config - } + return {config_name: tmp_self_dict[config_name] for config_name in self.__default_config} diff --git a/assembler_tools/hec-assembler-tools/assembler/common/utilities.py b/assembler_tools/hec-assembler-tools/assembler/common/utilities.py index 55b30d65..bfff346f 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/utilities.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/utilities.py @@ -1,5 +1,8 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 -def clamp(x, minimum = float("-inf"), maximum = float("inf")): + +def clamp(x, minimum=float("-inf"), maximum=float("inf")): """ Clamp a value between a specified minimum and maximum. diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/__init__.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/__init__.py index 837b3ae6..02db1a93 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/__init__.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/__init__.py @@ -1,4 +1,6 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from . import bload, bones, cexit, cload, cnop, cstore, csyncm, ifetch, kgload, kgseed, kgstart, nload, xinstfetch # MInst aliases diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py index 9d2b0837..890da01b 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py @@ -78,12 +78,7 @@ def __repr__(self): str: A string representation. """ assert len(self.sources) > 0 - retval = ( - "<{}({}) object at {}>(id={}[0], " - "col_num={}, m_idx={}, src={}, " - "mem_model, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "col_num={}, m_idx={}, src={}, " "mem_model, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -106,9 +101,7 @@ def _set_dests(self, value): Raises: RuntimeError: Always, as `bload` does not have destination parameters. """ - raise RuntimeError( - f"Instruction `{self.name}` does not have destination parameters." - ) + raise RuntimeError(f"Instruction `{self.name}` does not have destination parameters.") def _set_sources(self, value): """ @@ -123,8 +116,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -148,28 +140,19 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction, i.e., the number of cycles by which to advance the current cycle counter. """ - assert ( - Instruction._OP_NUM_SOURCES > 0 - and len(self.sources) == Instruction._OP_NUM_SOURCES - ) + assert Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES variable: Variable = self.sources[0] # expected sources to contain a Variable if variable.spad_address < 0: - raise RuntimeError( - f'Null Access Violation: Variable "{variable}" not allocated in SPAD.' - ) + raise RuntimeError(f'Null Access Violation: Variable "{variable}" not allocated in SPAD.') if self.m_idx < 0: raise RuntimeError(f"Invalid negative index `m_idx`.") if self.col_num not in range(4): - raise RuntimeError( - f"Invalid `col_num`: {self.col_num}. Must be in range [0, 4)." - ) + raise RuntimeError(f"Invalid `col_num`: {self.col_num}. Must be in range [0, 4).") retval = super()._schedule(cycle_count, schedule_id) # Track last access to SPAD address - spad_access_tracking = self.__mem_model.spad.getAccessTracking( - variable.spad_address - ) + spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) spad_access_tracking.last_cload = self # No need to sync to any previous MLoads after bload spad_access_tracking.last_mload = None diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py index 31fa5108..f40df3c4 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py @@ -73,12 +73,7 @@ def __repr__(self): str: A string representation of the Instruction object. """ assert len(self.sources) > 0 - retval = ( - "<{}({}) object at {}>(id={}[0], " - "src_col_num={}, src={}, " - "mem_model, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "src_col_num={}, src={}, " "mem_model, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -115,8 +110,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -140,24 +134,17 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction, i.e., the number of cycles by which to advance the current cycle counter. """ - assert ( - Instruction._OP_NUM_SOURCES > 0 - and len(self.sources) == Instruction._OP_NUM_SOURCES - ) + assert Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES variable: Variable = self.sources[0] # Expected sources to contain a Variable. if variable.spad_address < 0: - raise RuntimeError( - f"Null Access Violation: Variable `{variable}` not allocated in SPAD." - ) + raise RuntimeError(f"Null Access Violation: Variable `{variable}` not allocated in SPAD.") if self.src_col_num < 0: raise RuntimeError("Invalid `src_col_num` negative `Ones` target index.") retval = super()._schedule(cycle_count, schedule_id) # Track last access to SPAD address. - spad_access_tracking = self.__mem_model.spad.getAccessTracking( - variable.spad_address - ) + spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) spad_access_tracking.last_cload = self # No need to sync to any previous MLoads after bones. spad_access_tracking.last_mload = None diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py index 124542b8..38794432 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py @@ -24,9 +24,7 @@ def _get_op_name_asm(cls) -> str: """ return "cexit" - def __init__( - self, id: int, throughput: int = None, latency: int = None, comment: str = "" - ): + def __init__(self, id: int, throughput: int = None, latency: int = None, comment: str = ""): """ Constructs a new `cexit` CInstruction. @@ -49,9 +47,7 @@ def __repr__(self): Returns: str: A string representation of the Instruction object. """ - retval = ( - "<{}({}) object at {}>(id={}[0], " "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py index 0dc8bfe9..778a3ab3 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py @@ -25,9 +25,7 @@ class CInstruction(BaseInstruction): # Constructor # ----------- - def __init__( - self, instruction_id: int, throughput: int, latency: int, comment: str = "" - ): + def __init__(self, instruction_id: int, throughput: int, latency: int, comment: str = ""): """ Constructs a new CInstruction. @@ -49,9 +47,7 @@ def _get_op_name_asm(cls) -> str: Returns: str: The ASM name for the operation. """ - raise NotImplementedError( - "Derived CInstruction must implement _get_op_name_asm." - ) + raise NotImplementedError("Derived CInstruction must implement _get_op_name_asm.") # Methods and properties # ---------------------- diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py index 34c27c74..233ce2fc 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py @@ -55,9 +55,7 @@ def __init__( Raises: AssertionError: If the destination register bank index is not 0. """ - assert ( - dst.bank.bank_index == 0 - ) # We must be following convention of loading from SPAD into bank 0 + assert dst.bank.bank_index == 0 # We must be following convention of loading from SPAD into bank 0 if not throughput: throughput = Instruction._OP_DEFAULT_THROUGHPUT if not latency: @@ -75,11 +73,7 @@ def __repr__(self): str: A string representation of the Instruction object. """ assert len(self.dests) > 0 - retval = ( - "<{}({}) object at {}>(id={}[0], " - "dst={}, src={}," - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "dst={}, src={}," "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -105,8 +99,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} `Register` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Register` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -128,8 +121,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -155,40 +147,26 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction, i.e., the number of cycles by which to advance the current cycle counter. """ - assert ( - Instruction._OP_NUM_DESTS > 0 - and len(self.dests) == Instruction._OP_NUM_DESTS - ) - assert ( - Instruction._OP_NUM_SOURCES > 0 - and len(self.sources) == Instruction._OP_NUM_SOURCES - ) + assert Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS + assert Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES variable: Variable = self.sources[0] # Expected sources to contain a Variable target_register: Register = self.dests[0] if variable.spad_address < 0: - raise RuntimeError( - f"Null Access Violation: Variable `{variable}` not allocated in SPAD." - ) + raise RuntimeError(f"Null Access Violation: Variable `{variable}` not allocated in SPAD.") # Cannot allocate variable to more than one register (memory coherence) # and must not overwrite a register that already contains a variable. if variable.register: - raise RuntimeError( - f"Variable `{variable}` already allocated in register `{variable.register}`." - ) + raise RuntimeError(f"Variable `{variable}` already allocated in register `{variable.register}`.") if target_register.contained_variable: - raise RuntimeError( - f"Register `{target_register}` already contains a Variable object." - ) + raise RuntimeError(f"Register `{target_register}` already contains a Variable object.") retval = super()._schedule(cycle_count, schedule_id) # Perform the load target_register.allocateVariable(variable) # Track last access to SPAD address - spad_access_tracking = self.__mem_model.spad.getAccessTracking( - variable.spad_address - ) + spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) spad_access_tracking.last_cload = self # No need to sync to any previous MLoads after cload spad_access_tracking.last_mload = None diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py index 21161e67..b68950e6 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py @@ -75,11 +75,7 @@ def __repr__(self): str: A string representation of the Instruction object, including its type, name, memory address, ID, throughput, and latency. """ - retval = ( - "<{}({}) object at {}>(id={}[0], " - "mem_model, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "mem_model, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -102,8 +98,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -126,8 +121,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Register` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Register` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -161,9 +155,7 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: """ spad = self.__mem_model.spad - var_name, (variable, self.__spad_addr) = ( - self.__mem_model.store_buffer.pop() - ) # Will raise IndexError if popping from empty queue + var_name, (variable, self.__spad_addr) = self.__mem_model.store_buffer.pop() # Will raise IndexError if popping from empty queue assert var_name == variable.name assert self.__spad_addr >= 0 and ( variable.spad_address < 0 or variable.spad_address == self.__spad_addr diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py index 0268689a..61cf2482 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py @@ -69,11 +69,7 @@ def __repr__(self): str: A string representation of the Instruction object, including its type, name, memory address, ID, minstr, throughput, and latency. """ - retval = ( - "<{}({}) object at {}>(id={}[0], " - "minstr={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "minstr={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py index 1072221d..f7988edf 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py @@ -56,9 +56,7 @@ def __init__( if not latency: latency = Instruction._OP_DEFAULT_LATENCY super().__init__(id, throughput, latency, comment=comment) - self.bundle_id = ( - bundle_id # Instruction number from the MINST queue for which to wait - ) + self.bundle_id = bundle_id # Instruction number from the MINST queue for which to wait def __repr__(self): """ @@ -68,11 +66,7 @@ def __repr__(self): str: A string representation of the Instruction object, including its type, name, memory address, ID, bundle_id, throughput, and latency. """ - retval = ( - "<{}({}) object at {}>(id={}[0], " - "bundle_id={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "bundle_id={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py index ad9fb69e..67383cd0 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py @@ -28,9 +28,7 @@ class Instruction(CInstruction): @classmethod def set_num_sources(cls, val): - cls._OP_NUM_SOURCES = ( - val + 1 - ) # Adding the keygen variable (since the actual instruction needs no sources) + cls._OP_NUM_SOURCES = val + 1 # Adding the keygen variable (since the actual instruction needs no sources) @classmethod def _get_op_name_asm(cls) -> str: @@ -86,12 +84,7 @@ def __repr__(self): its type, name, memory address, ID, column number, memory index, source, throughput, and latency. """ assert len(self.sources) > 0 - retval = ( - "<{}({}) object at {}>(id={}[0], " - "col_num={}, m_idx={}, src={}, " - "mem_model, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "col_num={}, m_idx={}, src={}, " "mem_model, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -118,8 +111,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} `Register` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Register` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -142,8 +134,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -170,14 +161,8 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert ( - Instruction._OP_NUM_DESTS > 0 - and len(self.dests) == Instruction._OP_NUM_DESTS - ) - assert ( - Instruction._OP_NUM_SOURCES > 0 - and len(self.sources) == Instruction._OP_NUM_SOURCES - ) + assert Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS + assert Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES variable: Variable = self.sources[0] # Expected sources to contain a Variable target_register: Register = self.dests[0] @@ -187,13 +172,9 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: # Cannot allocate variable to more than one register (memory coherence) # and must not overwrite a register that already contains a variable. if variable.register: - raise RuntimeError( - f"Variable `{variable}` already allocated in register `{variable.register}`." - ) + raise RuntimeError(f"Variable `{variable}` already allocated in register `{variable.register}`.") if target_register.contained_variable: - raise RuntimeError( - f"Register `{target_register}` already contains a Variable object." - ) + raise RuntimeError(f"Register `{target_register}` already contains a Variable object.") retval = super()._schedule(cycle_count, schedule_id) # Variable generated, reflect the load diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py index 92f16661..0cbd1f55 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py @@ -88,12 +88,7 @@ def __repr__(self): its type, name, memory address, ID, block index, source, throughput, and latency. """ assert len(self.sources) > 0 - retval = ( - "<{}({}) object at {}>(id={}[0], " - "col_num={}, m_idx={}, src={}, " - "mem_model, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "col_num={}, m_idx={}, src={}, " "mem_model, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -115,9 +110,7 @@ def _set_dests(self, value): Raises: RuntimeError: Always raised as the instruction does not have destination parameters. """ - raise RuntimeError( - f"Instruction `{self.name}` does not have destination parameters." - ) + raise RuntimeError(f"Instruction `{self.name}` does not have destination parameters.") def _set_sources(self, value): """ @@ -133,8 +126,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -160,26 +152,17 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert ( - Instruction._OP_NUM_SOURCES > 0 - and len(self.sources) == Instruction._OP_NUM_SOURCES - ) + assert Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES variable: Variable = self.sources[0] # Expected sources to contain a Variable if variable.spad_address < 0: - raise RuntimeError( - f'Null Access Violation: Variable "{variable}" not allocated in SPAD.' - ) + raise RuntimeError(f'Null Access Violation: Variable "{variable}" not allocated in SPAD.') if self.block_index not in range(4): - raise RuntimeError( - f"Invalid `block_index`: {self.block_index}. Must be in range [0, 4)." - ) + raise RuntimeError(f"Invalid `block_index`: {self.block_index}. Must be in range [0, 4).") retval = super()._schedule(cycle_count, schedule_id) # Track last access to SPAD address - spad_access_tracking = self.__mem_model.spad.getAccessTracking( - variable.spad_address - ) + spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) spad_access_tracking.last_cload = self # No need to sync to any previous MLoads after kg_seed spad_access_tracking.last_mload = None diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py index a8eafca8..17a656a3 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py @@ -30,9 +30,7 @@ def _get_op_name_asm(cls) -> str: """ return "kg_start" - def __init__( - self, id: int, throughput: int = None, latency: int = None, comment: str = "" - ): + def __init__(self, id: int, throughput: int = None, latency: int = None, comment: str = ""): """ Constructs a new `kg_start` CInstruction. @@ -59,9 +57,7 @@ def __repr__(self): str: A string representation of the Instruction object, including its type, name, memory address, ID, throughput, and latency. """ - retval = ( - "<{}({}) object at {}>(id={}[0], " "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py index 794b3d64..cd814732 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py @@ -82,12 +82,7 @@ def __repr__(self): its type, name, memory address, ID, table index, source, throughput, and latency. """ assert len(self.sources) > 0 - retval = ( - "<{}({}) object at {}>(id={}[0], " - "table_idx={}, src={}, " - "mem_model, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "table_idx={}, src={}, " "mem_model, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -125,8 +120,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -152,24 +146,17 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert ( - Instruction._OP_NUM_SOURCES > 0 - and len(self.sources) == Instruction._OP_NUM_SOURCES - ) + assert Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES variable: Variable = self.sources[0] # Expected sources to contain a Variable if variable.spad_address < 0: - raise RuntimeError( - f"Null Access Violation: Variable `{variable}` not allocated in SPAD." - ) + raise RuntimeError(f"Null Access Violation: Variable `{variable}` not allocated in SPAD.") if self.table_idx < 0: raise RuntimeError("Invalid `table_idx` negative routing table index.") retval = super()._schedule(cycle_count, schedule_id) # Track last access to SPAD address - spad_access_tracking = self.__mem_model.spad.getAccessTracking( - variable.spad_address - ) + spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) spad_access_tracking.last_cload = self # No need to sync to any previous MLoads after bones spad_access_tracking.last_mload = None diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py index a5085233..8d070d7b 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py @@ -76,11 +76,7 @@ def __repr__(self): its type, name, memory address, ID, xq_dst, hbm_src, throughput, and latency. """ assert len(self.dests) > 0 - retval = ( - "<{}({}) object at {}>(id={}[0], " - "xq_dst={}, hbm_src={}," - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "xq_dst={}, hbm_src={}," "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -133,14 +129,10 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - if ( - self.xq_dst < 0 - or self.xq_dst >= constants.MemoryModel.XINST_QUEUE_MAX_CAPACITY_WORDS - ): + if self.xq_dst < 0 or self.xq_dst >= constants.MemoryModel.XINST_QUEUE_MAX_CAPACITY_WORDS: raise RuntimeError( ( - "Invalid `xq_dst` XINST queue destination address. Expected value in range " - "[0, {}), but received {}.".format( + "Invalid `xq_dst` XINST queue destination address. Expected value in range " "[0, {}), but received {}.".format( constants.MemoryModel.XINST_QUEUE_MAX_CAPACITY_WORDS, self.xq_dst, ) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py index 0418b49e..af1c51e4 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """BaseInstruction and related classes for assembler instructions.""" + from typing import final, NamedTuple, List, Optional # pylint: disable=too-many-instance-attributes, too-many-public-methods @@ -112,9 +113,7 @@ class BaseInstruction(CycleTracker): _OP_DEFAULT_THROUGHPUT: int _OP_DEFAULT_LATENCY: int - __id_count = Counter.count( - 0 - ) # internal unique sequence counter to generate unique IDs + __id_count = Counter.count(0) # internal unique sequence counter to generate unique IDs # Class methods and properties # ---------------------------- @@ -182,9 +181,7 @@ def _get_op_name_asm(cls) -> str: # Constructor # ----------- - def __init__( - self, instruction_id: int, throughput: int, latency: int, comment: str = "" - ): + def __init__(self, instruction_id: int, throughput: int, latency: int, comment: str = ""): """ Initializes a new BaseInstruction object. @@ -204,19 +201,9 @@ def __init__( """ # validate inputs if throughput < 1: - raise ValueError( - ( - f"`throughput`: must be a positive number, " - f"but {throughput} received." - ) - ) + raise ValueError((f"`throughput`: must be a positive number, " f"but {throughput} received.")) if latency < throughput: - raise ValueError( - ( - f"`latency`: cannot be less than throughput. " - f"Expected, at least, {throughput}, but {latency} received." - ) - ) + raise ValueError((f"`latency`: cannot be less than throughput. " f"Expected, at least, {throughput}, but {latency} received.")) super().__init__(CycleType(0, 0)) self.__id = (instruction_id, next(BaseInstruction.__id_count)) @@ -286,12 +273,8 @@ def set_schedule_timing_index(self, value: int): ValueError: If the value is less than 0. """ if value < 0: - raise ValueError( - "`value`: expected a value of `0` or greater for `schedule_timing.index`." - ) - self.__schedule_timing = ScheduleTiming( - cycle=self.__schedule_timing.cycle, index=value - ) + raise ValueError("`value`: expected a value of `0` or greater for `schedule_timing.index`.") + self.__schedule_timing = ScheduleTiming(cycle=self.__schedule_timing.cycle, index=value) @property def is_scheduled(self) -> bool: @@ -428,9 +411,7 @@ def _get_cycle_ready(self): # INST1's dests are ready in cycle 6 and they are written to in cycle 5. # If INST2 uses any INST1 dest as its dest, INST2 can start the cycle # following INST1, 2, because INST2 will write to the same dest in cycle 6. - retval = max( - retval, *(dst.cycle_ready - self.latency + 1 for dst in self.dests) - ) + retval = max(retval, *(dst.cycle_ready - self.latency + 1 for dst in self.dests)) return retval def freeze(self): @@ -455,9 +436,7 @@ def freeze(self): RuntimeError: If the instruction has not been scheduled yet. """ if not self.is_scheduled: - raise RuntimeError( - f"Instruction `{self.name}` (id = {self.id}) is not yet scheduled." - ) + raise RuntimeError(f"Instruction `{self.name}` (id = {self.id}) is not yet scheduled.") self._frozen_pisa = self._to_pisa_format() self._frozen_xisa = self._to_xasmisa_format() @@ -487,15 +466,11 @@ def _schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: the current cycle counter. """ if self.is_scheduled: - raise RuntimeError( - f"Instruction `{self.name}` (id = {self.id}) is already scheduled." - ) + raise RuntimeError(f"Instruction `{self.name}` (id = {self.id}) is already scheduled.") if schedule_idx < 1: raise ValueError("`schedule_idx`: expected a value of `1` or greater.") if len(cycle_count) < 2: - raise ValueError( - "`cycle_count`: expected a pair/tuple with two components." - ) + raise ValueError("`cycle_count`: expected a pair/tuple with two components.") if cycle_count < self.cycle_ready: raise RuntimeError( f"Instruction {self.name}, id: {self.id}, not ready to schedule. " diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/__init__.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/__init__.py index 7b485ccb..977c2d06 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/__init__.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/__init__.py @@ -1,4 +1,6 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from . import mload, mstore, msyncc # MInst aliases diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py index f7075eab..9116d8b7 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py @@ -89,11 +89,7 @@ def __repr__(self): its type, name, memory address, ID, source, destination SPAD address, throughput, and latency. """ assert len(self.dests) > 0 - retval = ( - "<{}({}) object at {}>(id={}[0], " - "src={}, dst_spad_addr={}, mem_model, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "src={}, dst_spad_addr={}, mem_model, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -131,8 +127,7 @@ def __internal_set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -155,8 +150,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -186,14 +180,8 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert ( - Instruction._OP_NUM_SOURCES > 0 - and len(self.sources) == Instruction._OP_NUM_SOURCES - ) - assert ( - Instruction._OP_NUM_DESTS > 0 - and len(self.dests) == Instruction._OP_NUM_DESTS - ) + assert Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES + assert Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS assert all(src == dst for src, dst in zip(self.sources, self.dests)) hbm = self.__mem_model.hbm @@ -202,13 +190,9 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: variable: Variable = self.sources[0] if variable.spad_address >= 0: - raise RuntimeError( - "Source variable is already in SPAD. Cannot load a variable into SPAD more than once." - ) + raise RuntimeError("Source variable is already in SPAD. Cannot load a variable into SPAD more than once.") if variable.hbm_address < 0: - raise RuntimeError( - "Null reference exception: source variable is not in HBM." - ) + raise RuntimeError("Null reference exception: source variable is not in HBM.") retval = super()._schedule(cycle_count, schedule_id) # Perform the load diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py index 4c521adc..2d05b29b 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py @@ -91,11 +91,7 @@ def __repr__(self): its type, name, memory address, ID, source, destination HBM address, throughput, and latency. """ assert len(self.dests) > 0 - retval = ( - "<{}({}) object at {}>(id={}[0], " - "src={}, dst_hbm_addr={}, mem_model, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "src={}, dst_hbm_addr={}, mem_model, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -134,8 +130,7 @@ def __internal_set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -158,8 +153,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -189,14 +183,8 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert ( - Instruction._OP_NUM_SOURCES > 0 - and len(self.sources) == Instruction._OP_NUM_SOURCES - ) - assert ( - Instruction._OP_NUM_DESTS > 0 - and len(self.dests) == Instruction._OP_NUM_DESTS - ) + assert Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES + assert Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS assert all(src == dst for src, dst in zip(self.sources, self.dests)) hbm = self.__mem_model.hbm @@ -208,23 +196,17 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: if variable.hbm_address >= 0: if self.dst_hbm_addr != variable.hbm_address: - raise RuntimeError( - "Source variable is already in different HBM location. Cannot store a variable into HBM more than once." - ) + raise RuntimeError("Source variable is already in different HBM location. Cannot store a variable into HBM more than once.") assert hbm.buffer[variable.hbm_address] == variable if self.__source_spad_address < 0: - raise RuntimeError( - "Null reference exception: source variable is not in SPAD." - ) + raise RuntimeError("Null reference exception: source variable is not in SPAD.") if self.comment: self.comment += ";" # self.comment += ' variable "{}": HBM({}) <- SPAD({})'.format(variable.name, # self.dst_hbm_addr, # variable.spad_address) - self.comment += ' variable "{}" <- SPAD({})'.format( - variable.name, variable.spad_address - ) + self.comment += ' variable "{}" <- SPAD({})'.format(variable.name, variable.spad_address) retval = super()._schedule(cycle_count, schedule_id) # Perform the store diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py index a45e6160..0c4c3c4d 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py @@ -56,9 +56,7 @@ def __init__( if not latency: latency = Instruction._OP_DEFAULT_LATENCY super().__init__(id, throughput, latency, comment=comment) - self.cinstr = ( - cinstr # Instruction number from the MINST queue for which to wait - ) + self.cinstr = cinstr # Instruction number from the MINST queue for which to wait def __repr__(self): """ @@ -69,11 +67,7 @@ def __repr__(self): its type, name, memory address, ID, cinstr, throughput, and latency. """ assert len(self.dests) > 0 - retval = ( - "<{}({}) object at {}>(id={}[0], " - "cinstr={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "cinstr={}, " "throughput={}, latency={})").format( type(self).__name__, self.op_name_pisa, hex(id(self)), diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py index d9abaeab..6b6075a4 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py @@ -66,9 +66,7 @@ GLOBAL_CYCLE_TRACKING_INSTRUCTIONS = (rShuffle, irShuffle, XStore) -def createFromParsedObj( - mem_model: MemoryModel, inst_type, parsed_op, new_id: int = 0 -) -> XInstruction: +def createFromParsedObj(mem_model: MemoryModel, inst_type, parsed_op, new_id: int = 0) -> XInstruction: """ Creates an XInstruction object XInst from the specified namespace data. @@ -123,9 +121,7 @@ def createFromParsedObj( return inst_type(new_id, **parsed_op) -def createFromPISALine( - mem_model: MemoryModel, line: str, line_no: int = 0 -) -> XInstruction: +def createFromPISALine(mem_model: MemoryModel, line: str, line_no: int = 0) -> XInstruction: """ Parses an XInst from the specified string (in P-ISA kernel input format) and returns a XInstruction object encapsulating the resulting instruction. @@ -156,7 +152,6 @@ def createFromPISALine( retval = None try: - for inst_type in __PISA_INSTRUCTIONS: parsed_op = inst_type.parseFromPISALine(line) if parsed_op: diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py index 84af953b..cc62dcab 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py @@ -80,9 +80,7 @@ def parseFromPISALine(cls, line: str) -> list: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start) retval.update(dst_src) retval["res"] = int(instr_tokens[params_end]) @@ -141,11 +139,7 @@ def __repr__(self): Returns: str: A string representation of object. """ - retval = ( - "<{}({}) object at {}>(id={}[0], res={}, " - "dst={}, src={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], res={}, " "dst={}, src={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -171,8 +165,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -194,8 +187,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py index 7946e57f..bc848c3d 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py @@ -100,9 +100,7 @@ def parseFromPISALine(cls, line: str) -> list: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start) retval.update(dst_src) if len(instr_tokens) < cls._OP_NUM_TOKENS: # temporary warning to avoid syntax error during testing @@ -153,9 +151,7 @@ def __init__( N = 0 # does not require ring-size super().__init__(id, N, throughput, latency, comment=comment) if dst[0].name == src[0].name: - raise ValueError( - f'`dst`: Source and destination cannot be the same for instruction "{self.name}".' - ) + raise ValueError(f'`dst`: Source and destination cannot be the same for instruction "{self.name}".') self._set_dests(dst) self._set_sources(src) @@ -166,11 +162,7 @@ def __repr__(self): Returns: str: A string representation of object. """ - retval = ( - "<{}({}) object at {}>(id={}[0], " - "dst={}, src={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "dst={}, src={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -195,8 +187,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -218,8 +209,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py index f21334c7..472cc758 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py @@ -25,9 +25,7 @@ def _get_op_name_asm(cls) -> str: """ return "bexit" - def __init__( - self, id: int, throughput: int = None, latency: int = None, comment: str = "" - ): + def __init__(self, id: int, throughput: int = None, latency: int = None, comment: str = ""): """ Initializes an Instruction object with the given parameters. @@ -51,9 +49,7 @@ def __repr__(self): Returns: str: A string representation. """ - retval = ( - "<{}({}) object at {}>(id={}[0], " "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -73,9 +69,7 @@ def _set_dests(self, value): Raises: RuntimeError: Always raised as `bexit` does not have parameters. """ - raise RuntimeError( - f"Instruction `{self.op_name_pisa}` does not have parameters." - ) + raise RuntimeError(f"Instruction `{self.op_name_pisa}` does not have parameters.") def _set_sources(self, value): """ @@ -87,9 +81,7 @@ def _set_sources(self, value): Raises: RuntimeError: Always raised as `bexit` does not have parameters. """ - raise RuntimeError( - f"Instruction `{self.op_name_pisa}` does not have parameters." - ) + raise RuntimeError(f"Instruction `{self.op_name_pisa}` does not have parameters.") def _to_pisa_format(self, *extra_args) -> str: """ diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py index 03d6e7ef..4dcd3a1f 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py @@ -88,9 +88,7 @@ def parseFromPISALine(cls, line: str) -> object: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start) retval.update(dst_src) retval["stage"] = int(instr_tokens[params_end]) retval["res"] = int(instr_tokens[params_end + 1]) @@ -153,11 +151,7 @@ def __repr__(self): Returns: str: A string representation of the Instruction object. """ - retval = ( - "<{}({}) object at {}>(id={}[0], res={}, " - "dst={}, src={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], res={}, " "dst={}, src={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -193,8 +187,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -216,8 +209,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py index 2c27d914..8355636f 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py @@ -38,12 +38,8 @@ class Instruction(XInstruction): _OP_IRMOVE_LATENCY_MAX: int _OP_IRMOVE_LATENCY_INC: int - __irshuffle_global_cycle_ready = CycleType( - 0, 0 - ) # private class attribute to track cycle ready among irshuffles - __rshuffle_global_cycle_ready = CycleType( - 0, 0 - ) # private class attribute to track the cycle ready based on last rshuffle + __irshuffle_global_cycle_ready = CycleType(0, 0) # private class attribute to track cycle ready among irshuffles + __rshuffle_global_cycle_ready = CycleType(0, 0) # private class attribute to track the cycle ready based on last rshuffle @classmethod def isa_spec_as_dict(cls) -> dict: @@ -159,9 +155,7 @@ def parseFromPISALine(cls, line: str) -> object: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start) retval.update(dst_src) # ignore "res", but make sure it exists (syntax) assert instr_tokens[params_end] is not None @@ -233,10 +227,7 @@ def __init__( latency = Instruction._OP_DEFAULT_LATENCY if latency < Instruction._OP_IRMOVE_LATENCY: raise ValueError( - ( - f"`latency`: expected a value greater than or equal to " - "{Instruction._OP_IRMOVE_LATENCY}, but {latency} received." - ) + (f"`latency`: expected a value greater than or equal to " "{Instruction._OP_IRMOVE_LATENCY}, but {latency} received.") ) super().__init__(id, N, throughput, latency, comment=comment) @@ -252,9 +243,7 @@ def __repr__(self): Returns: str: A string representation of the Instruction object. """ - retval = ( - "<{}({}) object at {}>(id={}[0], " "dst={}, src={}, " "wait_cyc={}, res={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "dst={}, src={}, " "wait_cyc={}, res={})").format( type(self).__name__, self.name, hex(id(self)), @@ -383,15 +372,9 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: original_throughput = super()._schedule(cycle_count, schedule_id) retval = self.throughput + self.wait_cyc assert original_throughput <= retval - Instruction.__set_irshuffleGlobalCycleReady( - CycleType( - cycle_count.bundle, cycle_count.cycle + Instruction._OP_IRMOVE_LATENCY - ) - ) + Instruction.__set_irshuffleGlobalCycleReady(CycleType(cycle_count.bundle, cycle_count.cycle + Instruction._OP_IRMOVE_LATENCY)) # Avoid rshuffles and irshuffles in the same bundle - rshuffle.Instruction.set_irshuffleGlobalCycleReady( - CycleType(cycle_count.bundle + 1, 0) - ) + rshuffle.Instruction.set_irshuffleGlobalCycleReady(CycleType(cycle_count.bundle + 1, 0)) return retval def _to_pisa_format(self, *extra_args) -> str: diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py index 82d7e531..41dc6268 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py @@ -91,9 +91,7 @@ def parseFromPISALine(cls, line: str) -> list: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_PISA_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_PISA_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_PISA_SOURCES, params_start) retval.update(dst_src) retval["res"] = int(instr_tokens[params_end]) @@ -160,11 +158,7 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ( - "<{}({}) object at {}>(id={}[0], res={}, " - "dst={}, src={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], res={}, " "dst={}, src={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -190,8 +184,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -213,8 +206,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -243,9 +235,7 @@ def _to_pisa_format(self, *extra_args) -> str: raise ValueError("`extra_args` not supported.") preamble = (self.N,) - extra_args = ( - tuple(src.to_pisa_format() for src in self.sources[1:]) + extra_args - ) + extra_args = tuple(src.to_pisa_format() for src in self.sources[1:]) + extra_args extra_args = tuple(dst.to_pisa_format() for dst in self.dests) + extra_args if self.res is not None: extra_args += (self.res,) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py index f52d41f4..8c6c446d 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py @@ -89,9 +89,7 @@ def parseFromPISALine(cls, line: str) -> list: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_PISA_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_PISA_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_PISA_SOURCES, params_start) retval.update(dst_src) retval["imm"] = instr_tokens[params_end] retval["res"] = int(instr_tokens[params_end + 1]) @@ -163,11 +161,7 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ( - "<{}({}) object at {}>(id={}[0], res={}, imm={}, " - "dst={}, src={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], res={}, imm={}, " "dst={}, src={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -204,8 +198,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -227,8 +220,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -259,9 +251,7 @@ def _to_pisa_format(self, *extra_args) -> str: # N, muli, dst (bank), src0 (bank), imm, res # comment preamble = (self.N,) extra_args = (self.imm,) - extra_args = ( - tuple(src.to_pisa_format() for src in self.sources[1:]) + extra_args - ) + extra_args = tuple(src.to_pisa_format() for src in self.sources[1:]) + extra_args extra_args = tuple(dst.to_pisa_format() for dst in self.dests) + extra_args if self.res is not None: extra_args += (self.res,) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py index bf3409c6..a87271fa 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py @@ -66,12 +66,8 @@ def __init__( if not latency: latency = Instruction._OP_DEFAULT_LATENCY if any(isinstance(v, DummyVariable) or not v.name for v in src): - raise ValueError( - f"{Instruction.op_name_asm} cannot have dummy variable as source." - ) - if dst.contained_variable and not isinstance( - dst.contained_variable, DummyVariable - ): + raise ValueError(f"{Instruction.op_name_asm} cannot have dummy variable as source.") + if dst.contained_variable and not isinstance(dst.contained_variable, DummyVariable): raise ValueError( "{}: destination register must be empty, but variable {}.{} found.".format( Instruction.op_name_asm, @@ -92,11 +88,7 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ( - "<{}({}) object at {}>(id={}[0], " - "dst={}, src={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "dst={}, src={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -122,8 +114,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} `Register` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Register` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -145,8 +136,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -174,14 +164,8 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction, i.e., the number of cycles by which to advance the current cycle counter. """ - assert ( - Instruction._OP_NUM_DESTS > 0 - and len(self.dests) == Instruction._OP_NUM_DESTS - ) - assert ( - Instruction._OP_NUM_SOURCES > 0 - and len(self.sources) == Instruction._OP_NUM_SOURCES - ) + assert Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS + assert Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES variable = self.sources[0] # Expected sources to contain a Variable target_register = self.dests[0] @@ -189,13 +173,9 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: # Source and target types are swapped after scheduling # Instruction already scheduled: can only schedule once assert isinstance(target_register, Variable) - raise RuntimeError( - f"Instruction `{self.name}` (id = {self.id}) already scheduled." - ) + raise RuntimeError(f"Instruction `{self.name}` (id = {self.id}) already scheduled.") - if target_register.contained_variable and not isinstance( - target_register.contained_variable, DummyVariable - ): + if target_register.contained_variable and not isinstance(target_register.contained_variable, DummyVariable): raise RuntimeError( ( "Instruction `{}` (id = {}) " @@ -209,17 +189,12 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: ) ) - assert ( - not target_register.contained_variable - or self.__dummy_var == target_register.contained_variable - ) + assert not target_register.contained_variable or self.__dummy_var == target_register.contained_variable # Perform the move register_dirty = variable.register_dirty source_register = variable.register target_register.allocateVariable(variable) - source_register.allocateVariable( - self.__dummy_var - ) # Mark source register as free for next bundle + source_register.allocateVariable(self.__dummy_var) # Mark source register as free for next bundle assert source_register.bank.bank_index == 0 # Swap source and dest to keep the output format of the string instruction consistent self.sources[0] = source_register diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py index 2698b78e..d29cad7c 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py @@ -81,9 +81,7 @@ def parseFromPISALine(cls, line: str) -> Namespace: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start) retval.update(dst_src) retval["res"] = int(instr_tokens[params_end]) @@ -141,11 +139,7 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ( - "<{}({}) object at {}>(id={}[0], res={}, " - "dst={}, src={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], res={}, " "dst={}, src={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -171,8 +165,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -194,8 +187,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py index 3d1852aa..397bc10f 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py @@ -85,9 +85,7 @@ def parseFromPISALine(cls, line: str) -> list: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start) retval.update(dst_src) retval["imm"] = instr_tokens[params_end] retval["res"] = int(instr_tokens[params_end + 1]) @@ -150,11 +148,7 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ( - "<{}({}) object at {}>(id={}[0], res={}, imm={}, " - "dst={}, src={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], res={}, imm={}, " "dst={}, src={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -191,8 +185,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -214,8 +207,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py index 22c81924..e87a1bfd 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py @@ -82,9 +82,7 @@ def parseFromPISALine(cls, line: str) -> object: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start) retval.update(dst_src) retval["stage"] = int(instr_tokens[params_end]) retval["res"] = int(instr_tokens[params_end + 1]) @@ -147,11 +145,7 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ( - "<{}({}) object at {}>(id={}[0], res={}, " - "dst={}, src={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], res={}, " "dst={}, src={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -187,8 +181,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -210,8 +203,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py index c455e79f..b871e7f5 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py @@ -50,17 +50,13 @@ def parseXNTTKernelLine(line: str, op_name: str, tw_separator: str) -> Namespace instr_tokens = tokens[0] if len(instr_tokens) > OP_NUM_TOKENS: - warnings.warn( - f'Extra tokens detected for instruction "{op_name}"', SyntaxWarning - ) + warnings.warn(f'Extra tokens detected for instruction "{op_name}"', SyntaxWarning) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + OP_NUM_DESTS + OP_NUM_SOURCES - dst_src = xinst.XInstruction.parsePISASourceDestsFromTokens( - instr_tokens, OP_NUM_DESTS, OP_NUM_SOURCES, params_start - ) + dst_src = xinst.XInstruction.parsePISASourceDestsFromTokens(instr_tokens, OP_NUM_DESTS, OP_NUM_SOURCES, params_start) retval.update(dst_src) twiddle = instr_tokens[params_end] retval["res"] = int(instr_tokens[params_end + 1]) @@ -68,17 +64,11 @@ def parseXNTTKernelLine(line: str, op_name: str, tw_separator: str) -> Namespace # Parse twiddle (w___, where "_" is the `tw_separator`) twiddle_tokens = list(map(lambda s: s.strip(), twiddle.split(tw_separator))) if len(twiddle_tokens) != 4: - raise ValueError( - f'Error parsing twiddle information for "{op_name}" in line "{line}".' - ) + raise ValueError(f'Error parsing twiddle information for "{op_name}" in line "{line}".') if twiddle_tokens[0] != "w": - raise ValueError( - f'Invalid twiddle detected for "{op_name}" in line "{line}".' - ) + raise ValueError(f'Invalid twiddle detected for "{op_name}" in line "{line}".') if int(twiddle_tokens[1]) != retval["res"]: - raise ValueError( - f'Invalid "residual" component detected in twiddle information for "{op_name}" in line "{line}".' - ) + raise ValueError(f'Invalid "residual" component detected in twiddle information for "{op_name}" in line "{line}".') retval["stage"] = int(twiddle_tokens[2]) retval["block"] = int(twiddle_tokens[3]) @@ -113,11 +103,7 @@ def __generateRMoveParsedOp(kntt_parsed_op: Namespace) -> (type, Namespace): xrshuffle_type = xinst.irShuffle parsed_op["dst"] = [s for s in kntt_parsed_op.src] else: - raise ValueError( - '`kntt_parsed_op`: cannot process operation with name "{}".'.format( - kntt_parsed_op.op_name - ) - ) + raise ValueError('`kntt_parsed_op`: cannot process operation with name "{}".'.format(kntt_parsed_op.op_name)) assert xrshuffle_type @@ -156,11 +142,7 @@ def __generateTWNTTParsedOp(xntt_parsed_op: Namespace) -> Namespace: # Find types depending on whether we are doing ntt or intt twxntt_type = next( - ( - t - for t in (xinst.twNTT, xinst.twiNTT) - if t.op_name_pisa == parsed_op["op_name"] - ), + (t for t in (xinst.twNTT, xinst.twiNTT) if t.op_name_pisa == parsed_op["op_name"]), None, ) assert twxntt_type @@ -225,9 +207,7 @@ def __generateTWNTTParsedOp(xntt_parsed_op: Namespace) -> Namespace: return retval, Namespace(**parsed_op), tw_var_name_bank -def generateXNTT( - mem_model: MemoryModel, xntt_parsed_op: Namespace, new_id: int = 0 -) -> list: +def generateXNTT(mem_model: MemoryModel, xntt_parsed_op: Namespace, new_id: int = 0) -> list: """ Parses an `xntt` instruction from a P-ISA kernel instruction string. @@ -248,43 +228,27 @@ def generateXNTT( # Find xntt type depending on whether we are doing ntt or intt xntt_type = next( - ( - t - for t in (xinst.NTT, xinst.iNTT) - if t.op_name_pisa == xntt_parsed_op.op_name - ), + (t for t in (xinst.NTT, xinst.iNTT) if t.op_name_pisa == xntt_parsed_op.op_name), None, ) if not xntt_type: - raise ValueError( - '`xntt_parsed_op`: cannot process parsed kernel operation with name "{}".'.format( - xntt_parsed_op.op_name - ) - ) + raise ValueError('`xntt_parsed_op`: cannot process parsed kernel operation with name "{}".'.format(xntt_parsed_op.op_name)) # Generate twiddle instruction # ----------------------------- - twxntt_type, twxntt_parsed_op, last_twxinput_name = __generateTWNTTParsedOp( - xntt_parsed_op - ) + twxntt_type, twxntt_parsed_op, last_twxinput_name = __generateTWNTTParsedOp(xntt_parsed_op) # print(twxntt_parsed_op) twxntt_inst = None if twxntt_type: - twxntt_inst = xinst.createFromParsedObj( - mem_model, twxntt_type, twxntt_parsed_op, new_id - ) + twxntt_inst = xinst.createFromParsedObj(mem_model, twxntt_type, twxntt_parsed_op, new_id) # Generate corresponding rshuffle # ----------------------------- rshuffle_type, rshuffle_parsed_op = __generateRMoveParsedOp(xntt_parsed_op) - rshuffle_parsed_op.comment += ( - (" " + twxntt_parsed_op.comment) if twxntt_parsed_op else "" - ) - rshuffle_inst = xinst.createFromParsedObj( - mem_model, rshuffle_type, rshuffle_parsed_op, new_id - ) + rshuffle_parsed_op.comment += (" " + twxntt_parsed_op.comment) if twxntt_parsed_op else "" + rshuffle_inst = xinst.createFromParsedObj(mem_model, rshuffle_type, rshuffle_parsed_op, new_id) # Generate xntt instruction # -------------------------- diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py index 378d6c70..3a2a0560 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py @@ -31,12 +31,8 @@ class Instruction(XInstruction): _OP_REMOVE_LATENCY_MAX: int _OP_REMOVE_LATENCY_INC: int - __rshuffle_global_cycle_ready = CycleType( - 0, 0 - ) # Private class attribute to track cycle ready among rshuffles - __irshuffle_global_cycle_ready = CycleType( - 0, 0 - ) # Private class attribute to track the cycle ready based on last irshuffle + __rshuffle_global_cycle_ready = CycleType(0, 0) # Private class attribute to track cycle ready among rshuffles + __irshuffle_global_cycle_ready = CycleType(0, 0) # Private class attribute to track the cycle ready based on last irshuffle @classmethod def isa_spec_as_dict(cls) -> dict: @@ -152,9 +148,7 @@ def parseFromPISALine(cls, line: str) -> object: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start) retval.update(dst_src) # Ignore "res", but make sure it exists (syntax) assert instr_tokens[params_end] is not None @@ -206,10 +200,7 @@ def __init__( latency = Instruction._OP_DEFAULT_LATENCY if latency < Instruction._OP_REMOVE_LATENCY: raise ValueError( - ( - f"`latency`: expected a value greater than or equal to " - "{Instruction._OP_REMOVE_LATENCY}, but {latency} received." - ) + (f"`latency`: expected a value greater than or equal to " "{Instruction._OP_REMOVE_LATENCY}, but {latency} received.") ) super().__init__(id, N, throughput, latency, comment=comment) @@ -226,9 +217,7 @@ def __repr__(self): str: A string representation of the Instruction object, including its type, name, memory address, ID, destinations, sources, and wait cycles. """ - retval = ( - "<{}({}) object at {}>(id={}[0], " "dst={}, src={}, " "wait_cyc={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "dst={}, src={}, " "wait_cyc={})").format( type(self).__name__, self.name, hex(id(self)), @@ -357,15 +346,9 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: original_throughput = super()._schedule(cycle_count, schedule_id) retval = self.throughput + self.wait_cyc assert original_throughput <= retval - Instruction.__set_rshuffleGlobalCycleReady( - CycleType( - cycle_count.bundle, cycle_count.cycle + Instruction._OP_REMOVE_LATENCY - ) - ) + Instruction.__set_rshuffleGlobalCycleReady(CycleType(cycle_count.bundle, cycle_count.cycle + Instruction._OP_REMOVE_LATENCY)) # Avoid rshuffles and irshuffles in the same bundle - irshuffle.Instruction.set_rshuffleGlobalCycleReady( - CycleType(cycle_count.bundle + 1, 0) - ) + irshuffle.Instruction.set_rshuffleGlobalCycleReady(CycleType(cycle_count.bundle + 1, 0)) return retval def _to_pisa_format(self, *extra_args) -> str: diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py index d28febab..1ebf1468 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py @@ -79,9 +79,7 @@ def parseFromPISALine(cls, line: str) -> list: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start) retval.update(dst_src) retval["res"] = int(instr_tokens[params_end]) @@ -148,11 +146,7 @@ def __repr__(self): str: A string representation of the Instruction object, including its type, name, memory address, ID, residual, destinations, sources, throughput, and latency. """ - retval = ( - "<{}({}) object at {}>(id={}[0], res={}, " - "dst={}, src={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], res={}, " "dst={}, src={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -178,8 +172,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -201,8 +194,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py index 4c5904af..3ba58d1c 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py @@ -90,9 +90,7 @@ def parseFromPISALine(cls, line: str) -> object: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start) retval.update(dst_src) retval["tw_meta"] = int(instr_tokens[params_end]) retval["stage"] = int(instr_tokens[params_end + 1]) @@ -165,9 +163,7 @@ def __repr__(self): its type, name, memory address, ID, residual, tw_meta, stage, block, destinations, sources, throughput, and latency. """ retval = ( - "<{}({}) object at {}>(id={}[0], res={}, tw_meta={}, stage={}, block={}, " - "dst={}, src={}, " - "throughput={}, latency={})" + "<{}({}) object at {}>(id={}[0], res={}, tw_meta={}, stage={}, block={}, " "dst={}, src={}, " "throughput={}, latency={})" ).format( type(self).__name__, self.name, @@ -227,8 +223,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -250,8 +245,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py index 33928f16..5860f2c7 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py @@ -90,9 +90,7 @@ def parseFromPISALine(cls, line: str) -> object: retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens( - instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start - ) + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start) retval.update(dst_src) retval["tw_meta"] = int(instr_tokens[params_end]) retval["stage"] = int(instr_tokens[params_end + 1]) @@ -175,9 +173,7 @@ def __repr__(self): its type, name, memory address, ID, residual, tw_meta, stage, block, destinations, sources, throughput, and latency. """ retval = ( - "<{}({}) object at {}>(id={}[0], res={}, tw_meta={}, stage={}, block={}, " - "dst={}, src={}, " - "throughput={}, latency={})" + "<{}({}) object at {}>(id={}[0], res={}, tw_meta={}, stage={}, block={}, " "dst={}, src={}, " "throughput={}, latency={})" ).format( type(self).__name__, self.name, @@ -237,8 +233,7 @@ def _set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_DESTS, len(value) ) ) @@ -260,8 +255,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} Variable objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py index 719205e5..ff8481cc 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py @@ -48,9 +48,7 @@ def tokenizeFromPISALine(op_name: str, line: str) -> list: return retval @staticmethod - def parsePISASourceDestsFromTokens( - tokens: list, num_dests: int, num_sources: int, offset: int = 0 - ) -> dict: + def parsePISASourceDestsFromTokens(tokens: list, num_dests: int, num_sources: int, offset: int = 0) -> dict: """ Parses the sources and destinations for an instruction, given sources and destinations in tokens in P-ISA format. @@ -183,11 +181,7 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: # Check that variable is in register file if not v.register: # All variables must be in register before scheduling instruction - raise RuntimeError( - "Instruction( {}, id={} ): Variable {} not in register file.".format( - self.name, self.id, v.name - ) - ) + raise RuntimeError("Instruction( {}, id={} ): Variable {} not in register file.".format(self.name, self.id, v.name)) # Update accessed cycle v.last_x_access = cycle_count # Remove this instruction from access list @@ -197,16 +191,11 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: accessed_idx = idx break assert accessed_idx >= 0 - v.accessed_by_xinsts = ( - v.accessed_by_xinsts[:accessed_idx] - + v.accessed_by_xinsts[accessed_idx + 1 :] - ) + v.accessed_by_xinsts = v.accessed_by_xinsts[:accessed_idx] + v.accessed_by_xinsts[accessed_idx + 1 :] # Update ready cycle and dirty state of dests for dst in self.dests: - dst.cycle_ready = CycleType( - cycle_count.bundle, cycle_count.cycle + self.latency - ) + dst.cycle_ready = CycleType(cycle_count.bundle, cycle_count.cycle + self.latency) dst.register_dirty = True return retval diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py index 96899866..102336ec 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py @@ -25,9 +25,7 @@ class Instruction(XInstruction): reset_GlobalCycleReady: Resets the global cycle ready for `xstore` instructions. """ - __xstore_global_cycle_ready = CycleType( - 0, 0 - ) # private class attribute to track cycle ready among xstores + __xstore_global_cycle_ready = CycleType(0, 0) # private class attribute to track cycle ready among xstores @classmethod def _get_op_name_asm(cls) -> str: @@ -86,20 +84,10 @@ def __init__( self.__internal_set_dests(src) if dest_spad_addr < 0 and src[0].spad_address < 0: - raise ValueError( - "`dest_spad_addr` must be a valid SPAD address if source variable is not allocated in SPAD." - ) - if ( - dest_spad_addr >= 0 - and src[0].spad_address >= 0 - and dest_spad_addr != src[0].spad_address - ): - raise ValueError( - "`dest_spad_addr` must be null SPAD address (negative) if source variable is allocated in SPAD." - ) - self.dest_spad_address = ( - src[0].spad_address if dest_spad_addr < 0 else dest_spad_addr - ) + raise ValueError("`dest_spad_addr` must be a valid SPAD address if source variable is not allocated in SPAD.") + if dest_spad_addr >= 0 and src[0].spad_address >= 0 and dest_spad_addr != src[0].spad_address: + raise ValueError("`dest_spad_addr` must be null SPAD address (negative) if source variable is allocated in SPAD.") + self.dest_spad_address = src[0].spad_address if dest_spad_addr < 0 else dest_spad_addr def __repr__(self): """ @@ -109,11 +97,7 @@ def __repr__(self): str: A string representation of the Instruction object, including its type, name, memory address, ID, source, memory model, destination SPAD address, throughput, and latency. """ - retval = ( - "<{}({}) object at {}>(id={}[0], " - "src={}, mem_model, dest_spad_addr={}, " - "throughput={}, latency={})" - ).format( + retval = ("<{}({}) object at {}>(id={}[0], " "src={}, mem_model, dest_spad_addr={}, " "throughput={}, latency={})").format( type(self).__name__, self.name, hex(id(self)), @@ -172,8 +156,7 @@ def __internal_set_dests(self, value): if len(value) != Instruction._OP_NUM_DESTS: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -196,8 +179,7 @@ def _set_sources(self, value): if len(value) != Instruction._OP_NUM_SOURCES: raise ValueError( ( - "`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format( + "`value`: Expected list of {} `Variable` objects, " "but list with {} elements received.".format( Instruction._OP_NUM_SOURCES, len(value) ) ) @@ -243,37 +225,21 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert ( - Instruction._OP_NUM_SOURCES > 0 - and len(self.sources) == Instruction._OP_NUM_SOURCES - ) - assert ( - Instruction._OP_NUM_DESTS > 0 - and len(self.dests) == Instruction._OP_NUM_DESTS - ) + assert Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES + assert Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS assert all(src == dst for src, dst in zip(self.sources, self.dests)) if not isinstance(self.sources[0], Variable): - raise RuntimeError( - "XInstruction ({}, id = {}) already scheduled.".format( - self.name, self.id - ) - ) + raise RuntimeError("XInstruction ({}, id = {}) already scheduled.".format(self.name, self.id)) - store_buffer_item = MemoryModel.StoreBufferValueType( - variable=self.sources[0], dest_spad_address=self.dest_spad_address - ) + store_buffer_item = MemoryModel.StoreBufferValueType(variable=self.sources[0], dest_spad_address=self.dest_spad_address) register = self.sources[0].register retval = super()._schedule(cycle_count, schedule_id) # Perform xstore register.register_dirty = False # Register has been flushed register.allocateVariable(None) - self.sources[0] = ( - register # Make the register the source for freezing, since variable is no longer in it - ) - self.__mem_model.store_buffer[store_buffer_item.variable.name] = ( - store_buffer_item - ) + self.sources[0] = register # Make the register the source for freezing, since variable is no longer in it + self.__mem_model.store_buffer[store_buffer_item.variable.name] = store_buffer_item # Matching CInst cstore completes the xstore if self.comment: @@ -285,9 +251,7 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: ) # Set the global cycle ready for next xstore - Instruction.__set_xstoreGlobalCycleReady( - CycleType(cycle_count.bundle, cycle_count.cycle + self.latency) - ) + Instruction.__set_xstoreGlobalCycleReady(CycleType(cycle_count.bundle, cycle_count.cycle + self.latency)) return retval def _to_pisa_format(self, *extra_args) -> str: diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/__init__.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/__init__.py index 7bf6ccfa..b8c84135 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/__init__.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/__init__.py @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import os import math import pathlib @@ -13,6 +16,7 @@ from .variable import findVarByName from pickle import NONE + class MemoryModel: """ Represents a memory model with various components such as HBM, SPAD, and register banks. @@ -20,6 +24,7 @@ class MemoryModel: This class provides methods and properties to manage and interact with different parts of the memory model, including metadata variables and output variables. """ + class StoreBufferValueType(NamedTuple): """ Represents a value type for the store buffer. @@ -28,6 +33,7 @@ class StoreBufferValueType(NamedTuple): variable (Variable): The variable associated with the store buffer entry. dest_spad_address (int): The destination SPAD address for the variable. """ + variable: Variable dest_spad_address: int @@ -39,18 +45,16 @@ def MAX_TWIDDLE_META_VARS_PER_SEGMENT(cls): Returns: int: The number of variables per segment. """ - return math.ceil(constants.MemoryModel.NUM_TWIDDLE_META_REGISTERS * \ - constants.MemoryModel.TWIDDLE_META_REGISTER_SIZE_BYTES / \ - constants.Constants.WORD_SIZE) + return math.ceil( + constants.MemoryModel.NUM_TWIDDLE_META_REGISTERS + * constants.MemoryModel.TWIDDLE_META_REGISTER_SIZE_BYTES + / constants.Constants.WORD_SIZE + ) # Constructor # ----------- - def __init__(self, - hbm_capacity_words: int, - spad_capacity_words: int, - num_register_banks: int, - register_range: range = None): + def __init__(self, hbm_capacity_words: int, spad_capacity_words: int, num_register_banks: int, register_range: range = None): """ Initializes a new MemoryModel object. @@ -68,27 +72,30 @@ def __init__(self, assert self.MAX_TWIDDLE_META_VARS_PER_SEGMENT == 8 if num_register_banks < constants.MemoryModel.NUM_REGISTER_BANKS: - raise ValueError(('`num_register_banks`: there must be at least {} register banks, ' - 'but {} requested.').format(constants.MemoryModel.NUM_REGISTER_BANKS, - num_register_banks)) + raise ValueError( + ("`num_register_banks`: there must be at least {} register banks, " "but {} requested.").format( + constants.MemoryModel.NUM_REGISTER_BANKS, num_register_banks + ) + ) self.__register_range = range(constants.MemoryModel.NUM_REGISTERS_PER_BANK) if not register_range else register_range # initialize members - self.__store_buffer = QueueDict() # QueueDict(var_name: str, StoreBufferValueType) - self.__variables = {} # dict(var_name, Variable) - self.__meta_ones_vars = [] # list(QueueDict()) - self.meta_ntt_aux_table: str = "" # var name - self.meta_ntt_routing_table: str = "" # var name - self.meta_intt_aux_table: str = "" # var name - self.meta_intt_routing_table: str = "" # var name - self.__meta_twiddle_vars = [] # list(QueueDict()) - self.__meta_keygen_seed_vars = QueueDict() # QueueDict(var_name: str, None): set of variables that are seeds to this operation - self.__keygen_vars = dict() # dict(var_name: str, tuple(seed_idx: int, key_idx: int)): set of variables that are output to this operation - self.__output_vars = QueueDict() # QueueDict(var_name: str, None): set of variables that are output to this operation - self.__last_keygen_order = (0, -1) # tracks the generation order of last keygen variable; next must be 1 above this order. + self.__store_buffer = QueueDict() # QueueDict(var_name: str, StoreBufferValueType) + self.__variables = {} # dict(var_name, Variable) + self.__meta_ones_vars = [] # list(QueueDict()) + self.meta_ntt_aux_table: str = "" # var name + self.meta_ntt_routing_table: str = "" # var name + self.meta_intt_aux_table: str = "" # var name + self.meta_intt_routing_table: str = "" # var name + self.__meta_twiddle_vars = [] # list(QueueDict()) + self.__meta_keygen_seed_vars = QueueDict() # QueueDict(var_name: str, None): set of variables that are seeds to this operation + self.__keygen_vars = ( + dict() + ) # dict(var_name: str, tuple(seed_idx: int, key_idx: int)): set of variables that are output to this operation + self.__output_vars = QueueDict() # QueueDict(var_name: str, None): set of variables that are output to this operation + self.__last_keygen_order = (0, -1) # tracks the generation order of last keygen variable; next must be 1 above this order. self.__hbm = hbm.HBM(hbm_capacity_words) self.__spad = spad.SPAD(spad_capacity_words) - self.__register_file = tuple([register_file.RegisterBank(idx, self.__register_range) \ - for idx in range(num_register_banks)]) + self.__register_file = tuple([register_file.RegisterBank(idx, self.__register_range) for idx in range(num_register_banks)]) # Special Methods # --------------- @@ -100,18 +107,18 @@ def __repr__(self): Returns: str: The string representation. """ - retval = ('<{} object at {}>(hbm_capacity_words={}, ' - 'spad_capacity_words={}, ' - 'num_register_banks={}, ' - 'register_range={})').format(type(self).__name__, - hex(id(self)), - self.spad.CAPACITY_WORDS, - self.hbm.CAPACITY_WORDS, - len(self.reister_banks), - self.__register_range) + retval = ( + "<{} object at {}>(hbm_capacity_words={}, " "spad_capacity_words={}, " "num_register_banks={}, " "register_range={})" + ).format( + type(self).__name__, + hex(id(self)), + self.spad.CAPACITY_WORDS, + self.hbm.CAPACITY_WORDS, + len(self.register_banks), + self.__register_range, + ) return retval - # Methods and properties # ---------------------- @@ -188,13 +195,13 @@ def add_meta_ones_var(self, var_name: str): def meta_ones_vars_segments(self) -> list: """ Retrieves the set of variable names that have been marked as Metadata Ones variables. - - A list of segments (list[QueueDict(str, None)]), where each segment is - the set of variable names that have been marked as Metadata Ones variables. - The size of each set is given by the number of variables needed to fill up + + A list of segments (list[QueueDict(str, None)]), where each segment is + the set of variable names that have been marked as Metadata Ones variables. + The size of each set is given by the number of variables needed to fill up the ones metadata registers (see constants.MemoryModel.NUM_ONES_META_REGISTERS). Clients should not change these values. Use add_meta_ones_var() to add new ones metadata. - + Returns: list: A list of segments, each containing variable names. @@ -214,8 +221,7 @@ def add_meta_twiddle_var(self, var_name: str): if var_name not in self.variables: raise RuntimeError(f'Variable "{var_name}" is not in memory model.') # Twiddle metadata variables are grouped in segments of 8 - if len(self.__meta_twiddle_vars) <= 0 \ - or len(self.__meta_twiddle_vars[-1]) >= self.MAX_TWIDDLE_META_VARS_PER_SEGMENT: + if len(self.__meta_twiddle_vars) <= 0 or len(self.__meta_twiddle_vars[-1]) >= self.MAX_TWIDDLE_META_VARS_PER_SEGMENT: self.__meta_twiddle_vars.append(QueueDict()) self.__meta_twiddle_vars[-1].push(var_name, None) @@ -226,7 +232,7 @@ def meta_twiddle_vars_segments(self) -> list: Clients should not change these values. Use meta_twiddle_vars_segments() to add new twiddle metadata. - + A list of segments (list[QueueDict(str, None)]), where each segment is a set of variable names that have been marked as Metadata Twiddle variables. The size of each set is given by the number of variables needed to fill up the twiddle @@ -247,12 +253,13 @@ def isMetaVar(self, var_name: str) -> bool: Returns: bool: True if the variable is a meta variable, False otherwise. """ - return bool(var_name) and \ - (var_name in self.meta_keygen_seed_vars \ - or any(var_name in meta_twiddle_vars for meta_twiddle_vars in self.meta_twiddle_vars_segments) \ - or any(var_name in meta_ones_vars for meta_ones_vars in self.meta_ones_vars_segments) \ - or var_name in set((self.meta_ntt_aux_table, self.meta_ntt_routing_table, - self.meta_intt_aux_table, self.meta_intt_routing_table))) + return bool(var_name) and ( + var_name in self.meta_keygen_seed_vars + or any(var_name in meta_twiddle_vars for meta_twiddle_vars in self.meta_twiddle_vars_segments) + or any(var_name in meta_ones_vars for meta_ones_vars in self.meta_ones_vars_segments) + or var_name + in set((self.meta_ntt_aux_table, self.meta_ntt_routing_table, self.meta_intt_aux_table, self.meta_intt_routing_table)) + ) @property def output_variables(self) -> QueueDict: @@ -325,11 +332,14 @@ def add_keygen_variable(self, var_name: str, seed_index: int, key_index: int): if var_name in self.output_variables: raise RuntimeError(f'Variable "{var_name}" is marked as output and cannot be marked as key material.') if key_index < 0: - raise IndexError('`key_index` must be a valid zero-based index.') + raise IndexError("`key_index` must be a valid zero-based index.") if seed_index < 0 or seed_index >= len(self.meta_keygen_seed_vars): - raise IndexError(('`seed_index` must be a valid index into the existing keygen seeds. ' - 'Expected value in range [0, {}), but {} received.').format(len(self.meta_keygen_seed_vars), - seed_index)) + raise IndexError( + ( + "`seed_index` must be a valid index into the existing keygen seeds. " + "Expected value in range [0, {}), but {} received." + ).format(len(self.meta_keygen_seed_vars), seed_index) + ) self.keygen_variables[var_name] = (seed_index, key_index) @@ -353,9 +363,7 @@ def isVarInMem(self, var_name: str) -> bool: variable: Variable = self.variables[var_name] return variable.hbm_address >= 0 or variable.spad_address >= 0 or variable.register is not None - def retrieveVarAdd(self, - var_name: str, - suggested_bank: int = -1) -> Variable: + def retrieveVarAdd(self, var_name: str, suggested_bank: int = -1) -> Variable: """ Retrieves a Variable object from the global list of variables or add a new variable if not found. @@ -378,10 +386,11 @@ def retrieveVarAdd(self, retval.suggested_bank = suggested_bank elif suggested_bank >= 0: if retval.suggested_bank != suggested_bank: - raise ValueError(('`suggested_bank`: value {} does not match existing variable "{}" ' - 'suggested bank of {}.').format(suggested_bank, - var_name, - retval.suggested_bank)) + raise ValueError( + ('`suggested_bank`: value {} does not match existing variable "{}" ' "suggested bank of {}.").format( + suggested_bank, var_name, retval.suggested_bank + ) + ) return retval def findUniqueVarName(self) -> str: @@ -407,46 +416,49 @@ def __dumpVariables(self, ostream): """ print("name, hbm, spad, spad dirty, suggested bank, register, register_dirty, last xinst use, pending xinst use", file=ostream) for _, variable in self.variables.items(): - print('{}, {}, {}, {}, {}, {}, {}'.format(variable.name, - variable.hbm_address, - variable.spad_address, - variable.spad_dirty, - variable.suggested_bank, - variable.register, - variable.register_dirty, - repr(variable.last_x_access), - repr(variable.accessed_by_xinsts)), - file = ostream) - - def dump(self, - output_dir = ''): + print( + "{}, {}, {}, {}, {}, {}, {}".format( + variable.name, + variable.hbm_address, + variable.spad_address, + variable.spad_dirty, + variable.suggested_bank, + variable.register, + variable.register_dirty, + repr(variable.last_x_access), + repr(variable.accessed_by_xinsts), + ), + file=ostream, + ) + + def dump(self, output_dir=""): """ Dump the memory model information to files in the specified output directory. Args: - output_dir (str, optional): - The directory to write the dump files to. + output_dir (str, optional): + The directory to write the dump files to. Defaults to the current working directory. """ if not output_dir: output_dir = os.path.join(pathlib.Path.cwd(), "tmp") - pathlib.Path(output_dir).mkdir(exist_ok = True, parents=True) - print('******************') - print(f'Dumping to: {output_dir}') + pathlib.Path(output_dir).mkdir(exist_ok=True, parents=True) + print("******************") + print(f"Dumping to: {output_dir}") vars_filename = os.path.join(output_dir, "variables.dump.csv") hbm_filename = os.path.join(output_dir, "hbm.dump.csv") spad_filename = os.path.join(output_dir, "spad.dump.csv") - with open(vars_filename, 'w') as outnum: + with open(vars_filename, "w") as outnum: self.__dumpVariables(outnum) - with open(hbm_filename, 'w') as outnum: + with open(hbm_filename, "w") as outnum: self.hbm.dump(outnum) - with open(spad_filename, 'w') as outnum: + with open(spad_filename, "w") as outnum: self.spad.dump(outnum) for idx, rb in enumerate(self.register_banks): register_filename = os.path.join(output_dir, f"register_bank_{idx}.dump.csv") - with open(register_filename, 'w') as outnum: + with open(register_filename, "w") as outnum: rb.dump(outnum) - print('******************') + print("******************") diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/hbm.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/hbm.py index 105d790c..01554ab1 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/hbm.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/hbm.py @@ -1,9 +1,13 @@ -from assembler.common.constants import MemoryModel as mmconstants +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common.constants import MemoryModel as mmconstants from assembler.common.decorators import * from .memory_bank import MemoryBank from .variable import Variable, findVarByName from . import mem_utilities as utilities + class HBM(MemoryBank): """ Encapsulates the high-bandwidth DRAM memory model, also known as HBM. @@ -36,8 +40,7 @@ class HBM(MemoryBank): Dumps the current state of the HBM to the specified output stream. """ - def __init__(self, - hbm_data_capacity_words: int): + def __init__(self, hbm_data_capacity_words: int): """ Initializes a new HBM object. @@ -49,15 +52,18 @@ def __init__(self, """ # validate input if hbm_data_capacity_words > mmconstants.HBM.MAX_CAPACITY_WORDS: - raise ValueError(("`hbm_data_capacity_words` must be in the range (0, {}], " - "but {} received.".format(mmconstants.HBM.MAX_CAPACITY_WORDS, hbm_data_capacity_words))) + raise ValueError( + ( + "`hbm_data_capacity_words` must be in the range (0, {}], " "but {} received.".format( + mmconstants.HBM.MAX_CAPACITY_WORDS, hbm_data_capacity_words + ) + ) + ) # initialize base super().__init__(hbm_data_capacity_words) - def allocateForce(self, - hbm_addr: int, - var: Variable): + def allocateForce(self, hbm_addr: int, var: Variable): """ Forces the allocation of an existing variable at a specific address. @@ -72,8 +78,9 @@ def allocateForce(self, # validate variable if var.hbm_address >= 0: # variable is already allocated (avoid dangling pointers) - raise ValueError(('`var`: Variable {} address is not cleared. ' - 'Expected negative address, but {} received.'.format(var, var.hbm_address))) + raise ValueError( + ("`var`: Variable {} address is not cleared. " "Expected negative address, but {} received.".format(var, var.hbm_address)) + ) # allocate in memory bank super().allocateForce(hbm_addr, var) @@ -113,11 +120,10 @@ def deallocateVariable(self, var: Variable) -> Variable: Variable: The object that was contained in the deallocated slot. """ retval = self.deallocate(var.hbm_address) - assert(retval.name == var.name) + assert retval.name == var.name return retval - def findAvailableAddress(self, - live_var_names) -> int: + def findAvailableAddress(self, live_var_names) -> int: """ Retrieves the next available HBM address. @@ -136,25 +142,21 @@ def dump(self, ostream): Args: ostream: The output stream to write the HBM state to. """ - print('HBM', file = ostream) - print(f'Max Capacity, {self.CAPACITY}, Bytes', file = ostream) - print(f'Max Capacity, {self.CAPACITY_WORDS}, Words', file = ostream) - print(f'Current Capacity, {self.currentCapacityWords}, Words', file = ostream) - print(f'Current Occupied, {self.CAPACITY_WORDS - self.currentCapacityWords}, Words', file = ostream) - print("", file = ostream) - print("address, variable, variable hbm", file = ostream) + print("HBM", file=ostream) + print(f"Max Capacity, {self.CAPACITY}, Bytes", file=ostream) + print(f"Max Capacity, {self.CAPACITY_WORDS}, Words", file=ostream) + print(f"Current Capacity, {self.currentCapacityWords}, Words", file=ostream) + print(f"Current Occupied, {self.CAPACITY_WORDS - self.currentCapacityWords}, Words", file=ostream) + print("", file=ostream) + print("address, variable, variable hbm", file=ostream) last_addr = 0 for addr, variable in enumerate(self.buffer): if variable is not None: for idx in range(last_addr, addr): # empty addresses - print(f'{idx}, None', file = ostream) + print(f"{idx}, None", file=ostream) if variable.name: - print('{}, {}'.format(addr, - variable.name, - variable.hbm_address), - file = ostream) + print("{}, {}".format(addr, variable.name, variable.hbm_address), file=ostream) else: - print('f{addr}, Dummy_{variable.tag}', - file = ostream) + print("f{addr}, Dummy_{variable.tag}", file=ostream) last_addr = addr + 1 diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py index 8d034d59..6bc5613a 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py @@ -264,10 +264,7 @@ def parse_meta_field_from_mem_tokens( """ retval = None if len(tokens) >= 3: - if ( - tokens[0] == MemInfo.Const.Keyword.LOAD - and tokens[1] == meta_field_name - ): + if tokens[0] == MemInfo.Const.Keyword.LOAD and tokens[1] == meta_field_name: hbm_addr = int(tokens[2]) if len(tokens) >= 4 and tokens[3]: # name supplied in the tokenized line @@ -290,9 +287,7 @@ def __init__(self, **kwargs): """ self.__meta_dict = {} for meta_field in MemInfo.Const.FIELD_METADATA_SUBFIELDS: - self.__meta_dict[meta_field] = [ - MemInfoVariable(**d) for d in kwargs.get(meta_field, []) - ] + self.__meta_dict[meta_field] = [MemInfoVariable(**d) for d in kwargs.get(meta_field, [])] def __getitem__(self, key): """ @@ -394,9 +389,7 @@ def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: seed_idx = int(tokens[1]) key_idx = int(tokens[2]) var_name = tokens[3] - retval = MemInfoKeygenVariable( - var_name=var_name, seed_index=seed_idx, key_index=key_idx - ) + retval = MemInfoKeygenVariable(var_name=var_name, seed_index=seed_idx, key_index=key_idx) return retval class Input: @@ -413,16 +406,11 @@ def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ retval = None if len(tokens) >= 4: - if ( - tokens[0] == MemInfo.Const.Keyword.LOAD - and tokens[1] == MemInfo.Const.Keyword.LOAD_INPUT - ): + if tokens[0] == MemInfo.Const.Keyword.LOAD and tokens[1] == MemInfo.Const.Keyword.LOAD_INPUT: hbm_addr = int(tokens[2]) var_name = tokens[3] if Variable.validateName(var_name): - retval = MemInfoVariable( - var_name=var_name, hbm_address=hbm_addr - ) + retval = MemInfoVariable(var_name=var_name, hbm_address=hbm_addr) return retval class Output: @@ -443,9 +431,7 @@ def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: hbm_addr = int(tokens[2]) var_name = tokens[1] if Variable.validateName(var_name): - retval = MemInfoVariable( - var_name=var_name, hbm_address=hbm_addr - ) + retval = MemInfoVariable(var_name=var_name, hbm_address=hbm_addr) return retval def __init__(self, **kwargs): @@ -459,19 +445,10 @@ def __init__(self, **kwargs): kwargs (dict): A dictionary as generated by the method MemInfo.as_dict(). This is provided as a shortcut to creating a MemInfo object from structured data such as the contents of a YAML file. """ - self._keygens = [ - MemInfoKeygenVariable(**d) - for d in kwargs.get(MemInfo.Const.FIELD_KEYGENS, []) - ] - self._inputs = [ - MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_INPUTS, []) - ] - self._outputs = [ - MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_OUTPUTS, []) - ] - self._metadata = MemInfo.Metadata( - **kwargs.get(MemInfo.Const.FIELD_METADATA, {}) - ) + self._keygens = [MemInfoKeygenVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_KEYGENS, [])] + self._inputs = [MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_INPUTS, [])] + self._outputs = [MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_OUTPUTS, [])] + self._metadata = MemInfo.Metadata(**kwargs.get(MemInfo.Const.FIELD_METADATA, {})) self.validate() @property @@ -504,9 +481,7 @@ def mem_info_types(cls): return dummy.factory_dict.keys() @classmethod - def get_meminfo_var_from_tokens( - cls, tokens - ) -> tuple[Optional[MemInfoVariable], Optional[type]]: + def get_meminfo_var_from_tokens(cls, tokens) -> tuple[Optional[MemInfoVariable], Optional[type]]: """ Parses a MemInfo variable from a list of tokens. @@ -665,14 +640,9 @@ def validate(self): Raises: RuntimeError: If the validation fails due to inconsistent metadata or duplicate variable names. """ - if len( - self.metadata.ones - ) * MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT != len(self.metadata.twiddle): + if len(self.metadata.ones) * MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT != len(self.metadata.twiddle): raise RuntimeError( - ( - "Expected {} times as many twiddles as ones metadata values, " - "but received {} twiddles and {} ones." - ).format( + ("Expected {} times as many twiddles as ones metadata values, " "but received {} twiddles and {} ones.").format( MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT, len(self.metadata.twiddle), len(self.metadata.ones), @@ -695,10 +665,7 @@ def validate(self): mem_info_vars[var_info.var_name] = var_info elif mem_info_vars[var_info.var_name].hbm_address != var_info.hbm_address: raise RuntimeError( - ( - 'Variable "{}" already allocated in HBM address {}, ' - "but new allocation requested into address {}." - ).format( + ('Variable "{}" already allocated in HBM address {}, ' "but new allocation requested into address {}.").format( var_info.var_name, mem_info_vars[var_info.var_name].hbm_address, var_info.hbm_address, @@ -728,15 +695,10 @@ def _allocateMemInfoVariable(mem_model: MemoryModel, v_info: MemInfoVariable): f"Variable {v_info.var_name} not in memory model. All variables used in mem info must be present in P-ISA kernel." ) if mem_model.variables[v_info.var_name].hbm_address < 0: - mem_model.hbm.allocateForce( - v_info.hbm_address, mem_model.variables[v_info.var_name] - ) + mem_model.hbm.allocateForce(v_info.hbm_address, mem_model.variables[v_info.var_name]) elif v_info.hbm_address != mem_model.variables[v_info.var_name].hbm_address: raise RuntimeError( - ( - "Variable {} already allocated in HBM address {}, " - "but new allocation requested into address {}." - ).format( + ("Variable {} already allocated in HBM address {}, " "but new allocation requested into address {}.").format( v_info.var_name, mem_model.variables[v_info.var_name].hbm_address, v_info.hbm_address, diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_utilities.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_utilities.py index ed9b29dc..8cfd6768 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_utilities.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_utilities.py @@ -1,8 +1,11 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from assembler.common.constants import Constants from assembler.common.cycle_tracking import CycleType from assembler.common.priority_queue import PriorityQueue + def computePriority(variable, replacement_policy): """ Computes the priority for reusing the location of a specified variable. @@ -18,30 +21,28 @@ def computePriority(variable, replacement_policy): Returns: tuple: A tuple representing the priority for reusing the variable's location. """ - retval = (float("-inf"), ) # Default: highest priority if no variable + retval = (float("-inf"),) # Default: highest priority if no variable if variable: # Register in use # last_x_access = variable.last_x_access.bundle * Constants.MAX_BUNDLE_SIZE + variable.last_x_access.cycle \ - last_x_access = variable.last_x_access if variable.last_x_access \ - else CycleType(0, 0) + last_x_access = variable.last_x_access if variable.last_x_access else CycleType(0, 0) if replacement_policy == Constants.REPLACEMENT_POLICY_FTBU: if variable.accessed_by_xinsts: # Priority by - retval = (-variable.accessed_by_xinsts[0].index, # Largest (furthest) accessing instruction - *last_x_access, # Oldest accessed cycle (oldest == smallest) - len(variable.accessed_by_xinsts)) # How many more uses this variable has + retval = ( + -variable.accessed_by_xinsts[0].index, # Largest (furthest) accessing instruction + *last_x_access, # Oldest accessed cycle (oldest == smallest) + len(variable.accessed_by_xinsts), + ) # How many more uses this variable has elif replacement_policy == Constants.REPLACEMENT_POLICY_LRU: # Priority by oldest accessed cycle (oldest == smallest) - retval = (*last_x_access, ) + retval = (*last_x_access,) else: raise ValueError(f'`replacement_policy`: invalid value "{replacement_policy}". Expected value in {REPLACEMENT_POLICIES}.') return retval -def flushRegisterBank(register_bank, - current_cycle: CycleType, - replacement_policy, - live_var_names = None, - pct: float = 0.5): + +def flushRegisterBank(register_bank, current_cycle: CycleType, replacement_policy, live_var_names=None, pct: float = 0.5): """ Cleans up a register bank by removing variables assigned to registers. @@ -70,24 +71,20 @@ def flushRegisterBank(register_bank, v = reg.contained_variable if v is not None: occupied_count += 1 - if not reg.register_dirty \ - and (v.name and v.name not in live_var_names) \ - and current_cycle >= v.cycle_ready: + if not reg.register_dirty and (v.name and v.name not in live_var_names) and current_cycle >= v.cycle_ready: # Variable can be cleared from the register if needed priority = computePriority(v, replacement_policy) - local_heap.push(priority, reg, (idx, )) + local_heap.push(priority, reg, (idx,)) # Clean up registers until we reach the specified pct occupancy or we have # no registers left that can be cleaned up - while local_heap \ - and occupied_count / register_bank.register_count > pct: + while local_heap and occupied_count / register_bank.register_count > pct: _, reg = local_heap.pop() reg.allocateVariable(None) occupied_count -= 1 -def findAvailableLocation(vars_lst, - live_var_names, - replacement_policy: str = None): + +def findAvailableLocation(vars_lst, live_var_names, replacement_policy: str = None): """ Retrieves the index of the next available location in a collection of Variable objects. @@ -112,18 +109,19 @@ def findAvailableLocation(vars_lst, if no suitable location is found. """ if replacement_policy and replacement_policy not in Constants.REPLACEMENT_POLICIES: - raise ValueError(('`replacement_policy`: invalid value "{}". ' - 'Expected value in {} or None.').format(replacement_policy, - Constants.REPLACEMENT_POLICIES)) + raise ValueError( + ('`replacement_policy`: invalid value "{}". ' "Expected value in {} or None.").format( + replacement_policy, Constants.REPLACEMENT_POLICIES + ) + ) retval = -1 priority = (float("inf"), float("inf"), float("inf")) for idx, v in enumerate(vars_lst): if not v: retval = idx - break # Found an empty spot - elif replacement_policy \ - and (v.name and v.name not in live_var_names): # Avoids dummy variables + break # Found an empty spot + elif replacement_policy and (v.name and v.name not in live_var_names): # Avoids dummy variables # Find priority for replacement of this location v_priority = computePriority(v, replacement_policy) if v_priority < priority: diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/memory_bank.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/memory_bank.py index 3029ea0e..8daeca59 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/memory_bank.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/memory_bank.py @@ -1,4 +1,8 @@ -from assembler.common import constants +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common import constants + class MemoryBank: """ @@ -52,8 +56,7 @@ def fromCapacityBytes(cls, data_capacity_bytes: int): # Constructor # ----------- - def __init__(self, - data_capacity_words: int): + def __init__(self, data_capacity_words: int): """ Initializes a new MemoryBank object with a specified capacity in words. @@ -64,9 +67,8 @@ def __init__(self, ValueError: If the capacity is not a positive number. """ if data_capacity_words <= 0: - raise ValueError(("`data_capacity_words` must be a positive number, " - "but {} received.".format(data_capacity_words))) - self.__data_capacity_words = data_capacity_words # max capacity in words + raise ValueError(("`data_capacity_words` must be a positive number, " "but {} received.".format(data_capacity_words))) + self.__data_capacity_words = data_capacity_words # max capacity in words self.__data_capacity = constants.convertWords2Bytes(data_capacity_words) self.__buffer = [None for _ in range(self.__data_capacity_words)] self._current_data_capacity_words = self.__data_capacity_words @@ -114,9 +116,7 @@ def buffer(self): """ return self.__buffer - def allocateForce(self, - addr: int, - obj: object): + def allocateForce(self, addr: int, obj: object): """ Force the allocation of an existing object at a specific address. @@ -134,8 +134,7 @@ def allocateForce(self, if self.currentCapacityWords <= 0: raise RuntimeError("Critical error: Out of memory.") if addr < 0 or addr >= len(self.buffer): - raise ValueError(("`addr` out of range. Must be in range [0, {})," - "but {} received.".format(len(self.buffer), addr))) + raise ValueError(("`addr` out of range. Must be in range [0, {})," "but {} received.".format(len(self.buffer), addr))) if not self.buffer[addr]: # track the obj our buffer self.buffer[addr] = obj @@ -159,12 +158,11 @@ def deallocate(self, addr) -> object: object: The object that was contained in the deallocated slot. """ if addr < 0 or addr >= len(self.buffer): - raise ValueError(("`addr` out of range. Must be in range [0, {})," - "but {} received.".format(len(self.buffer), addr))) + raise ValueError(("`addr` out of range. Must be in range [0, {})," "but {} received.".format(len(self.buffer), addr))) obj = self.buffer[addr] if not obj: - raise ValueError('`addr`: Adress "{}" is already free.'.format(addr)) + raise ValueError('`addr`: Address "{}" is already free.'.format(addr)) self.buffer[addr] = None self._current_data_capacity_words += 1 diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py index 88be136c..a1f63379 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py @@ -61,12 +61,7 @@ def __init__(self, bank_index: int, register_range: range = None): ValueError: If the bank index is negative or if the register range is invalid. """ if bank_index < 0: - raise ValueError( - ( - f"`bank_index`: expected non-negative a index for bank, " - f"but {bank_index} received." - ) - ) + raise ValueError((f"`bank_index`: expected non-negative a index for bank, " f"but {bank_index} received.")) if not register_range: register_range = range(constants.MemoryModel.NUM_REGISTER_PER_BANKS) elif len(register_range) < 1: @@ -77,12 +72,7 @@ def __init__(self, bank_index: int, register_range: range = None): ) ) elif abs(register_range.step) != 1: - raise ValueError( - ( - f"`register_range`: expected a range within step of 1 or -1, " - f"but {register_range} received." - ) - ) + raise ValueError((f"`register_range`: expected a range within step of 1 or -1, " f"but {register_range} received.")) self.__bank_index = bank_index # list of registers in this bank self.__registers = [Register(self, register_i) for register_i in register_range] @@ -106,9 +96,7 @@ def __repr__(self): Returns: str: A string representation of the RegisterBank. """ - return "<{} object at {}>(bank_index = {})".format( - type(self).__name__, hex(id(self)), self.bank_index - ) + return "<{} object at {}>(bank_index = {})".format(type(self).__name__, hex(id(self)), self.bank_index) # Methods and properties # ---------------------- @@ -203,9 +191,7 @@ def dump(self, ostream): variable = register.contained_variable if variable is not None: if variable.name: - var_data = "{}, {}".format( - variable.name, variable.register, variable.register_dirty - ) + var_data = "{}, {}".format(variable.name, variable.register, variable.register_dirty) else: var_data = f"Dummy_{variable.tag}" print("{}, {}".format(register.name, var_data), file=ostream) @@ -246,10 +232,7 @@ def __init__(self, bank: RegisterBank, register_index: int): Raises: ValueError: If the register index is out of the valid range. """ - if ( - register_index < 0 - or register_index >= constants.MemoryModel.NUM_REGISTERS_PER_BANK - ): + if register_index < 0 or register_index >= constants.MemoryModel.NUM_REGISTERS_PER_BANK: raise ValueError( ( f"`register_index`: expected an index for register in the range [0, {constants.MemoryModel.NUM_REGISTERS_PER_BANK}), " @@ -275,9 +258,7 @@ def __eq__(self, other): Returns: bool: True if the other Register is the same as this one, False otherwise. """ - return other is self or ( - isinstance(other, Register) and other.name == self.name - ) + return other is self or (isinstance(other, Register) and other.name == self.name) def __hash__(self): """ @@ -307,9 +288,7 @@ def __repr__(self): var_section = "" if self.contained_variable: var_section = "Variable='{}'".format(self.contained_variable.name) - return "<{}({}) object at {}>({})".format( - type(self).__name__, self.name, hex(id(self)), var_section - ) + return "<{}({}) object at {}>({})".format(type(self).__name__, self.name, hex(id(self)), var_section) # Methods and properties # ---------------------- @@ -384,9 +363,7 @@ def allocateVariable(self, variable: Variable = None): old_var: Variable = self.contained_variable if old_var: # make old variable aware that it is no longer in this register - assert ( - not old_var.register_dirty - ) # we should not be deallocating dirty variables + assert not old_var.register_dirty # we should not be deallocating dirty variables old_var.register = None if variable: # make variable aware of new register diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/spad.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/spad.py index dc32fc32..962ad5e3 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/spad.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/spad.py @@ -1,4 +1,7 @@ -import itertools +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import itertools from assembler.common.constants import MemoryModel as mmconstants from assembler.common.counter import Counter @@ -7,6 +10,7 @@ from .variable import Variable from . import mem_utilities as utilities + class SPAD(MemoryBank): """ Encapsulates the SRAM cache, also known as SPAD, within the memory model. @@ -45,13 +49,9 @@ class AccessTracker: allowing clients to determine the order of accesses. """ - __idx_counter = Counter.count(0) # internal unique sequence counter to generate monotonous indices + __idx_counter = Counter.count(0) # internal unique sequence counter to generate monotonous indices - def __init__(self, - last_mload = None, - last_mstore = None, - last_cload = None, - last_cstore = None): + def __init__(self, last_mload=None, last_mstore=None, last_cload=None, last_cstore=None): self.__last_mload = (next(SPAD.AccessTracker.__idx_counter), last_mload) self.__last_mstore = (next(SPAD.AccessTracker.__idx_counter), last_mstore) self.__last_cload = (next(SPAD.AccessTracker.__idx_counter), last_cload) @@ -122,8 +122,7 @@ def last_cstore(self, value: object): # Constructor # ----------- - def __init__(self, - data_capacity_words: int): + def __init__(self, data_capacity_words: int): """ Initializes a new SPAD object representing the SRAM cache or scratchpad. @@ -135,13 +134,16 @@ def __init__(self, """ # validate input if data_capacity_words > mmconstants.SPAD.MAX_CAPACITY_WORDS: - raise ValueError(("`data_capacity_words` must be in the range (0, {}], " - "but {} received.").format(mmconstants.SPAD.MAX_CAPACITY_WORDS, data_capacity_words)) + raise ValueError( + ("`data_capacity_words` must be in the range (0, {}], " "but {} received.").format( + mmconstants.SPAD.MAX_CAPACITY_WORDS, data_capacity_words + ) + ) # initialize base super().__init__(data_capacity_words) - self.__var_lookup = {} # dict(var_name: str, variable: Variable) - reverse look-up on variable name - self.__access_tracker = [ SPAD.AccessTracker() for _ in range(len(self.buffer)) ] + self.__var_lookup = {} # dict(var_name: str, variable: Variable) - reverse look-up on variable name + self.__access_tracker = [SPAD.AccessTracker() for _ in range(len(self.buffer))] # Special methods # --------------- @@ -220,9 +222,7 @@ def findContainedVariable(self, var_name: str) -> Variable: """ return self.__var_lookup[var_name] if var_name in self.__var_lookup else None - def allocateForce(self, - addr: int, - variable: Variable): + def allocateForce(self, addr: int, variable: Variable): """ Forces the allocation of an existing `Variable` object at a specific address. @@ -235,17 +235,21 @@ def allocateForce(self, RuntimeError: If the SPAD is out of capacity. """ if variable.spad_address < 0: - assert(variable.name not in self.__var_lookup) + assert variable.name not in self.__var_lookup # Allocate variable in SPAD super().allocateForce(addr, variable) variable.spad_address = addr - if variable.name: # avoid dummy vars + if variable.name: # avoid dummy vars self.__var_lookup[variable.name] = variable elif addr >= 0 and variable.spad_address != addr: # Multiple allocations not allowed - raise ValueError(('`variable` already allocated in address "{}", ' - 'but new allocation requested in address "{}".'.format(variable.spad_address, - addr))) + raise ValueError( + ( + '`variable` already allocated in address "{}", ' 'but new allocation requested in address "{}".'.format( + variable.spad_address, addr + ) + ) + ) def deallocate(self, addr) -> object: """ @@ -261,14 +265,12 @@ def deallocate(self, addr) -> object: Variable: The Variable object that was contained in the deallocated slot. """ retval = super().deallocate(addr) - retval.spad_address = -1 # deallocate variable - if retval.name: # avoid dummy vars + retval.spad_address = -1 # deallocate variable + if retval.name: # avoid dummy vars self.__var_lookup.pop(retval.name) return retval - def findAvailableAddress(self, - live_var_names, - replacement_policy: str = None) -> int: + def findAvailableAddress(self, live_var_names, replacement_policy: str = None) -> int: """ Retrieves the next available SPAD address or propose an address to use if all are occupied. @@ -279,9 +281,7 @@ def findAvailableAddress(self, Returns: int: The first empty address, or the address to replace if all are occupied. Returns -1 if no suitable address is found. """ - return utilities.findAvailableLocation(self.buffer, - live_var_names, - replacement_policy) + return utilities.findAvailableLocation(self.buffer, live_var_names, replacement_policy) def dump(self, ostream): """ @@ -290,33 +290,35 @@ def dump(self, ostream): Args: ostream: The output stream to write the SPAD state to. """ - print('SPAD', file = ostream) - print(f'Max Capacity, {self.CAPACITY}, Bytes', file = ostream) - print(f'Max Capacity, {self.CAPACITY_WORDS}, Words', file = ostream) - print(f'Current Capacity, {self.currentCapacityWords}, Words', file = ostream) - print(f'Current Occupied, {self.CAPACITY_WORDS - self.currentCapacityWords}, Words', file = ostream) - print("", file = ostream) - print("address, variable, variable spad, dirty, last mload, last mstore, last cload, last cstore", file = ostream) + print("SPAD", file=ostream) + print(f"Max Capacity, {self.CAPACITY}, Bytes", file=ostream) + print(f"Max Capacity, {self.CAPACITY_WORDS}, Words", file=ostream) + print(f"Current Capacity, {self.currentCapacityWords}, Words", file=ostream) + print(f"Current Occupied, {self.CAPACITY_WORDS - self.currentCapacityWords}, Words", file=ostream) + print("", file=ostream) + print("address, variable, variable spad, dirty, last mload, last mstore, last cload, last cstore", file=ostream) last_addr = 0 for addr, variable in enumerate(self.buffer): if variable is not None: for idx in range(last_addr, addr): # empty addresses - print(f'{idx}, None', file = ostream) + print(f"{idx}, None", file=ostream) if variable.name: spad_access_tracker = self.getAccessTracking(addr) - print('{}, {}, {}, {}, {}, {}, {}'.format(addr, - variable.name, - variable.spad_address, - variable.spad_dirty, - repr(spad_access_tracker.last_mload), - repr(spad_access_tracker.last_mstore), - repr(spad_access_tracker.last_cload), - repr(spad_access_tracker.last_cstore)), - - file = ostream) + print( + "{}, {}, {}, {}, {}, {}, {}".format( + addr, + variable.name, + variable.spad_address, + variable.spad_dirty, + repr(spad_access_tracker.last_mload), + repr(spad_access_tracker.last_mstore), + repr(spad_access_tracker.last_cload), + repr(spad_access_tracker.last_cstore), + ), + file=ostream, + ) else: - print('f{addr}, Dummy_{variable.tag}', - file = ostream) + print("f{addr}, Dummy_{variable.tag}", file=ostream) last_addr = addr + 1 diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py index cf51bee4..af9c7259 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py @@ -126,15 +126,12 @@ def __init__(self, var_name: str, suggested_bank: int = -1): # validate bank number if suggested_bank >= constants.MemoryModel.NUM_REGISTER_BANKS: raise ValueError( - ( - "`suggested_bank`: Expected negative to indicate no " - "suggestion or a bank index less than {}, but {} received." - ).format(constants.MemoryModel.NUM_REGISTER_BANKS, suggested_bank) + ("`suggested_bank`: Expected negative to indicate no " "suggestion or a bank index less than {}, but {} received.").format( + constants.MemoryModel.NUM_REGISTER_BANKS, suggested_bank + ) ) - super().__init__( - CycleType(0, 0) - ) # cycle ready in the form (bundle, clock_cycle) + super().__init__(CycleType(0, 0)) # cycle ready in the form (bundle, clock_cycle) self.__suggested_bank = suggested_bank # HBM data region address (zero-based word index) where this variable is stored. @@ -144,9 +141,7 @@ def __init__(self, var_name: str, suggested_bank: int = -1): self.__spad_dirty = False self.__register = None # Register self.__register_dirty = False - self.accessed_by_xinsts = ( - [] - ) # list of AccessElements containing instruction IDs that access this variable + self.accessed_by_xinsts = [] # list of AccessElements containing instruction IDs that access this variable self.last_x_access = None # last xinstruction that accessed this variable # Special methods @@ -230,9 +225,7 @@ def suggested_bank(self): def suggested_bank(self, value: int): if value >= constants.MemoryModel.NUM_REGISTER_BANKS: raise ValueError( - "`value`: must be in range [0, {}), but {} received.".format( - constants.MemoryModel.NUM_REGISTER_BANKS, str(value) - ) + "`value`: must be in range [0, {}), but {} received.".format(constants.MemoryModel.NUM_REGISTER_BANKS, str(value)) ) if value >= 0: # ignore negative values self.__suggested_bank = value @@ -256,13 +249,7 @@ def _set_register(self, value): if value: if not isinstance(value, Register): - raise ValueError( - ( - "`value`: expected a `Register`, but received a `{}`.".format( - type(value).__name__ - ) - ) - ) + raise ValueError(("`value`: expected a `Register`, but received a `{}`.".format(type(value).__name__))) self.__register = value else: self.__register = None @@ -358,9 +345,7 @@ def to_xasmisa_format(self) -> str: RuntimeError: If the variable is not allocated to a register. """ if not self.register: - raise RuntimeError( - "`Variable` object not allocated to register. Cannot convert to XInst ASM-ISA format." - ) + raise RuntimeError("`Variable` object not allocated to register. Cannot convert to XInst ASM-ISA format.") return self.register.to_xasmisa_format() def to_casmisa_format(self) -> str: @@ -374,9 +359,7 @@ def to_casmisa_format(self) -> str: RuntimeError: If the variable is not stored in SPAD. """ if self.spad_address < 0: - raise RuntimeError( - "`Variable` object not allocated in SPAD. Cannot convert to CInst ASM-ISA format." - ) + raise RuntimeError("`Variable` object not allocated in SPAD. Cannot convert to CInst ASM-ISA format.") return self.spad_address if GlobalConfig.hasHBM else self.name def to_masmisa_format(self) -> str: @@ -390,9 +373,7 @@ def to_masmisa_format(self) -> str: RuntimeError: If the variable is not stored in HBM. """ if self.hbm_address < 0: - raise RuntimeError( - "`Variable` object not allocated in HBM. Cannot convert to MInst ASM-ISA format." - ) + raise RuntimeError("`Variable` object not allocated in HBM. Cannot convert to MInst ASM-ISA format.") return self.name if GlobalConfig.useHBMPlaceHolders else self.hbm_address diff --git a/assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py b/assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py index d3999932..2db7d7a1 100644 --- a/assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py +++ b/assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py @@ -113,15 +113,11 @@ def init_isa_spec_from_json(cls, filename): for inst_type, ops in cls._target_ops.items(): if inst_type not in isa_spec: - raise ValueError( - f"Instruction type '{inst_type}' is not found in the JSON file." - ) + raise ValueError(f"Instruction type '{inst_type}' is not found in the JSON file.") for op_name, op in ops.items(): if op_name not in isa_spec[inst_type]: - raise ValueError( - f"Operation '{op_name}' is not found in the JSON file for instruction type '{inst_type}'." - ) + raise ValueError(f"Operation '{op_name}' is not found in the JSON file for instruction type '{inst_type}'.") attributes = isa_spec[inst_type][op_name] @@ -135,7 +131,6 @@ def init_isa_spec_from_json(cls, filename): @classmethod def initialize_isa_spec(cls, module_dir, isa_spec_file): - if not isa_spec_file: isa_spec_file = os.path.join(module_dir, "config/isa_spec.json") isa_spec_file = os.path.abspath(isa_spec_file) diff --git a/assembler_tools/hec-assembler-tools/assembler/spec_config/mem_spec.py b/assembler_tools/hec-assembler-tools/assembler/spec_config/mem_spec.py index 0fb5a061..ef1c9ce9 100644 --- a/assembler_tools/hec-assembler-tools/assembler/spec_config/mem_spec.py +++ b/assembler_tools/hec-assembler-tools/assembler/spec_config/mem_spec.py @@ -1,10 +1,14 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json import os import re -import json + from assembler.common.constants import Constants, MemoryModel -class MemSpecConfig: +class MemSpecConfig: _target_attributes = { "bytes_per_xinstruction": Constants.setXInstructionSizeBytes, "max_instructions_per_bundle": Constants.setMaxBundleSize, @@ -47,20 +51,19 @@ def dump_mem_spec_to_json(cls, filename): output_dict = {"mem_spec": hw_specs} # Write the dictionary to a JSON file - with open(filename, 'w') as json_file: + with open(filename, "w") as json_file: json.dump(output_dict, json_file, indent=4) - @classmethod def init_mem_spec_from_json(cls, filename): """ Updates class attributes using methods specified in the target_attributes dictionary based on a JSON file. - This method checks wether values found on json file exists in target dictionaries. + This method checks whether values found on json file exists in target dictionaries. Args: filename (str): The name of the JSON file to read from. """ - with open(filename, 'r') as json_file: + with open(filename) as json_file: data = json.load(json_file) # Check for the "mem_spec" section @@ -73,39 +76,38 @@ def init_mem_spec_from_json(cls, filename): missing_keys = set(cls._target_attributes.keys()) - set(mem_spec.keys()) if missing_keys: raise ValueError(f"The JSON file is missing the following attributes: {', '.join(missing_keys)}") - + # Internal function to convert size expressions to bytes def parse_size_expression(value): size_map = { - 'kb': Constants.KILOBYTE, - 'mb': Constants.MEGABYTE, - 'gb': Constants.GIGABYTE, - 'kib': Constants.KILOBYTE, - 'mib': Constants.MEGABYTE, - 'gib': Constants.GIGABYTE, - 'b': 1 + "kb": Constants.KILOBYTE, + "mb": Constants.MEGABYTE, + "gb": Constants.GIGABYTE, + "kib": Constants.KILOBYTE, + "mib": Constants.MEGABYTE, + "gib": Constants.GIGABYTE, + "b": 1, } value = value.strip() - match = re.match(r'^\s*(\d+(\.\d+)?)\s*(b|kb|mb|gb|tb|kib|mib|gib|tib)?\s*$', value.lower()) + match = re.match(r"^\s*(\d+(\.\d+)?)\s*(b|kb|mb|gb|tb|kib|mib|gib|tib)?\s*$", value.lower()) if not match: raise ValueError(f"Invalid size expression: {value}") number, _, unit = match.groups() - unit = unit or 'b' # Default to bytes if no unit is specified + unit = unit or "b" # Default to bytes if no unit is specified return int(float(number) * size_map[unit]) - + for key, value in mem_spec.items(): if key not in cls._target_attributes: raise ValueError(f"Attribute key '{key}' is not valid.") else: # Convert value to bytes if necessary - if 'bytes' in key: + if "bytes" in key: value = parse_size_expression(str(value)) update_method = cls._target_attributes[key] update_method(value) - + @classmethod def initialize_mem_spec(cls, module_dir, mem_spec_file): - if not mem_spec_file: mem_spec_file = os.path.join(module_dir, "config/mem_spec.json") mem_spec_file = os.path.abspath(mem_spec_file) @@ -115,8 +117,8 @@ def initialize_mem_spec(cls, module_dir, mem_spec_file): f"Required Mem Spec file not found: {mem_spec_file}\n" "Please provide a valid path using the `mem_spec` option, " "or use a valid default file at: `/config/mem_spec.json`." - ) - + ) + cls.init_mem_spec_from_json(mem_spec_file) return mem_spec_file diff --git a/assembler_tools/hec-assembler-tools/assembler/stages/__init__.py b/assembler_tools/hec-assembler-tools/assembler/stages/__init__.py index 98e940f8..fb0ee7ac 100644 --- a/assembler_tools/hec-assembler-tools/assembler/stages/__init__.py +++ b/assembler_tools/hec-assembler-tools/assembler/stages/__init__.py @@ -1,6 +1,11 @@ -import networkx as nx +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import networkx as nx + from assembler.memory_model.variable import Variable + def buildVarAccessListFromTopoSort(dependency_graph: nx.DiGraph): """ Given the dependency directed acyclic graph of XInsts, builds the list of @@ -21,9 +26,9 @@ def buildVarAccessListFromTopoSort(dependency_graph: nx.DiGraph): topo_sort = list(nx.topological_sort(dependency_graph)) for idx, node in enumerate(topo_sort): - instr = dependency_graph.nodes[node]['instruction'] - vars = set(instr.sources + instr.dests) - for v in vars: + instr = dependency_graph.nodes[node]["instruction"] + vars_ = set(instr.sources + instr.dests) + for v in vars_: v.accessed_by_xinsts.append(Variable.AccessElement(idx, instr.id)) return topo_sort diff --git a/assembler_tools/hec-assembler-tools/assembler/stages/asm_scheduler.py b/assembler_tools/hec-assembler-tools/assembler/stages/asm_scheduler.py index de7f6aeb..d02832e3 100644 --- a/assembler_tools/hec-assembler-tools/assembler/stages/asm_scheduler.py +++ b/assembler_tools/hec-assembler-tools/assembler/stages/asm_scheduler.py @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import warnings from typing import NamedTuple import networkx as nx @@ -26,8 +29,8 @@ # -------------------- # FUTURE: -# - Analyze about adding instruction window to dependecy graph creation. -# - Analize about adding terms to priority that will prioritize P-ISA instructions over all others as tie-breaker +# - Analyze about adding instruction window to dependency graph creation. +# - Analyze about adding terms to priority that will prioritize P-ISA instructions over all others as tie-breaker # in simulation priority queue. # Maybe add a way to track preparation stage of instructions as part of the priority. # - Separate variable xinst usage by inputs and outputs to avoid xstoring vars where next usage is a write. @@ -35,6 +38,7 @@ auto_allocate = True + class XStoreAssign(xinst.XStore): """ Encapsulates a compound operation of an `xstore` instruction and a @@ -43,15 +47,18 @@ class XStoreAssign(xinst.XStore): This is used for variable eviction from the register file, when the register being flushed is needed for a new variable. """ - def __init__(self, - id: int, - src: list, - mem_model: MemoryModel, - var_target: Variable, - dest_spad_addr: int = -1, - throughput : int = None, - latency : int = None, - comment: str = ""): + + def __init__( + self, + id: int, + src: list, + mem_model: MemoryModel, + var_target: Variable, + dest_spad_addr: int = -1, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `XStoreAssign` object. @@ -73,7 +80,7 @@ def __init__(self, ValueError: If `var_target` is an invalid empty or dummy `Variable` object. """ if not var_target or isinstance(var_target, DummyVariable): - raise ValueError('`var_target`: Invalid empty or dummy `Variable` object.') + raise ValueError("`var_target`: Invalid empty or dummy `Variable` object.") super().__init__(id, src, mem_model, dest_spad_addr, throughput, latency, comment) self.__var_target = var_target @@ -98,6 +105,7 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: register.allocateVariable(self.__var_target) return retval + class BundleData(NamedTuple): """ Structure for a completed bundle of instructions. @@ -109,10 +117,12 @@ class BundleData(NamedTuple): It is used to track latency before scheduling the next bundle with ifetch more easily, and to attempt to avoid too many idle cycles. """ + xinsts: list latency: int latency_from_xstore: int + class XWriteCycleTrack(NamedTuple): """ Tracks the cycle where a write occurs to the register file by an XInstruction @@ -122,9 +132,11 @@ class XWriteCycleTrack(NamedTuple): cycle (CycleType): The cycle in which the write happens. banks (set): A set of indices of banks being written to in this cycle. """ + cycle: CycleType banks: set + class CurrentrShuffleTable(NamedTuple): """ Tracks the current rShuffle routing table. @@ -133,9 +145,11 @@ class CurrentrShuffleTable(NamedTuple): r_type (type): The type of rShuffle currently loaded. It can be one of {rShuffle, irShuffle, None}. bundle (int): The bundle where the specified r_type was set. """ + r_type: type bundle: int + class Simulation: """ Simulates the scheduling of instructions in a dependency graph. @@ -205,12 +219,14 @@ class Simulation: INSTRUCTION_WINDOW_SIZE = 100 MIN_INSTRUCTIONS_IN_TOPO_SORT = 10 - def __init__(self, - dependency_graph: nx.DiGraph, - max_bundle_size: int, # Max number of instructions in a bundle - mem_model: MemoryModel, - replacement_policy: str, - progress_verbose: bool): + def __init__( + self, + dependency_graph: nx.DiGraph, + max_bundle_size: int, # Max number of instructions in a bundle + mem_model: MemoryModel, + replacement_policy: str, + progress_verbose: bool, + ): """ Initializes the simulation of schedule. @@ -228,22 +244,22 @@ def __init__(self, self.minsts = [] self.cinsts = [] - self.xinsts = [] # List of bundles + self.xinsts = [] # List of bundles # Scheduling vars - self.current_cycle = CycleType(bundle = len(self.xinsts), cycle = 1) + self.current_cycle = CycleType(bundle=len(self.xinsts), cycle=1) self.full_topo_sort = buildVarAccessListFromTopoSort(dependency_graph) - self.topo_start_idx = 0 # Starting index of the instruction window in full topo_sort - self.topo_sort = [] # Current slice of topo sort being scheduled - self.b_topo_sort_changed = True # All changed-tracking flags start as true because scheduling has changed (brought into existence) - self.dependency_graph = nx.DiGraph(dependency_graph) # Make a copy of the incoming graph to avoid modifying input + self.topo_start_idx = 0 # Starting index of the instruction window in full topo_sort + self.topo_sort = [] # Current slice of topo sort being scheduled + self.b_topo_sort_changed = True # All changed-tracking flags start as true because scheduling has changed (brought into existence) + self.dependency_graph = nx.DiGraph(dependency_graph) # Make a copy of the incoming graph to avoid modifying input self.b_dependency_graph_changed = True # Contains instructions without parent dependencies: sorted list by priority: ready cycle # (never edit directly unless absolutely necessary; use priority_queue_remove/push instead) self.priority_queue = PriorityQueue() - self.xstore_pq = PriorityQueue() # Sorted list by priority: ready cycle - self.b_priority_queue_changed = True # Tracks when there are changes in the priority queue + self.xstore_pq = PriorityQueue() # Sorted list by priority: ready cycle + self.b_priority_queue_changed = True # Tracks when there are changes in the priority queue self.total_idle_cycles = 0 # Tracks instructions that are in priority queue or have been removed from graph to avoid encountering # if duplicated in the topo sort (instructions are only added to this when extracting them from the topo sort) @@ -257,7 +273,7 @@ def __init__(self, # Bundle vars self.__max_bundle_size = max_bundle_size - self.b_empty_bundle: bool = False # Tracks if last bundle was empty + self.b_empty_bundle: bool = False # Tracks if last bundle was empty # Tracks if last bundle was flushed with very few instructions self.num_short_bundles: int = 0 # Local dummy variable to be updated per bundle: used to indicate that a register in bank 0 is live @@ -266,7 +282,7 @@ def __init__(self, # Tracks instructions in current bundle getting constructed # (never add to this manually, use appendXInstToBundle() method instead) self.xinsts_bundle = [] - self.current_bundle_latency = 0 # Tracks current bundle latency + self.current_bundle_latency = 0 # Tracks current bundle latency self.pre_bundle_csync_minstr = (0, None) self.post_bundle_cinsts = [] # Initial value for live vars (these will always be live) @@ -290,8 +306,8 @@ def __init__(self, for meta_twid_var_name in meta_twid_vars_segment: self.live_vars_0[meta_twid_var_name] = None # Tracks live in variable names for current bundle (variables to be used by current bundle) - self.live_vars: dict = self.live_vars_0 # dict(var_name:str, pending_uses: set(XInstruction)) - self.live_outs = set() # Contains variables being stored in this bundle to avoid reusing them + self.live_vars: dict = self.live_vars_0 # dict(var_name:str, pending_uses: set(XInstruction)) + self.live_outs = set() # Contains variables being stored in this bundle to avoid reusing them # Ordered list of XWriteCycleTrack to track the cycle in which rshuffles are writing. # This is used to avoid scheduling instructions that write to these banks on the same cycle as # rshuffles. @@ -304,9 +320,9 @@ def __init__(self, # Starting SPAD address for keygen seed metadata: # this will be overwritten by new keygen seed metadata whenever a swap is needed. self.metadata_spad_addr_start_kgseed = -1 - self.bundle_current_kgseed = -1 # Tracks current index of keygen seed metadata loaded - self.bundle_used_kg_seed = -1 # Tracks the last bundle that used current keygen seed - self.last_keygen_index = -1 # Tracks the last key material generation index with current seed + self.bundle_current_kgseed = -1 # Tracks current index of keygen seed metadata loaded + self.bundle_used_kg_seed = -1 # Tracks the last bundle that used current keygen seed + self.last_keygen_index = -1 # Tracks the last key material generation index with current seed # Book-keeping to track residual metadata @@ -316,8 +332,8 @@ def __init__(self, # Metadata for ones segment `i` supports computation of arithmetic operations # with rns in range `[i * 64, (i + 1) * 64)` # i == -1 means uninitialized - self.bundle_current_ones_segment = -1 # Tracks current ones segment metadata loaded - self.bundle_needed_ones_segment = -1 # Signals the ones segment metadata needed + self.bundle_current_ones_segment = -1 # Tracks current ones segment metadata loaded + self.bundle_needed_ones_segment = -1 # Signals the ones segment metadata needed # Starting SPAD address for twid metadata: # this will be overwritten by new twid metadata whenever a swap is needed. @@ -325,23 +341,23 @@ def __init__(self, # Metadata for twiddles segment `i` supports computation of twiddle factors # with rns in range `[i * 64, (i + 1) * 64)` # i == -1 means uninitialized - self.bundle_current_twid_segment = -1 # Tracks current twid segment metadata loaded - self.bundle_needed_twid_segment = -1 # Signals the twid segment metadata needed + self.bundle_current_twid_segment = -1 # Tracks current twid segment metadata loaded + self.bundle_needed_twid_segment = -1 # Signals the twid segment metadata needed # Book-keeping to track that rShuffle and irShuffle don't mix in the same bundle # Tracks the current type of rshuffle supported (rShuffle, irShuffle, None), # and what bundle was it last set - self.bundle_current_rshuffle_type = (None, 0) # (type: {rShuffle, irShuffle, None}, bundle: int) - self.bundle_needed_rshuffle_type = None # Type of last rshuffle {rShuffle, irShuffle, None} scheduled in current bundle + self.bundle_current_rshuffle_type = (None, 0) # (type: {rShuffle, irShuffle, None}, bundle: int) + self.bundle_needed_rshuffle_type = None # Type of last rshuffle {rShuffle, irShuffle, None} scheduled in current bundle # xinstfetch vars self.xinstfetch_hbm_addr = 0 self.xinstfetch_xq_addr = 0 self.__max_bundles_per_xinstfetch = Constants.WORD_SIZE / (self.max_bundle_size * Constants.XINSTRUCTION_SIZE_BYTES) - self.xinstfetch_cinsts_buffer = [] # Used to group all xinstfetch per capacity of XInst queue - self.xinstfetch_location_idx_in_cinsts = 0 # Location in cinst where to insert xinstfetch's when a group is completed + self.xinstfetch_cinsts_buffer = [] # Used to group all xinstfetch per capacity of XInst queue + self.xinstfetch_location_idx_in_cinsts = 0 # Location in cinst where to insert xinstfetch's when a group is completed # Progress report vars @@ -446,9 +462,9 @@ def addXInstrBackIntoPipeline(self, xinstr: object): ValueError: If `xinstr` is a `Move` instruction or is already scheduled. """ if isinstance(xinstr, xinst.Move): - raise ValueError('`xinstr` is a `Move` instruction. `Move` instructions cannot be inserted into the pipeline.') + raise ValueError("`xinstr` is a `Move` instruction. `Move` instructions cannot be inserted into the pipeline.") if xinstr.is_scheduled: - raise ValueError('`xinstr` already scheduled.') + raise ValueError("`xinstr` already scheduled.") assert xinstr.id in self.dependency_graph if self.dependency_graph.in_degree(xinstr.id) > 0: if xinstr in self.priority_queue: @@ -464,9 +480,7 @@ def addXInstrBackIntoPipeline(self, xinstr: object): # Pending xstore variables must be kept alive to avoid attempts to flush them again. if not isinstance(xinstr, xinst.XStore): for v in xinstr.sources + xinstr.dests: - if isinstance(v, Variable) \ - and v.name in self.live_vars \ - and xinstr in self.live_vars[v.name]: + if isinstance(v, Variable) and v.name in self.live_vars and xinstr in self.live_vars[v.name]: self.addUsedVar(v.name, xinstr) def addXInstrToTopoSort(self, xinstr_id: tuple): @@ -486,24 +500,23 @@ def addXInstrToTopoSort(self, xinstr_id: tuple): raise ValueError("`xinstr_id`: cannot be in priority queue.") # Find position in topo sort target_idx = len(self.topo_sort) - match_idxs = [] # Locations where the same xinstr was found in topo sort + match_idxs = [] # Locations where the same xinstr was found in topo sort for idx, topo_instr_id in enumerate(self.topo_sort): if topo_instr_id == xinstr_id: match_idxs.append(idx) - elif topo_instr_id in self.dependency_graph \ - and self.dependency_graph.in_degree(topo_instr_id) >= self.dependency_graph.in_degree(xinstr_id): + elif topo_instr_id in self.dependency_graph and self.dependency_graph.in_degree( + topo_instr_id + ) >= self.dependency_graph.in_degree(xinstr_id): target_idx = idx break - self.topo_sort = self.topo_sort[:target_idx] + [ xinstr_id ] + self.topo_sort[target_idx:] + self.topo_sort = self.topo_sort[:target_idx] + [xinstr_id] + self.topo_sort[target_idx:] # Remove the previous instances found of xinstr from topo sort as it has incorrect order now for idx, match_idx in enumerate(match_idxs): del self.topo_sort[match_idx - idx] self.b_topo_sort_changed = True self.set_extracted_xinstrs.discard(xinstr_id) - def addDependency(self, - new_dependency_instr, - original_instr): + def addDependency(self, new_dependency_instr, original_instr): """ Adds `new_dependency_instr` to the instruction listing as a new dependency of `original_instr`. @@ -522,11 +535,14 @@ def addDependency(self, self.b_dependency_graph_changed = True if original_instr: assert original_instr.id in self.dependency_graph - self.dependency_graph.add_edge(new_dependency_instr.id, original_instr.id) # Link as dependency to input instruction + self.dependency_graph.add_edge(new_dependency_instr.id, original_instr.id) # Link as dependency to input instruction self.addXInstrBackIntoPipeline(original_instr) - all_vars = set(v for v in new_dependency_instr.sources + new_dependency_instr.dests \ - if isinstance(v, Variable) and not isinstance(v, DummyVariable)) + all_vars = set( + v + for v in new_dependency_instr.sources + new_dependency_instr.dests + if isinstance(v, Variable) and not isinstance(v, DummyVariable) + ) for v in all_vars: # Add dependencies to all other instructions deps_added = 0 @@ -538,22 +554,22 @@ def addDependency(self, break if next_instr_id != new_dependency_instr.id: assert next_instr_id in self.dependency_graph - self.dependency_graph.add_edge(new_dependency_instr.id, next_instr_id) # Link as dependency to input instruction + self.dependency_graph.add_edge(new_dependency_instr.id, next_instr_id) # Link as dependency to input instruction if self.dependency_graph.in_degree(next_instr_id) == 1: # We need to add next instruction back to topo sort because it will have a dependency - next_instr = self.dependency_graph.nodes[next_instr_id]['instruction'] + next_instr = self.dependency_graph.nodes[next_instr_id]["instruction"] self.addXInstrBackIntoPipeline(next_instr) deps_added += 1 - self.addLiveVar(v.name, new_dependency_instr) # Source and dests variables are now a live-in for new_dependency_instr + self.addLiveVar(v.name, new_dependency_instr) # Source and dests variables are now a live-in for new_dependency_instr def addExtraXStoreDependencies(self, original_instr, xstore_instr, new_var): """ - Adds instructions using `new_var` as new dependencies of `xstore_instr`. + Adds instructions using `new_var` as new dependencies of `xstore_instr`. `new_var` is awaiting `xstore_instr` to get a register free. Dependency graph and topo sort are updated as appropriate. `xstore_instr` is NOT added to the topo_sort. Parameters: - new_var: Variable waiting for eviction to occurr. + new_var: Variable waiting for eviction to occur. xstore_instr: The instruction in charge of eviction. original_instr: The original instruction awaiting `new_var` to be ready. """ @@ -566,16 +582,14 @@ def addExtraXStoreDependencies(self, original_instr, xstore_instr, new_var): break if next_instr_id != xstore_instr.id and next_instr_id != original_instr.id: assert next_instr_id in self.dependency_graph - self.dependency_graph.add_edge(xstore_instr.id, next_instr_id) # Link as dependency to input instruction + self.dependency_graph.add_edge(xstore_instr.id, next_instr_id) # Link as dependency to input instruction if self.dependency_graph.in_degree(next_instr_id) == 1: # We need to add next instruction back to topo sort because it will have a dependency - next_instr = self.dependency_graph.nodes[next_instr_id]['instruction'] + next_instr = self.dependency_graph.nodes[next_instr_id]["instruction"] self.addXInstrBackIntoPipeline(next_instr) deps_added += 1 - def addLiveVar(self, - var_name: str, - instr): + def addLiveVar(self, var_name: str, instr): """ Adds a live variable to the current bundle. @@ -587,9 +601,7 @@ def addLiveVar(self, self.live_vars[var_name] = set() self.live_vars[var_name].add(instr) - def addUsedVar(self, - var_name: str, - instr): + def addUsedVar(self, var_name: str, instr): """ Removes a used variable from the current bundle. @@ -613,8 +625,8 @@ def appendXInstToBundle(self, xinstr): AssertionError: If the bundle is already full. """ if not xinstr: - raise ValueError('`xinstr` cannot be `None`.') - assert len(self.xinsts_bundle) < self.max_bundle_size, 'Cannot append XInstruction to full bundle.' + raise ValueError("`xinstr` cannot be `None`.") + assert len(self.xinsts_bundle) < self.max_bundle_size, "Cannot append XInstruction to full bundle." self.xinsts_bundle.append(xinstr) if self.current_bundle_latency < self.current_cycle.cycle + xinstr.latency: self.current_bundle_latency = self.current_cycle.cycle + xinstr.latency @@ -624,13 +636,13 @@ def cleanupPendingWriteCycles(self): Cleans up pending write cycles that have passed. """ # Remove any write cycles that passed - front_write_cycle_idx = -1 # len(self.pending_write_cycles) + front_write_cycle_idx = -1 # len(self.pending_write_cycles) for idx, write_cycle in enumerate(self.pending_write_cycles): - if write_cycle.cycle < self.current_cycle: # Not <= because no instruction writes on its decoding (first) cycle + if write_cycle.cycle < self.current_cycle: # Not <= because no instruction writes on its decoding (first) cycle # Found first write cycle in the list that occurs after current cycle front_write_cycle_idx = idx break - self.pending_write_cycles = self.pending_write_cycles[front_write_cycle_idx + 1:] + self.pending_write_cycles = self.pending_write_cycles[front_write_cycle_idx + 1 :] def canSchedulerShuffle(self, xinstr) -> CycleType: """ @@ -650,36 +662,39 @@ def canSchedulerShuffle(self, xinstr) -> CycleType: retval = instr_ready_cycle - if xinstr.cycle_ready.bundle <= self.current_cycle.bundle \ - and self.last_xrshuffle is not None \ - and isinstance(xinstr, (xinst.rShuffle, xinst.irShuffle)): + if ( + xinstr.cycle_ready.bundle <= self.current_cycle.bundle + and self.last_xrshuffle is not None + and isinstance(xinstr, (xinst.rShuffle, xinst.irShuffle)) + ): # Attempting to schedule an rshuffle after a previous one already got # scheduled in the same bundle last_rshuffle_cycle = self.last_xrshuffle.schedule_timing.cycle - assert self.current_cycle.bundle >= last_rshuffle_cycle.bundle, \ - "Last scheduled rshuffle cannot be in the future." + assert self.current_cycle.bundle >= last_rshuffle_cycle.bundle, "Last scheduled rshuffle cannot be in the future." if self.current_cycle.bundle == last_rshuffle_cycle.bundle: # Last scheduled rshuffle was in this bundle cycle_delta = abs(instr_ready_cycle.cycle - last_rshuffle_cycle.cycle) - if (isinstance(xinstr, xinst.rShuffle) and isinstance(self.last_xrshuffle, xinst.rShuffle)) \ - or (isinstance(xinstr, xinst.irShuffle) and isinstance(self.last_xrshuffle, xinst.irShuffle)): + if (isinstance(xinstr, xinst.rShuffle) and isinstance(self.last_xrshuffle, xinst.rShuffle)) or ( + isinstance(xinstr, xinst.irShuffle) and isinstance(self.last_xrshuffle, xinst.irShuffle) + ): # New rshuffle and previous are of the same kind if cycle_delta < self.last_xrshuffle.SpecialLatencyMax: # Trying to schedule within max special latency: attempt to slot r = cycle_delta % self.last_xrshuffle.SpecialLatencyIncrement - cycle_delta += ((0 if r == 0 else self.last_xrshuffle.SpecialLatencyIncrement) - r) + cycle_delta += (0 if r == 0 else self.last_xrshuffle.SpecialLatencyIncrement) - r if cycle_delta >= self.last_xrshuffle.SpecialLatencyMax: # Slot found is greater than max latency, so, we can schedule at max latency cycle_delta = self.last_xrshuffle.SpecialLatencyMax - retval = CycleType(bundle = self.current_cycle.bundle, - cycle = last_rshuffle_cycle.cycle + cycle_delta) + retval = CycleType(bundle=self.current_cycle.bundle, cycle=last_rshuffle_cycle.cycle + cycle_delta) else: # New rshuffle and previous are inverse: only schedule outside # of the full latency - retval = CycleType(bundle = self.current_cycle.bundle, - cycle = max(self.current_cycle.cycle, last_rshuffle_cycle.cycle + self.last_xrshuffle.latency)) + retval = CycleType( + bundle=self.current_cycle.bundle, + cycle=max(self.current_cycle.cycle, last_rshuffle_cycle.cycle + self.last_xrshuffle.latency), + ) if retval < instr_ready_cycle: retval = instr_ready_cycle @@ -706,13 +721,13 @@ def canSchedulerShuffleType(self, xinstr) -> bool: # Can schedule if not on this bundle, or is instance of previously scheduled # rshuffles in this bundle - retval = xinstr.cycle_ready.bundle > self.current_cycle.bundle \ - or self.bundle_needed_rshuffle_type is None \ - or isinstance(xinstr, self.bundle_needed_rshuffle_type) + retval = ( + xinstr.cycle_ready.bundle > self.current_cycle.bundle + or self.bundle_needed_rshuffle_type is None + or isinstance(xinstr, self.bundle_needed_rshuffle_type) + ) - if self.bundle_current_rshuffle_type[0] is not None \ - and not isinstance(xinstr, self.bundle_current_rshuffle_type[0]) \ - and retval: + if self.bundle_current_rshuffle_type[0] is not None and not isinstance(xinstr, self.bundle_current_rshuffle_type[0]) and retval: # Routing table change will be needed if we want to schedule specified xrshuffle # Search priority queue to see if there are any rshuffles matching @@ -721,15 +736,22 @@ def canSchedulerShuffleType(self, xinstr) -> bool: # messing with its contents, but it is needed for the single type # of rshuffle per bundle restriction. - retval = next((False for _, inv_rshuffle in self.priority_queue \ - if isinstance(inv_rshuffle, self.bundle_current_rshuffle_type[0]) \ - and inv_rshuffle.cycle_ready.bundle <= self.current_cycle.bundle), - retval) - - assert not retval \ - or xinstr.cycle_ready.bundle > self.current_cycle.bundle \ - or self.bundle_needed_rshuffle_type is None or isinstance(xinstr, self.bundle_needed_rshuffle_type), \ - f'Found rshuffle of type {type(xinstr)}, but type {self.bundle_needed_rshuffle_type} already scheduled in bundle.' + retval = next( + ( + False + for _, inv_rshuffle in self.priority_queue + if isinstance(inv_rshuffle, self.bundle_current_rshuffle_type[0]) + and inv_rshuffle.cycle_ready.bundle <= self.current_cycle.bundle + ), + retval, + ) + + assert ( + not retval + or xinstr.cycle_ready.bundle > self.current_cycle.bundle + or self.bundle_needed_rshuffle_type is None + or isinstance(xinstr, self.bundle_needed_rshuffle_type) + ), f"Found rshuffle of type {type(xinstr)}, but type {self.bundle_needed_rshuffle_type} already scheduled in bundle." return retval @@ -750,22 +772,24 @@ def canScheduleArithmeticXInstr(self, xinstr: xinst.XInstruction) -> bool: if xinstr.res is not None: # Instruction has residual - assert self.bundle_current_ones_segment == self.bundle_current_twid_segment, \ - 'Current Ones and Twiddle metadata segments are not synchronized.' - assert self.bundle_needed_ones_segment == self.bundle_needed_twid_segment, \ - 'Needed Ones and Twiddle metadata segments are not synchronized.' + assert ( + self.bundle_current_ones_segment == self.bundle_current_twid_segment + ), "Current Ones and Twiddle metadata segments are not synchronized." + assert ( + self.bundle_needed_ones_segment == self.bundle_needed_twid_segment + ), "Needed Ones and Twiddle metadata segments are not synchronized." xinstr_required_segment = xinstr.res // constants.MemoryModel.MAX_RESIDUALS # Can schedule if not on this bundle, or required residual segment # is same as previously scheduled in this bundle - retval = xinstr.cycle_ready.bundle > self.current_cycle.bundle \ - or self.bundle_needed_ones_segment == -1 \ - or self.bundle_needed_ones_segment == xinstr_required_segment + retval = ( + xinstr.cycle_ready.bundle > self.current_cycle.bundle + or self.bundle_needed_ones_segment == -1 + or self.bundle_needed_ones_segment == xinstr_required_segment + ) # Check if a metadata change is needed for specified XInst - if self.bundle_current_ones_segment != -1 \ - and self.bundle_current_ones_segment != xinstr_required_segment \ - and retval: + if self.bundle_current_ones_segment != -1 and self.bundle_current_ones_segment != xinstr_required_segment and retval: # Metadata change will be needed if we want to schedule specified XInst # Search priority queue to see if there are any arithmetic instructions matching @@ -773,16 +797,23 @@ def canScheduleArithmeticXInstr(self, xinstr: xinst.XInstruction) -> bool: # NOTE: Traversing a priority queue is not good practice because we should # not be messing with its contents, but it is needed for the single # metadata segment per bundle restriction. - retval = next((False for _, other_xinstr in self.priority_queue \ - if other_xinstr.res is not None \ - and other_xinstr.res // constants.MemoryModel.MAX_RESIDUALS == self.bundle_current_ones_segment \ - and other_xinstr.cycle_ready.bundle <= self.current_cycle.bundle), - retval) - - assert not retval \ - or xinstr.cycle_ready.bundle > self.current_cycle.bundle \ - or self.bundle_needed_ones_segment == -1 or xinstr_required_segment == self.bundle_needed_ones_segment, \ - f'Found XInst of residual segment {xinstr_required_segment}, but segment {self.bundle_needed_ones_segment} already scheduled in bundle.' + retval = next( + ( + False + for _, other_xinstr in self.priority_queue + if other_xinstr.res is not None + and other_xinstr.res // constants.MemoryModel.MAX_RESIDUALS == self.bundle_current_ones_segment + and other_xinstr.cycle_ready.bundle <= self.current_cycle.bundle + ), + retval, + ) + + assert ( + not retval + or xinstr.cycle_ready.bundle > self.current_cycle.bundle + or self.bundle_needed_ones_segment == -1 + or xinstr_required_segment == self.bundle_needed_ones_segment + ), f"Found XInst of residual segment {xinstr_required_segment}, but segment {self.bundle_needed_ones_segment} already scheduled in bundle." return retval @@ -799,8 +830,7 @@ def findNextInstructionToSchedule(self) -> object: retval = None if self.priority_queue: - while retval is None \ - and self.priority_queue.peek()[1].cycle_ready.bundle <= self.current_cycle.bundle: + while retval is None and self.priority_queue.peek()[1].cycle_ready.bundle <= self.current_cycle.bundle: # Check if there is any immediate instruction we can schedule # in this cycle immediate_instr = self.priority_queue.find(self.current_cycle) @@ -810,17 +840,19 @@ def findNextInstructionToSchedule(self) -> object: # Check for write cycle conflicts if hasBankWriteConflict(immediate_instr, self): # Write cycle conflict found, so, update found instruction cycle ready - new_cycle_ready = CycleType(bundle = self.current_cycle.bundle, - cycle = max(immediate_instr.cycle_ready.cycle, self.current_cycle.cycle) + 1) + new_cycle_ready = CycleType( + bundle=self.current_cycle.bundle, cycle=max(immediate_instr.cycle_ready.cycle, self.current_cycle.cycle) + 1 + ) immediate_instr.cycle_ready = new_cycle_ready else: new_cycle_ready = self.canSchedulerShuffle(immediate_instr) if immediate_instr.cycle_ready != new_cycle_ready: # Only xrshuffles should have a changed cycle ready if slotted # and got picked outside of a slot to schedule. - assert immediate_instr.cycle_ready < new_cycle_ready, \ - "Computed new cycle ready cannot be earlier than instruction's cycle ready." - immediate_instr.cycle_ready = new_cycle_ready # Update instruction's cycle ready + assert ( + immediate_instr.cycle_ready < new_cycle_ready + ), "Computed new cycle ready cannot be earlier than instruction's cycle ready." + immediate_instr.cycle_ready = new_cycle_ready # Update instruction's cycle ready else: # Found immediate instruction self.priority_queue_remove(immediate_instr) @@ -837,54 +869,53 @@ def findNextInstructionToSchedule(self) -> object: priority, p_inst = self.priority_queue.peek() if p_inst.cycle_ready.bundle < self.current_cycle.bundle: # Correct instruction ready cycle to this bundle - p_inst.cycle_ready = CycleType(bundle = self.current_cycle.bundle, - cycle = 0) + p_inst.cycle_ready = CycleType(bundle=self.current_cycle.bundle, cycle=0) # Check found instruction has correct priority if p_inst.cycle_ready == priority: # Check for write cycle conflicts if hasBankWriteConflict(p_inst, self): # Write cycle conflict found, so, update found instruction cycle ready - new_cycle_ready = CycleType(bundle = self.current_cycle.bundle, - cycle = max(p_inst.cycle_ready.cycle, self.current_cycle.cycle) + 1) + new_cycle_ready = CycleType( + bundle=self.current_cycle.bundle, cycle=max(p_inst.cycle_ready.cycle, self.current_cycle.cycle) + 1 + ) p_inst.cycle_ready = new_cycle_ready else: new_cycle_ready = self.canSchedulerShuffle(p_inst) if p_inst.cycle_ready != new_cycle_ready: # Only xrshuffles should have a changed cycle ready if slotted # and got picked outside of a slot to schedule. - assert p_inst.cycle_ready < new_cycle_ready, \ - "Computed new cycle ready cannot be earlier than instruction's cycle ready." - p_inst.cycle_ready = new_cycle_ready # Update instruction's cycle ready + assert ( + p_inst.cycle_ready < new_cycle_ready + ), "Computed new cycle ready cannot be earlier than instruction's cycle ready." + p_inst.cycle_ready = new_cycle_ready # Update instruction's cycle ready else: # Found instruction to schedule at the head of queue priority, retval = self.priority_queue.pop() - assert(retval.id == p_inst.id and priority == retval.cycle_ready) + assert retval.id == p_inst.id and priority == retval.cycle_ready if not retval: # Found instruction that has incorrect priority, so, correct it # (this may change its order in the priority queue) self.priority_queue_push(p_inst) - assert(retval) + assert retval if not self.canSchedulerShuffleType(retval): # Found rshuffle that requires routing table change, but other # rshuffles with current routing table still available: # Move rshuffle to next bundle - retval.cycle_ready = CycleType(bundle = self.current_cycle.bundle + 1, - cycle = 0) + retval.cycle_ready = CycleType(bundle=self.current_cycle.bundle + 1, cycle=0) # Put back in priority queue self.priority_queue_push(retval) - retval = None # Continue looping to find another suitable instruction + retval = None # Continue looping to find another suitable instruction if retval and not self.canScheduleArithmeticXInstr(retval): # Found XInst that requires metadata change, but other # XInst with current metadata still available: # Move XInst to next bundle - retval.cycle_ready = CycleType(bundle = self.current_cycle.bundle + 1, - cycle = 0) + retval.cycle_ready = CycleType(bundle=self.current_cycle.bundle + 1, cycle=0) # Put back in priority queue self.priority_queue_push(retval) - retval = None # Continue looping to find another suitable instruction + retval = None # Continue looping to find another suitable instruction return retval @@ -897,9 +928,9 @@ def flushBundle(self): """ if self.b_empty_bundle and len(self.xinsts_bundle) <= 0: # Previous bundle was short - raise RuntimeError('Cannot flush an empty bundle.') + raise RuntimeError("Cannot flush an empty bundle.") - self.b_empty_bundle = len(self.xinsts_bundle) <= 0 # Flag whether this is an empty bundle + self.b_empty_bundle = len(self.xinsts_bundle) <= 0 # Flag whether this is an empty bundle # Flag if this is a short bundle if len(self.xinsts_bundle) <= self.BUNDLE_INSTRUCTION_MIN_LIMIT: self.num_short_bundles += 1 @@ -920,17 +951,22 @@ def flushBundle(self): self.appendXInstToBundle(instr) # Find bundle latency measurements before padding bundle - assert(not isinstance(self.xinsts_bundle[-1], xinst.Nop)) # Last instruction in bundle is not a nop + assert not isinstance(self.xinsts_bundle[-1], xinst.Nop) # Last instruction in bundle is not a nop bundle_latency = self.current_bundle_latency # Find last xstore in bundle - bundle_last_xstore = next((self.xinsts_bundle[idx] for idx in reversed(range(len(self.xinsts_bundle))) \ - if isinstance(self.xinsts_bundle[idx], xinst.XStore)), - None) + bundle_last_xstore = next( + ( + self.xinsts_bundle[idx] + for idx in reversed(range(len(self.xinsts_bundle))) + if isinstance(self.xinsts_bundle[idx], xinst.XStore) + ), + None, + ) # Latency from last xstore is the total bundle latency minus the cycle where the xstore was scheduled: # Measured from the cycle where last xstore was scheduled to the total latency - bundle_latency_from_last_xstore = (bundle_latency - bundle_last_xstore.schedule_timing.cycle.cycle) \ - if bundle_last_xstore \ - else bundle_latency + bundle_latency_from_last_xstore = ( + (bundle_latency - bundle_last_xstore.schedule_timing.cycle.cycle) if bundle_last_xstore else bundle_latency + ) if bundle_latency_from_last_xstore < 0: bundle_latency_from_last_xstore = 0 @@ -939,40 +975,42 @@ def flushBundle(self): for _ in range(self.max_bundle_size - len(self.xinsts_bundle)): # Pad incomplete bundle with nops: # Incomplete bundles are finished by an xexit, but need to be padded to max_bundle_size - b_scheduled = scheduleXNOP(instr, - 1, # Idle cycles - self, - force_nop=True) # We want nop to be added regardless of last in bundle - assert(b_scheduled) + b_scheduled = scheduleXNOP( + instr, + 1, # Idle cycles + self, + force_nop=True, + ) # We want nop to be added regardless of last in bundle + assert b_scheduled - assert(len(self.xinsts_bundle) == self.max_bundle_size) + assert len(self.xinsts_bundle) == self.max_bundle_size # See if we need to sync to MInstQ before fetching bundle if self.pre_bundle_csync_minstr[1]: minstr = self.pre_bundle_csync_minstr[1] - assert(minstr.is_scheduled) + assert minstr.is_scheduled csyncm = cinst.CSyncm(minstr.id[0], minstr) csyncm.schedule(self.current_cycle, len(self.cinsts) + 1) self.cinsts.append(csyncm) - self.pre_bundle_csync_minstr = (0, None) # Clear sync because we may not need in next bundle + self.pre_bundle_csync_minstr = (0, None) # Clear sync because we may not need in next bundle # Schedule the bundle fetch - ifetch = cinst.IFetch(self.xinsts_bundle[0].id[1], # ID of first instruction in bundle, just for book-keeping - self.current_cycle.bundle) + ifetch = cinst.IFetch( + self.xinsts_bundle[0].id[1], # ID of first instruction in bundle, just for book-keeping + self.current_cycle.bundle, + ) # See if we need idle CInstQ cycles from previous bundle before ifetch this bundle if len(self.xinsts) > 0: # Find latency for the CInstQ since last cstore (or ifetch if not cstore) idx = len(self.cinsts) - 1 cq_throughput = 0 - while idx >= 0 \ - and not isinstance(self.cinsts[idx], (cinst.IFetch, cinst.CStore)): + while idx >= 0 and not isinstance(self.cinsts[idx], (cinst.IFetch, cinst.CStore)): cq_throughput += self.cinsts[idx].throughput idx -= 1 # Added ifetch latency to avoid timing errors when bundles are short or empty - idle_c_cycles = self.xinsts[-1].latency_from_xstore - cq_throughput \ - + ifetch.latency + idle_c_cycles = self.xinsts[-1].latency_from_xstore - cq_throughput + ifetch.latency if idle_c_cycles > 0: cnop = cinst.CNop(self.current_cycle.bundle, idle_c_cycles) cnop.schedule(self.current_cycle, len(self.cinsts) + 1) @@ -981,24 +1019,21 @@ def flushBundle(self): # See if we need to load a new rshuffle routing table # (not counted in the nops before next bundle because we don't want to # switch routing tables in mid rshuffle if it is still in flight) - if self.bundle_needed_rshuffle_type is not None \ - and self.bundle_current_rshuffle_type[0] != self.bundle_needed_rshuffle_type: + if self.bundle_needed_rshuffle_type is not None and self.bundle_current_rshuffle_type[0] != self.bundle_needed_rshuffle_type: self.loadrShuffleRoutingTable(self.bundle_needed_rshuffle_type.RSHUFFLE_DATA_TYPE) self.bundle_current_rshuffle_type = (self.bundle_needed_rshuffle_type, self.current_cycle.bundle) # See if we need to load new twid metadata # (not counted in the nops before next bundle because we don't want to # switch twid metadata in mid bundle if it is still in flight) - if self.bundle_needed_twid_segment >= 0 \ - and self.bundle_current_twid_segment != self.bundle_needed_twid_segment: + if self.bundle_needed_twid_segment >= 0 and self.bundle_current_twid_segment != self.bundle_needed_twid_segment: self.loadTwiddleMetadata(self.metadata_spad_addr_start_twid, self.bundle_needed_twid_segment) self.bundle_current_twid_segment = self.bundle_needed_twid_segment # See if we need to load new ones metadata # (not counted in the nops before next bundle because we don't want to # switch ones metadata in mid bundle if it is still in flight) - if self.bundle_needed_ones_segment >= 0 \ - and self.bundle_current_ones_segment != self.bundle_needed_ones_segment: + if self.bundle_needed_ones_segment >= 0 and self.bundle_current_ones_segment != self.bundle_needed_ones_segment: self.loadBOnesMetadata(self.metadata_spad_addr_start_ones, self.bundle_needed_ones_segment) self.bundle_current_ones_segment = self.bundle_needed_ones_segment @@ -1006,9 +1041,9 @@ def flushBundle(self): self.cinsts.append(ifetch) # Add bundle to list of bundles - self.xinsts.append(BundleData(xinsts=self.xinsts_bundle, - latency=bundle_latency, - latency_from_xstore=bundle_latency_from_last_xstore)) + self.xinsts.append( + BundleData(xinsts=self.xinsts_bundle, latency=bundle_latency, latency_from_xstore=bundle_latency_from_last_xstore) + ) # Schedule all the pending CInsts for idx, cstore_instr in enumerate(self.post_bundle_cinsts): @@ -1016,13 +1051,11 @@ def flushBundle(self): cstore_instr.schedule(self.current_cycle, len(self.cinsts) + idx + 1) # Check if this is an output variable which is done - if variable.name in self.mem_model.output_variables \ - and not variable.accessed_by_xinsts: + if variable.name in self.mem_model.output_variables and not variable.accessed_by_xinsts: # Variable is output and it is not used anymore # Sync to last CInst access to avoid storing before access completes - assert(self.mem_model.spad.getAccessTracking(dst_spad_addr).last_cstore[1] == cstore_instr) - msyncc = minst.MSyncc(cstore_instr.id[0], - cstore_instr) + assert self.mem_model.spad.getAccessTracking(dst_spad_addr).last_cstore[1] == cstore_instr + msyncc = minst.MSyncc(cstore_instr.id[0], cstore_instr) msyncc.schedule(self.current_cycle, len(self.minsts) + 1) self.minsts.append(msyncc) dest_hbm_addr = variable.hbm_address @@ -1032,11 +1065,9 @@ def flushBundle(self): dest_hbm_addr = self.mem_model.hbm.findAvailableAddress(self.mem_model.output_variables) if dest_hbm_addr < 0: raise RuntimeError("Out of HBM space.") - mstore = minst.MStore(cstore_instr.id[0], - [ variable ], - self.mem_model, - dest_hbm_addr, - comment=(' id: {} - flushing').format(cstore_instr.id)) + mstore = minst.MStore( + cstore_instr.id[0], [variable], self.mem_model, dest_hbm_addr, comment=(" id: {} - flushing").format(cstore_instr.id) + ) mstore.schedule(self.current_cycle, len(self.minsts) + 1) self.minsts.append(mstore) @@ -1045,14 +1076,14 @@ def flushBundle(self): # Clean up for next bundle self.current_bundle_latency = 0 - self.xinsts_bundle = [] - self.post_bundle_cinsts = [] - self.pending_write_cycles = [] - self.live_outs = set() + self.xinsts_bundle = [] + self.post_bundle_cinsts = [] + self.pending_write_cycles = [] + self.live_outs = set() self.bundle_needed_rshuffle_type = None - self.bundle_needed_ones_segment = -1 - self.bundle_needed_twid_segment = -1 + self.bundle_needed_ones_segment = -1 + self.bundle_needed_twid_segment = -1 # Reset all global cycle trackings for xinstr_type in xinst.GLOBAL_CYCLE_TRACKING_INSTRUCTIONS: @@ -1063,20 +1094,17 @@ def flushBundle(self): for idx in range(len(self.mem_model.register_banks)): bank = self.mem_model.register_banks[idx] for reg in bank: - if isinstance(reg.contained_variable, DummyVariable) \ - and reg.contained_variable.tag < self.current_cycle.bundle: + if isinstance(reg.contained_variable, DummyVariable) and reg.contained_variable.tag < self.current_cycle.bundle: # Register was used more than a bundle ago and can be re-used reg.allocateVariable(None) self.b_dependency_graph_changed = True # Next bundle starts - assert(len(self.xinsts) == self.current_cycle.bundle + 1) - self.current_cycle = CycleType(bundle = len(self.xinsts), cycle = 1) - self.bundle_dummy_var = DummyVariable(self.current_cycle.bundle) # Dummy variable for new bundle + assert len(self.xinsts) == self.current_cycle.bundle + 1 + self.current_cycle = CycleType(bundle=len(self.xinsts), cycle=1) + self.bundle_dummy_var = DummyVariable(self.current_cycle.bundle) # Dummy variable for new bundle - def flushOutputVariableFromRegister(self, - variable, - xinstr = None) -> bool: + def flushOutputVariableFromRegister(self, variable, xinstr=None) -> bool: """ Flushes an output variable from the register. @@ -1095,7 +1123,7 @@ def flushOutputVariableFromRegister(self, if not xinstr: xinstr = self.last_xinstr if not xinstr: - raise ValueError('`xinstr`: cannot be None when there are no other XInstructions available in the listing.') + raise ValueError("`xinstr`: cannot be None when there are no other XInstructions available in the listing.") if variable.register_dirty: # Variable is in a dirty register: # Flush the register @@ -1105,25 +1133,16 @@ def flushOutputVariableFromRegister(self, if dest_spad_addr < 0: dest_spad_addr = findSPADAddress(xinstr, self) if dest_spad_addr < 0: - retval = False # No SPAD available, flush later + retval = False # No SPAD available, flush later if retval: - xstore = _createXStore(xinstr.id[0], - dest_spad_addr, - variable, - None, - ' flushing output', - self) + xstore = _createXStore(xinstr.id[0], dest_spad_addr, variable, None, " flushing output", self) self.addDependency(xstore, None) # Add to topo_sort self.addXInstrToTopoSort(xstore.id) return retval - def generateKeyMaterial(self, - instr_id: int, - variable: Variable, - register: Register, - dep_id = None) -> int: + def generateKeyMaterial(self, instr_id: int, variable: Variable, register: Register, dep_id=None) -> int: """ Generates key material for the specified variable. @@ -1165,18 +1184,16 @@ def generateKeyMaterial(self, # Seed ready to be used to generate new key material if key_idx != self.last_keygen_index + 1: - raise RuntimeError(('Keygen variable "{}" generation out of order. ' - 'Expected key index {}, but received {} for seed {}.').format(variable.name, - self.last_keygen_index + 1, - key_idx, - self.bundle_current_kgseed)) - - comment = "" if dep_id is None else 'dep id: {}'.format(dep_id) - kg_load = cinst.KGLoad(instr_id, register, [ variable ], comment=comment) + raise RuntimeError( + ('Keygen variable "{}" generation out of order. ' "Expected key index {}, but received {} for seed {}.").format( + variable.name, self.last_keygen_index + 1, key_idx, self.bundle_current_kgseed + ) + ) + + comment = "" if dep_id is None else "dep id: {}".format(dep_id) + kg_load = cinst.KGLoad(instr_id, register, [variable], comment=comment) # Nop required because kg_load/kg_start instructions have a resource dependency among them - cnop = cinst.CNop(instr_id, - kg_load.latency, - comment='kg_load {} wait period'.format(kg_load.id)) + cnop = cinst.CNop(instr_id, kg_load.latency, comment="kg_load {} wait period".format(kg_load.id)) cnop.schedule(self.current_cycle, len(self.cinsts) + 1) self.cinsts.append(cnop) @@ -1185,12 +1202,11 @@ def generateKeyMaterial(self, # Seed used this bundle self.bundle_used_kg_seed = self.current_cycle.bundle - self.last_keygen_index = key_idx # Advance the last generated index tracker + self.last_keygen_index = key_idx # Advance the last generated index tracker return retval - def loadrShuffleRoutingTable(self, - rshuffle_data_type_name: str): + def loadrShuffleRoutingTable(self, rshuffle_data_type_name: str): """ Queues CInstructions needed to load the `rshuffle` routing table into CE. @@ -1208,39 +1224,37 @@ def loadrShuffleRoutingTable(self, routing_table_name = "" routing_table_target = -1 if rshuffle_data_type_name == xinst.rShuffle.RSHUFFLE_DATA_TYPE: - aux_table_name = self.mem_model.meta_ntt_aux_table + aux_table_name = self.mem_model.meta_ntt_aux_table routing_table_name = self.mem_model.meta_ntt_routing_table elif rshuffle_data_type_name == xinst.irShuffle.RSHUFFLE_DATA_TYPE: - aux_table_name = self.mem_model.meta_intt_aux_table + aux_table_name = self.mem_model.meta_intt_aux_table routing_table_name = self.mem_model.meta_intt_routing_table else: - raise ValueError(('`rshuffle_data_type_name`: invalid value "{}". Expected one of {}.').format(rshuffle_data_type_name, - { xinst.rShuffle.RSHUFFLE_DATA_TYPE, - xinst.irShuffle.RSHUFFLE_DATA_TYPE })) + raise ValueError( + ('`rshuffle_data_type_name`: invalid value "{}". Expected one of {}.').format( + rshuffle_data_type_name, {xinst.rShuffle.RSHUFFLE_DATA_TYPE, xinst.irShuffle.RSHUFFLE_DATA_TYPE} + ) + ) # Only NTT targets are supported for both NTT and iNTT in RTL 0.9 - aux_table_target = RegisterTargets.TARGET_NTT_AUX_TABLE + aux_table_target = RegisterTargets.TARGET_NTT_AUX_TABLE routing_table_target = RegisterTargets.TARGET_NTT_ROUTING_TABLE if aux_table_name and routing_table_name: - spad_map = QueueDict() # dict(var_name, (Variable, target_register)) - spad_map[aux_table_name] = (self.mem_model.variables[aux_table_name], - aux_table_target) - spad_map[routing_table_name] = (self.mem_model.variables[routing_table_name], - routing_table_target) + spad_map = QueueDict() # dict(var_name, (Variable, target_register)) + spad_map[aux_table_name] = (self.mem_model.variables[aux_table_name], aux_table_target) + spad_map[routing_table_name] = (self.mem_model.variables[routing_table_name], routing_table_target) # Load meta SPAD -> special CE rshuffle registers for shuffle_meta_table_name in spad_map: variable, target_idx = spad_map[shuffle_meta_table_name] - assert variable.spad_address >= 0, f'Metadata variable {variable.name} must be in SPAD' + assert variable.spad_address >= 0, f"Metadata variable {variable.name} must be in SPAD" self.queueCSyncmLoad(0, variable.spad_address) nload = cinst.NLoad(0, target_idx, variable, self.mem_model) - nload.comment = f' loading routing table for `{rshuffle_data_type_name}`' + nload.comment = f" loading routing table for `{rshuffle_data_type_name}`" nload.schedule(self.current_cycle, len(self.cinsts) + 1) self.cinsts.append(nload) else: - raise RuntimeError(f'`rshuffle`: required routing table for `{rshuffle_data_type_name}` not present in metadata.') + raise RuntimeError(f"`rshuffle`: required routing table for `{rshuffle_data_type_name}` not present in metadata.") - def loadBOnesMetadata(self, - spad_addr_offset: int, - ones_metadata_segment: int) -> int: + def loadBOnesMetadata(self, spad_addr_offset: int, ones_metadata_segment: int) -> int: """ Queues MInstructions and CInstructions needed to load the Ones metadata. @@ -1263,20 +1277,28 @@ def loadBOnesMetadata(self, assert constants.MemoryModel.NUM_ONES_META_REGISTERS == 1 if ones_metadata_segment < 0 or ones_metadata_segment >= len(self.mem_model.meta_ones_vars_segments): - raise IndexError(('`twid_metadata_segment`: requested segment index {}, but there are only {} ' - 'segments of ones metadata available for up to {} residuals.').format(ones_metadata_segment, - len(self.mem_model.meta_ones_vars_segments), - len(self.mem_model.meta_ones_vars_segments) * constants.MemoryModel.MAX_RESIDUALS)) + raise IndexError( + ( + "`twid_metadata_segment`: requested segment index {}, but there are only {} " + "segments of ones metadata available for up to {} residuals." + ).format( + ones_metadata_segment, + len(self.mem_model.meta_ones_vars_segments), + len(self.mem_model.meta_ones_vars_segments) * constants.MemoryModel.MAX_RESIDUALS, + ) + ) RegisterTargets = constants.MemInfo.MetaTargets spad_addr = 0 - spad_map = QueueDict() # dict(var_name, (Variable, target_register)) + spad_map = QueueDict() # dict(var_name, (Variable, target_register)) meta_ones_vars = self.mem_model.meta_ones_vars_segments[ones_metadata_segment] - if meta_ones_vars \ - and len(meta_ones_vars) != constants.MemoryModel.NUM_ONES_META_REGISTERS: - raise RuntimeError("Required {} twiddle metadata variables per segment, but {} received.".format(constants.MemoryModel.NUM_ONES_META_REGISTERS, - len(meta_ones_vars))) + if meta_ones_vars and len(meta_ones_vars) != constants.MemoryModel.NUM_ONES_META_REGISTERS: + raise RuntimeError( + "Required {} twiddle metadata variables per segment, but {} received.".format( + constants.MemoryModel.NUM_ONES_META_REGISTERS, len(meta_ones_vars) + ) + ) # Load HBM -> SPAD for meta_ones_var_name in meta_ones_vars: @@ -1286,9 +1308,15 @@ def loadBOnesMetadata(self, self.mem_model.spad.deallocate(target_spad_addr) # Load variable into SPAD variable = self.mem_model.variables[meta_ones_var_name] - self.queueMLoad(0, target_spad_addr, variable, - comment='loading ones metadata for residuals [{}, {})'.format(ones_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, - (ones_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) + self.queueMLoad( + 0, + target_spad_addr, + variable, + comment="loading ones metadata for residuals [{}, {})".format( + ones_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, + (ones_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS, + ), + ) spad_map[constants.MemInfo.MetaFields.FIELD_ONES] = (variable, RegisterTargets.TARGET_ONES) spad_addr += 1 @@ -1296,9 +1324,16 @@ def loadBOnesMetadata(self, for ones_meta_name in spad_map: variable, target_idx = spad_map[ones_meta_name] self.queueCSyncmLoad(0, variable.spad_address) - bones = cinst.BOnes(0, target_idx, variable, self.mem_model, - comment='loading ones metadata for residuals [{}, {})'.format(ones_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, - (ones_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) + bones = cinst.BOnes( + 0, + target_idx, + variable, + self.mem_model, + comment="loading ones metadata for residuals [{}, {})".format( + ones_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, + (ones_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS, + ), + ) bones.schedule(self.current_cycle, len(self.cinsts) + 1) self.cinsts.append(bones) @@ -1307,9 +1342,7 @@ def loadBOnesMetadata(self, return spad_addr_offset + spad_addr - def loadTwiddleMetadata(self, - spad_addr_offset: int, - twid_metadata_segment: int): + def loadTwiddleMetadata(self, spad_addr_offset: int, twid_metadata_segment: int): """ Queues MInstructions and CInstructions needed to load the Twiddle factor generation metadata. @@ -1334,17 +1367,25 @@ def loadTwiddleMetadata(self, spad_addr = 0 if twid_metadata_segment < 0 or twid_metadata_segment >= len(self.mem_model.meta_twiddle_vars_segments): - raise IndexError(('`twid_metadata_segment`: requested segment index {}, but there are only {} ' - 'segments of twiddle metadata available for up to {} residuals.').format(twid_metadata_segment, - len(self.mem_model.meta_twiddle_vars_segments), - len(self.mem_model.meta_twiddle_vars_segments) * constants.MemoryModel.MAX_RESIDUALS)) + raise IndexError( + ( + "`twid_metadata_segment`: requested segment index {}, but there are only {} " + "segments of twiddle metadata available for up to {} residuals." + ).format( + twid_metadata_segment, + len(self.mem_model.meta_twiddle_vars_segments), + len(self.mem_model.meta_twiddle_vars_segments) * constants.MemoryModel.MAX_RESIDUALS, + ) + ) meta_twiddle_vars = self.mem_model.meta_twiddle_vars_segments[twid_metadata_segment] - if meta_twiddle_vars \ - and len(meta_twiddle_vars) != self.mem_model.MAX_TWIDDLE_META_VARS_PER_SEGMENT: - raise RuntimeError("Required {} twiddle metadata variables per segment, but {} received.".format(self.mem_model.MAX_TWIDDLE_META_VARS_PER_SEGMENT, - len(meta_twiddle_vars))) + if meta_twiddle_vars and len(meta_twiddle_vars) != self.mem_model.MAX_TWIDDLE_META_VARS_PER_SEGMENT: + raise RuntimeError( + "Required {} twiddle metadata variables per segment, but {} received.".format( + self.mem_model.MAX_TWIDDLE_META_VARS_PER_SEGMENT, len(meta_twiddle_vars) + ) + ) # Load HBM -> SPAD for meta_twiddle_var_name in meta_twiddle_vars: @@ -1354,24 +1395,34 @@ def loadTwiddleMetadata(self, self.mem_model.spad.deallocate(target_spad_addr) # Load variable into SPAD variable = self.mem_model.variables[meta_twiddle_var_name] - self.queueMLoad(0, target_spad_addr, variable, - comment='loading twid metadata for residuals [{}, {})'.format(twid_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, - (twid_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) + self.queueMLoad( + 0, + target_spad_addr, + variable, + comment="loading twid metadata for residuals [{}, {})".format( + twid_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, + (twid_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS, + ), + ) spad_addr += 1 # Load meta SPAD -> special CE twiddle registers target_bload_register = 0 for meta_twiddle_var_name in meta_twiddle_vars: variable = self.mem_model.variables[meta_twiddle_var_name] - for col_num in range(constants.MemoryModel.NUM_BLOCKS_PER_TWID_META_WORD): # Block + for col_num in range(constants.MemoryModel.NUM_BLOCKS_PER_TWID_META_WORD): # Block self.queueCSyncmLoad(0, variable.spad_address) - bload = cinst.BLoad(0, - col_num, - target_bload_register, - variable, - self.mem_model, - comment='loading twid metadata for residuals [{}, {})'.format(twid_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, - (twid_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) + bload = cinst.BLoad( + 0, + col_num, + target_bload_register, + variable, + self.mem_model, + comment="loading twid metadata for residuals [{}, {})".format( + twid_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, + (twid_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS, + ), + ) bload.schedule(self.current_cycle, len(self.cinsts) + 1) self.cinsts.append(bload) target_bload_register += 1 @@ -1381,9 +1432,7 @@ def loadTwiddleMetadata(self, return spad_addr_offset + spad_addr - def loadKeygenSeedMetadata(self, - spad_addr_offset: int, - kgseed_idx: int) -> int: + def loadKeygenSeedMetadata(self, spad_addr_offset: int, kgseed_idx: int) -> int: """ Queues MInstructions and CInstructions needed to load a new keygen seed. @@ -1401,13 +1450,17 @@ def loadKeygenSeedMetadata(self, Raises: IndexError: If the seed index is out of range. """ - if kgseed_idx < 0 \ - or kgseed_idx >= len(self.mem_model.meta_keygen_seed_vars) * constants.MemoryModel.NUM_BLOCKS_PER_KGSEED_META_WORD: - raise IndexError('`kgseed_idx` must index in the range [0, {}), but {} received'.format(len(self.mem_model.meta_keygen_seed_vars) * constants.MemoryModel.NUM_BLOCKS_PER_KGSEED_META_WORD, - kgseed_idx)) + if ( + kgseed_idx < 0 + or kgseed_idx >= len(self.mem_model.meta_keygen_seed_vars) * constants.MemoryModel.NUM_BLOCKS_PER_KGSEED_META_WORD + ): + raise IndexError( + "`kgseed_idx` must index in the range [0, {}), but {} received".format( + len(self.mem_model.meta_keygen_seed_vars) * constants.MemoryModel.NUM_BLOCKS_PER_KGSEED_META_WORD, kgseed_idx + ) + ) # Only switch seeds if different from current if kgseed_idx != self.bundle_current_kgseed: - spad_addr = 0 # One word contains 4 seeds: find the right seed seed_word_block = kgseed_idx % constants.MemoryModel.NUM_BLOCKS_PER_KGSEED_META_WORD @@ -1426,20 +1479,17 @@ def loadKeygenSeedMetadata(self, if self.mem_model.spad.buffer[target_spad_addr]: self.mem_model.spad.deallocate(target_spad_addr) # Load variable into SPAD - self.queueMLoad(0, target_spad_addr, seed_variable, - comment='loading keygen seed ({}, block = {})'.format(seed_word_idx, - seed_word_block)) + self.queueMLoad( + 0, target_spad_addr, seed_variable, comment="loading keygen seed ({}, block = {})".format(seed_word_idx, seed_word_block) + ) spad_addr += 1 # Load seed SPAD -> key material generation subsystem self.queueCSyncmLoad(len(self.cinsts), seed_variable.spad_address) - kg_seed = cinst.KGSeed(len(self.cinsts), - seed_word_block, - seed_variable, - self.mem_model) - kg_start = cinst.KGStart(len(self.cinsts) + 1, comment=f'seed {kgseed_idx}') + kg_seed = cinst.KGSeed(len(self.cinsts), seed_word_block, seed_variable, self.mem_model) + kg_start = cinst.KGStart(len(self.cinsts) + 1, comment=f"seed {kgseed_idx}") kg_seed.schedule(self.current_cycle, len(self.cinsts) + 1) self.cinsts.append(kg_seed) @@ -1448,7 +1498,7 @@ def loadKeygenSeedMetadata(self, # Update the currently loaded seed self.bundle_current_kgseed = kgseed_idx - self.last_keygen_index = -1 # Restart the keygen index + self.last_keygen_index = -1 # Restart the keygen index return spad_addr_offset + spad_addr @@ -1467,8 +1517,7 @@ def loadMetadata(self): self.metadata_spad_addr_start_kgseed = spad_addr_offset spad_addr_offset = self.loadKeygenSeedMetadata(spad_addr_offset, 0) - def prepareShuffleMetadata(self, - spad_addr_offset: int) -> int: + def prepareShuffleMetadata(self, spad_addr_offset: int) -> int: """ Queues MInstructions needed to load the `rshuffle` metadata into SPAD. @@ -1484,8 +1533,7 @@ def prepareShuffleMetadata(self, spad_addr = 0 # Load HBM -> SPAD - if self.mem_model.meta_ntt_aux_table \ - and self.mem_model.meta_ntt_routing_table: + if self.mem_model.meta_ntt_aux_table and self.mem_model.meta_ntt_routing_table: variable = self.mem_model.variables[self.mem_model.meta_ntt_aux_table] self.queueMLoad(0, spad_addr_offset + spad_addr, variable) spad_addr += 1 @@ -1495,10 +1543,9 @@ def prepareShuffleMetadata(self, spad_addr += 1 else: # If one of NTT aux table or routing table is specified, so must be the other - raise RuntimeError('Both, NTT Auxiliary table and Routing table must exist in memory model.') + raise RuntimeError("Both, NTT Auxiliary table and Routing table must exist in memory model.") - if self.mem_model.meta_intt_aux_table \ - and self.mem_model.meta_intt_routing_table: + if self.mem_model.meta_intt_aux_table and self.mem_model.meta_intt_routing_table: variable = self.mem_model.variables[self.mem_model.meta_intt_aux_table] self.queueMLoad(0, spad_addr_offset + spad_addr, variable) spad_addr += 1 @@ -1508,11 +1555,11 @@ def prepareShuffleMetadata(self, spad_addr += 1 else: # If one of iNTT aux table or routing table is specified, so must be the other - raise RuntimeError('Both, iNTT Auxiliary table and Routing table must exist in memory model.') + raise RuntimeError("Both, iNTT Auxiliary table and Routing table must exist in memory model.") return spad_addr_offset + spad_addr - def priority_queue_push(self, xinstr, tie_breaker = None): + def priority_queue_push(self, xinstr, tie_breaker=None): """ Adds a new instruction to the priority queue. @@ -1525,14 +1572,14 @@ def priority_queue_push(self, xinstr, tie_breaker = None): Raises: AssertionError: If the instruction is not in the dependency graph. """ - assert xinstr.id in self.dependency_graph, f'{xinstr.id} NOT in simulation.dependency_graph' + assert xinstr.id in self.dependency_graph, f"{xinstr.id} NOT in simulation.dependency_graph" if isinstance(xinstr, xinst.XStore): if tie_breaker is None: - tie_breaker = (-1, ) + tie_breaker = (-1,) self.xstore_pq.push(xinstr.cycle_ready, xinstr, tie_breaker) if isinstance(xinstr, xinst.Move): if tie_breaker is None: - tie_breaker = (-2, ) + tie_breaker = (-2,) self.priority_queue.push(xinstr.cycle_ready, xinstr, tie_breaker) self.set_extracted_xinstrs.add(xinstr.id) @@ -1550,9 +1597,7 @@ def priority_queue_remove(self, xinstr): assert isinstance(xinstr, xinst.XStore) self.xstore_pq.remove(xinstr) - def queueCSyncmLoad(self, - instr_id: int, - source_spad_addr: int): + def queueCSyncmLoad(self, instr_id: int, source_spad_addr: int): """ Checks if needed, and, if so, queues a CSyncm CInstruction to sync to SPAD access from HBM in order to write from SPAD into CE. @@ -1568,11 +1613,7 @@ def queueCSyncmLoad(self, csyncm.schedule(self.current_cycle, len(self.cinsts) + 1) self.cinsts.append(csyncm) - def queueMLoad(self, - instr_id: int, - target_spad_addr: int, - variable, - comment = ""): + def queueMLoad(self, instr_id: int, target_spad_addr: int, variable, comment=""): """ Generates instructions to copy from HBM into SPAD. @@ -1588,7 +1629,7 @@ def queueMLoad(self, """ # Generate instructions to copy from HBM into SPAD if target_spad_addr < 0: - raise ValueError('Argument Null Exception: Target SPAD address cannot be null (negative address).') + raise ValueError("Argument Null Exception: Target SPAD address cannot be null (negative address).") self.queueMSynccLoad(instr_id, target_spad_addr) if variable.hbm_address < 0: @@ -1598,13 +1639,11 @@ def queueMLoad(self, if hbm_addr < 0: raise RuntimeError("Out of HBM space.") self.mem_model.hbm.allocateForce(hbm_addr, variable) - mload = minst.MLoad(instr_id, [ variable ], self.mem_model, target_spad_addr, comment=comment) + mload = minst.MLoad(instr_id, [variable], self.mem_model, target_spad_addr, comment=comment) mload.schedule(self.current_cycle, len(self.minsts) + 1) self.minsts.append(mload) - def queueMSynccLoad(self, - instr_id: int, - target_spad_addr: int): + def queueMSynccLoad(self, instr_id: int, target_spad_addr: int): """ Checks if needed, and, if so, queues an MSyncc MInstruction to sync to SPAD access to write from HBM into specified SPAD address. @@ -1617,20 +1656,18 @@ def queueMSynccLoad(self, ValueError: If the target SPAD address is negative. """ if target_spad_addr < 0: - raise ValueError('Argument Null Exception: Target SPAD address cannot be null (negative address).') + raise ValueError("Argument Null Exception: Target SPAD address cannot be null (negative address).") # mload depends on the last c access (cload or cstore) last_access = self.mem_model.spad.getAccessTracking(target_spad_addr) last_c_access = last_access.last_cstore - if not last_access.last_cstore[1] \ - or (last_access.last_cload[1] \ - and last_access.last_cload[0] > last_access.last_cstore[0]): + if not last_access.last_cstore[1] or (last_access.last_cload[1] and last_access.last_cload[0] > last_access.last_cstore[0]): # No last cstore or cload happened after cstore last_c_access = last_access.last_cload last_c_access = last_c_access[1] if last_c_access: # Need to sync to CInst - assert(last_c_access.is_scheduled) + assert last_c_access.is_scheduled msyncc = minst.MSyncc(instr_id, last_c_access) msyncc.schedule(self.current_cycle, len(self.minsts) + 1) self.minsts.append(msyncc) @@ -1656,14 +1693,14 @@ def updateQueuesSyncsPass2(self): target_cinstr.set_schedule_timing_index(map_cinsts[target_cinstr.id] + 1) else: target_cinstr.set_schedule_timing_index(map_cinsts[target_cinstr.id]) - minstr.freeze() # Re-freeze with new value + minstr.freeze() # Re-freeze with new value # Traverse CInstQ and update csyncm targets for cinstr in self.cinsts: if isinstance(cinstr, cinst.CSyncm): target_minstr = cinstr.minstr target_minstr.set_schedule_timing_index(map_minsts[target_minstr.id]) - cinstr.freeze() # Re-freeze with new value + cinstr.freeze() # Re-freeze with new value def updateSchedule(self, instr) -> bool: """ @@ -1680,23 +1717,23 @@ def updateSchedule(self, instr) -> bool: RuntimeError: If the bundle is already full or if an attempt is made to schedule an instruction in a bundle that only allows specific types or residuals. """ if not instr: - raise ValueError('`instr` cannot be `None`.') + raise ValueError("`instr` cannot be `None`.") if instr.id not in self.dependency_graph: raise ValueError(f'`instr`: invalid instruction "{instr}" not in dependency graph.') if len(self.xinsts_bundle) >= self.max_bundle_size: raise RuntimeError("Bundle already full.") - dependents = list(self.dependency_graph.successors(instr.id)) # Find instructions that depend on this instruction - self.dependency_graph.remove_node(instr.id) # Remove from graph to update the in_degree of dependent instrs + dependents = list(self.dependency_graph.successors(instr.id)) # Find instructions that depend on this instruction + self.dependency_graph.remove_node(instr.id) # Remove from graph to update the in_degree of dependent instrs self.b_dependency_graph_changed = True # "move" dependent instrs that have no other dependencies to the top of the topo sort if isinstance(instr, xinst.XStore): for instr_id in dependents: if self.dependency_graph.in_degree(instr_id) <= 0: if instr_id not in self.set_extracted_xinstrs: - self.priority_queue_push(self.dependency_graph.nodes[instr_id]['instruction']) + self.priority_queue_push(self.dependency_graph.nodes[instr_id]["instruction"]) else: - self.topo_sort = [ instr_id for instr_id in dependents if self.dependency_graph.in_degree(instr_id) <= 0 ] + self.topo_sort + self.topo_sort = [instr_id for instr_id in dependents if self.dependency_graph.in_degree(instr_id) <= 0] + self.topo_sort self.b_topo_sort_changed = True if instr in self.priority_queue: @@ -1709,16 +1746,13 @@ def updateSchedule(self, instr) -> bool: if isinstance(instr, xinst.XStore): # Add corresponding cstore - cstore = cinst.CStore(instr.id[0], - self.mem_model, - comment=instr.comment) + cstore = cinst.CStore(instr.id[0], self.mem_model, comment=instr.comment) self.post_bundle_cinsts.append(cstore) # Make sure bundle syncs to last mstore before fetching because # it does cstores that overwrite SPAD addresses that may still be in process # of storing to HBM: last_mstore = self.mem_model.spad.getAccessTracking(instr.dest_spad_address).last_mstore - if self.pre_bundle_csync_minstr[0] <= last_mstore[0] \ - and last_mstore[1] is not None: + if self.pre_bundle_csync_minstr[0] <= last_mstore[0] and last_mstore[1] is not None: self.pre_bundle_csync_minstr = last_mstore if isinstance(instr, (xinst.rShuffle, xinst.irShuffle)): @@ -1728,9 +1762,10 @@ def updateSchedule(self, instr) -> bool: # Add rshuffle to list of pending writes scheduled_cycle = instr.schedule_timing.cycle - write_cycle = XWriteCycleTrack(cycle = CycleType(bundle = scheduled_cycle.bundle, - cycle = scheduled_cycle.cycle + instr.latency - 1), - banks = set(v.suggested_bank for v in instr.dests)) + write_cycle = XWriteCycleTrack( + cycle=CycleType(bundle=scheduled_cycle.bundle, cycle=scheduled_cycle.cycle + instr.latency - 1), + banks=set(v.suggested_bank for v in instr.dests), + ) self.pending_write_cycles.append(write_cycle) # Track the scheduled xrshuffle to try to schedule others in slotted intervals @@ -1741,38 +1776,52 @@ def updateSchedule(self, instr) -> bool: if self.bundle_needed_rshuffle_type is None: self.bundle_needed_rshuffle_type = type(instr) elif not isinstance(instr, self.bundle_needed_rshuffle_type): - raise RuntimeError('Attempted to schedule {} in bundle that only allows {}.'.format(instr, - self.bundle_needed_rshuffle_type)) + raise RuntimeError( + "Attempted to schedule {} in bundle that only allows {}.".format(instr, self.bundle_needed_rshuffle_type) + ) # Rule: cannot mix XInsts of different residual segments in same bundle. if instr.res is not None: instr_needed_segment = instr.res // constants.MemoryModel.MAX_RESIDUALS - assert self.bundle_needed_ones_segment == self.bundle_needed_twid_segment, \ - 'Needed Ones and Twiddle metadata segments are not synchronized.' + assert ( + self.bundle_needed_ones_segment == self.bundle_needed_twid_segment + ), "Needed Ones and Twiddle metadata segments are not synchronized." if self.bundle_needed_ones_segment == -1: self.bundle_needed_ones_segment = instr_needed_segment elif self.bundle_needed_ones_segment != instr_needed_segment: - raise RuntimeError(('Attempted to schedule XInstruction "{}", residual = {}, ' - 'in bundle {} that only allows residuals in range [{}, {}).').format(str(instr), - instr.res, - self.current_cycle.bundle, - self.bundle_needed_ones_segment * constants.MemoryModel.MAX_RESIDUALS, - (self.bundle_needed_ones_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) + raise RuntimeError( + ( + 'Attempted to schedule XInstruction "{}", residual = {}, ' + "in bundle {} that only allows residuals in range [{}, {})." + ).format( + str(instr), + instr.res, + self.current_cycle.bundle, + self.bundle_needed_ones_segment * constants.MemoryModel.MAX_RESIDUALS, + (self.bundle_needed_ones_segment + 1) * constants.MemoryModel.MAX_RESIDUALS, + ) + ) if self.bundle_needed_twid_segment == -1: self.bundle_needed_twid_segment = instr_needed_segment elif self.bundle_needed_twid_segment != instr_needed_segment: - raise RuntimeError(('Attempted to schedule XInstruction {}, residual = {}, ' - 'in bundle {} that only allows residuals in range [{}, {}).').format(str(instr), - instr.res, - self.current_cycle.bundle, - self.bundle_needed_twid_segment * constants.MemoryModel.MAX_RESIDUALS, - (self.bundle_needed_twid_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) - - self.appendXInstToBundle(instr) # add instruction to bundle + raise RuntimeError( + ( + "Attempted to schedule XInstruction {}, residual = {}, " + "in bundle {} that only allows residuals in range [{}, {})." + ).format( + str(instr), + instr.res, + self.current_cycle.bundle, + self.bundle_needed_twid_segment * constants.MemoryModel.MAX_RESIDUALS, + (self.bundle_needed_twid_segment + 1) * constants.MemoryModel.MAX_RESIDUALS, + ) + ) + + self.appendXInstToBundle(instr) # add instruction to bundle # True <=> bundle needs to be flushed (because of exit or full) - return isinstance(instr, xinst.Exit) \ - or (len(self.xinsts_bundle) >= self.max_bundle_size) + return isinstance(instr, xinst.Exit) or (len(self.xinsts_bundle) >= self.max_bundle_size) + def __canScheduleInBundle(instr, simulation: Simulation, padding: int = 1) -> bool: """ @@ -1790,9 +1839,10 @@ def __canScheduleInBundle(instr, simulation: Simulation, padding: int = 1) -> bo # Look into this function to see if we can bring back skip scheduling of rshuffles at the end of bundles. # Right now, this featuer is disabled because ifetch does not have the same latency as XInstrs, so # the simulation keeps track of the whole bundle latency and just adds nops to the CInstQ as needed. - #----------------- + # ----------------- return len(simulation.xinsts_bundle) < simulation.max_bundle_size and True + def __flushVariableFromSPAD(instr, dest_hbm_addr: int, variable: Variable, simulation: Simulation) -> bool: """ Flushes a variable from the SPAD to HBM. @@ -1809,15 +1859,15 @@ def __flushVariableFromSPAD(instr, dest_hbm_addr: int, variable: Variable, simul Raises: AssertionError: If the destination HBM address is invalid. """ - assert(dest_hbm_addr >= 0) + assert dest_hbm_addr >= 0 spad = simulation.mem_model.spad - comment = (' id: {} - flushing').format(instr.id) + comment = (" id: {} - flushing").format(instr.id) last_cstore = spad.getAccessTracking(variable.spad_address).last_cstore[1] if last_cstore: # mstore needs to happen after last cstore - assert(last_cstore.is_scheduled) + assert last_cstore.is_scheduled # Sync to last CInst access to avoid storing before access completes msyncc = minst.MSyncc(instr.id[0], last_cstore, comment=comment) msyncc.schedule(simulation.current_cycle, len(simulation.minsts) + 1) @@ -1829,7 +1879,10 @@ def __flushVariableFromSPAD(instr, dest_hbm_addr: int, variable: Variable, simul return True -def _createXStore(instr_id: int, dest_spad_addr: int, evict_variable: Variable, new_variable: Variable, comment: str, simulation: Simulation) -> object: + +def _createXStore( + instr_id: int, dest_spad_addr: int, evict_variable: Variable, new_variable: Variable, comment: str, simulation: Simulation +) -> object: """ Creates an XStore instruction to move a variable into SPAD. @@ -1847,25 +1900,30 @@ def _createXStore(instr_id: int, dest_spad_addr: int, evict_variable: Variable, Raises: AssertionError: If the evict variable's register is None or if the SPAD address is invalid. """ - assert(evict_variable.register is not None) - assert(evict_variable.spad_address < 0 or evict_variable.spad_address == dest_spad_addr) + assert evict_variable.register is not None + assert evict_variable.spad_address < 0 or evict_variable.spad_address == dest_spad_addr spad = simulation.mem_model.spad # Block SPAD address to avoid it being found by another findSPADAddress if spad[dest_spad_addr]: - assert(not isinstance(spad[dest_spad_addr], DummyVariable)) + assert not isinstance(spad[dest_spad_addr], DummyVariable) spad.deallocate(dest_spad_addr) spad.allocateForce(dest_spad_addr, DummyVariable()) # Generate the xstore instruction to move variable into SPAD - xstore = XStoreAssign(instr_id, [evict_variable], simulation.mem_model, new_variable, dest_spad_addr=dest_spad_addr, comment=comment) \ - if new_variable else \ - xinst.XStore(instr_id, [evict_variable], simulation.mem_model, dest_spad_addr=dest_spad_addr, comment=comment) + xstore = ( + XStoreAssign(instr_id, [evict_variable], simulation.mem_model, new_variable, dest_spad_addr=dest_spad_addr, comment=comment) + if new_variable + else xinst.XStore(instr_id, [evict_variable], simulation.mem_model, dest_spad_addr=dest_spad_addr, comment=comment) + ) evict_variable.accessed_by_xinsts = [Variable.AccessElement(0, xstore.id)] + evict_variable.accessed_by_xinsts return xstore -def __flushVariableFromRegisterFile(instr, dest_spad_addr: int, evict_variable: Variable, new_variable: Variable, simulation: Simulation) -> object: + +def __flushVariableFromRegisterFile( + instr, dest_spad_addr: int, evict_variable: Variable, new_variable: Variable, simulation: Simulation +) -> object: """ Flushes a variable from the register file to SPAD. @@ -1879,12 +1937,13 @@ def __flushVariableFromRegisterFile(instr, dest_spad_addr: int, evict_variable: Returns: object: The created XStore instruction. """ - comment = (' dep id: {} - flushing'.format(instr.id)) + comment = " dep id: {} - flushing".format(instr.id) xstore = _createXStore(instr.id[0], dest_spad_addr, evict_variable, new_variable, comment, simulation) simulation.addDependency(xstore, instr) return xstore + def scheduleXNOP(instr, idle_cycles: int, simulation: Simulation, force_nop: bool = False) -> bool: """ Schedules a NOP instruction if necessary. @@ -1902,32 +1961,33 @@ def scheduleXNOP(instr, idle_cycles: int, simulation: Simulation, force_nop: boo ValueError: If idle_cycles is not greater than 0. """ if idle_cycles <= 0: - raise ValueError(f'`idle_cycles`: expected greater than `0`, but {idle_cycles} received.') + raise ValueError(f"`idle_cycles`: expected greater than `0`, but {idle_cycles} received.") retval = True comment = "" if not isinstance(instr, xinst.Exit): comment = f" nop for not ready instr {instr.id}" - #prev_xinst = simulation.xinsts_bundle[-1] if len(simulation.xinsts_bundle) > 0 else None - prev_xinst = None # rshuffle wait cycle no longer works + # prev_xinst = simulation.xinsts_bundle[-1] if len(simulation.xinsts_bundle) > 0 else None + prev_xinst = None # rshuffle wait cycle no longer works if not force_nop and isinstance(prev_xinst, (xinst.rShuffle, xinst.irShuffle)): # Add idle cycles using previous rshuffle prev_xinst.wait_cyc = idle_cycles if comment: prev_xinst.comment += "{} {}".format(";" if len(prev_xinst.comment) > 0 else "", comment) prev_xinst.freeze() # Refreeze rshuffle to reflect the new wait_cyc - simulation.current_cycle += idle_cycles # Advance current cycle + simulation.current_cycle += idle_cycles # Advance current cycle else: retval = force_nop or len(simulation.xinsts_bundle) < simulation.max_bundle_size - 1 if retval: - assert len(simulation.xinsts_bundle) < simulation.max_bundle_size, 'Cannot queue NOP into full bundle.' + assert len(simulation.xinsts_bundle) < simulation.max_bundle_size, "Cannot queue NOP into full bundle." xnop = xinst.Nop(instr.id[0], idle_cycles, comment=comment) simulation.current_cycle += xnop.schedule(simulation.current_cycle, len(simulation.xinsts_bundle) + 1) simulation.appendXInstToBundle(xnop) return retval + def findSPADAddress(instr, simulation: Simulation) -> int: """ Finds an available SPAD address for an instruction. @@ -1976,7 +2036,7 @@ def findSPADAddress(instr, simulation: Simulation) -> int: # SPAD address found variable: Variable = spad.buffer[retval_addr] if variable: # Contains a variable - assert(variable.spad_address == retval_addr) + assert variable.spad_address == retval_addr # Address needs to be evicted if variable.spad_dirty: # Check usage @@ -2015,14 +2075,17 @@ def findSPADAddress(instr, simulation: Simulation) -> int: if retval_addr >= 0: if variable.spad_address >= 0: - assert(variable.spad_address == retval_addr) + assert variable.spad_address == retval_addr # Variable still in SPAD # SPAD address now clean, just free the address spad.deallocate(retval_addr) return retval_addr -def findRegister(instr, bank_idx: int, simulation: Simulation, override_replacement_policy: str = None, dest_var: Variable = None) -> object: + +def findRegister( + instr, bank_idx: int, simulation: Simulation, override_replacement_policy: str = None, dest_var: Variable = None +) -> object: """ Finds an available register for an instruction. @@ -2049,7 +2112,7 @@ def findRegister(instr, bank_idx: int, simulation: Simulation, override_replacem # if found retval_register to evict: # Eviction: # if register is clean: - # no need to flush register, just evict variable since it has not been writen to. + # no need to flush register, just evict variable since it has not been written to. # else, register is dirty: # need to flush variable to SPAD cache: # flush logic: @@ -2088,7 +2151,7 @@ def inner_computeLiveVars(register_bank): if retval_register.contained_variable: # Register needs to evict contained variable variable = retval_register.contained_variable - assert(not isinstance(variable, DummyVariable)) + assert not isinstance(variable, DummyVariable) if variable.register_dirty: # Check usage if len(variable.accessed_by_xinsts) > 0 or variable.name in simulation.mem_model.output_variables: @@ -2124,6 +2187,7 @@ def inner_computeLiveVars(register_bank): return retval, retval_register + def loadVariableHBMToSPAD(instr, variable: Variable, simulation: Simulation) -> bool: """ Loads a variable from HBM to SPAD. @@ -2167,7 +2231,7 @@ def loadVariableHBMToSPAD(instr, variable: Variable, simulation: Simulation) -> last_c_access = last_c_access[1] if last_c_access: # Need to sync to CInst - assert(last_c_access.is_scheduled) + assert last_c_access.is_scheduled msyncc = minst.MSyncc(instr.id[0], last_c_access) msyncc.schedule(simulation.current_cycle, len(simulation.minsts)) simulation.minsts.append(msyncc) @@ -2182,6 +2246,7 @@ def loadVariableHBMToSPAD(instr, variable: Variable, simulation: Simulation) -> return target_spad_addr >= 0 + def hasBankWriteConflictGeneral(ready_cycle: CycleType, latency: int, banks, simulation: Simulation) -> bool: """ Checks for bank write conflicts in general. @@ -2197,7 +2262,12 @@ def hasBankWriteConflictGeneral(ready_cycle: CycleType, latency: int, banks, sim """ retval = False if ready_cycle.bundle <= simulation.current_cycle.bundle: # Instruction has no conflicts if it is on a later bundle - instr_write_cycle = XWriteCycleTrack(cycle=CycleType(bundle=simulation.current_cycle.bundle, cycle=max(ready_cycle.cycle, simulation.current_cycle.cycle) + latency - 1), banks=set(banks)) + instr_write_cycle = XWriteCycleTrack( + cycle=CycleType( + bundle=simulation.current_cycle.bundle, cycle=max(ready_cycle.cycle, simulation.current_cycle.cycle) + latency - 1 + ), + banks=set(banks), + ) if len(instr_write_cycle.banks) > 0: for rshuffle_write_cycle in simulation.pending_write_cycles: if instr_write_cycle.cycle < rshuffle_write_cycle.cycle: @@ -2213,6 +2283,7 @@ def hasBankWriteConflictGeneral(ready_cycle: CycleType, latency: int, banks, sim return retval + def hasBankWriteConflict(instr, simulation: Simulation) -> bool: """ Checks for bank write conflicts for a specific instruction. @@ -2236,6 +2307,7 @@ def hasBankWriteConflict(instr, simulation: Simulation) -> bool: return hasBankWriteConflictGeneral(ready_cycle, instr.latency, banks, simulation) + def prepareInstruction(original_xinstr, simulation: Simulation) -> int: """ Prepares an instruction for scheduling. @@ -2278,25 +2350,31 @@ def prepareInstruction(original_xinstr, simulation: Simulation) -> int: if not src_var.register: # Needs to start at bank 0 - b_generated_keygen_var = not simulation.mem_model.isVarInMem(src_var.name) and src_var.name in simulation.mem_model.keygen_variables + b_generated_keygen_var = ( + not simulation.mem_model.isVarInMem(src_var.name) and src_var.name in simulation.mem_model.keygen_variables + ) if not b_generated_keygen_var: # Variable is not keygen or it has already been generated # Load into SPAD if src_var.spad_address < 0: - assert src_var.name not in simulation.mem_model.store_buffer, f'Attempting to load from HBM: "{src_var.name}"; already in transit in SPAD store buffer.' + assert ( + src_var.name not in simulation.mem_model.store_buffer + ), f'Attempting to load from HBM: "{src_var.name}"; already in transit in SPAD store buffer.' if not loadVariableHBMToSPAD(original_xinstr, src_var, simulation): # Could not find location in SPAD, move to next bundle retval = 0 if retval != 0: - retval, new_instr_or_reg = findRegister(original_xinstr, 0, simulation, override_replacement_policy="") # No replacement policy for bank 0 + retval, new_instr_or_reg = findRegister( + original_xinstr, 0, simulation, override_replacement_policy="" + ) # No replacement policy for bank 0 # retval == 1 => register good to go # retval == 2 => xstore needed for eviction if retval == 1: # Register ready, load from SPAD - assert(new_instr_or_reg.bank.bank_index == 0) + assert new_instr_or_reg.bank.bank_index == 0 if b_generated_keygen_var: # This is a keygen variable that has not been generated @@ -2312,7 +2390,13 @@ def prepareInstruction(original_xinstr, simulation: Simulation) -> int: csyncm = cinst.CSyncm(original_xinstr.id[0], last_mload_access) csyncm.schedule(simulation.current_cycle, len(simulation.cinsts) + 1) simulation.cinsts.append(csyncm) - cload = cinst.CLoad(original_xinstr.id[0], new_instr_or_reg, [src_var], simulation.mem_model, comment="dep id: {}".format(original_xinstr.id)) + cload = cinst.CLoad( + original_xinstr.id[0], + new_instr_or_reg, + [src_var], + simulation.mem_model, + comment="dep id: {}".format(original_xinstr.id), + ) cload.schedule(simulation.current_cycle, len(simulation.cinsts) + 1) simulation.cinsts.append(cload) if retval == 2: @@ -2333,7 +2417,9 @@ def prepareInstruction(original_xinstr, simulation: Simulation) -> int: if retval == 1: # Generate instruction to move variable from bank 0 to its suggested bank new_instr_or_reg.allocateVariable(simulation.bundle_dummy_var) - xmove = xinst.Move(original_xinstr.id[0], new_instr_or_reg, [src_var], dummy_var=simulation.bundle_dummy_var) + xmove = xinst.Move( + original_xinstr.id[0], new_instr_or_reg, [src_var], dummy_var=simulation.bundle_dummy_var + ) if xmove.cycle_ready.bundle < simulation.current_cycle.bundle: # Correct cycle ready's bundle xmove.cycle_ready = CycleType(bundle=simulation.current_cycle.bundle, cycle=0) @@ -2349,7 +2435,7 @@ def prepareInstruction(original_xinstr, simulation: Simulation) -> int: if retval == 2: # XInsts needed to prepare variable - + # Add extra dependencies in case of XStore if isinstance(new_instr_or_reg, xinst.XStore): simulation.addExtraXStoreDependencies(original_xinstr, new_instr_or_reg, src_var) @@ -2360,7 +2446,11 @@ def prepareInstruction(original_xinstr, simulation: Simulation) -> int: if retval == 1: if src_var.register.bank.bank_index != src_var.suggested_bank: - raise RuntimeError('Variable `{}` is in register `{}`, which is not in suggested bank {}.'.format(src_var.name, src_var.register.name, src_var.suggested_bank)) + raise RuntimeError( + "Variable `{}` is in register `{}`, which is not in suggested bank {}.".format( + src_var.name, src_var.register.name, src_var.suggested_bank + ) + ) if b_generated_keygen_var: # Mark register as dirty since this variable is keygen and # does not exist elsewhere: we want to preserve this value @@ -2392,9 +2482,15 @@ def prepareInstruction(original_xinstr, simulation: Simulation) -> int: if retval == 1: if dst_var.register.bank.bank_index != dst_var.suggested_bank: - raise RuntimeError('Variable `{}` is in register `{}`, which is not in suggested bank {}.'.format(dst_var.name, dst_var.register.name, dst_var.suggested_bank)) + raise RuntimeError( + "Variable `{}` is in register `{}`, which is not in suggested bank {}.".format( + dst_var.name, dst_var.register.name, dst_var.suggested_bank + ) + ) - assert retval == 0 or (retval_instr is not None and __canScheduleInBundle(retval_instr, simulation)) # We should always be able to schedule preparation instructions + assert retval == 0 or ( + retval_instr is not None and __canScheduleInBundle(retval_instr, simulation) + ) # We should always be able to schedule preparation instructions if retval == 0: retval_instr = None @@ -2403,16 +2499,17 @@ def prepareInstruction(original_xinstr, simulation: Simulation) -> int: if hasBankWriteConflict(retval_instr, simulation): assert not isinstance(retval_instr, xinst.Move) # Moves must be scheduled immediately # Write cycle conflict found, so, update found instruction cycle ready - new_cycle_ready = CycleType(bundle=simulation.current_cycle.bundle, cycle=max(retval_instr.cycle_ready.cycle, simulation.current_cycle.cycle) + 1) + new_cycle_ready = CycleType( + bundle=simulation.current_cycle.bundle, cycle=max(retval_instr.cycle_ready.cycle, simulation.current_cycle.cycle) + 1 + ) retval_instr.cycle_ready = new_cycle_ready return retval, retval_instr -def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, - max_bundle_size: int, - mem_model: MemoryModel, - replacement_policy, - progress_verbose: bool = False) -> (list, list, list, int): + +def scheduleASMISAInstructions( + dependency_graph: nx.DiGraph, max_bundle_size: int, mem_model: MemoryModel, replacement_policy, progress_verbose: bool = False +) -> (list, list, list, int): """ Schedules ASM-ISA instructions based on a dependency graph of XInsts to minimize idle cycles. @@ -2426,25 +2523,27 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, Returns: tuple: A tuple containing lists of xinst, cinst, minst, and the total idle cycles. """ - simulation = Simulation(dependency_graph, - max_bundle_size, # Max number of instructions in a bundle - mem_model, - replacement_policy, - progress_verbose) + simulation = Simulation( + dependency_graph, + max_bundle_size, # Max number of instructions in a bundle + mem_model, + replacement_policy, + progress_verbose, + ) # DEBUG iter_counter = 0 pisa_instr_counter = 0 # ENDDEBUG if progress_verbose: - print('Dependency Graph') - print(f' Initial number of dependencies: {simulation.dependency_graph.size()}') - print('Scheduling metadata preparation.') + print("Dependency Graph") + print(f" Initial number of dependencies: {simulation.dependency_graph.size()}") + print("Scheduling metadata preparation.") simulation.loadMetadata() if progress_verbose: - print('Scheduling XInstructions...') + print("Scheduling XInstructions...") try: b_flush_bundle = False @@ -2468,13 +2567,11 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, if new_bundle and len(simulation.xinsts) % simulation.max_bundles_per_xinstfetch == 0: if progress_verbose: pct = int(simulation.scheduled_xinsts_count * 100 / simulation.total_instructions) - print("{}% - {}/{}".format(pct, - simulation.scheduled_xinsts_count, - simulation.total_instructions)) + print("{}% - {}/{}".format(pct, simulation.scheduled_xinsts_count, simulation.total_instructions)) # Handle xinstfetch - xinstfetch = cinst.XInstFetch(len(simulation.xinstfetch_cinsts_buffer), - simulation.xinstfetch_xq_addr, - simulation.xinstfetch_hbm_addr) + xinstfetch = cinst.XInstFetch( + len(simulation.xinstfetch_cinsts_buffer), simulation.xinstfetch_xq_addr, simulation.xinstfetch_hbm_addr + ) xinstfetch.schedule(simulation.current_cycle, len(simulation.xinstfetch_cinsts_buffer) + 1) simulation.xinstfetch_cinsts_buffer.append(xinstfetch) @@ -2486,9 +2583,11 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, if progress_verbose: print("XInst queue filled: wrapping around...") # Flush buffered xinstfetches to cinst - simulation.cinsts = simulation.cinsts[:simulation.xinstfetch_location_idx_in_cinsts] \ - + simulation.xinstfetch_cinsts_buffer \ - + simulation.cinsts[simulation.xinstfetch_location_idx_in_cinsts:] + simulation.cinsts = ( + simulation.cinsts[: simulation.xinstfetch_location_idx_in_cinsts] + + simulation.xinstfetch_cinsts_buffer + + simulation.cinsts[simulation.xinstfetch_location_idx_in_cinsts :] + ) # Point to next location to insert xinstfetches simulation.xinstfetch_location_idx_in_cinsts = len(simulation.cinsts) simulation.xinstfetch_cinsts_buffer = [] # Buffer flushed, start new @@ -2498,15 +2597,19 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, # Remove any write cycles that have passed simulation.cleanupPendingWriteCycles() - while True: # do/while - if simulation.topo_start_idx < len(simulation.full_topo_sort) \ - and len(simulation.topo_sort) < Simulation.MIN_INSTRUCTIONS_IN_TOPO_SORT: + while True: # do/while + if ( + simulation.topo_start_idx < len(simulation.full_topo_sort) + and len(simulation.topo_sort) < Simulation.MIN_INSTRUCTIONS_IN_TOPO_SORT + ): if len(simulation.priority_queue) < Simulation.MIN_INSTRUCTIONS_IN_TOPO_SORT: - simulation.topo_sort += simulation.full_topo_sort[simulation.topo_start_idx:simulation.topo_start_idx + Simulation.INSTRUCTION_WINDOW_SIZE] + simulation.topo_sort += simulation.full_topo_sort[ + simulation.topo_start_idx : simulation.topo_start_idx + Simulation.INSTRUCTION_WINDOW_SIZE + ] simulation.topo_start_idx += Simulation.INSTRUCTION_WINDOW_SIZE simulation.b_topo_sort_changed = True # Added to topo window - assert len(simulation.priority_queue) > 0 or len(simulation.topo_sort) > 0, 'Possible infinite loop detected.' + assert len(simulation.priority_queue) > 0 or len(simulation.topo_sort) > 0, "Possible infinite loop detected." # Try to exhaust the priority queue first: # These may introduce some inefficiency to the schedule, but avoids @@ -2527,13 +2630,13 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, # Found first instruction with dependencies last_idx = idx - 1 break - instr = simulation.dependency_graph.nodes[instr_id]['instruction'] + instr = simulation.dependency_graph.nodes[instr_id]["instruction"] simulation.priority_queue_push(instr) simulation.b_priority_queue_changed = True # Remove all instructions that got queued for scheduling if last_idx >= 0: - simulation.topo_sort = simulation.topo_sort[last_idx + 1:] + simulation.topo_sort = simulation.topo_sort[last_idx + 1 :] if xinstr: # Next instruction to schedule may have changed after pulling from topo sort simulation.priority_queue_push(xinstr) @@ -2549,11 +2652,11 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, # else, attempt to refill topo sort (restart the top of do/while loop) break # There is, at least, one instruction to schedule - assert(len(simulation.xinsts_bundle) < simulation.max_bundle_size) # We should have space in current bundle for an xinstruction + assert len(simulation.xinsts_bundle) < simulation.max_bundle_size # We should have space in current bundle for an xinstruction # Find next xinstruction to schedule if not xinstr: - assert(simulation.priority_queue) + assert simulation.priority_queue xinstr = simulation.findNextInstructionToSchedule() if not xinstr: # No instruction left to schedule this bundle @@ -2569,23 +2672,26 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, if b_bundle_needs_fix: # DEBUG if GlobalConfig.debugVerbose: - print(f'---- Fixing short bundle {simulation.current_cycle.bundle}') + print(f"---- Fixing short bundle {simulation.current_cycle.bundle}") # ENDDEBUG # Flush register banks and attempt to schedule again for bank_idx in range(1, len(simulation.mem_model.register_banks)): - mem_utilities.flushRegisterBank(simulation.mem_model.register_banks[bank_idx], - simulation.current_cycle, - simulation.replacement_policy, - simulation.live_vars, - pct=0.5) + mem_utilities.flushRegisterBank( + simulation.mem_model.register_banks[bank_idx], + simulation.current_cycle, + simulation.replacement_policy, + simulation.live_vars, + pct=0.5, + ) # Attempt to schedule instructions slated for next bundle in this bundle tmp_set = set() for _, xinstr in simulation.priority_queue: if xinstr.cycle_ready.bundle == simulation.current_cycle.bundle + 1: if xinstr.cycle_ready.cycle <= 1: - xinstr.cycle_ready = CycleType(bundle=simulation.current_cycle.bundle, - cycle=xinstr.cycle_ready.cycle) + xinstr.cycle_ready = CycleType( + bundle=simulation.current_cycle.bundle, cycle=xinstr.cycle_ready.cycle + ) tmp_set.add(xinstr) for xinstr in tmp_set: simulation.priority_queue_push(xinstr) @@ -2596,7 +2702,7 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, b_flush_bundle = xinstr is None if not b_flush_bundle: - assert(xinstr is not None) # Only None if priority queue is empty + assert xinstr is not None # Only None if priority queue is empty # Attempt to schedule xinstruction @@ -2621,14 +2727,13 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, if GlobalConfig.debugVerbose: if iter_counter % int(GlobalConfig.debugVerbose) == 0: - print('prep_counter', prep_counter) + print("prep_counter", prep_counter) xinstr_prepped, xinstr = prepareInstruction(original_xinstr, simulation) if xinstr_prepped == 0: assert xinstr is None # Failed to prepare instruction in this bundle, leave it for next bundle - original_xinstr.cycle_ready = CycleType(bundle=simulation.current_cycle.bundle + 1, - cycle=0) + original_xinstr.cycle_ready = CycleType(bundle=simulation.current_cycle.bundle + 1, cycle=0) # Add back to priority queue simulation.priority_queue_push(original_xinstr) elif xinstr != original_xinstr: @@ -2649,21 +2754,23 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, # Ready to schedule xinstruction # Check if xinstruction is cycle ready for scheduling elif xinstr.cycle_ready > simulation.current_cycle: - if prep_counter > 0: # Instructions were added to prep original if original_xinstr == xinstr: - assert (xinstr_prepped == 1) + assert xinstr_prepped == 1 # Original instruction prepped in this group, but not ready to schedule yet: # Put it back in the priority queue during schedule update phase else: # Xinstr is not the original, but one needed to prepare the original - assert not isinstance(xinstr, xinst.Move), f'xinstr = {repr(xinstr)} \ncycle = {simulation.current_cycle}; iter = {iter_counter}' + assert not isinstance( + xinstr, xinst.Move + ), f"xinstr = {repr(xinstr)} \ncycle = {simulation.current_cycle}; iter = {iter_counter}" # Cycle for xinstr is not ready yet, so, # put it back in the correct place in the simulation pipeline - assert xinstr.id in simulation.dependency_graph \ - and simulation.dependency_graph.in_degree(xinstr.id) <= 0 + assert ( + xinstr.id in simulation.dependency_graph and simulation.dependency_graph.in_degree(xinstr.id) <= 0 + ) simulation.addXInstrBackIntoPipeline(xinstr) # This will cause the schedule update phase below to put original instruction @@ -2676,15 +2783,12 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, # Nop required idle_cycles_required = xinstr.cycle_ready.cycle - simulation.current_cycle.cycle - if scheduleXNOP(xinstr, - idle_cycles_required, - simulation): + if scheduleXNOP(xinstr, idle_cycles_required, simulation): simulation.total_idle_cycles += idle_cycles_required else: # Could not schedule required NOP in this bundle: # Leave xinstruction for next bundle - xinstr.cycle_ready = CycleType(bundle=simulation.current_cycle.bundle + 1, - cycle=1) + xinstr.cycle_ready = CycleType(bundle=simulation.current_cycle.bundle + 1, cycle=1) # Add back to pipeline during schedule update phase xinstr = None @@ -2692,14 +2796,17 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, # We are still valid for scheduling # At this point, xinstruction should be in ready cycle - assert(__canScheduleInBundle(xinstr, simulation, padding=0)) - assert(simulation.current_cycle >= xinstr.cycle_ready) + assert __canScheduleInBundle(xinstr, simulation, padding=0) + assert simulation.current_cycle >= xinstr.cycle_ready # Simulate schedule of xinstruction simulation.current_cycle += xinstr.schedule(simulation.current_cycle, len(simulation.xinsts_bundle) + 1) # Mark the used lives - xinstr_var_names = set(v.name for v in xinstr.sources + xinstr.dests \ - if isinstance(v, Variable) and not isinstance(v, DummyVariable)) + xinstr_var_names = set( + v.name + for v in xinstr.sources + xinstr.dests + if isinstance(v, Variable) and not isinstance(v, DummyVariable) + ) if isinstance(xinstr, xinst.XStore): simulation.live_outs.update(xinstr_var_names) for var_name in xinstr_var_names: @@ -2714,16 +2821,17 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, pisa_instr_counter += 1 if GlobalConfig.debugVerbose: if iter_counter % int(GlobalConfig.debugVerbose) == 0: - print(f'P-ISA scheduled: {pisa_instr_counter}') + print(f"P-ISA scheduled: {pisa_instr_counter}") # Check for completed outputs to flush for variable in original_xinstr.dests: # This assertion may be broken if move instructions end up back in the topo sort - assert(variable.name not in simulation.mem_model.store_buffer \ - or isinstance(original_xinstr, xinst.XStore)) - if variable.name in simulation.mem_model.output_variables \ - and not variable.accessed_by_xinsts \ - and variable.name not in simulation.mem_model.store_buffer: + assert variable.name not in simulation.mem_model.store_buffer or isinstance(original_xinstr, xinst.XStore) + if ( + variable.name in simulation.mem_model.output_variables + and not variable.accessed_by_xinsts + and variable.name not in simulation.mem_model.store_buffer + ): # Variable is an output variable # and it is no longer needed # and it is not in-flight to be stored already @@ -2756,7 +2864,7 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, last_xinstr = original_xinstr for output_var_name in simulation.mem_model.output_variables: variable = simulation.mem_model.variables[output_var_name] - assert(not variable.accessed_by_xinsts) # Variable should not be accessed any more + assert not variable.accessed_by_xinsts # Variable should not be accessed any more if not simulation.flushOutputVariableFromRegister(variable): break # Continue next bundle @@ -2771,9 +2879,11 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, # Flush buffered xinstfetches to cinst if GlobalConfig.useXInstFetch: if len(simulation.xinstfetch_cinsts_buffer) > 0: - simulation.cinsts = simulation.cinsts[:simulation.xinstfetch_location_idx_in_cinsts] \ - + simulation.xinstfetch_cinsts_buffer \ - + simulation.cinsts[simulation.xinstfetch_location_idx_in_cinsts:] + simulation.cinsts = ( + simulation.cinsts[: simulation.xinstfetch_location_idx_in_cinsts] + + simulation.xinstfetch_cinsts_buffer + + simulation.cinsts[simulation.xinstfetch_location_idx_in_cinsts :] + ) # TODO: ################################# @@ -2792,7 +2902,7 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, simulation.cinsts.append(cexit) # Rule: last instruction in MInstQ must be a sync pointing to cexit + 1 - last_msyncc = minst.MSyncc(cexit.id[0], cexit, comment='terminating MInstQ') + last_msyncc = minst.MSyncc(cexit.id[0], cexit, comment="terminating MInstQ") last_msyncc.schedule(simulation.current_cycle, len(simulation.minsts) + 1) simulation.minsts.append(last_msyncc) @@ -2809,20 +2919,21 @@ def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, cnt = 0 while cnt < 10 and simulation.priority_queue: _, xinstr = simulation.priority_queue.pop() - print('Cycle ready', xinstr.cycle_ready) + print("Cycle ready", xinstr.cycle_ready) print(repr(xinstr)) cnt += 1 if len(simulation.priority_queue) > 10: - print('...') - print('priority_queue', len(simulation.priority_queue)) - print('topo_sort', len(simulation.topo_sort)) - print('current cycle', simulation.current_cycle) + print("...") + print("priority_queue", len(simulation.priority_queue)) + print("topo_sort", len(simulation.topo_sort)) + print("current cycle", simulation.current_cycle) simulation.mem_model.dump() import traceback + traceback.print_exc() print(ex) else: raise - return simulation.minsts, simulation.cinsts, simulation.xinsts, simulation.total_idle_cycles \ No newline at end of file + return simulation.minsts, simulation.cinsts, simulation.xinsts, simulation.total_idle_cycles diff --git a/assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py b/assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py index 84719396..6a0e0f42 100644 --- a/assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py +++ b/assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py @@ -3,17 +3,16 @@ """Preprocessing utilities for HERACLES assembler stages.""" -from typing import Tuple import networkx as nx from assembler.common.constants import Constants from assembler.instructions import xinst -from assembler.instructions.xinst.xinstruction import XInstruction from assembler.instructions.xinst import parse_xntt +from assembler.instructions.xinst.xinstruction import XInstruction from assembler.memory_model import MemoryModel -def dependency_graph_for_vars(insts_list: list) -> Tuple[nx.Graph, set, set]: +def dependency_graph_for_vars(insts_list: list) -> tuple[nx.Graph, set, set]: """ Given the listing of instructions, this method returns the dependency graph for the variables in the listing and the sets of destination and source variables. @@ -50,9 +49,7 @@ def dependency_graph_for_vars(insts_list: list) -> Tuple[nx.Graph, set, set]: for v_i in range(idx + 1, len(inst.dests)): v_next = inst.dests[v_i] if v.name == v_next.name: - raise RuntimeError( - f"Cannot write to the same variable in the same instruction more than once: {inst.to_pisa_format()}" - ) + raise RuntimeError(f"Cannot write to the same variable in the same instruction more than once: {inst.to_pisa_format()}") if not retval.has_edge(v.name, v_next.name): retval.add_edge(v.name, v_next.name) # Mac deps already handled in the Mac instructions themselves @@ -76,9 +73,7 @@ def dependency_graph_for_vars(insts_list: list) -> Tuple[nx.Graph, set, set]: return retval, all_dests_vars, all_sources_vars -def inject_variable_copy( - mem_model: MemoryModel, insts_list: list, instruction_idx: int, var_name: str -) -> int: +def inject_variable_copy(mem_model: MemoryModel, insts_list: list, instruction_idx: int, var_name: str) -> int: """ Injects a copy of a variable into the instruction list at the specified index. @@ -95,9 +90,7 @@ def inject_variable_copy( IndexError: If the instruction index is out of range. """ if instruction_idx < 0 or instruction_idx >= len(insts_list): - raise IndexError( - f"instruction_idx: Expected index in range [0, {len(insts_list)}), but received {instruction_idx}." - ) + raise IndexError(f"instruction_idx: Expected index in range [0, {len(insts_list)}), but received {instruction_idx}.") last_instruction: XInstruction = insts_list[instruction_idx] last_instruction_sources = last_instruction.sources[:] for idx, src_var in enumerate(last_instruction_sources): @@ -137,42 +130,27 @@ def reduce_var_deps_by_var(mem_model: MemoryModel, insts_list: list, var_name: s while last_pos < len(insts_list): if var_name in (v.name for v in insts_list[last_pos].sources): last_instruction = insts_list[last_pos] - if isinstance(last_instruction, (xinst.Mac, xinst.Maci)): + if isinstance(last_instruction, xinst.Mac | xinst.Maci): # Check if the conflicting variable is the accumulator if last_instruction.sources[0].name == var_name: # Turn all other variables into copies for src_var in last_instruction.sources[1:]: - last_pos = inject_variable_copy( - mem_model, insts_list, last_pos, src_var.name - ) + last_pos = inject_variable_copy(mem_model, insts_list, last_pos, src_var.name) assert last_instruction == insts_list[last_pos] last_instruction = None # avoid further processing of instruction last_pos += 1 continue # If conflict variable was not the accumulator, proceed to change the other variables # Skip copy, twxntt and xrshuffle - if not isinstance( - last_instruction, - ( - xinst.twiNTT, - xinst.twiNTT, - xinst.irShuffle, - xinst.rShuffle, - xinst.Copy, - ), - ): + if not isinstance(last_instruction, xinst.twiNTT | xinst.twiNTT | xinst.irShuffle | xinst.rShuffle | xinst.Copy): # Break up indicated variable in sources into a temp copy - last_pos = inject_variable_copy( - mem_model, insts_list, last_pos, var_name - ) + last_pos = inject_variable_copy(mem_model, insts_list, last_pos, var_name) assert last_instruction == insts_list[last_pos] last_pos += 1 -def assign_register_banks_to_vars( - mem_model: MemoryModel, insts_list: list, use_bank0: bool, verbose=False -) -> str: +def assign_register_banks_to_vars(mem_model: MemoryModel, insts_list: list, use_bank0: bool, verbose=False) -> str: """ Assigns register banks to variables using vertex coloring graph algorithm. @@ -212,18 +190,14 @@ def assign_register_banks_to_vars( while needs_reduction: # Extract the dependency graph for variables dep_graph_vars, dest_names, source_names = dependency_graph_for_vars(insts_list) - only_sources = ( - source_names - dest_names - ) # Find which variables are ever only used as sources + only_sources = source_names - dest_names # Find which variables are ever only used as sources color_dict = nx.greedy_color(dep_graph_vars) # Do coloring needs_reduction = False for var_name, bank in color_dict.items(): if bank > 2: if var_name in reduced_vars: - raise RuntimeError( - f"Found invalid bank {bank} > 2 for variable {var_name} already reduced." - ) + raise RuntimeError(f"Found invalid bank {bank} > 2 for variable {var_name} already reduced.") # DEBUG print if verbose: print(f"Variable {var_name} ({bank}) requires reduction.") @@ -238,9 +212,7 @@ def assign_register_banks_to_vars( bank = color_dict[v.name] assert bank < 3, f"{v.name}, {bank}" # If requested, keep vars used only as sources in bank 0 - v.suggested_bank = bank + ( - 0 if use_bank0 and (v.name in only_sources) else 1 - ) + v.suggested_bank = bank + (0 if use_bank0 and (v.name in only_sources) else 1) retval: str = mem_model.findUniqueVarName() @@ -249,21 +221,15 @@ def assign_register_banks_to_vars( def ntt_kernel_grammar(line): """Parse NTT kernel grammar from a line.""" - return parse_xntt.parseXNTTKernelLine( - line, xinst.NTT.op_name_pisa, Constants.TW_GRAMMAR_SEPARATOR - ) + return parse_xntt.parseXNTTKernelLine(line, xinst.NTT.op_name_pisa, Constants.TW_GRAMMAR_SEPARATOR) def intt_kernel_grammar(line): """Parse INTT kernel grammar from a line.""" - return parse_xntt.parseXNTTKernelLine( - line, xinst.iNTT.op_name_pisa, Constants.TW_GRAMMAR_SEPARATOR - ) + return parse_xntt.parseXNTTKernelLine(line, xinst.iNTT.op_name_pisa, Constants.TW_GRAMMAR_SEPARATOR) -def preprocess_pisa_kernel_listing( - mem_model: MemoryModel, line_iter, progress_verbose: bool = False -) -> list: +def preprocess_pisa_kernel_listing(mem_model: MemoryModel, line_iter, progress_verbose: bool = False) -> list: """ Parses a P-ISA kernel listing, given as an iterator for strings, where each is a line representing a P-ISA instruction. @@ -305,9 +271,7 @@ def preprocess_pisa_kernel_listing( parsed_op = intt_kernel_grammar(s_line) if parsed_op: # Instruction is a P-ISA xntt - parsed_insts = parse_xntt.generateXNTT( - mem_model, parsed_op, new_id=line_no - ) + parsed_insts = parse_xntt.generateXNTT(mem_model, parsed_op, new_id=line_no) if not parsed_insts: # Instruction is one that is represented by single XInst inst = xinst.createFromPISALine(mem_model, s_line, line_no) @@ -315,9 +279,7 @@ def preprocess_pisa_kernel_listing( parsed_insts = [inst] if not parsed_insts: - raise SyntaxError( - f"Line {line_no}: unable to parse kernel instruction:\n{s_line}" - ) + raise SyntaxError(f"Line {line_no}: unable to parse kernel instruction:\n{s_line}") retval += parsed_insts diff --git a/assembler_tools/hec-assembler-tools/assembler/stages/scheduler.py b/assembler_tools/hec-assembler-tools/assembler/stages/scheduler.py index 4ee8f65c..8eaad4b6 100644 --- a/assembler_tools/hec-assembler-tools/assembler/stages/scheduler.py +++ b/assembler_tools/hec-assembler-tools/assembler/stages/scheduler.py @@ -1,4 +1,7 @@ -import collections +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import collections import heapq import networkx as nx from typing import NamedTuple @@ -10,6 +13,7 @@ from assembler.memory_model.variable import Variable from pickle import TRUE + def __orderKeygenVars(mem_model: MemoryModel) -> list: """ Returns the name of the keygen variables in the order they have to be generated. @@ -30,20 +34,18 @@ def __orderKeygenVars(mem_model: MemoryModel) -> list: for var_name, (seed_idx, key_idx) in mem_model.keygen_variables.items(): assert seed_idx < len(retval) if key_idx >= len(retval[seed_idx]): - retval[seed_idx] += ((key_idx - len(retval[seed_idx]) + 1) * [None]) + retval[seed_idx] += (key_idx - len(retval[seed_idx]) + 1) * [None] retval[seed_idx][key_idx] = var_name # Validate that no key material was skipped for seed_idx, l in enumerate(retval): for key_idx, var_name in enumerate(l): if var_name is None: - raise RuntimeError(f'Detected key material {key_idx} generation skipped for seed {seed_idx}.') + raise RuntimeError(f"Detected key material {key_idx} generation skipped for seed {seed_idx}.") return retval -def __findVarInPrevDeps(deps_graph: nx.DiGraph, - instr_id: tuple, - var_name: str, - b_only_sources: bool = False) -> tuple: + +def __findVarInPrevDeps(deps_graph: nx.DiGraph, instr_id: tuple, var_name: str, b_only_sources: bool = False) -> tuple: """ Returns the ID for an instruction that uses the specified variable, and is a dependency for input instruction. @@ -64,7 +66,7 @@ def __findVarInPrevDeps(deps_graph: nx.DiGraph, retval = None if instr_id in deps_graph: - checked_instructions = set() # avoids checking same instruction multiple times + checked_instructions = set() # avoids checking same instruction multiple times dep_instructions = collections.deque() last_instr = deps_graph.nodes[instr_id]["instruction"] # Repeat while we have instructions to process and we haven't found what we need @@ -87,9 +89,8 @@ def __findVarInPrevDeps(deps_graph: nx.DiGraph, return retval -def enforceKeygenOrdering(deps_graph: nx.DiGraph, - mem_model: MemoryModel, - verbose_ostream = None): + +def enforceKeygenOrdering(deps_graph: nx.DiGraph, mem_model: MemoryModel, verbose_ostream=None): """ Given the dependency graph for instructions and a complete memory model, injects instructions and dependencies to enforce ordering required for the keygen subsystem. @@ -129,23 +130,25 @@ def enforceKeygenOrdering(deps_graph: nx.DiGraph, ordered_kg_vars = __orderKeygenVars(mem_model) if ordered_kg_vars and verbose_ostream: - print("Enforcing keygen ordering", file = verbose_ostream) + print("Enforcing keygen ordering", file=verbose_ostream) for seed_idx, kg_seed_list in enumerate(ordered_kg_vars): if verbose_ostream: - print(f"Seed {seed_idx} / {len(ordered_kg_vars)}", file = verbose_ostream) + print(f"Seed {seed_idx} / {len(ordered_kg_vars)}", file=verbose_ostream) last_copy_id = None - b_copy_deps_found = False # tracks whether we have correctly added dependencies for the new copy + b_copy_deps_found = False # tracks whether we have correctly added dependencies for the new copy for key_idx, kg_var_name in enumerate(kg_seed_list): # Create a copy instruction and make all instructions using this kg var depend on it src = mem_model.variables[kg_var_name] # Create temp target variable dst = mem_model.retrieveVarAdd(mem_model.findUniqueVarName(), src.suggested_bank) - copy_instr = xinst.Copy(0, # id - 0, # N - [ dst ], - [ src ], - comment=f'injected copy to generate keygen var {kg_var_name} (seed = {seed_idx}, key = {key_idx})') + copy_instr = xinst.Copy( + 0, # id + 0, # N + [dst], + [src], + comment=f"injected copy to generate keygen var {kg_var_name} (seed = {seed_idx}, key = {key_idx})", + ) deps_graph.add_node(copy_instr.id, instruction=copy_instr) # Enforce ordering of copies based on ordering of keygen if last_copy_id is not None: @@ -154,8 +157,7 @@ def enforceKeygenOrdering(deps_graph: nx.DiGraph, last_copy_id = copy_instr.id for instr_id in deps_graph: - if instr_id != copy_instr.id \ - and kg_var_name in set(src.name for src in deps_graph.nodes[instr_id]['instruction'].sources): + if instr_id != copy_instr.id and kg_var_name in set(src.name for src in deps_graph.nodes[instr_id]["instruction"].sources): # Found instruction that uses the kg var: if not b_copy_deps_found: @@ -169,19 +171,19 @@ def enforceKeygenOrdering(deps_graph: nx.DiGraph, # dependency -> copy_instr deps_graph.add_edge(dependency_id, copy_instr.id) - b_copy_deps_found = True # found artificial dependencies for copy + b_copy_deps_found = True # found artificial dependencies for copy # Make instruction depend on this injected copy # copy_instr -> instr deps_graph.add_edge(copy_instr.id, instr_id) if ordered_kg_vars and verbose_ostream: - print(f"Seed {len(ordered_kg_vars)} / {len(ordered_kg_vars)}", file = verbose_ostream) + print(f"Seed {len(ordered_kg_vars)} / {len(ordered_kg_vars)}", file=verbose_ostream) # We should not have introduced any cycles with these modifications assert nx.is_directed_acyclic_graph(deps_graph) -def generateInstrDependencyGraph(insts_listing: list, - verbose_ostream = None) -> nx.DiGraph: + +def generateInstrDependencyGraph(insts_listing: list, verbose_ostream=None) -> nx.DiGraph: """ Given a pre-processed P-ISA instructions listing, generates a dependency graph for the instructions based on their inputs and outputs, and any shared HW resources @@ -203,8 +205,8 @@ def generateInstrDependencyGraph(insts_listing: list, class VarTracking(NamedTuple): # Used for clarity - last_write: object # last instruction that wrote to this variable - reads_after_last_write: list # all insts that read from this variable after last write + last_write: object # last instruction that wrote to this variable + reads_after_last_write: list # all insts that read from this variable after last write retval = nx.DiGraph() @@ -215,14 +217,11 @@ class VarTracking(NamedTuple): verbose_report_every_x_insts = 1 # Look up table for already seen variables - vars2insts = {} # dict(var_name, VarTracking ) + vars2insts = {} # dict(var_name, VarTracking ) for idx, inst in enumerate(insts_listing): - if verbose_ostream: if idx % verbose_report_every_x_insts == 0: - print("{}% - {}/{}".format(idx * 100 // len(insts_listing), - idx, - len(insts_listing)), file = verbose_ostream) + print("{}% - {}/{}".format(idx * 100 // len(insts_listing), idx, len(insts_listing)), file=verbose_ostream) # Add new node # All instructions are nodes @@ -242,52 +241,54 @@ class VarTracking(NamedTuple): for inst_dep in vars2insts[variable.name].reads_after_last_write: if inst_dep.id != inst.id: retval.add_edge(inst_dep.id, inst.id) - else: # Add dep to last write - inst_dep = vars2insts[variable.name].last_write # last instruction that wrote to this variable + else: # Add dep to last write + inst_dep = vars2insts[variable.name].last_write # last instruction that wrote to this variable if inst_dep and inst_dep.id != inst.id: retval.add_edge(inst_dep.id, inst.id) # Record write - vars2insts[variable.name] = VarTracking( inst, [] ) # (last inst that wrote to this, all insts that read from it after last write) + vars2insts[variable.name] = VarTracking( + inst, [] + ) # (last inst that wrote to this, all insts that read from it after last write) for variable in inst.sources: if variable.name in vars2insts: # Add dependency to last write - inst_dep = vars2insts[variable.name].last_write # last instruction that wrote to this variable + inst_dep = vars2insts[variable.name].last_write # last instruction that wrote to this variable if inst_dep and inst_dep.id != inst.id: retval.add_edge(inst_dep.id, inst.id) else: # First time seeing this var - vars2insts[variable.name] = VarTracking( None, [] ) + vars2insts[variable.name] = VarTracking(None, []) # Record read vars2insts[variable.name].reads_after_last_write.append(inst) # Different variants to enforce ordering - #print('##### DEBUG #####') + # print('##### DEBUG #####') ### sequential instructions (no reordering) - #print('***** Sequential *****') - #for idx in range(len(insts_listing) - 1): + # print('***** Sequential *****') + # for idx in range(len(insts_listing) - 1): # retval.add_edge(insts_listing[idx].id, insts_listing[idx + 1].id) ## tw before rshuffle - #print('***** twid before rshuffle *****') - #for idx in range(len(insts_listing) - 1): + # print('***** twid before rshuffle *****') + # for idx in range(len(insts_listing) - 1): # if isinstance(insts_listing[idx], xinst.rShuffle): # if isinstance(insts_listing[idx + 1], xinst.twNTT): # print(insts_listing[idx].id) # retval.add_edge(insts_listing[idx + 1].id, insts_listing[idx].id) # rshuffle before tw - #print('***** rshuffle before twid *****') - #for idx in range(len(insts_listing) - 1): + # print('***** rshuffle before twid *****') + # for idx in range(len(insts_listing) - 1): # if isinstance(insts_listing[idx], xinst.rShuffle): # if isinstance(insts_listing[idx + 1], xinst.twNTT): # print(insts_listing[idx].id) # retval.add_edge(insts_listing[idx].id, insts_listing[idx + 1].id) # rshuffles ordered - #print('***** Ordered rshuffles *****') - #for idx in range(len(insts_listing) - 1): + # print('***** Ordered rshuffles *****') + # for idx in range(len(insts_listing) - 1): # if isinstance(insts_listing[idx], xinst.rShuffle): # for j in range(len(insts_listing) - idx): # jdx = j + idx + 1 @@ -297,8 +298,8 @@ class VarTracking(NamedTuple): # break # twid ordered - #print('***** Ordered twntt *****') - #for idx in range(len(insts_listing) - 1): + # print('***** Ordered twntt *****') + # for idx in range(len(insts_listing) - 1): # if isinstance(insts_listing[idx], xinst.twNTT): # for jdx in range(idx + 1, len(insts_listing)): # if isinstance(insts_listing[jdx], xinst.twNTT): @@ -308,16 +309,16 @@ class VarTracking(NamedTuple): # Detect cycles in result if not nx.is_directed_acyclic_graph(retval): - raise nx.NetworkXUnfeasible('Instruction listing must form a Directed Acyclic Graph dependency.') + raise nx.NetworkXUnfeasible("Instruction listing must form a Directed Acyclic Graph dependency.") if verbose_ostream: - print("100% - {0}/{0}".format(len(insts_listing)), file = verbose_ostream) + print("100% - {0}/{0}".format(len(insts_listing)), file=verbose_ostream) # retval contains the dependency graph return retval -def schedulePISAInstructions(dependency_graph: nx.DiGraph, - progress_verbose: bool = False) -> (list, int, int): + +def schedulePISAInstructions(dependency_graph: nx.DiGraph, progress_verbose: bool = False) -> (list, int, int): """ Given the dependency directed acyclic graph of XInsts, returns a schedule for the corresponding P-ISA instructions, that minimizes idle cycles. @@ -332,18 +333,16 @@ def schedulePISAInstructions(dependency_graph: nx.DiGraph, - int: The total number of idle cycles. - int: The number of NOPs inserted. """ + class PrioritizedInstruction(PrioritizedPlaceholder): - def __init__(self, - instruction, - priority_delta = (0, 0)): + def __init__(self, instruction, priority_delta=(0, 0)): super().__init__(priority_delta=priority_delta) self.__instruction = instruction def __repr__(self): - return '<{} (priority = {})>(instruction={}, priority_delta={})'.format(type(self).__name__, - self.priority, - repr(self.instruction), - self.priority_delta) + return "<{} (priority = {})>(instruction={}, priority_delta={})".format( + type(self).__name__, self.priority, repr(self.instruction), self.priority_delta + ) @property def instruction(self): @@ -354,24 +353,22 @@ def _get_priority(self): retval = [] topo_sort = buildVarAccessListFromTopoSort(dependency_graph) - dependency_graph = nx.DiGraph(dependency_graph) # make a copy of the incoming graph to avoid modifying input + dependency_graph = nx.DiGraph(dependency_graph) # make a copy of the incoming graph to avoid modifying input total_idle_cycles = 0 num_nops = 0 - set_processed_instrs = set() # track instructions that have been process to avoid encountering them after scheduling - current_cycle = CycleType(bundle = 0, cycle = 1) - p_queue = [] # Sorted list by priority: ready cycle - b_changed = True # Track when there are changes in the priority queue or dependency graph + set_processed_instrs = set() # track instructions that have been process to avoid encountering them after scheduling + current_cycle = CycleType(bundle=0, cycle=1) + p_queue = [] # Sorted list by priority: ready cycle + b_changed = True # Track when there are changes in the priority queue or dependency graph total_insts = dependency_graph.number_of_nodes() prev_report_pct = -1 while dependency_graph: - if progress_verbose: pct = int(len(retval) * 100 / total_insts) if pct > prev_report_pct and pct % 10 == 0: prev_report_pct = pct print(f"{pct}% - {len(retval)}/{total_insts}") - if b_changed: # If priority queue or dependency graph have changed since last iteration - + if b_changed: # If priority queue or dependency graph have changed since last iteration # Extract all the instructions that can be executed without dependencies # and merge to current instructions that can be executed without dependencies last_idx = -1 @@ -380,16 +377,16 @@ def _get_priority(self): if dependency_graph.in_degree(instr_id) > 0: # Found first instruction with dependencies break - instr = dependency_graph.nodes[instr_id]['instruction'] + instr = dependency_graph.nodes[instr_id]["instruction"] p_queue.append(PrioritizedInstruction(instr)) set_processed_instrs.add(instr.id) last_idx = idx # Remove all instructions that got queued for scheduling if last_idx >= 0: - topo_sort = topo_sort[last_idx + 1:] + topo_sort = topo_sort[last_idx + 1 :] # Reorder priority queue since the items' priorities may change after scheduling an instruction - assert(p_queue) + assert p_queue heapq.heapify(p_queue) # Schedule next instruction @@ -412,23 +409,23 @@ def _get_priority(self): # Make new instruction to execute a nop instr = xinst.Nop(instr.id[0], num_idle_cycles) num_nops += 1 - b_changed = False # No changes in the queue or graph + b_changed = False # No changes in the queue or graph # Do not pop actual instruction from graph or queue since we had to add nops before its scheduling else: # Instruction ready: pop instruction from queue and update dependency graph # (this breaks the heap invariant for p_queue, but we heapify # on every iteration due to priorities changing based on latency) - p_queue = p_queue[:element_idx] + p_queue[element_idx + 1:] - dependents = list(dependency_graph.neighbors(instr.id)) # find instructions that depend on this instruction - dependency_graph.remove_node(instr.id) # remove from graph to update the in_degree of dependendent instrs + p_queue = p_queue[:element_idx] + p_queue[element_idx + 1 :] + dependents = list(dependency_graph.neighbors(instr.id)) # find instructions that depend on this instruction + dependency_graph.remove_node(instr.id) # remove from graph to update the in_degree of dependent instrs # "move" dependent instrs that have no other dependencies to the top of the topo sort - topo_sort = [ instr_id for instr_id in dependents if dependency_graph.in_degree(instr_id) <= 0 ] + topo_sort + topo_sort = [instr_id for instr_id in dependents if dependency_graph.in_degree(instr_id) <= 0] + topo_sort # Do not search the topo sort to actually remove the duplicated instrs because it is O(N) costly: # set_processed_instrs will take care of skipping them once encountered. - b_changed = True # queue and/or graph changed + b_changed = True # queue and/or graph changed - cycle_throughput = instr.schedule(current_cycle, len(retval) + 1) # simulate execution to update cycle ready of dependents + cycle_throughput = instr.schedule(current_cycle, len(retval) + 1) # simulate execution to update cycle ready of dependents retval.append(instr) # Next cycle starts @@ -437,4 +434,4 @@ def _get_priority(self): if progress_verbose: print(f"100% - {total_insts}/{total_insts}") - return retval, total_idle_cycles, num_nops \ No newline at end of file + return retval, total_idle_cycles, num_nops diff --git a/assembler_tools/hec-assembler-tools/config/isa_spec.json b/assembler_tools/hec-assembler-tools/config/isa_spec.json index 3ea0b212..ff0db4ac 100644 --- a/assembler_tools/hec-assembler-tools/config/isa_spec.json +++ b/assembler_tools/hec-assembler-tools/config/isa_spec.json @@ -222,4 +222,4 @@ } } } -} \ No newline at end of file +} diff --git a/assembler_tools/hec-assembler-tools/debug_tools/deadlock_test.py b/assembler_tools/hec-assembler-tools/debug_tools/deadlock_test.py index b0009351..87797c18 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/deadlock_test.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/deadlock_test.py @@ -1,9 +1,13 @@ -import argparse +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse import os # Searches the CInstQ and MInstQ to find deadlocks caused by sync instructions. # Raises exception on first deadlock found, otherwise, completes successfully. + def makeUniquePath(path: str): """ Normalizes and expand a given file path. @@ -16,6 +20,7 @@ def makeUniquePath(path: str): """ return os.path.normcase(os.path.realpath(os.path.expanduser(path))) + def loadInstructions(istream) -> list: """ Loads instructions from an input iterator. @@ -33,21 +38,22 @@ def loadInstructions(istream) -> list: # Separate comment s_instr = "" s_comment = "" - comment_start_idx = line.find('#') + comment_start_idx = line.find("#") if comment_start_idx < 0: s_instr = line else: s_instr = line[:comment_start_idx] - s_comment = line[comment_start_idx + 1:] + s_comment = line[comment_start_idx + 1 :] # Tokenize instruction - s_instr = map(lambda s: s.strip(), s_instr.split(",")) + s_instr = (s.strip() for s in s_instr.split(",")) # Add instruction to collection retval.append((list(s_instr), s_comment)) return retval + def findDeadlock(minsts: list, cinsts: list) -> tuple: """ Searches the CInstQ and MInstQ to find the first deadlock. @@ -68,13 +74,13 @@ def findDeadlock(minsts: list, cinsts: list) -> tuple: # Remove all non-syncs from q sync_idx = len(q) for idx, instr in enumerate(q): - if 'sync' in instr[1]: + if "sync" in instr[1]: # Sync found sync_idx = idx break q = q[sync_idx:] if q: - assert 'sync' in q[0][1], 'Next instruction in queue is not a sync!' + assert "sync" in q[0][1], "Next instruction in queue is not a sync!" if sync_idx != 0: # Queue moved: restart the deadlock watcher @@ -107,6 +113,7 @@ def findDeadlock(minsts: list, cinsts: list) -> tuple: return retval + def main(input_dir: str, input_prefix: str = None): """ Main function to check for deadlocks in instruction queues. @@ -119,30 +126,31 @@ def main(input_dir: str, input_prefix: str = None): if not input_prefix: input_prefix = os.path.basename(input_dir) - print('Deadlock test.') + print("Deadlock test.") print() - print('Input dir:', input_dir) - print('Input prefix:', input_prefix) + print("Input dir:", input_dir) + print("Input prefix:", input_prefix) xinst_file = os.path.join(input_dir, input_prefix + ".xinst") cinst_file = os.path.join(input_dir, input_prefix + ".cinst") minst_file = os.path.join(input_dir, input_prefix + ".minst") - with open(xinst_file, 'r') as f_xin: + with open(xinst_file) as f_xin: xinsts = loadInstructions(f_xin) xinsts = [x for (x, _) in xinsts] - with open(cinst_file, 'r') as f_cin: + with open(cinst_file) as f_cin: cinsts = loadInstructions(f_cin) cinsts = [x for (x, _) in cinsts] - with open(minst_file, 'r') as f_min: + with open(minst_file) as f_min: minsts = loadInstructions(f_min) minsts = [x for (x, _) in minsts] deadlock_indices = findDeadlock(minsts, cinsts) if deadlock_indices is not None: - raise RuntimeError('Deadlock detected: MinstQ: {}, CInstQ: {}'.format(deadlock_indices[0], deadlock_indices[1])) + raise RuntimeError(f"Deadlock detected: MinstQ: {deadlock_indices[0]}, CInstQ: {deadlock_indices[1]}") + + print("No deadlock detected between CInstQ and MInstQ.") - print('No deadlock detected between CInstQ and MInstQ.') if __name__ == "__main__": module_name = os.path.basename(__file__) @@ -157,4 +165,4 @@ def main(input_dir: str, input_prefix: str = None): main(args.input_dir, args.input_prefix) print() - print(module_name, "- Complete") \ No newline at end of file + print(module_name, "- Complete") diff --git a/assembler_tools/hec-assembler-tools/debug_tools/isolation_test.py b/assembler_tools/hec-assembler-tools/debug_tools/isolation_test.py index 21ac5af0..6dcbaa74 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/isolation_test.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/isolation_test.py @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import argparse import os import re @@ -17,20 +20,32 @@ def parse_args(): argparse.Namespace: Parsed command-line arguments. """ parser = argparse.ArgumentParser( - description=("Isolation Test.\n" - "Given a set of variables in P-ISA, this script will replace all instructions that do not" - " affect the variable with appropriate NOPs.")) - parser.add_argument("--pisa_file", required= True, help="Input P-ISA prep (.csv) file.") + description=( + "Isolation Test.\n" + "Given a set of variables in P-ISA, this script will replace all instructions that do not" + " affect the variable with appropriate NOPs." + ) + ) + parser.add_argument("--pisa_file", required=True, help="Input P-ISA prep (.csv) file.") parser.add_argument("--xinst_file", required=True, help="Input (xinst) instruction file.") parser.add_argument("--out_file", default="", help="Output file name.") - parser.add_argument("--track", default="", dest="variables_set", nargs='+', help="Set of variables to track.") - parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, - help=("If enabled, extra information and progress reports are printed to stdout. " - "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) + parser.add_argument("--track", default="", dest="variables_set", nargs="+", help="Set of variables to track.") + parser.add_argument( + "-v", + "--verbose", + dest="verbose", + action="count", + default=0, + help=( + "If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv" + ), + ) args = parser.parse_args() return args + if __name__ == "__main__": module_name = os.path.basename(__file__) @@ -40,7 +55,7 @@ def parse_args(): pisa_prep_file = args.pisa_file xinst_file = args.xinst_file output_file = "" - if (args.out_file): + if args.out_file: output_file = args.out_file else: # Create the new file name @@ -53,42 +68,42 @@ def parse_args(): if args.verbose > 0: print(module_name) print() - print("P-ISA: {0}".format(pisa_prep_file)) - print("Xinst File: {0}".format(xinst_file)) - print("Output Name: {0}".format(output_file)) - print("Tracking: {0}".format(variables_set)) + print(f"P-ISA: {pisa_prep_file}") + print(f"Xinst File: {xinst_file}") + print(f"Output Name: {output_file}") + print(f"Tracking: {variables_set}") # Find all related variables pisa_instrs = [] pisa_file_contents = [] - with open(pisa_prep_file, 'r') as f_in_pisa: + with open(pisa_prep_file) as f_in_pisa: pisa_file_contents = [line for line in f_in_pisa if line] - l = [] + l = [] # noqa: E741 set_updated = True while set_updated: set_updated = False - for line_idx, line in enumerate(pisa_file_contents): + for _, line in enumerate(pisa_file_contents): # Remove comment s_split = line.split("#") line = s_split[0] # Split into components - tmp_split = map(lambda s: s.strip(), line.split(",")) + tmp_split = (s.strip() for s in line.split(",")) s_split = [] for component in tmp_split: - s_split.append(component.split('(')[0].strip()) + s_split.append(component.split("(")[0].strip()) pisa_instrs.append(s_split[1:]) if any(x in s_split for x in variables_set): # Add all other variables as dependents - if s_split[1] == 'muli' or s_split[1] == 'maci': + if s_split[1] == "muli" or s_split[1] == "maci": s_split = s_split[2:-2] else: s_split = s_split[2:-1] - new_vars = set(v for v in s_split if re.search('^[A-Za-z_][A-Za-z0-9_]*', v)) - if 'iN' in new_vars: - print('iN') + new_vars = {v for v in s_split if re.search("^[A-Za-z_][A-Za-z0-9_]*", v)} + if "iN" in new_vars: + print("iN") if new_vars - variables_set: - l += [x for x in new_vars if x not in variables_set] + l += [x for x in new_vars if x not in variables_set] # noqa: E741 variables_set |= new_vars set_updated = True @@ -101,36 +116,36 @@ def parse_args(): pisa_instr_num_set.add(idx + 1) # Keep only xinsts that are used for the kept p-isa instr - with open(xinst_file, 'r') as f_in: - with open(output_file, 'w') as f_out: + with open(xinst_file) as f_in: + with open(output_file, "w") as f_out: for line in f_in: # Remove comment s_split = line.split("#") s_line = s_split[0].strip() # Split into components - s_split = list(map(lambda s: s.strip(), line.split(","))) - out_line = '' + s_split = (s.strip() for s in line.split(",")) + out_line = "" if int(s_split[1]) in pisa_instr_num_set: # Xinstruction is needed to complete p-isa instr - if s_split[2] not in ('move', 'xstore', 'nop'): + if s_split[2] not in ("move", "xstore", "nop"): out_line = s_line + " # " + str(pisa_instrs[int(s_split[1]) - 1]) else: out_line = line.strip() - elif 'xstore' in s_line: + elif "xstore" in s_line: # All xstores are required because they are sync points with CInstQ out_line = s_line.strip() - elif 'exit' in s_line: + elif "exit" in s_line: # Keep all exits out_line = s_line.strip() - elif 'rshuffle' in s_line: + elif "rshuffle" in s_line: # Other rshuffles are converted to nops for timing - out_line = '{}, {}, nop, {} # rshuffle'.format(s_split[0], s_split[1], s_split[7]) - elif 'nop' in s_line: + out_line = f"{s_split[0]}, {s_split[1]}, nop, {s_split[7]} # rshuffle" + elif "nop" in s_line: # Keep nops timing out_line = s_line.strip() if not out_line: # Any other instructions are converted to single cycle nop - out_line = '{}, {}, nop, 0'.format(s_split[0], s_split[1]) + out_line = f"{s_split[0]}, {s_split[1]}, nop, 0" print(out_line, file=f_out) print("Done") diff --git a/assembler_tools/hec-assembler-tools/debug_tools/main.py b/assembler_tools/hec-assembler-tools/debug_tools/main.py index 31573508..e93297db 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/main.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/main.py @@ -77,9 +77,7 @@ def main_readmem(args): if args.mem_file: mem_filename = args.mem_file else: - raise argparse.ArgumentError( - None, "Please provide input memory file using `--mem_file` option." - ) + raise argparse.ArgumentError(None, "Please provide input memory file using `--mem_file` option.") mem_meta_info = None with open(mem_filename, "r") as mem_ifnum: @@ -101,9 +99,7 @@ def main_readmem(args): print("None") -def asmisa_preprocessing( - input_filename: str, output_filename: str, b_use_bank_0: bool, b_verbose=True -) -> int: +def asmisa_preprocessing(input_filename: str, output_filename: str, b_use_bank_0: bool, b_verbose=True) -> int: """ Preprocess P-ISA kernel and save the intermediate result. @@ -128,15 +124,11 @@ def asmisa_preprocessing( start_time = time.time() with open(input_filename, "r") as insts: - insts_listing = preprocessor.preprocess_pisa_kernel_listing( - hec_mem_model, insts, progress_verbose=b_verbose - ) + insts_listing = preprocessor.preprocess_pisa_kernel_listing(hec_mem_model, insts, progress_verbose=b_verbose) if b_verbose: print("Assigning register banks to variables...") - preprocessor.assign_register_banks_to_vars( - hec_mem_model, insts_listing, use_bank0=b_use_bank_0 - ) + preprocessor.assign_register_banks_to_vars(hec_mem_model, insts_listing, use_bank0=b_use_bank_0) retval_timing = time.time() - start_time @@ -189,9 +181,7 @@ def asmisa_assembly( print("Assembling!") print("Reloading kernel from intermediate...") - hec_mem_model = MemoryModel( - hbm_capacity_words, spad_capacity_words, num_register_banks, register_range - ) + hec_mem_model = MemoryModel(hbm_capacity_words, spad_capacity_words, num_register_banks, register_range) insts_listing = [] with open(input_filename, "r") as insts: @@ -206,11 +196,7 @@ def asmisa_assembly( parsed_insts = [inst] if not parsed_insts: - raise SyntaxError( - "Line {}: unable to parse kernel instruction:\n{}".format( - line_no, s_line - ) - ) + raise SyntaxError("Line {}: unable to parse kernel instruction:\n{}".format(line_no, s_line)) insts_listing += parsed_insts @@ -303,9 +289,7 @@ def main_asmisa(args): if len(args.base_names) > 0: all_base_names = args.base_names else: - raise argparse.ArgumentError( - message=f"Please provide one or more input file prefixes using `--prefix` option." - ) + raise argparse.ArgumentError(message=f"Please provide one or more input file prefixes using `--prefix` option.") for base_name in all_base_names: in_kernel = f"{base_name}.csv" @@ -363,16 +347,12 @@ def main_pisa(args): b_use_bank_0: bool = False b_verbose = True if args.verbose > 0 else False - hec_mem_model = MemoryModel( - constants.MemoryModel.HBM.MAX_CAPACITY_WORDS // 2, 16, 4, range(8) - ) + hec_mem_model = MemoryModel(constants.MemoryModel.HBM.MAX_CAPACITY_WORDS // 2, 16, 4, range(8)) if len(args.base_names) == 1: base_name = args.base_names[0] else: - raise argparse.ArgumentError( - None, f"Please provide an input file prefix using `--prefix` option." - ) + raise argparse.ArgumentError(None, f"Please provide an input file prefix using `--prefix` option.") print("HBM") print(hec_mem_model.hbm.CAPACITY / constants.Constants.GIGABYTE, "GB") @@ -389,18 +369,12 @@ def main_pisa(args): # Resulting instructions will be correctly transformed and ready to be converted into ASM-ISA instructions; # Variables used in the kernel will be automatically assigned to banks. with open(in_kernel, "r") as insts: - insts_listing = preprocessor.preprocessPISAKernelListing( - hec_mem_model, insts, progress_verbose=b_verbose - ) + insts_listing = preprocessor.preprocessPISAKernelListing(hec_mem_model, insts, progress_verbose=b_verbose) print("Assigning register banks to variables...") - preprocessor.assignRegisterBanksToVars( - hec_mem_model, insts_listing, use_bank0=b_use_bank_0 - ) + preprocessor.assignRegisterBanksToVars(hec_mem_model, insts_listing, use_bank0=b_use_bank_0) - hec_mem_model.output_variables.update( - v_name for v_name in hec_mem_model.variables if "output" in v_name - ) + hec_mem_model.output_variables.update(v_name for v_name in hec_mem_model.variables if "output" in v_name) insts_end = time.time() - start_time @@ -439,9 +413,7 @@ def main_pisa(args): print("Scheduling P-ISA instructions...") start_time = time.time() - pisa_insts_schedule, num_idle_cycles, num_nops = schedulePISAInstructions( - dep_graph, progress_verbose=b_verbose - ) + pisa_insts_schedule, num_idle_cycles, num_nops = schedulePISAInstructions(dep_graph, progress_verbose=b_verbose) sched_end = time.time() - start_time print("Saving...") diff --git a/assembler_tools/hec-assembler-tools/debug_tools/order_test.py b/assembler_tools/hec-assembler-tools/debug_tools/order_test.py index 615ff879..eb0e26e0 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/order_test.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/order_test.py @@ -1,6 +1,10 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import argparse -import re import os +import re + # Tests all registers in an XInstQ for whether a register is used out of order based on P-ISA instruction order. # This only works for kernels without evictions. @@ -15,17 +19,29 @@ def parse_args(): argparse.Namespace: Parsed command-line arguments. """ parser = argparse.ArgumentParser( - description=("Order Test.\n" - "Tests all registers in an XInstQ for whether a register is used out of order based on P-ISA instruction order.\n" - "This only works for kernels without evictions.")) - parser.add_argument("--input_file", required= True, help="Input (.xinst) file.") - parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, - help=("If enabled, extra information and progress reports are printed to stdout. " - "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) + description=( + "Order Test.\n" + "Tests all registers in an XInstQ for whether a register is used out of order based on P-ISA instruction order.\n" + "This only works for kernels without evictions." + ) + ) + parser.add_argument("--input_file", required=True, help="Input (.xinst) file.") + parser.add_argument( + "-v", + "--verbose", + dest="verbose", + action="count", + default=0, + help=( + "If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv" + ), + ) args = parser.parse_args() return args + def convertRegNameToTuple(reg_name) -> tuple: """ Converts a register name to a tuple representation. @@ -40,6 +56,7 @@ def convertRegNameToTuple(reg_name) -> tuple: tmp_s = tmp_s.split("b") return (int(tmp_s[1]), int(tmp_s[0])) + if __name__ == "__main__": module_name = os.path.basename(__file__) @@ -49,7 +66,7 @@ def convertRegNameToTuple(reg_name) -> tuple: if args.verbose > 0: print(module_name) print() - print("Xinst File: {0}".format(input_file)) + print(f"Xinst File: {input_file}") print() print("Starting") @@ -58,8 +75,8 @@ def convertRegNameToTuple(reg_name) -> tuple: my_rx = "r[0-9]+b[0-3]" prev_pisa_inst = 0 instr_counter = 0 - with open(input_file, 'r') as f_in: - for line_idx, s_line in enumerate(f_in): + with open(input_file) as f_in: + for _, s_line in enumerate(f_in): instr_regs = set() s_split = s_line.split("#") s_split = s_split[0].split(",") @@ -67,7 +84,7 @@ def convertRegNameToTuple(reg_name) -> tuple: for s in s_split: match = re.search(my_rx, s) if match: - reg_name = s[match.start():match.end()] + reg_name = s[match.start() : match.end()] if reg_name not in instr_regs: instr_regs.add(reg_name) reg = convertRegNameToTuple(reg_name) @@ -75,12 +92,12 @@ def convertRegNameToTuple(reg_name) -> tuple: register_map[reg] = [] register_map[reg].append(pisa_instr_num) - sorted_keys = [x for x in register_map] + sorted_keys = list(register_map) sorted_keys.sort() error_map = set() for reg in sorted_keys: - reg_name = f'r{reg[1]}b{reg[0]}' + reg_name = f"r{reg[1]}b{reg[0]}" print(reg_name, register_map[reg]) reg_lst = register_map[reg] inverted_map = {} @@ -91,10 +108,10 @@ def convertRegNameToTuple(reg_name) -> tuple: else: inverted_map[idx] = (prev_in, reg_lst[idx]) if inverted_map: - print('*** Ahead:', inverted_map) + print("*** Ahead:", inverted_map) error_map.add(reg_name) if error_map: - raise RuntimeError(f'Registers used out of order: {error_map}') + raise RuntimeError(f"Registers used out of order: {error_map}") - print("Done") \ No newline at end of file + print("Done") diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py index 519ef37f..0c849786 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py @@ -94,13 +94,12 @@ def main( cinst_file_o = os.path.join(output_dir, output_prefix + ".cinst") minst_file_o = os.path.join(output_dir, output_prefix + ".minst") - with open(xinst_file_i, "r") as f_xinst_file_i, open( - cinst_file_i, "r" - ) as f_cinst_file_i, open(minst_file_i, "r") as f_minst_file_i: - with open(xinst_file_o, "w") as f_xinst_file_o, open( - cinst_file_o, "w" - ) as f_cinst_file_o, open(minst_file_o, "w") as f_minst_file_o: - + with open(xinst_file_i, "r") as f_xinst_file_i, open(cinst_file_i, "r") as f_cinst_file_i, open(minst_file_i, "r") as f_minst_file_i: + with ( + open(xinst_file_o, "w") as f_xinst_file_o, + open(cinst_file_o, "w") as f_cinst_file_o, + open(minst_file_o, "w") as f_minst_file_o, + ): current_bundle = 0 # Read xinst until first bundle is over @@ -123,12 +122,8 @@ def main( num_xstores += 1 cinst_line_no = 0 - cinst_insertion_line_start = ( - 0 # Track which line we started inserting dummy bundles into CInstQ - ) - cinst_insertion_line_count = ( - 0 # Track how many lines of dummy bundles were inserted into CInstQ - ) + cinst_insertion_line_start = 0 # Track which line we started inserting dummy bundles into CInstQ + cinst_insertion_line_count = 0 # Track how many lines of dummy bundles were inserted into CInstQ # Read cinst until first bundle is over while True: # do-while @@ -246,9 +241,7 @@ def main( print(idx) tokens, comment = xinstruction.tokenize_from_line(line) - assert ( - int(tokens[0]) == idx - ), "Unexpected line number mismatch in MInstQ." + assert int(tokens[0]) == idx, "Unexpected line number mismatch in MInstQ." tokens = list(tokens) # Process sync instruction @@ -287,9 +280,7 @@ def main( parser.add_argument("-b", "--dummy_bundles", dest="nbundles", type=int, default=0) parser.add_argument("-ne", "--skip_exit", dest="b_use_exit", action="store_false") args = parser.parse_args() - args.isa_spec_file = XTCSpecConfig.initialize_isa_spec( - module_dir, args.isa_spec_file - ) + args.isa_spec_file = XTCSpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) print(f"ISA Spec: {args.isa_spec_file}") print() diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py index 51b1ce05..15bfefb5 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py @@ -65,7 +65,6 @@ def dump_isa_spec_to_json(cls, filename): @classmethod def initialize_isa_spec(cls, module_dir, isa_spec_file): - if not isa_spec_file: isa_spec_file = os.path.join(module_dir, "../../config/isa_spec.json") isa_spec_file = os.path.abspath(isa_spec_file) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/__init__.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/__init__.py index 5d49e768..4a36b9b8 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/__init__.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/__init__.py @@ -1,4 +1,7 @@ -from .xinstruction import XInstruction +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .xinstruction import XInstruction from . import add, mul, muli, mac, maci, ntt, intt, twntt, twintt, rshuffle, sub, move, xstore, nop from . import exit as exit_mod @@ -21,4 +24,4 @@ Nop = nop.Instruction # collection of XInstructions with P-ISA or intermediate P-ISA equivalents -ASMISA_INSTRUCTIONS = ( Add, Mul, Muli, Mac, Maci, NTT, iNTT, twNTT, twiNTT, rShuffle, Sub, Move, XStore, Exit, Nop ) +ASMISA_INSTRUCTIONS = (Add, Mul, Muli, Mac, Maci, NTT, iNTT, twNTT, twiNTT, rShuffle, Sub, Move, XStore, Exit, Nop) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/add.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/add.py index 648d0909..f0e355a5 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/add.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/add.py @@ -1,12 +1,16 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents an `add` instruction, inheriting from XInstruction. - + This instructions adds two polynomials stored in the register file and store the result in a register. - + For more information, check the specification: https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_add.md """ @@ -30,16 +34,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # Bundle - int(tokens[1]), # PISA instruction number - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # PISA instruction number + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -52,15 +58,9 @@ def _get_name(cls) -> str: """ return "add" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Constructs a new Instruction object. diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/exit.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/exit.py index e6747286..b3606d40 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/exit.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/exit.py @@ -1,15 +1,19 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents an `bexit` instruction, inheriting from XInstruction. - + This instruction terminates execution of an instruction bundle. For more information, check the specification: https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_exit.md """ - + @classmethod def fromASMISALine(cls, line: str) -> list: """ @@ -29,15 +33,17 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') - retval = cls(int(tokens[0][1:]), # Bundle - int(tokens[1]), # PISA instruction number - [], - [], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + raise ValueError("`line`: could not parse f{cls.name} from specified line.") + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # PISA instruction number + [], + [], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -50,15 +56,9 @@ def _get_name(cls) -> str: """ return "bexit" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Constructs a new Instruction object. @@ -72,4 +72,4 @@ def __init__(self, other (list): Additional parameters for the instruction. comment (str): Optional comment for the instruction. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/intt.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/intt.py index 18c4e7d6..9dc86e6a 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/intt.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/intt.py @@ -1,9 +1,13 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents an `intt` instruction, inheriting from XInstruction. - + The Inverse Number Theoretic Transform (iNTT), converts NTT form to positional form. For more information, check the specification: @@ -29,16 +33,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # Bundle - int(tokens[1]), # PISA instruction number - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # PISA instruction number + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -51,15 +57,9 @@ def _get_name(cls) -> str: """ return "intt" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Constructs a new Instruction object. @@ -73,4 +73,4 @@ def __init__(self, other (list): Additional parameters for the instruction. comment (str): Optional comment for the instruction. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mac.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mac.py index c5ce986c..63da0418 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mac.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mac.py @@ -1,5 +1,9 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents a `mac` Instruction for element-wise polynomial multiplication and accumulation. @@ -30,16 +34,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # Bundle - int(tokens[1]), # Pisa - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -52,15 +58,9 @@ def _get_name(cls) -> str: """ return "mac" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Initializes an Instruction instance. @@ -74,4 +74,4 @@ def __init__(self, other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/maci.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/maci.py index 3525f06d..08e53c5d 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/maci.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/maci.py @@ -1,5 +1,9 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents a `maci` Instruction. @@ -32,16 +36,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # Bundle - int(tokens[1]), # Psisa - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # Psisa + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -54,15 +60,9 @@ def _get_name(cls) -> str: """ return "maci" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Initializes an Instruction instance. @@ -76,4 +76,4 @@ def __init__(self, other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/move.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/move.py index 4d9205ae..a790100b 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/move.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/move.py @@ -1,9 +1,13 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents a `move` Instruction. - + This instruction copies data from one register to a different one. For more information, check the specification: @@ -32,16 +36,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # Bundle - int(tokens[1]), # Pisa - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -54,15 +60,9 @@ def _get_name(cls) -> str: """ return "move" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Initializes an Instruction instance. @@ -76,4 +76,4 @@ def __init__(self, other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mul.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mul.py index d243b0d5..f49004f8 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mul.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mul.py @@ -1,5 +1,9 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents a `mul` Instruction. @@ -8,7 +12,7 @@ class Instruction(XInstruction): For more information, check the specification: https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_mul.md - + Methods: fromASMISALine: Parses an ASM ISA line to create an Instruction instance. """ @@ -32,16 +36,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # Bundle - int(tokens[1]), # Pisa - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -54,15 +60,9 @@ def _get_name(cls) -> str: """ return "mul" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Initializes an Instruction instance. @@ -76,4 +76,4 @@ def __init__(self, other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/muli.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/muli.py index e117e7cd..b9c561fc 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/muli.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/muli.py @@ -1,5 +1,9 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents a `muli` Instruction. @@ -32,16 +36,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # Bundle - int(tokens[1]), # Pisa - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -54,15 +60,9 @@ def _get_name(cls) -> str: """ return "muli" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Initializes an Instruction instance. @@ -76,4 +76,4 @@ def __init__(self, other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/nop.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/nop.py index 02355b75..f9b56073 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/nop.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/nop.py @@ -1,5 +1,9 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents a `nop` Instruction. @@ -12,7 +16,7 @@ class Instruction(XInstruction): Methods: fromASMISALine: Parses an ASM ISA line to create an Instruction instance. """ - + @classmethod def fromASMISALine(cls, line: str) -> list: """ @@ -32,7 +36,7 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 4 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") idle_cycles = int(tokens[3]) + 1 retval = cls( int(tokens[0][1:]), # Bundle @@ -42,7 +46,7 @@ def fromASMISALine(cls, line: str) -> list: idle_cycles, idle_cycles, tokens[3:], - comment + comment, ) return retval @@ -56,15 +60,9 @@ def _get_name(cls) -> str: """ return "nop" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Initializes an Instruction instance. @@ -78,4 +76,4 @@ def __init__(self, other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/ntt.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/ntt.py index 22eac064..3a144f72 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/ntt.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/ntt.py @@ -1,19 +1,23 @@ -from argparse import Namespace +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from argparse import Namespace from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents an `ntt` instruction (Number Theoretic Transform). - + Converts positional form to NTT form. For more information, check the specification: https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_ntt.md - + Methods: fromASMISALine: Parses an ASM ISA line to create an Instruction instance. """ - + @classmethod def fromASMISALine(cls, line: str) -> list: """ @@ -33,16 +37,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # bundle - int(tokens[1]), # pisa - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # bundle + int(tokens[1]), # pisa + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -55,15 +61,9 @@ def _get_name(cls) -> str: """ return "ntt" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Initializes an Instruction instance. @@ -77,4 +77,4 @@ def __init__(self, other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/rshuffle.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/rshuffle.py index 94ad23bc..4441fdf1 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/rshuffle.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/rshuffle.py @@ -1,9 +1,13 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents an Instruction with specific operational parameters and special latency properties. - + Methods: fromASMISALine: Parses an ASM ISA line to create an Instruction instance. _get_name: Gets the name of the instruction. @@ -14,20 +18,20 @@ class Instruction(XInstruction): special_latency_max: Gets the special latency maximum. special_latency_increment: Gets the special latency increment. """ - + # To be initialized from ASM ISA spec - _OP_RMOVE_LATENCY : int - _OP_RMOVE_LATENCY_MAX: int - _OP_RMOVE_LATENCY_INC: int + _OP_REMOVE_LATENCY: int + _OP_REMOVE_LATENCY_MAX: int + _OP_REMOVE_LATENCY_INC: int @classmethod def SetSpecialLatencyMax(cls, val): - cls._OP_RMOVE_LATENCY_MAX = val - cls._OP_RMOVE_LATENCY = cls._OP_RMOVE_LATENCY_MAX + cls._OP_REMOVE_LATENCY_MAX = val + cls._OP_REMOVE_LATENCY = cls._OP_REMOVE_LATENCY_MAX @classmethod def SetSpecialLatencyIncrement(cls, val): - cls._OP_RMOVE_LATENCY_INC = val + cls._OP_REMOVE_LATENCY_INC = val @classmethod def fromASMISALine(cls, line: str) -> list: @@ -48,16 +52,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 9 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # bundle - int(tokens[1]), # pisa - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # bundle + int(tokens[1]), # pisa + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -70,15 +76,9 @@ def _get_name(cls) -> str: """ return "rshuffle" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Initializes an Instruction instance. @@ -96,7 +96,7 @@ def __init__(self, ValueError: If the 'other' list does not contain at least two parameters. """ if len(other) < 2: - raise ValueError('`other`: requires two parameters after sources.') + raise ValueError("`other`: requires two parameters after sources.") super().__init__(bundle, pisa_instr, dsts, srcs, throughput + int(other[0]), latency, other, comment) @property @@ -127,7 +127,7 @@ def special_latency_max(self): Returns: int: The special latency maximum. """ - return self._OP_RMOVE_LATENCY + return self._OP_REMOVE_LATENCY @property def special_latency_increment(self): @@ -137,4 +137,4 @@ def special_latency_increment(self): Returns: int: The special latency increment. """ - return self._OP_RMOVE_LATENCY_INC + return self._OP_REMOVE_LATENCY_INC diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/sub.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/sub.py index 1d495849..0d1dea96 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/sub.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/sub.py @@ -1,5 +1,9 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents a `sub` Instruction. @@ -32,16 +36,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # Bundle - int(tokens[1]), # Pisa - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -54,15 +60,9 @@ def _get_name(cls) -> str: """ return "sub" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Initializes an Instruction instance. @@ -76,4 +76,4 @@ def __init__(self, other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twintt.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twintt.py index 3334a584..0d4d1fc5 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twintt.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twintt.py @@ -1,9 +1,13 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents a `twintt` Instruction. - + This instruction performs on-die generation of twiddle factors for the next stage of iNTT. For more information, check the specification: @@ -32,16 +36,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # Bundle - int(tokens[1]), # Pisa - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -54,15 +60,9 @@ def _get_name(cls) -> str: """ return "twntt" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Initializes an Instruction instance. @@ -76,4 +76,4 @@ def __init__(self, other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twntt.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twntt.py index 175db3b4..3752edef 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twntt.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twntt.py @@ -1,9 +1,13 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents a `twntt` Instruction. - + This instruction performs on-die generation of twiddle factors for the next stage of NTT. For more information, check the specification: @@ -32,16 +36,18 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') + raise ValueError("`line`: could not parse f{cls.name} from specified line.") dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # Bundle - int(tokens[1]), # Pisa - dst_src_map['dst'], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map["dst"], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -54,15 +60,9 @@ def _get_name(cls) -> str: """ return "twintt" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, bundle: int, pisa_instr: int, dsts: list, srcs: list, throughput: int, latency: int, other: list = [], comment: str = "" + ): """ Initializes an Instruction instance. @@ -76,4 +76,4 @@ def __init__(self, other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py index 1d0aa938..01a1eb74 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py @@ -7,7 +7,6 @@ class XInstruction: - # To be initialized from ASM ISA spec _OP_NUM_DESTS: int _OP_NUM_SOURCES: int @@ -57,9 +56,7 @@ def tokenizeFromASMISALine(op_name: str, line: str) -> list: return retval @staticmethod - def parseASMISASourceDestsFromTokens( - tokens: list, num_dests: int, num_sources: int, offset: int = 0 - ) -> dict: + def parseASMISASourceDestsFromTokens(tokens: list, num_dests: int, num_sources: int, offset: int = 0) -> dict: """ Parses the sources and destinations for an instruction, given sources and destinations in tokens in P-ISA format. diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py index b1317ba9..91647847 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py @@ -40,12 +40,8 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError( - "`line`: could not parse f{cls.name} from specified line." - ) - dst_src_map = XInstruction.parseASMISASourceDestsFromTokens( - tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3 - ) + raise ValueError("`line`: could not parse f{cls.name} from specified line.") + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) retval = cls( int(tokens[0][1:]), # bundle int(tokens[1]), # pisa @@ -92,6 +88,4 @@ def __init__( other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__( - bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment - ) + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py index de3a6df4..2e228442 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py @@ -84,11 +84,7 @@ def computeXBundleLatencies(xinstrs: list) -> list: bundle_id = 0 while xinstrs: if bundle_id % 1000 == 0: - print( - f"{(total_xinstr - len(xinstrs)) * 100 // total_xinstr}% " - f"- {(total_xinstr - len(xinstrs))}" - f"/{total_xinstr}" - ) + print(f"{(total_xinstr - len(xinstrs)) * 100 // total_xinstr}% " f"- {(total_xinstr - len(xinstrs))}" f"/{total_xinstr}") bundle = xinstrs[:NUM_BUNDLE_INSTRUCTIONS] xinstrs = xinstrs[NUM_BUNDLE_INSTRUCTIONS:] assert bundle[0].bundle == bundle_id and bundle[-1].bundle == bundle_id @@ -125,9 +121,7 @@ def computeCBundleLatencies(cinstr_lines) -> list: raise RuntimeError("Invalid CInstruction detected after end of CInstQ") if "ifetch" == s_split[1]: # New bundle - assert ( - int(s_split[2]) == bundle_id - ), f"ifetch, {s_split[2]} | expected {bundle_id}" + assert int(s_split[2]) == bundle_id, f"ifetch, {s_split[2]} | expected {bundle_id}" retval.append(bundle_latency) bundle_id += 1 bundle_latency = 0 @@ -198,27 +192,19 @@ def main(input_dir: str, input_prefix: Optional[str] = None): cbundle_cycles = computeCBundleLatencies(f_in) if len(xbundle_cycles) != len(cbundle_cycles): - raise RuntimeError( - "Mismatched bundles: {} xbundles vs. {} cbundles".format( - len(xbundle_cycles), len(cbundle_cycles) - ) - ) + raise RuntimeError("Mismatched bundles: {} xbundles vs. {} cbundles".format(len(xbundle_cycles), len(cbundle_cycles))) print("Comparing latencies...") bundle_cycles_violation_list = [] for idx in range(len(xbundle_cycles)): if xbundle_cycles[idx] > cbundle_cycles[idx]: bundle_cycles_violation_list.append( - "Bundle {} | X {} cycles; C {} cycles".format( - idx, xbundle_cycles[idx], cbundle_cycles[idx] - ) + "Bundle {} | X {} cycles; C {} cycles".format(idx, xbundle_cycles[idx], cbundle_cycles[idx]) ) # Check timings for register access print("--------------") print("Checking timings for register access...") - violation_lst: List[Tuple[int, int, str, int]] = ( - [] - ) # list(tuple(xinstr_idx, violating_idx, register: str, cycle_counter)) + violation_lst: List[Tuple[int, int, str, int]] = [] # list(tuple(xinstr_idx, violating_idx, register: str, cycle_counter)) for idx, xinstr in enumerate(xinstrs): if idx % 50000 == 0: print("{}% - {}/{}".format(idx * 100 // len(xinstrs), idx, len(xinstrs))) @@ -243,13 +229,9 @@ def main(input_dir: str, input_prefix: Optional[str] = None): src_bank = xinstr.srcs[0][1] dst_bank = xinstr.dsts[0][1] if src_bank != 0: - violation_lst.append( - (idx + 1, idx + 1, f"Move bank error sources {src_bank}", 0) - ) + violation_lst.append((idx + 1, idx + 1, f"Move bank error sources {src_bank}", 0)) if dst_bank == src_bank: - violation_lst.append( - (idx + 1, idx + 1, f"Move bank error dests {dst_bank}", 0) - ) + violation_lst.append((idx + 1, idx + 1, f"Move bank error dests {dst_bank}", 0)) # Check timing cycle_counter = xinstr.throughput @@ -266,9 +248,7 @@ def main(input_dir: str, input_prefix: Optional[str] = None): for reg in xinstr.dsts: if reg in all_next_regs: # Register is not ready and still used by an instruction - violation_lst.append( - (idx + 1, jdx + 1, f"r{reg[0]}b{reg[1]}", cycle_counter) - ) + violation_lst.append((idx + 1, jdx + 1, f"r{reg[0]}b{reg[1]}", cycle_counter)) cycle_counter += next_xinstr.throughput @@ -277,9 +257,7 @@ def main(input_dir: str, input_prefix: Optional[str] = None): # Check rshuffle separation print("--------------") print("Checking rshuffle separation...") - rshuffle_violation_lst: List[Tuple[int, int, str, int]] = ( - [] - ) # list(tuple(xinstr_idx, violating_idx, data_types: str, cycle_counter)) + rshuffle_violation_lst: List[Tuple[int, int, str, int]] = [] # list(tuple(xinstr_idx, violating_idx, data_types: str, cycle_counter)) print("WARNING: No distinction between `rshuffle` and `irshuffle`.") for idx, xinstr in enumerate(xinstrs): if idx % 50000 == 0: @@ -307,10 +285,7 @@ def main(input_dir: str, input_prefix: Optional[str] = None): cycle_counter, ) ) - elif ( - cycle_counter < xinstr.special_latency_max - and cycle_counter % xinstr.special_latency_increment != 0 - ): + elif cycle_counter < xinstr.special_latency_max and cycle_counter % xinstr.special_latency_increment != 0: # Same data type rshuffle_violation_lst.append( ( @@ -328,9 +303,7 @@ def main(input_dir: str, input_prefix: Optional[str] = None): # Check bank conflicts with rshuffle print("--------------") print("Checking bank conflicts with rshuffle...") - rshuffle_bank_violation_lst: List[Tuple[int, int, str, int]] = ( - [] - ) # list(tuple(xinstr_idx, violating_idx, banks: str, cycle_counter)) + rshuffle_bank_violation_lst: List[Tuple[int, int, str, int]] = [] # list(tuple(xinstr_idx, violating_idx, banks: str, cycle_counter)) for idx, xinstr in enumerate(xinstrs): if idx % 50000 == 0: print("{}% - {}/{}".format(idx * 100 // len(xinstrs), idx, len(xinstrs))) @@ -357,9 +330,7 @@ def main(input_dir: str, input_prefix: Optional[str] = None): ( idx + 1, jdx + 1, - "{} | banks: {}".format( - next_xinstr.name, rshuffle_banks & next_xinstr_banks - ), + "{} | banks: {}".format(next_xinstr.name, rshuffle_banks & next_xinstr_banks), cycle_counter, ) ) @@ -421,9 +392,7 @@ def main(input_dir: str, input_prefix: Optional[str] = None): ) args = parser.parse_args() - args.isa_spec_file = XTCSpecConfig.initialize_isa_spec( - module_dir, args.isa_spec_file - ) + args.isa_spec_file = XTCSpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) print(f"ISA Spec: {args.isa_spec_file}") print() diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nop.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nop.md index 197720ed..b9102ece 100644 --- a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nop.md +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nop.md @@ -32,4 +32,4 @@ Note that this instruction will cause the compute flow to stall for `1 + cycles` cnop, 0 ``` -Parameter `cycles` is encoded into a 10 bits field, and thus, its value must be less than 1024. If more thatn 1024 idle cycles is required, multiple `cnop` instructions must be scheduled back to back. +Parameter `cycles` is encoded into a 10 bits field, and thus, its value must be less than 1024. If more than 1024 idle cycles is required, multiple `cnop` instructions must be scheduled back to back. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_intt.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_intt.md index b9786e42..93077d94 100644 --- a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_intt.md +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_intt.md @@ -34,4 +34,4 @@ Performs one-stage of inverse NTT. Both NTT and inverse NTT instructions are defined as one-stage of the transformation. A complete NTT/iNTT transformation is composed of LOG_N such one-stage instructions. -This instruction matches to HERACLES ISA `intt`. It requires a preceeding, matching [`rmove`](xinst_rmove.md) to shuffle the input bits into correct tile-pairs. +This instruction matches to HERACLES ISA `intt`. It requires a preceding, matching [`remove`](xinst_remove.md) to shuffle the input bits into correct tile-pairs. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_ntt.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_ntt.md index 3fccca29..22f7429d 100644 --- a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_ntt.md +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_ntt.md @@ -34,4 +34,4 @@ Performs one-stage of NTT on an input positional polynomial. Both NTT and inverse NTT instructions are defined as one-stage of the transformation. A complete NTT/iNTT transformation is composed of LOG_N such one-stage instructions. -This instruction matches to HERACLES ISA `ntt` with `store_local` bit set. i.e. it requires a subsequent, matching [`rmove`](xinst_rmove.md) to shuffle the output bits into correct tile-pairs. +This instruction matches to HERACLES ISA `ntt` with `store_local` bit set. i.e. it requires a subsequent, matching [`remove`](xinst_remove.md) to shuffle the output bits into correct tile-pairs. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_rshuffle.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_rshuffle.md index a0bc2cbc..75edb834 100644 --- a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_rshuffle.md +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_rshuffle.md @@ -73,4 +73,4 @@ Notice that while `rshuffle`'s source and destination registers are the same in Routing table metadata defining shuffling patterns is loaded by [`nload`](../cinst/cinst_nload.md). -Parameter `data_type` is intended to select the correct routing table, however, in the current HERACLES implementation, only one routing table is availabe, and this parameter is used only for book keeping and error detection during scheduling. +Parameter `data_type` is intended to select the correct routing table, however, in the current HERACLES implementation, only one routing table is available, and this parameter is used only for book keeping and error detection during scheduling. diff --git a/assembler_tools/hec-assembler-tools/docsrc/specs.md b/assembler_tools/hec-assembler-tools/docsrc/specs.md index 1c238dd7..a7964cd7 100644 --- a/assembler_tools/hec-assembler-tools/docsrc/specs.md +++ b/assembler_tools/hec-assembler-tools/docsrc/specs.md @@ -1,4 +1,4 @@ -# HERACLES Instruction Specification +# HERACLES Instruction Specification Terms used in this document are defined in the HERACLES Instruction Set Architecture (ISA). diff --git a/assembler_tools/hec-assembler-tools/he_as.py b/assembler_tools/hec-assembler-tools/he_as.py index 0b666c39..129f8fe9 100644 --- a/assembler_tools/hec-assembler-tools/he_as.py +++ b/assembler_tools/hec-assembler-tools/he_as.py @@ -24,6 +24,7 @@ to specify input and output files and configuration options for the assembly process. """ + import argparse import io import os @@ -108,9 +109,7 @@ def __init__(self, **kwargs): if not hasattr(self, config_name): setattr(self, config_name, default_value) if getattr(self, config_name) is None: - raise TypeError( - f"Expected value for configuration `{config_name}`, but `None` received." - ) + raise TypeError(f"Expected value for configuration `{config_name}`, but `None` received.") # class members self.input_prefix = "" @@ -126,9 +125,7 @@ def __init__(self, **kwargs): self.input_prefix = os.path.splitext(os.path.basename(self.input_file))[0] if not self.input_mem_file: - self.input_mem_file = "{}.{}".format( - os.path.join(input_dir, self.input_prefix), DEFAULT_MEM_FILE_EXT - ) + self.input_mem_file = "{}.{}".format(os.path.join(input_dir, self.input_prefix), DEFAULT_MEM_FILE_EXT) self.input_mem_file = makeUniquePath(self.input_mem_file) @classmethod @@ -173,12 +170,7 @@ def as_dict(self) -> dict: """ retval = super().as_dict() tmp_self_dict = vars(self) - retval.update( - { - config_name: tmp_self_dict[config_name] - for config_name in self.__default_config - } - ) + retval.update({config_name: tmp_self_dict[config_name] for config_name in self.__default_config}) return retval @@ -210,12 +202,8 @@ def asmisaAssemble( input_filename: str = run_config.input_file mem_filename: str = run_config.input_mem_file - hbm_capacity_words: int = constants.convertBytes2Words( - run_config.hbm_size * constants.Constants.KILOBYTE - ) - spad_capacity_words: int = constants.convertBytes2Words( - run_config.spad_size * constants.Constants.KILOBYTE - ) + hbm_capacity_words: int = constants.convertBytes2Words(run_config.hbm_size * constants.Constants.KILOBYTE) + spad_capacity_words: int = constants.convertBytes2Words(run_config.spad_size * constants.Constants.KILOBYTE) num_register_banks: int = constants.MemoryModel.NUM_REGISTER_BANKS register_range: range = None @@ -223,9 +211,7 @@ def asmisaAssemble( print("Assembling!") print("Reloading kernel from intermediate...") - hec_mem_model = MemoryModel( - hbm_capacity_words, spad_capacity_words, num_register_banks, register_range - ) + hec_mem_model = MemoryModel(hbm_capacity_words, spad_capacity_words, num_register_banks, register_range) insts_listing = [] with open(input_filename, "r") as insts: @@ -240,11 +226,7 @@ def asmisaAssemble( parsed_insts = [inst] if not parsed_insts: - raise SyntaxError( - "Line {}: unable to parse kernel instruction:\n{}".format( - line_no, s_line - ) - ) + raise SyntaxError("Line {}: unable to parse kernel instruction:\n{}".format(line_no, s_line)) insts_listing += parsed_insts @@ -257,12 +239,8 @@ def asmisaAssemble( if b_verbose: print("Generating dependency graph...") start_time = time.time() - dep_graph = scheduler.generateInstrDependencyGraph( - insts_listing, sys.stdout if b_verbose else None - ) - scheduler.enforceKeygenOrdering( - dep_graph, hec_mem_model, sys.stdout if b_verbose else None - ) + dep_graph = scheduler.generateInstrDependencyGraph(insts_listing, sys.stdout if b_verbose else None) + scheduler.enforceKeygenOrdering(dep_graph, hec_mem_model, sys.stdout if b_verbose else None) deps_end = time.time() - start_time if b_verbose: @@ -402,10 +380,7 @@ def parse_args(): ) parser.add_argument( "input_file", - help=( - "Input pre-processed P-ISA kernel file. " - "File must be the result of pre-processing a P-ISA kernel with he_prep.py" - ), + help=("Input pre-processed P-ISA kernel file. " "File must be the result of pre-processing a P-ISA kernel with he_prep.py"), ) parser.add_argument( "--isa_spec", @@ -439,10 +414,7 @@ def parse_args(): parser.add_argument( "--output_prefix", default="", - help=( - "Prefix for the output files. " - "Defaults to the same the input file without extension." - ), + help=("Prefix for the output files. " "Defaults to the same the input file without extension."), ) parser.add_argument("--spad_size", type=int, help="Scratchpad size in KB.") parser.add_argument("--hbm_size", type=int, help="HBM size in KB.") @@ -468,9 +440,7 @@ def parse_args(): "--no_comments", dest="suppress_comments", action="store_true", - help=( - "When enabled, no comments will be emitted on the output generated by the assembler." - ), + help=("When enabled, no comments will be emitted on the output generated by the assembler."), ) parser.add_argument( "-v", @@ -494,12 +464,8 @@ def parse_args(): # Initialize Defaults args = parse_args() - args.isa_spec_file = ISASpecConfig.initialize_isa_spec( - module_dir, args.isa_spec_file - ) - args.mem_spec_file = MemSpecConfig.initialize_mem_spec( - module_dir, args.mem_spec_file - ) + args.isa_spec_file = ISASpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) + args.mem_spec_file = MemSpecConfig.initialize_mem_spec(module_dir, args.mem_spec_file) config = AssemblerRunConfig(**vars(args)) # convert argsparser into a dictionary diff --git a/assembler_tools/hec-assembler-tools/he_link.py b/assembler_tools/hec-assembler-tools/he_link.py index d2cc57b1..4adeb1a2 100644 --- a/assembler_tools/hec-assembler-tools/he_link.py +++ b/assembler_tools/hec-assembler-tools/he_link.py @@ -7,7 +7,8 @@ """ @file he_link.py -@brief This module provides functionality for linking assembled kernels into a full HERACLES program for execution queues: MINST, CINST, and XINST. +@brief This module provides functionality for linking assembled kernels + into a full HERACLES program for execution queues: MINST, CINST, and XINST. @par Classes: - LinkerRunConfig: Maintains the configuration data for the run. @@ -21,32 +22,33 @@ This script is intended to be run as a standalone program. It requires specific command-line arguments to specify input and output files and configuration options for the linking process. """ + import argparse import os import sys import warnings -from assembler.common.counter import Counter from assembler.common.config import GlobalConfig -from assembler.spec_config.mem_spec import MemSpecConfig +from assembler.common.counter import Counter from assembler.spec_config.isa_spec import ISASpecConfig -from linker.instructions import BaseInstruction -from linker.linker_run_config import LinkerRunConfig -from linker.steps.variable_discovery import scan_variables, check_unused_variables -from linker.steps import program_linker -from linker.kern_trace.trace_info import TraceInfo -from linker.loader import Loader +from assembler.spec_config.mem_spec import MemSpecConfig from linker.he_link_utils import ( NullIO, - prepare_output_files, + initialize_memory_model, prepare_input_files, - update_input_prefixes, + prepare_output_files, remap_vars, - initialize_memory_model, + update_input_prefixes, ) +from linker.instructions import BaseInstruction +from linker.kern_trace.trace_info import TraceInfo +from linker.linker_run_config import LinkerRunConfig +from linker.loader import Loader +from linker.steps import program_linker +from linker.steps.variable_discovery import check_unused_variables, scan_variables -def main(run_config: LinkerRunConfig, verbose_stream=NullIO()): +def main(run_config: LinkerRunConfig, verbose_stream=None): """ @brief Executes the linking process using the provided configuration. @@ -54,10 +56,12 @@ def main(run_config: LinkerRunConfig, verbose_stream=NullIO()): and links each kernel, writing the output to specified files. @param run_config The configuration object containing run parameters. - @param verbose_stream The stream to which verbose output is printed. Defaults to NullIO. + @param verbose_stream The stream to which verbose output is printed. Defaults to None. @return None """ + if verbose_stream is None: + verbose_stream = NullIO() if run_config.use_xinstfetch: warnings.warn("Ignoring configuration flag 'use_xinstfetch'.") @@ -101,9 +105,7 @@ def main(run_config: LinkerRunConfig, verbose_stream=NullIO()): remap_vars(kernels_info, dinstrs_per_kernel, kernel_ops, verbose_stream) # Concatenate all mem info objects into one - program_dinstrs = program_linker.LinkedProgram.join_dinst_kernels( - dinstrs_per_kernel - ) + program_dinstrs = program_linker.LinkedProgram.join_dinst_kernels(dinstrs_per_kernel) # Write new program memory model to an output file if program_info.mem is None: @@ -117,9 +119,7 @@ def main(run_config: LinkerRunConfig, verbose_stream=NullIO()): print(" Finding all program variables...", file=verbose_stream) print(" Scanning", file=verbose_stream) - scan_variables( - kernels_info=kernels_info, mem_model=mem_model, verbose_stream=verbose_stream - ) + scan_variables(kernels_info=kernels_info, mem_model=mem_model, verbose_stream=verbose_stream) check_unused_variables(mem_model) @@ -127,9 +127,7 @@ def main(run_config: LinkerRunConfig, verbose_stream=NullIO()): print("Linking started", file=verbose_stream) # Link kernels and generate outputs - program_linker.LinkedProgram.link_kernels_to_files( - kernels_info, program_info, mem_model, verbose_stream=verbose_stream - ) + program_linker.LinkedProgram.link_kernels_to_files(kernels_info, program_info, mem_model, verbose_stream=verbose_stream) # Flush cached kernels Loader.flush_cache() @@ -281,23 +279,15 @@ def parse_args(): # Enforce only if use_trace_file is not set if not p_args.using_trace_file: if p_args.input_mem_file == "": - parser.error( - "the following arguments are required: -im/--input_mem_file (unless --use_trace_file is set)" - ) + parser.error("the following arguments are required: -im/--input_mem_file (unless --use_trace_file is set)") if not p_args.input_prefixes: - parser.error( - "the following arguments are required: -ip/--input_prefixes (unless --use_trace_file is set)" - ) + parser.error("the following arguments are required: -ip/--input_prefixes (unless --use_trace_file is set)") else: # If using trace file, input_mem_file and input_prefixes are ignored if p_args.input_mem_file != "": - warnings.warn( - "Ignoring input_mem_file argument because --use_trace_file is set." - ) + warnings.warn("Ignoring input_mem_file argument because --use_trace_file is set.") if p_args.input_prefixes: - warnings.warn( - "Ignoring input_prefixes argument because --use_trace_file is set." - ) + warnings.warn("Ignoring input_prefixes argument because --use_trace_file is set.") return p_args @@ -307,12 +297,8 @@ def parse_args(): module_name = os.path.basename(__file__) args = parse_args() - args.mem_spec_file = MemSpecConfig.initialize_mem_spec( - module_dir, args.mem_spec_file - ) - args.isa_spec_file = ISASpecConfig.initialize_isa_spec( - module_dir, args.isa_spec_file - ) + args.mem_spec_file = MemSpecConfig.initialize_mem_spec(module_dir, args.mem_spec_file) + args.isa_spec_file = ISASpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) config = LinkerRunConfig(**vars(args)) # convert argsparser into a dictionary if args.verbose > 0: diff --git a/assembler_tools/hec-assembler-tools/he_prep.py b/assembler_tools/hec-assembler-tools/he_prep.py index 1dae2ff1..4c5994bb 100644 --- a/assembler_tools/hec-assembler-tools/he_prep.py +++ b/assembler_tools/hec-assembler-tools/he_prep.py @@ -21,15 +21,16 @@ arguments to specify input and output files and verbosity options for the preprocessing process. """ + import argparse import os import time from assembler.common import constants +from assembler.memory_model import MemoryModel from assembler.spec_config.isa_spec import ISASpecConfig from assembler.spec_config.mem_spec import MemSpecConfig from assembler.stages import preprocessor -from assembler.memory_model import MemoryModel def save_pisa_listing(out_stream, instr_listing: list): @@ -89,18 +90,12 @@ def main(output_file_name: str, input_file_name: str, b_verbose: bool): # read input kernel and pre-process P-ISA: # resulting instructions will be correctly transformed and ready to be converted into ASM-ISA instructions; # variables used in the kernel will be automatically assigned to banks. - with open(input_file_name, "r", encoding="utf-8") as insts: - insts_listing = preprocessor.preprocess_pisa_kernel_listing( - hec_mem_model, insts, progress_verbose=b_verbose - ) - num_input_instr: int = len( - insts_listing - ) # track number of instructions in input kernel + with open(input_file_name, encoding="utf-8") as insts: + insts_listing = preprocessor.preprocess_pisa_kernel_listing(hec_mem_model, insts, progress_verbose=b_verbose) + num_input_instr: int = len(insts_listing) # track number of instructions in input kernel if b_verbose: print("Assigning register banks to variables...") - preprocessor.assign_register_banks_to_vars( - hec_mem_model, insts_listing, use_bank0=False, verbose=b_verbose - ) + preprocessor.assign_register_banks_to_vars(hec_mem_model, insts_listing, use_bank0=False, verbose=b_verbose) insts_end = time.time() - start_time if b_verbose: @@ -127,7 +122,10 @@ def parse_args(): argparse.Namespace: Parsed command-line arguments. """ parser = argparse.ArgumentParser( - description="HERACLES Assembling Pre-processor.\nThis program performs the preprocessing of P-ISA abstract kernels before further assembling." + description=( + "HERACLES Assembling Pre-processor.\n" + "This program performs the preprocessing of P-ISA abstract kernels before further assembling." + ) ) parser.add_argument( "input_file_name", @@ -172,12 +170,8 @@ def parse_args(): args = parse_args() - args.isa_spec_file = ISASpecConfig.initialize_isa_spec( - module_dir, args.isa_spec_file - ) - args.mem_spec_file = MemSpecConfig.initialize_mem_spec( - module_dir, args.mem_spec_file - ) + args.isa_spec_file = ISASpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) + args.mem_spec_file = MemSpecConfig.initialize_mem_spec(module_dir, args.mem_spec_file) if args.verbose > 0: print(module_name) diff --git a/assembler_tools/hec-assembler-tools/linker/__init__.py b/assembler_tools/hec-assembler-tools/linker/__init__.py index 3d6174e2..c506d314 100644 --- a/assembler_tools/hec-assembler-tools/linker/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/__init__.py @@ -7,7 +7,6 @@ """@brief linker/__init__.py contains classes to encapsulate the memory model used by the linker.""" import collections.abc as collections -from typing import Dict from assembler.common.config import GlobalConfig from assembler.memory_model import mem_info @@ -81,9 +80,7 @@ def force_allocate(self, var_info: VariableInfo, hbm_address: int): ) if var_info.hbm_address != hbm_address: if var_info.hbm_address >= 0: - raise ValueError( - f"`var_info`: variable {var_info.var_name} already allocated in address {var_info.hbm_address}." - ) + raise ValueError(f"`var_info`: variable {var_info.var_name} already allocated in address {var_info.hbm_address}.") in_var_info = self.buffer[hbm_address] # Validate hbm address @@ -96,10 +93,7 @@ def force_allocate(self, var_info: VariableInfo, hbm_address: int): f"when attempting to allocate variable {var_info.var_name}" ) else: - if in_var_info and ( - in_var_info.uses > 0 - or in_var_info.last_kernel_used >= var_info.last_kernel_used - ): + if in_var_info and (in_var_info.uses > 0 or in_var_info.last_kernel_used >= var_info.last_kernel_used): raise RuntimeError( f"HBM address {hbm_address} already occupied by variable {in_var_info.var_name} " f"when attempting to allocate variable {var_info.var_name}" @@ -124,10 +118,7 @@ def allocate(self, var_info: VariableInfo): retval = idx break else: - if not in_var_info or ( - in_var_info.uses <= 0 - and in_var_info.last_kernel_used < var_info.last_kernel_used - ): + if not in_var_info or (in_var_info.uses <= 0 and in_var_info.last_kernel_used < var_info.last_kernel_used): retval = idx break if retval < 0: @@ -149,64 +140,29 @@ def __init__(self, hbm_size_words: int, mem_meta_info: mem_info.MemInfo): """ self.hbm = HBM(hbm_size_words) self.__mem_info = mem_meta_info - self.__variables: Dict[str, VariableInfo] = ( - {} - ) # dict(var_name: str, VariableInfo) + self.__variables: dict[str, VariableInfo] = {} # dict(var_name: str, VariableInfo) # Group related collections into a dictionary self.__mem_collections = { - "keygen_vars": { - var_info.var_name: var_info for var_info in self.__mem_info.keygens - }, - "inputs": { - var_info.var_name: var_info for var_info in self.__mem_info.inputs - }, - "outputs": { - var_info.var_name: var_info for var_info in self.__mem_info.outputs - }, + "keygen_vars": {var_info.var_name: var_info for var_info in self.__mem_info.keygens}, + "inputs": {var_info.var_name: var_info for var_info in self.__mem_info.inputs}, + "outputs": {var_info.var_name: var_info for var_info in self.__mem_info.outputs}, "meta": ( - { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.intt_auxiliary_table - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.intt_routing_table - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.ntt_auxiliary_table - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.ntt_routing_table - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.ones - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.twiddle - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.keygen_seeds - } + {var_info.var_name: var_info for var_info in self.__mem_info.metadata.intt_auxiliary_table} + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.intt_routing_table} + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ntt_auxiliary_table} + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ntt_routing_table} + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ones} + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.twiddle} + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.keygen_seeds} ), } # Derived collections - self.__mem_info_fixed_addr_vars = ( - self.__mem_collections["outputs"] | self.__mem_collections["meta"] - ) + self.__mem_info_fixed_addr_vars = self.__mem_collections["outputs"] | self.__mem_collections["meta"] # Keygen variables should not be part of mem_info_vars set since they # do not start in HBM - self.__mem_info_vars = ( - self.__mem_collections["inputs"] - | self.__mem_collections["outputs"] - | self.__mem_collections["meta"] - ) + self.__mem_info_vars = self.__mem_collections["inputs"] | self.__mem_collections["outputs"] | self.__mem_collections["meta"] @property def mem_info_meta(self) -> collections.Collection: @@ -258,9 +214,7 @@ def add_variable(self, var_name: str): # with predefined HBM address if var_name in self.__mem_info_fixed_addr_vars: var_info.uses = float("inf") - self.hbm.force_allocate( - var_info, self.__mem_info_vars[var_name].hbm_address - ) + self.hbm.force_allocate(var_info, self.__mem_info_vars[var_name].hbm_address) self.variables[var_name] = var_info var_info.uses += 1 @@ -286,8 +240,9 @@ def use_variable(self, var_name: str, kernel: int) -> int: self.hbm.allocate(var_info) assert var_info.hbm_address >= 0 - assert ( - self.hbm.buffer[var_info.hbm_address].var_name == var_info.var_name - ), f"Expected variable {var_info.var_name} in HBM {var_info.hbm_address}, but variable {self.hbm.buffer[var_info.hbm_address].var_name} found instead." + assert self.hbm.buffer[var_info.hbm_address].var_name == var_info.var_name, ( + f"Expected variable {var_info.var_name} in HBM {var_info.hbm_address}," + f" but variable {self.hbm.buffer[var_info.hbm_address].var_name} found instead." + ) return var_info.hbm_address diff --git a/assembler_tools/hec-assembler-tools/linker/he_link_utils.py b/assembler_tools/hec-assembler-tools/linker/he_link_utils.py index dc99672d..cc364d78 100644 --- a/assembler_tools/hec-assembler-tools/linker/he_link_utils.py +++ b/assembler_tools/hec-assembler-tools/linker/he_link_utils.py @@ -8,13 +8,14 @@ @file he_link_utils.py @brief Utility functions for the he_link module """ + import os import pathlib -import linker -from assembler.common import constants -from assembler.common import makeUniquePath +from assembler.common import constants, makeUniquePath from assembler.memory_model import mem_info + +import linker from linker.kern_trace import KernelInfo, remap_dinstrs_vars @@ -44,9 +45,7 @@ def prepare_output_files(run_config) -> KernelInfo: """ path_prefix = os.path.join(run_config.output_dir, run_config.output_prefix) pathlib.Path(run_config.output_dir).mkdir(exist_ok=True, parents=True) - out_mem_file = ( - makeUniquePath(path_prefix + ".mem") if run_config.using_trace_file else None - ) + out_mem_file = makeUniquePath(path_prefix + ".mem") if run_config.using_trace_file else None return KernelInfo( { "directory": run_config.output_dir, @@ -72,11 +71,7 @@ def prepare_input_files(run_config, output_files) -> list: input_files = [] for file_prefix in run_config.input_prefixes: path_prefix = os.path.join(run_config.input_dir, file_prefix) - mem_file = ( - makeUniquePath(path_prefix + ".mem") - if run_config.using_trace_file - else None - ) + mem_file = makeUniquePath(path_prefix + ".mem") if run_config.using_trace_file else None kernel_info = KernelInfo( { "directory": run_config.input_dir, @@ -92,9 +87,7 @@ def prepare_input_files(run_config, output_files) -> list: if not os.path.isfile(input_filename): raise FileNotFoundError(input_filename) if input_filename in output_files.files: - raise RuntimeError( - f'Input files cannot match output files: "{input_filename}"' - ) + raise RuntimeError(f'Input files cannot match output files: "{input_filename}"') return input_files @@ -115,9 +108,7 @@ def update_input_prefixes(kernel_ops, run_config): run_config.input_prefixes = prefixes -def remap_vars( - kernels_info: list[KernelInfo], kernels_dinstrs, kernel_ops, verbose_stream -): +def remap_vars(kernels_info: list[KernelInfo], kernels_dinstrs, kernel_ops, verbose_stream): """ @brief Process kernel DInstructions to remap variables based on kernel operations and update KernelInfo with remap_dict. @@ -127,22 +118,15 @@ def remap_vars( @param kernel_ops List of kernel operations. @param verbose_stream Stream for verbose output. """ - assert len(kernels_info) == len( - kernel_ops - ), "Number of kernels_files must match number of kernel operations." - assert len(kernels_dinstrs) == len( - kernel_ops - ), "Number of kernel_dinstrs must match number of kernel operations." - - for kernel_info, kernel_op, kernel_dinstrs in zip( - kernels_info, kernel_ops, kernels_dinstrs - ): + assert len(kernels_info) == len(kernel_ops), "Number of kernels_files must match number of kernel operations." + assert len(kernels_dinstrs) == len(kernel_ops), "Number of kernel_dinstrs must match number of kernel operations." + + for kernel_info, kernel_op, kernel_dinstrs in zip(kernels_info, kernel_ops, kernels_dinstrs, strict=False): print(f"\tProcessing kernel: {kernel_info.prefix}", file=verbose_stream) expected_prefix = f"{kernel_op.expected_in_kern_file_name}_pisa.tw" assert expected_prefix in kernel_info.prefix, ( - f"Kernel operation prefix {expected_prefix} does not match " - f"kernel file prefix {kernel_info.prefix}" + f"Kernel operation prefix {expected_prefix} does not match " f"kernel file prefix {kernel_info.prefix}" ) # Remap dintrs' variables in kernel_dinstrs and return a mapping dict @@ -159,15 +143,13 @@ def initialize_memory_model(run_config, kernel_dinstrs=None, verbose_stream=None @param verbose_stream Stream for verbose output. @return MemoryModel instance. """ - hbm_capacity_words = constants.convertBytes2Words( - run_config.hbm_size * constants.Constants.KILOBYTE - ) + hbm_capacity_words = constants.convertBytes2Words(run_config.hbm_size * constants.Constants.KILOBYTE) # Parse memory information if kernel_dinstrs: mem_meta_info = mem_info.MemInfo.from_dinstrs(kernel_dinstrs) else: - with open(run_config.input_mem_file, "r", encoding="utf-8") as mem_ifnum: + with open(run_config.input_mem_file, encoding="utf-8") as mem_ifnum: mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) # Initialize memory model diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py index 32e96a7a..8254584f 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py @@ -6,12 +6,12 @@ """@brief This module provides functionality to create instruction objects from a line of text.""" -from typing import Optional from assembler.instructions import tokenize_from_line + from linker.instructions.instruction import BaseInstruction -def create_from_str_line(line: str, factory) -> Optional[BaseInstruction]: +def create_from_str_line(line: str, factory) -> BaseInstruction | None: """ @brief Parses an instruction from a line of text. diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/csyncm.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/csyncm.py index bbcf7dd3..95b008b0 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/csyncm.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/csyncm.py @@ -69,7 +69,5 @@ def target(self, value: int): @throws ValueError If the value is negative. """ if value < 0: - raise ValueError( - f"`value`: expected non-negative target, but {value} received." - ) + raise ValueError(f"`value`: expected non-negative target, but {value} received.") self.tokens[2] = str(value) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/ifetch.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/ifetch.py index 18d0c485..0e92c0cf 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/ifetch.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/ifetch.py @@ -68,7 +68,5 @@ def bundle(self, value: int): @throws ValueError If the value is negative. """ if value < 0: - raise ValueError( - f"`value`: expected non-negative bundle index, but {value} received." - ) + raise ValueError(f"`value`: expected non-negative bundle index, but {value} received.") self.tokens[2] = str(value) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/xinstfetch.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/xinstfetch.py index 25782268..f2468b97 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/xinstfetch.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/xinstfetch.py @@ -50,9 +50,7 @@ def __init__(self, tokens: list, comment: str = ""): @throws ValueError If the number of tokens is invalid or the instruction name is incorrect. """ super().__init__(tokens, comment=comment) - raise NotImplementedError( - "`xinstfetch` CInstruction is not currently supported in linker." - ) + raise NotImplementedError("`xinstfetch` CInstruction is not currently supported in linker.") @property def dst_x_queue(self) -> int: @@ -72,9 +70,7 @@ def dst_x_queue(self, value: int): @throws ValueError If the value is negative. """ if value < 0: - raise ValueError( - f"`value`: expected non-negative value, but {value} received." - ) + raise ValueError(f"`value`: expected non-negative value, but {value} received.") self.tokens[2] = str(value) @property @@ -95,7 +91,5 @@ def src_hbm(self, value: int): @throws ValueError If the value is negative. """ if value < 0: - raise ValueError( - f"`value`: expected non-negative value, but {value} received." - ) + raise ValueError(f"`value`: expected non-negative value, but {value} received.") self.tokens[3] = str(value) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py index 9cbd9899..554975ec 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py @@ -6,11 +6,9 @@ """@brief This module provides functionality to create and manage data instructions""" -from typing import Optional - from assembler.instructions import tokenize_from_line -from . import dload, dstore, dkeygen -from . import dinstruction + +from . import dinstruction, dkeygen, dload, dstore DLoad = dload.Instruction DStore = dstore.Instruction @@ -35,7 +33,7 @@ def create_from_mem_line(line: str) -> dinstruction.DInstruction: parsed from the specified input line. @throws RuntimeError If no valid instruction is found or if there's an error parsing the memory map line. """ - retval: Optional[dinstruction.DInstruction] = None + retval: dinstruction.DInstruction | None = None tokens, comment = tokenize_from_line(line) for instr_type in factory(): try: diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py index bdc3d127..c1066a4b 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py @@ -11,11 +11,12 @@ assembly process, providing common functionality and interfaces. """ -from linker.instructions.instruction import BaseInstruction from assembler.common.counter import Counter from assembler.common.decorators import classproperty from assembler.memory_model.mem_info import MemInfo +from linker.instructions.instruction import BaseInstruction + class DInstruction(BaseInstruction): """ @@ -80,9 +81,7 @@ def _validate_tokens(self, tokens: list) -> None: f"Instruction {self.name} requires at least {self.num_tokens}, but {len(tokens)} received" ) if tokens[self.name_token_index] != self.name: - raise ValueError( - f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received" - ) + raise ValueError(f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received") def __init__(self, tokens: list, comment: str = ""): """ @@ -106,9 +105,7 @@ def __init__(self, tokens: list, comment: str = ""): if self.name in [MemInfo.Const.Keyword.LOAD, MemInfo.Const.Keyword.STORE]: self.address = miv_dict["hbm_address"] except RuntimeError as e: - raise ValueError( - f"Failed to parse memory info from tokens: {tokens}. Error: {str(e)}" - ) from e + raise ValueError(f"Failed to parse memory info from tokens: {tokens}. Error: {str(e)}") from e @property def id(self): diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py index 2bf91fd9..83afb501 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py @@ -12,6 +12,7 @@ """ from assembler.memory_model.mem_info import MemInfo + from .dinstruction import DInstruction diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py index 61d99546..646714c4 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py @@ -12,6 +12,7 @@ """ from assembler.memory_model.mem_info import MemInfo + from .dinstruction import DInstruction diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py index 5b1d6770..74ac6026 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py @@ -8,9 +8,9 @@ @brief Base class for all instructions in the linker. """ -from assembler.common.decorators import classproperty -from assembler.common.counter import Counter from assembler.common.config import GlobalConfig +from assembler.common.counter import Counter +from assembler.common.decorators import classproperty class BaseInstruction: @@ -28,9 +28,7 @@ class BaseInstruction: @fn to_line Retrieves the string form of the instruction to write to the instruction file. """ - __id_count = Counter.count( - 0 - ) # Internal unique sequence counter to generate unique IDs + __id_count = Counter.count(0) # Internal unique sequence counter to generate unique IDs # Class methods and properties # ---------------------------- @@ -147,9 +145,7 @@ def _validate_tokens(self, tokens: list) -> None: ) if tokens[self.name_token_index] != self.name: # pylint: disable=W0143 - raise ValueError( - f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received" - ) + raise ValueError(f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received") def __repr__(self): retval = f"<{type(self).__name__}({self.name}, id={self.id}) object at {hex(id(self))}>(tokens={self.tokens})" diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/minst/msyncc.py b/assembler_tools/hec-assembler-tools/linker/instructions/minst/msyncc.py index 5d374293..d2573040 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/minst/msyncc.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/minst/msyncc.py @@ -69,7 +69,5 @@ def target(self, value: int): @throws ValueError If the value is negative. """ if value < 0: - raise ValueError( - f"`value`: expected non-negative target, but {value} received." - ) + raise ValueError(f"`value`: expected non-negative target, but {value} received.") self.tokens[2] = str(value) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/__init__.py index 6f92646d..12febc34 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/__init__.py @@ -5,19 +5,19 @@ from . import ( add, - sub, - mul, - muli, + intt, mac, maci, + move, + mul, + muli, + nop, ntt, - intt, - twntt, - twintt, rshuffle, - move, + sub, + twintt, + twntt, xstore, - nop, ) from . import exit as exit_mod diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/intt.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/intt.py index 8236e925..0a4f128a 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/intt.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/intt.py @@ -22,7 +22,8 @@ def _get_num_tokens(cls) -> int: @brief Gets the number of tokens required for the instruction. The `intt` instruction requires 10 tokens: - F, , intt, , , , , , , + F, , intt, , , + , , , , @return The number of tokens, which is 10. """ diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/maci.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/maci.py index d8c3e846..22c89e85 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/maci.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/maci.py @@ -1,7 +1,8 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""@brief This module implements the maci X-instruction which performs element-wise polynomial scaling by an immediate value and accumulation.""" +"""@brief This module implements the maci X-instruction which performs +element-wise polynomial scaling by an immediate value and accumulation.""" from .xinstruction import XInstruction diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/ntt.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/ntt.py index 18b3875c..171a12ca 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/ntt.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/ntt.py @@ -22,7 +22,8 @@ def _get_num_tokens(cls) -> int: @brief Gets the number of tokens required for the instruction. The `ntt` instruction requires 10 tokens: - F, , ntt, , , , , , , + F, , ntt, , , + , , , , @return The number of tokens, which is 10. """ diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twintt.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twintt.py index e6e4a245..2fa79f36 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twintt.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twintt.py @@ -22,7 +22,8 @@ def _get_num_tokens(cls) -> int: @brief Gets the number of tokens required for the instruction. The `twintt` instruction requires 10 tokens: - F, , twintt, , , , , , , + F, , twintt, , , + , , , , @return The number of tokens, which is 10. """ diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twntt.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twntt.py index f01fa5ed..abcbd812 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twntt.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twntt.py @@ -22,7 +22,8 @@ def _get_num_tokens(cls) -> int: @brief Gets the number of tokens required for the instruction. The `twntt` instruction requires 10 tokens: - F, , twntt, , , , , , , + F, , twntt, , , + , , , , @return The number of tokens, which is 10. """ diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py index 465f4127..ff185226 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py @@ -75,7 +75,5 @@ def bundle(self, value: int): @throws ValueError If the value is negative. """ if value < 0: - raise ValueError( - f"`value`: expected non-negative bundle index, but {value} received." - ) + raise ValueError(f"`value`: expected non-negative bundle index, but {value} received.") self.tokens[0] = f"F{value}" diff --git a/assembler_tools/hec-assembler-tools/linker/kern_trace/__init__.py b/assembler_tools/hec-assembler-tools/linker/kern_trace/__init__.py index 4d07e3e2..b57437b8 100644 --- a/assembler_tools/hec-assembler-tools/linker/kern_trace/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/kern_trace/__init__.py @@ -7,12 +7,11 @@ This package provides utilities for parsing trace files and extracting kernel operation information. """ -from linker.kern_trace.kern_var import KernVar from linker.kern_trace.context_config import ContextConfig -from linker.kern_trace.kernel_op import KernelOp -from linker.kern_trace.trace_info import TraceInfo from linker.kern_trace.kern_remap import remap_dinstrs_vars, remap_m_c_instrs_vars -from linker.kern_trace.trace_info import KernelInfo +from linker.kern_trace.kern_var import KernVar +from linker.kern_trace.kernel_op import KernelOp +from linker.kern_trace.trace_info import KernelInfo, TraceInfo __all__ = [ "KernVar", diff --git a/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_remap.py b/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_remap.py index 890d22d2..1ea9bbd6 100644 --- a/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_remap.py +++ b/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_remap.py @@ -7,16 +7,15 @@ """@brief Module for remapping kernel variables in DINST files.""" import re + +from linker.instructions import cinst, minst +from linker.instructions.cinst.cinstruction import CInstruction from linker.instructions.dinst.dinstruction import DInstruction from linker.instructions.minst.minstruction import MInstruction -from linker.instructions.cinst.cinstruction import CInstruction -from linker.instructions import minst, cinst from linker.kern_trace.kernel_op import KernelOp -def remap_dinstrs_vars( - kernel_dinstrs: list[DInstruction], kernel_op: KernelOp -) -> dict[str, str]: +def remap_dinstrs_vars(kernel_dinstrs: list[DInstruction], kernel_op: KernelOp) -> dict[str, str]: """ @brief Remaps variable names in DInstructions based on KernelOp variables. @@ -44,9 +43,7 @@ def remap_dinstrs_vars( try: prefix, rest = dinstr.var.split("_", 1) except ValueError as e: - raise ValueError( - f"Unexpected format: variable name '{dinstr.var}' does not contain items to split by '_': {e}" - ) from e + raise ValueError(f"Unexpected format: variable name '{dinstr.var}' does not contain items to split by '_': {e}") from e # Skip if prefix is not 'ct' or 'pt' if not (prefix.lower().startswith("ct") or prefix.lower().startswith("pt")): @@ -56,9 +53,7 @@ def remap_dinstrs_vars( match = re.search(r"([a-zA-Z]+)(\d+)", prefix) if not match: - raise ValueError( - f"Unexpected format: variable prefix '{prefix}' does not contain a number after text." - ) + raise ValueError(f"Unexpected format: variable prefix '{prefix}' does not contain a number after text.") number_part = int(match.group(2)) @@ -68,7 +63,8 @@ def remap_dinstrs_vars( kern_var = sorted_kern_vars[number_part] except IndexError as exc: raise IndexError( - f"Number part {number_part} from prefix '{prefix}' is out of range [0, {len(sorted_kern_vars)-1}] for the KernelOp variables" + f"Number part {number_part} from prefix '{prefix}' is out of range [0, {len(sorted_kern_vars)-1}]" + "for the KernelOp variables" ) from exc old_var = dinstr.var @@ -91,20 +87,14 @@ def remap_m_c_instrs_vars(kernel_instrs: list, remap_dict: dict[str, str]) -> No """ if remap_dict: for instr in kernel_instrs: - if not isinstance(instr, (MInstruction, CInstruction)): + if not isinstance(instr, MInstruction | CInstruction): raise TypeError(f"Item {instr} is not a valid M or C Instruction.") - if isinstance( - instr, (minst.MLoad, cinst.BLoad, cinst.CLoad, cinst.BOnes, cinst.NLoad) - ): + if isinstance(instr, minst.MLoad | cinst.BLoad | cinst.CLoad | cinst.BOnes | cinst.NLoad): if instr.source in remap_dict: - instr.comment = instr.comment.replace( - instr.source, remap_dict[instr.source] - ) + instr.comment = instr.comment.replace(instr.source, remap_dict[instr.source]) instr.source = remap_dict[instr.source] - elif isinstance(instr, (minst.MStore, cinst.CStore)): + elif isinstance(instr, minst.MStore | cinst.CStore): if instr.dest in remap_dict: - instr.comment = instr.comment.replace( - instr.dest, remap_dict[instr.dest] - ) + instr.comment = instr.comment.replace(instr.dest, remap_dict[instr.dest]) instr.dest = remap_dict[instr.dest] diff --git a/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_var.py b/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_var.py index 0b886bdd..20db040e 100644 --- a/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_var.py +++ b/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_var.py @@ -40,9 +40,7 @@ def from_string(cls, var_str: str): if len(parts) != 3: raise ValueError(f"Invalid kernel variable string format: {var_str}") if not parts[1].isdigit() or not parts[2].isdigit(): - raise ValueError( - f"Invalid degree or level in kernel variable string: {var_str}" - ) + raise ValueError(f"Invalid degree or level in kernel variable string: {var_str}") if not parts[0]: raise ValueError(f"Invalid label in kernel variable string: {var_str}") diff --git a/assembler_tools/hec-assembler-tools/linker/kern_trace/kernel_op.py b/assembler_tools/hec-assembler-tools/linker/kern_trace/kernel_op.py index 1f320598..784d098c 100644 --- a/assembler_tools/hec-assembler-tools/linker/kern_trace/kernel_op.py +++ b/assembler_tools/hec-assembler-tools/linker/kern_trace/kernel_op.py @@ -79,9 +79,7 @@ def get_level(self, kern_vars: list[KernVar]) -> int: which is used to categorize the kernel operation. """ if not kern_vars: - raise ValueError( - "Kernel operation must have at least one variable to determine level." - ) + raise ValueError("Kernel operation must have at least one variable to determine level.") # Assuming all input variables have the same level for the operation return kern_vars[1].level if len(kern_vars) > 1 else kern_vars[0].level @@ -101,15 +99,9 @@ def __init__( """ if name.lower() not in self.valid_kernel_ops: - raise ValueError( - f"Invalid kernel operation name: {name}. " - f"Valid names are: {', '.join(self.valid_kernel_ops)}" - ) + raise ValueError(f"Invalid kernel operation name: {name}. " f"Valid names are: {', '.join(self.valid_kernel_ops)}") if context_config.scheme.lower() not in self.valid_schemes: - raise ValueError( - f"Invalid encryption scheme: {context_config.scheme}. " - f"Valid schemes are: {', '.join(self.valid_schemes)}" - ) + raise ValueError(f"Invalid encryption scheme: {context_config.scheme}. " f"Valid schemes are: {', '.join(self.valid_schemes)}") if len(kern_args) < 2: raise ValueError("Kernel operation must have at least two arguments.") diff --git a/assembler_tools/hec-assembler-tools/linker/kern_trace/trace_info.py b/assembler_tools/hec-assembler-tools/linker/kern_trace/trace_info.py index 010f7bf0..db150dbe 100644 --- a/assembler_tools/hec-assembler-tools/linker/kern_trace/trace_info.py +++ b/assembler_tools/hec-assembler-tools/linker/kern_trace/trace_info.py @@ -7,9 +7,9 @@ """@brief Module for parsing and analyzing trace files.""" import os -from typing import Optional from assembler.instructions import tokenize_from_line + from linker.kern_trace.context_config import ContextConfig from linker.kern_trace.kernel_op import KernelOp @@ -38,7 +38,7 @@ class KernelInfo: minst: str cinst: str xinst: str - mem: Optional[str] = None + mem: str | None = None remap_dict: dict[str, str] = {} def __init__(self, config: dict): @@ -131,13 +131,9 @@ def extract_context_and_args(self, tokens, param_idxs, line_num): return name, context_config, kern_args except KeyError as e: - raise KeyError( - f"Missing required parameter in line {line_num} with tokens: {tokens}: {e}" - ) from e + raise KeyError(f"Missing required parameter in line {line_num} with tokens: {tokens}: {e}") from e except IndexError as e: - raise ValueError( - f"Invalid number of parameters in line {line_num}: {e}" - ) from e + raise ValueError(f"Invalid number of parameters in line {line_num}: {e}") from e except ValueError as e: raise ValueError(f"Invalid value in line {line_num}: {e}") from e @@ -153,7 +149,7 @@ def parse_kernel_ops(self) -> list[KernelOp]: kernel_ops: list = [] - with open(self._trace_file, "r", encoding="utf-8") as file: + with open(self._trace_file, encoding="utf-8") as file: lines = file.readlines() if not lines: @@ -170,9 +166,7 @@ def parse_kernel_ops(self) -> list[KernelOp]: if not tokens or not tokens[0]: # Skip empty lines continue - name, context_config, kern_args = self.extract_context_and_args( - tokens, param_idxs, line_num - ) + name, context_config, kern_args = self.extract_context_and_args(tokens, param_idxs, line_num) # Create and add KernelOp with all arguments kernel_op = KernelOp(name, context_config, kern_args) diff --git a/assembler_tools/hec-assembler-tools/linker/linker_run_config.py b/assembler_tools/hec-assembler-tools/linker/linker_run_config.py index a98f7f17..6272fdb5 100644 --- a/assembler_tools/hec-assembler-tools/linker/linker_run_config.py +++ b/assembler_tools/hec-assembler-tools/linker/linker_run_config.py @@ -8,6 +8,7 @@ @file linker_run_config.py @brief This module provides configuration for the linker process. """ + import io import os from typing import Any @@ -73,9 +74,7 @@ def __init__(self, **kwargs): if not isinstance(kwargs["hbm_size"], int): raise ValueError("Invalid param: hbm_size must be an integer") if kwargs["hbm_size"] < 0: - raise ValueError( - "Invalid param: hbm_size must be a non-negative integer" - ) + raise ValueError("Invalid param: hbm_size must be a non-negative integer") if "has_hbm" in kwargs and not isinstance(kwargs["has_hbm"], bool): raise ValueError("Invalid param: has_hbm must be a boolean value") @@ -89,9 +88,7 @@ def __init__(self, **kwargs): if not hasattr(self, config_name): setattr(self, config_name, default_value) if getattr(self, config_name) is None: - raise TypeError( - f"Expected value for configuration `{config_name}`, but `None` received." - ) + raise TypeError(f"Expected value for configuration `{config_name}`, but `None` received.") # Fix file paths # E0203: Access to member 'input_mem_file' before its definition. @@ -141,10 +138,5 @@ def as_dict(self) -> dict: """ retval = super().as_dict() tmp_self_dict = vars(self) - retval.update( - { - config_name: tmp_self_dict[config_name] - for config_name in self.__default_config - } - ) + retval.update({config_name: tmp_self_dict[config_name] for config_name in self.__default_config}) return retval diff --git a/assembler_tools/hec-assembler-tools/linker/loader.py b/assembler_tools/hec-assembler-tools/linker/loader.py index 443914c4..328c9999 100644 --- a/assembler_tools/hec-assembler-tools/linker/loader.py +++ b/assembler_tools/hec-assembler-tools/linker/loader.py @@ -9,14 +9,10 @@ """ import copy - from typing import Any -from linker.instructions import minst -from linker.instructions import cinst -from linker.instructions import xinst -from linker.instructions import dinst from linker import instructions +from linker.instructions import cinst, dinst, minst, xinst class Loader: @@ -66,7 +62,7 @@ def load_minst_kernel_from_file(cls, filename: str, use_cache: bool = True) -> l if use_cache and cache_key in cls._file_cache: return copy.deepcopy(cls._file_cache[cache_key]) - with open(filename, "r", encoding="utf-8") as kernel_minsts: + with open(filename, encoding="utf-8") as kernel_minsts: try: result = cls.load_minst_kernel(kernel_minsts) if use_cache: @@ -106,7 +102,7 @@ def load_cinst_kernel_from_file(cls, filename: str, use_cache: bool = True) -> l if use_cache and cache_key in cls._file_cache: return copy.deepcopy(cls._file_cache[cache_key]) - with open(filename, "r", encoding="utf-8") as kernel_cinsts: + with open(filename, encoding="utf-8") as kernel_cinsts: try: result = cls.load_cinst_kernel(kernel_cinsts) if use_cache: @@ -146,7 +142,7 @@ def load_xinst_kernel_from_file(cls, filename: str, use_cache: bool = True) -> l if use_cache and cache_key in cls._file_cache: return copy.deepcopy(cls._file_cache[cache_key]) - with open(filename, "r", encoding="utf-8") as kernel_xinsts: + with open(filename, encoding="utf-8") as kernel_xinsts: try: result = cls.load_xinst_kernel(kernel_xinsts) if use_cache: @@ -187,7 +183,7 @@ def load_dinst_kernel_from_file(cls, filename: str, use_cache: bool = True) -> l if use_cache and cache_key in cls._file_cache: return copy.deepcopy(cls._file_cache[cache_key]) - with open(filename, "r", encoding="utf-8") as kernel_dinsts: + with open(filename, encoding="utf-8") as kernel_dinsts: try: result = cls.load_dinst_kernel(kernel_dinsts) if use_cache: diff --git a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py index 43db1a98..29a0cf35 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py @@ -6,14 +6,16 @@ """@brief This module provides functionality to link kernels into a program.""" -from typing import Dict, Any, cast +from typing import Any, cast + +from assembler.common.config import GlobalConfig +from assembler.instructions import cinst as ISACInst + from linker import MemoryModel -from linker.loader import Loader -from linker.instructions import minst, cinst, dinst +from linker.instructions import cinst, dinst, minst from linker.instructions.dinst.dinstruction import DInstruction from linker.kern_trace.kern_remap import remap_m_c_instrs_vars -from assembler.common.config import GlobalConfig -from assembler.instructions import cinst as ISACInst +from linker.loader import Loader class LinkedProgram: # pylint: disable=too-many-instance-attributes @@ -54,9 +56,7 @@ def __init__( self._minst_line_offset = 0 self._cinst_line_offset = 0 self._kernel_count = 0 # Number of kernels linked into this program - self._is_open = ( - True # Tracks whether this program is still accepting kernels to link - ) + self._is_open = True # Tracks whether this program is still accepting kernels to link @property def is_open(self) -> bool: @@ -115,20 +115,16 @@ def _validate_hbm_address(self, var_name: str, hbm_address: int): @exception RuntimeError If the HBM address is invalid or does not match the declared address. """ if hbm_address < 0: - raise RuntimeError( - f'Invalid negative HBM address for variable "{var_name}".' - ) + raise RuntimeError(f'Invalid negative HBM address for variable "{var_name}".') if var_name in self.__mem_model.mem_info_vars: # Cast to dictionary to fix the indexing error - mem_info_vars_dict = cast(Dict[str, Any], self.__mem_model.mem_info_vars) + mem_info_vars_dict = cast(dict[str, Any], self.__mem_model.mem_info_vars) if mem_info_vars_dict[var_name].hbm_address != hbm_address: raise RuntimeError( - ( - f"Declared HBM address " - f"({mem_info_vars_dict[var_name].hbm_address})" - f" of mem Variable '{var_name}'" - f" differs from allocated HBM address ({hbm_address})." - ) + f"Declared HBM address " + f"({mem_info_vars_dict[var_name].hbm_address})" + f" of mem Variable '{var_name}'" + f" differs from allocated HBM address ({hbm_address})." ) def _validate_spad_address(self, var_name: str, spad_address: int): @@ -146,20 +142,16 @@ def _validate_spad_address(self, var_name: str, spad_address: int): # this method will validate the variable SPAD address against the # original HBM address, since there is no HBM if spad_address < 0: - raise RuntimeError( - f'Invalid negative SPAD address for variable "{var_name}".' - ) + raise RuntimeError(f'Invalid negative SPAD address for variable "{var_name}".') if var_name in self.__mem_model.mem_info_vars: # Cast to dictionary to fix the indexing error - mem_info_vars_dict = cast(Dict[str, Any], self.__mem_model.mem_info_vars) + mem_info_vars_dict = cast(dict[str, Any], self.__mem_model.mem_info_vars) if mem_info_vars_dict[var_name].hbm_address != spad_address: raise RuntimeError( - ( - f"Declared HBM address" - f" ({mem_info_vars_dict[var_name].hbm_address})" - f" of mem Variable '{var_name}'" - f" differs from allocated HBM address ({spad_address})." - ) + f"Declared HBM address" + f" ({mem_info_vars_dict[var_name].hbm_address})" + f" of mem Variable '{var_name}'" + f" differs from allocated HBM address ({spad_address})." ) def _update_minsts(self, kernel_minstrs: list): @@ -180,29 +172,17 @@ def _update_minsts(self, kernel_minstrs: list): # Change mload variable names into HBM addresses if isinstance(minstr, minst.MLoad): var_name = minstr.source - hbm_address = self.__mem_model.use_variable( - var_name, self._kernel_count - ) + hbm_address = self.__mem_model.use_variable(var_name, self._kernel_count) self._validate_hbm_address(var_name, hbm_address) minstr.source = str(hbm_address) - minstr.comment = ( - f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" - if minstr.comment - else "" - ) + minstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" if minstr.comment else "" # Change mstore variable names into HBM addresses if isinstance(minstr, minst.MStore): var_name = minstr.dest - hbm_address = self.__mem_model.use_variable( - var_name, self._kernel_count - ) + hbm_address = self.__mem_model.use_variable(var_name, self._kernel_count) self._validate_hbm_address(var_name, hbm_address) minstr.dest = str(hbm_address) - minstr.comment = ( - f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" - if minstr.comment - else "" - ) + minstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" if minstr.comment else "" def _remove_and_merge_csyncm_cnop(self, kernel_cinstrs: list): """ @@ -228,7 +208,7 @@ def _remove_and_merge_csyncm_cnop(self, kernel_cinstrs: list): # Idle cycles to account for the csyncm have been added csyncm_count = 0 - if isinstance(cinstr, (cinst.IFetch, cinst.NLoad, cinst.BLoad)): + if isinstance(cinstr, cinst.IFetch | cinst.NLoad | cinst.BLoad): if csyncm_count > 0: # Extra cycles needed before scheduling next bundle # Subtract 1 because cnop n, waits for n+1 cycles @@ -269,7 +249,8 @@ def _remove_and_merge_csyncm_cnop(self, kernel_cinstrs: list): # # replace instruction by cnop # kernel_cinstrs.pop(i) # if current_bundle > 0: - # cinstr_nop = cinst.CNop([i, cinst.CNop.name, str(ISACInst.CSyncm.get_throughput())]) # Subtract 1 because cnop n, waits for n+1 cycles + # # Subtract 1 because cnop n, waits for n+1 cycles + # cinstr_nop = cinst.CNop([i, cinst.CNop.name, str(ISACInst.CSyncm.get_throughput())]) # kernel_cinstrs.insert(i, cinstr_nop) # # i += 1 # next instruction @@ -305,9 +286,7 @@ def _update_cinsts_addresses_and_offsets(self, kernel_cinstrs: list): cinstr.bundle = cinstr.bundle + self._bundle_offset # Update xinstfetch if isinstance(cinstr, cinst.XInstFetch): - raise NotImplementedError( - "`xinstfetch` not currently supported by linker." - ) + raise NotImplementedError("`xinstfetch` not currently supported by linker.") # Update csyncm if isinstance(cinstr, cinst.CSyncm): cinstr.target = cinstr.target + self._minst_line_offset @@ -315,32 +294,18 @@ def _update_cinsts_addresses_and_offsets(self, kernel_cinstrs: list): if not GlobalConfig.hasHBM: # update all SPAD instruction variable names to be SPAD addresses # change xload variable names into SPAD addresses - if isinstance( - cinstr, (cinst.BLoad, cinst.BOnes, cinst.CLoad, cinst.NLoad) - ): + if isinstance(cinstr, cinst.BLoad | cinst.BOnes | cinst.CLoad | cinst.NLoad): var_name = cinstr.source - hbm_address = self.__mem_model.use_variable( - var_name, self._kernel_count - ) + hbm_address = self.__mem_model.use_variable(var_name, self._kernel_count) self._validate_spad_address(var_name, hbm_address) cinstr.source = str(hbm_address) - cinstr.comment = ( - f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" - if cinstr.comment - else "" - ) + cinstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" if cinstr.comment else "" if isinstance(cinstr, cinst.CStore): var_name = cinstr.dest - hbm_address = self.__mem_model.use_variable( - var_name, self._kernel_count - ) + hbm_address = self.__mem_model.use_variable(var_name, self._kernel_count) self._validate_spad_address(var_name, hbm_address) cinstr.dest = str(hbm_address) - cinstr.comment = ( - f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" - if cinstr.comment - else "" - ) + cinstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" if cinstr.comment else "" def _update_cinsts(self, kernel_cinstrs: list): """ @@ -373,15 +338,11 @@ def _update_xinsts(self, kernel_xinstrs: list) -> int: for xinstr in kernel_xinstrs: xinstr.bundle = xinstr.bundle + self._bundle_offset if last_bundle > xinstr.bundle: - raise RuntimeError( - f'Detected invalid bundle. Instruction bundle is less than previous: "{xinstr.to_line()}"' - ) + raise RuntimeError(f'Detected invalid bundle. Instruction bundle is less than previous: "{xinstr.to_line()}"') last_bundle = xinstr.bundle return last_bundle - def link_kernel( - self, kernel_minstrs: list, kernel_cinstrs: list, kernel_xinstrs: list - ): + def link_kernel(self, kernel_minstrs: list, kernel_cinstrs: list, kernel_xinstrs: list): """ @brief Links a specified kernel (given by its three instruction queues) into this program. @@ -431,18 +392,12 @@ def link_kernel( print(f" #{minstr.comment}", end="", file=self._minst_ostream) print(file=self._minst_ostream) - self._minst_line_offset += ( - len(kernel_minstrs) - 1 - ) # Subtract last line that is getting removed - self._cinst_line_offset += ( - len(kernel_cinstrs) - 1 - ) # Subtract last line that is getting removed + self._minst_line_offset += len(kernel_minstrs) - 1 # Subtract last line that is getting removed + self._cinst_line_offset += len(kernel_cinstrs) - 1 # Subtract last line that is getting removed self._kernel_count += 1 # Count the appended kernel @classmethod - def join_dinst_kernels( - cls, kernels_instrs: list[list[DInstruction]] - ) -> list[DInstruction]: + def join_dinst_kernels(cls, kernels_instrs: list[list[DInstruction]]) -> list[DInstruction]: """ @brief Joins a list of dinst kernels, consolidating variables that are outputs in one kernel and inputs in the next. This ensures that variables carried across kernels are not duplicated, @@ -466,20 +421,17 @@ def join_dinst_kernels( new_kernels_instrs: list[DInstruction] = [] for kernel_instrs in kernels_instrs: for cur_dinst in kernel_instrs: - # Save the current output instruction to add at the end if isinstance(cur_dinst, dinst.DStore): key = cur_dinst.var carry_over_vars[key] = cur_dinst continue - if isinstance(cur_dinst, (dinst.DLoad, dinst.DKeyGen)): + if isinstance(cur_dinst, dinst.DLoad | dinst.DKeyGen): key = cur_dinst.var # Skip if the input is already in carry-over from previous outputs if key in carry_over_vars: - carry_over_vars.pop( - key - ) # Remove from (output) carry-overs since it's now an input + carry_over_vars.pop(key) # Remove from (output) carry-overs since it's now an input continue # If the input is not (a previous output) in carry-over, add if it's not already (loaded) in inputs @@ -500,9 +452,7 @@ def join_dinst_kernels( return new_kernels_instrs @staticmethod - def link_kernels_to_files( - input_files, output_files, mem_model, verbose_stream=None - ): + def link_kernels_to_files(input_files, output_files, mem_model, verbose_stream=None): """ @brief Links input kernels and writes the output to the specified files. @@ -511,15 +461,12 @@ def link_kernels_to_files( @param mem_model Memory model to use. @param verbose_stream Stream for verbose output. """ - with open(output_files.minst, "w", encoding="utf-8") as fnum_output_minst, open( - output_files.cinst, "w", encoding="utf-8" - ) as fnum_output_cinst, open( - output_files.xinst, "w", encoding="utf-8" - ) as fnum_output_xinst: - - result_program = LinkedProgram( - fnum_output_minst, fnum_output_cinst, fnum_output_xinst, mem_model - ) + with ( + open(output_files.minst, "w", encoding="utf-8") as fnum_output_minst, + open(output_files.cinst, "w", encoding="utf-8") as fnum_output_cinst, + open(output_files.xinst, "w", encoding="utf-8") as fnum_output_xinst, + ): + result_program = LinkedProgram(fnum_output_minst, fnum_output_cinst, fnum_output_xinst, mem_model) for idx, kernel in enumerate(input_files): if verbose_stream: @@ -535,9 +482,7 @@ def link_kernels_to_files( remap_m_c_instrs_vars(kernel_minstrs, kernel.remap_dict) remap_m_c_instrs_vars(kernel_cinstrs, kernel.remap_dict) - result_program.link_kernel( - kernel_minstrs, kernel_cinstrs, kernel_xinstrs - ) + result_program.link_kernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) if verbose_stream: print( "[ 100% ] Finalizing output", diff --git a/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py b/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py index dfd1485c..922846f7 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py @@ -4,13 +4,16 @@ """ @brief This module provides functionality to discover variable names in MInstructions and CInstructions. """ -from typing import Optional, TextIO, List -from assembler.memory_model.variable import Variable -from assembler.memory_model import MemoryModel + +from typing import TextIO + from assembler.common.config import GlobalConfig -from linker.instructions import minst, cinst -from linker.instructions.minst.minstruction import MInstruction +from assembler.memory_model import MemoryModel +from assembler.memory_model.variable import Variable + +from linker.instructions import cinst, minst from linker.instructions.cinst.cinstruction import CInstruction +from linker.instructions.minst.minstruction import MInstruction from linker.kern_trace import KernelInfo, remap_m_c_instrs_vars from linker.loader import Loader @@ -27,9 +30,7 @@ def discover_variables_spad(cinstrs: list): """ for idx, cinstr in enumerate(cinstrs): if not isinstance(cinstr, CInstruction): - raise TypeError( - f"Item {idx} in list of CInstructions is not a valid CInstruction." - ) + raise TypeError(f"Item {idx} in list of CInstructions is not a valid CInstruction.") retval = None # TODO: Implement variable counting for CInst ############### @@ -41,9 +42,7 @@ def discover_variables_spad(cinstrs: list): if retval is not None: if not Variable.validateName(retval): - raise RuntimeError( - f'Invalid Variable name "{retval}" detected in instruction "{idx}, {cinstr.to_line()}"' - ) + raise RuntimeError(f'Invalid Variable name "{retval}" detected in instruction "{idx}, {cinstr.to_line()}"') yield retval @@ -59,9 +58,7 @@ def discover_variables(minstrs: list): """ for idx, minstr in enumerate(minstrs): if not isinstance(minstr, MInstruction): - raise TypeError( - f"Item {idx} in list of MInstructions is not a valid MInstruction." - ) + raise TypeError(f"Item {idx} in list of MInstructions is not a valid MInstruction.") retval = None if isinstance(minstr, minst.MLoad): retval = minstr.source @@ -70,16 +67,14 @@ def discover_variables(minstrs: list): if retval is not None: if not Variable.validateName(retval): - raise RuntimeError( - f'Invalid Variable name "{retval}" detected in instruction "{idx}, {minstr.to_line()}"' - ) + raise RuntimeError(f'Invalid Variable name "{retval}" detected in instruction "{idx}, {minstr.to_line()}"') yield retval def scan_variables( - kernels_info: List[KernelInfo], + kernels_info: list[KernelInfo], mem_model: MemoryModel, - verbose_stream: Optional[TextIO] = None, + verbose_stream: TextIO | None = None, ): """ @brief Scans input files for variables and adds them to the memory model. @@ -89,7 +84,6 @@ def scan_variables( @param verbose_stream Stream for verbose output. """ for idx, kernel_info in enumerate(kernels_info): - if not GlobalConfig.hasHBM: if verbose_stream: print( @@ -124,6 +118,4 @@ def check_unused_variables(mem_model): for var_name in mem_model.mem_info_vars: if var_name not in mem_model.variables: if GlobalConfig.hasHBM or var_name not in mem_model.mem_info_meta: - raise RuntimeError( - f'Unused variable from input mem file: "{var_name}" not in memory model.' - ) + raise RuntimeError(f'Unused variable from input mem file: "{var_name}" not in memory model.') diff --git a/assembler_tools/hec-assembler-tools/pytest.ini b/assembler_tools/hec-assembler-tools/pytest.ini deleted file mode 100644 index 06e92c68..00000000 --- a/assembler_tools/hec-assembler-tools/pytest.ini +++ /dev/null @@ -1,4 +0,0 @@ -[pytest] -pythonpath = . -testpaths = tests -#addopts = --cov=. diff --git a/assembler_tools/hec-assembler-tools/tests/conftest.py b/assembler_tools/hec-assembler-tools/tests/conftest.py index b8bc75ec..87a1b4e3 100644 --- a/assembler_tools/hec-assembler-tools/tests/conftest.py +++ b/assembler_tools/hec-assembler-tools/tests/conftest.py @@ -11,7 +11,6 @@ from unittest.mock import patch import pytest - from assembler.spec_config.isa_spec import ISASpecConfig from assembler.spec_config.mem_spec import MemSpecConfig diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py index 6858b858..0844eac9 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py @@ -6,15 +6,15 @@ """ import unittest -from unittest.mock import patch, MagicMock, call, PropertyMock +from unittest.mock import MagicMock, PropertyMock, call, patch from assembler.memory_model import MemoryModel from assembler.memory_model.mem_info import ( - MemInfoVariable, - MemInfoKeygenVariable, MemInfo, - updateMemoryModelWithMemInfo, + MemInfoKeygenVariable, + MemInfoVariable, _allocateMemInfoVariable, + updateMemoryModelWithMemInfo, ) @@ -23,44 +23,32 @@ class TestMemInfoVariable(unittest.TestCase): def test_init_valid(self): """@brief Test initialization with valid parameters.""" - with patch( - "assembler.memory_model.variable.Variable.validateName", return_value=True - ): + with patch("assembler.memory_model.variable.Variable.validateName", return_value=True): var = MemInfoVariable("test_var", 42) self.assertEqual(var.var_name, "test_var") self.assertEqual(var.hbm_address, 42) def test_init_strips_whitespace(self): """@brief Test that initialization strips whitespace from variable name.""" - with patch( - "assembler.memory_model.variable.Variable.validateName", return_value=True - ): + with patch("assembler.memory_model.variable.Variable.validateName", return_value=True): var = MemInfoVariable(" test_var ", 42) self.assertEqual(var.var_name, "test_var") def test_init_invalid_name(self): """@brief Test initialization with invalid variable name.""" - with patch( - "assembler.memory_model.variable.Variable.validateName", return_value=False - ): + with patch("assembler.memory_model.variable.Variable.validateName", return_value=False): with self.assertRaises(RuntimeError): MemInfoVariable("invalid!var", 42) def test_repr(self): """@brief Test the __repr__ method.""" - with patch( - "assembler.memory_model.variable.Variable.validateName", return_value=True - ): + with patch("assembler.memory_model.variable.Variable.validateName", return_value=True): var = MemInfoVariable("test_var", 42) - self.assertEqual( - repr(var), repr({"var_name": "test_var", "hbm_address": 42}) - ) + self.assertEqual(repr(var), repr({"var_name": "test_var", "hbm_address": 42})) def test_as_dict(self): """@brief Test the as_dict method.""" - with patch( - "assembler.memory_model.variable.Variable.validateName", return_value=True - ): + with patch("assembler.memory_model.variable.Variable.validateName", return_value=True): var = MemInfoVariable("test_var", 42) self.assertEqual(var.as_dict(), {"var_name": "test_var", "hbm_address": 42}) @@ -70,9 +58,7 @@ class TestMemInfoKeygenVariable(unittest.TestCase): def test_init_valid(self): """@brief Test initialization with valid parameters.""" - with patch( - "assembler.memory_model.variable.Variable.validateName", return_value=True - ): + with patch("assembler.memory_model.variable.Variable.validateName", return_value=True): var = MemInfoKeygenVariable("test_var", 2, 3) self.assertEqual(var.var_name, "test_var") self.assertEqual(var.hbm_address, -1) # Should be initialized to -1 @@ -81,29 +67,21 @@ def test_init_valid(self): def test_init_negative_seed_index(self): """@brief Test initialization with negative seed index.""" - with patch( - "assembler.memory_model.variable.Variable.validateName", return_value=True - ): + with patch("assembler.memory_model.variable.Variable.validateName", return_value=True): with self.assertRaises(IndexError): MemInfoKeygenVariable("test_var", -1, 3) def test_init_negative_key_index(self): """@brief Test initialization with negative key index.""" - with patch( - "assembler.memory_model.variable.Variable.validateName", return_value=True - ): + with patch("assembler.memory_model.variable.Variable.validateName", return_value=True): with self.assertRaises(IndexError): MemInfoKeygenVariable("test_var", 2, -1) def test_as_dict(self): """@brief Test the as_dict method.""" - with patch( - "assembler.memory_model.variable.Variable.validateName", return_value=True - ): + with patch("assembler.memory_model.variable.Variable.validateName", return_value=True): var = MemInfoKeygenVariable("test_var", 2, 3) - self.assertEqual( - var.as_dict(), {"var_name": "test_var", "seed_index": 2, "key_index": 3} - ) + self.assertEqual(var.as_dict(), {"var_name": "test_var", "seed_index": 2, "key_index": 3}) class TestMemInfoMetadata(unittest.TestCase): @@ -112,9 +90,7 @@ class TestMemInfoMetadata(unittest.TestCase): def test_parse_meta_field_from_mem_tokens_valid(self): """@brief Test parsing a valid metadata field.""" tokens = ["dload", "LOAD_ONES", "42", "ones_var"] - result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( - tokens, "LOAD_ONES", var_prefix="ONES" - ) + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, "LOAD_ONES", var_prefix="ONES") self.assertIsNotNone(result) self.assertEqual(result.var_name, "ones_var") self.assertEqual(result.hbm_address, 42) @@ -122,9 +98,7 @@ def test_parse_meta_field_from_mem_tokens_valid(self): def test_parse_meta_field_from_mem_tokens_no_name(self): """@brief Test parsing a metadata field without explicit name.""" tokens = ["dload", "LOAD_ONES", "42"] - result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( - tokens, "LOAD_ONES", var_prefix="ONES" - ) + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, "LOAD_ONES", var_prefix="ONES") self.assertIsNotNone(result) self.assertEqual(result.var_name, "ONES_42") self.assertEqual(result.hbm_address, 42) @@ -132,9 +106,7 @@ def test_parse_meta_field_from_mem_tokens_no_name(self): def test_parse_meta_field_from_mem_tokens_with_extra(self): """@brief Test parsing a metadata field with var_extra.""" tokens = ["dload", "LOAD_ONES", "42"] - result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( - tokens, "LOAD_ONES", var_prefix="ONES", var_extra="_extra" - ) + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, "LOAD_ONES", var_prefix="ONES", var_extra="_extra") self.assertIsNotNone(result) self.assertEqual(result.var_name, "ONES_extra") self.assertEqual(result.hbm_address, 42) @@ -143,23 +115,17 @@ def test_parse_meta_field_from_mem_tokens_invalid(self): """@brief Test parsing an invalid metadata field.""" # Not enough tokens tokens = ["dload"] - result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( - tokens, "LOAD_ONES", var_prefix="ONES" - ) + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, "LOAD_ONES", var_prefix="ONES") self.assertIsNone(result) # Wrong first token tokens = ["wrong", "LOAD_ONES", "42"] - result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( - tokens, "LOAD_ONES", var_prefix="ONES" - ) + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, "LOAD_ONES", var_prefix="ONES") self.assertIsNone(result) # Wrong second token tokens = ["dload", "WRONG", "42"] - result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( - tokens, "LOAD_ONES", var_prefix="ONES" - ) + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, "LOAD_ONES", var_prefix="ONES") self.assertIsNone(result) def test_metadata_init_and_properties(self): @@ -217,9 +183,7 @@ class TestMemInfoParsers(unittest.TestCase): def test_ones_parse_from_mem_tokens(self): """@brief Test parsing Ones metadata from tokens.""" tokens = ["dload", "LOAD_ONES", "42", "ones_var"] - with patch( - "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" - ) as mock_parse: + with patch("assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens") as mock_parse: mock_parse.return_value = MagicMock() result = MemInfo.Metadata.Ones.parse_from_mem_tokens(tokens) mock_parse.assert_called_once_with( @@ -232,9 +196,7 @@ def test_ones_parse_from_mem_tokens(self): def test_ntt_aux_table_parse_from_mem_tokens(self): """@brief Test parsing NTTAuxTable metadata from tokens.""" tokens = ["dload", "LOAD_NTT_AUX_TABLE", "42", "ntt_aux_var"] - with patch( - "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" - ) as mock_parse: + with patch("assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens") as mock_parse: mock_parse.return_value = MagicMock() result = MemInfo.Metadata.NTTAuxTable.parse_from_mem_tokens(tokens) mock_parse.assert_called_once_with( @@ -247,9 +209,7 @@ def test_ntt_aux_table_parse_from_mem_tokens(self): def test_ntt_routing_table_parse_from_mem_tokens(self): """@brief Test parsing NTTRoutingTable metadata from tokens.""" tokens = ["dload", "LOAD_NTT_ROUTING_TABLE", "42", "ntt_route_var"] - with patch( - "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" - ) as mock_parse: + with patch("assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens") as mock_parse: mock_parse.return_value = MagicMock() result = MemInfo.Metadata.NTTRoutingTable.parse_from_mem_tokens(tokens) mock_parse.assert_called_once_with( @@ -262,9 +222,7 @@ def test_ntt_routing_table_parse_from_mem_tokens(self): def test_intt_aux_table_parse_from_mem_tokens(self): """@brief Test parsing iNTTAuxTable metadata from tokens.""" tokens = ["dload", "LOAD_iNTT_AUX_TABLE", "42", "intt_aux_var"] - with patch( - "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" - ) as mock_parse: + with patch("assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens") as mock_parse: mock_parse.return_value = MagicMock() result = MemInfo.Metadata.iNTTAuxTable.parse_from_mem_tokens(tokens) mock_parse.assert_called_once_with( @@ -277,9 +235,7 @@ def test_intt_aux_table_parse_from_mem_tokens(self): def test_intt_routing_table_parse_from_mem_tokens(self): """@brief Test parsing iNTTRoutingTable metadata from tokens.""" tokens = ["dload", "LOAD_iNTT_ROUTING_TABLE", "42", "intt_route_var"] - with patch( - "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" - ) as mock_parse: + with patch("assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens") as mock_parse: mock_parse.return_value = MagicMock() result = MemInfo.Metadata.iNTTRoutingTable.parse_from_mem_tokens(tokens) mock_parse.assert_called_once_with( @@ -292,9 +248,7 @@ def test_intt_routing_table_parse_from_mem_tokens(self): def test_twiddle_parse_from_mem_tokens(self): """@brief Test parsing Twiddle metadata from tokens.""" tokens = ["dload", "LOAD_TWIDDLE", "42", "twiddle_var"] - with patch( - "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" - ) as mock_parse: + with patch("assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens") as mock_parse: mock_parse.return_value = MagicMock() result = MemInfo.Metadata.Twiddle.parse_from_mem_tokens(tokens) mock_parse.assert_called_once_with( @@ -307,9 +261,7 @@ def test_twiddle_parse_from_mem_tokens(self): def test_keygen_seed_parse_from_mem_tokens(self): """@brief Test parsing KeygenSeed metadata from tokens.""" tokens = ["dload", "LOAD_KEYGEN_SEED", "42", "keygen_seed_var"] - with patch( - "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" - ) as mock_parse: + with patch("assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens") as mock_parse: mock_parse.return_value = MagicMock() result = MemInfo.Metadata.KeygenSeed.parse_from_mem_tokens(tokens) mock_parse.assert_called_once_with( @@ -343,9 +295,7 @@ def test_keygen_parse_from_mem_tokens_invalid(self): def test_input_parse_from_mem_tokens_valid(self): """@brief Test parsing a valid input variable.""" tokens = ["dload", "poly", "42", "input_var"] - with patch( - "assembler.memory_model.variable.Variable.validateName", return_value=True - ): + with patch("assembler.memory_model.variable.Variable.validateName", return_value=True): result = MemInfo.Input.parse_from_mem_tokens(tokens) self.assertIsNotNone(result) self.assertEqual(result.var_name, "input_var") @@ -370,9 +320,7 @@ def test_input_parse_from_mem_tokens_invalid(self): def test_output_parse_from_mem_tokens_valid(self): """@brief Test parsing a valid output variable.""" tokens = ["dstore", "output_var", "42"] - with patch( - "assembler.memory_model.variable.Variable.validateName", return_value=True - ): + with patch("assembler.memory_model.variable.Variable.validateName", return_value=True): result = MemInfo.Output.parse_from_mem_tokens(tokens) self.assertIsNotNone(result) self.assertEqual(result.var_name, "output_var") @@ -481,9 +429,7 @@ def test_get_meminfo_var_from_tokens_valid(self): # Mock the parse_from_mem_tokens method to return a mock variable mock_variable = MagicMock() - with patch.object( - MemInfo.Keygen, "parse_from_mem_tokens", return_value=mock_variable - ): + with patch.object(MemInfo.Keygen, "parse_from_mem_tokens", return_value=mock_variable): var, var_type = MemInfo.get_meminfo_var_from_tokens(tokens) # Verify results @@ -498,9 +444,7 @@ def test_get_meminfo_var_from_tokens_not_found(self): with patch.object( MemInfo, "mem_info_types", - return_value=[ - MagicMock(parse_from_mem_tokens=MagicMock(return_value=None)) - ], + return_value=[MagicMock(parse_from_mem_tokens=MagicMock(return_value=None))], ): var, var_type = MemInfo.get_meminfo_var_from_tokens(tokens) @@ -519,14 +463,14 @@ def test_add_meminfo_var_from_tokens_valid(self): mock_list = MagicMock() mock_dict = {mock_type: mock_list} - with patch.object( - MemInfo, - "get_meminfo_var_from_tokens", - return_value=(mock_variable, mock_type), - ), patch.object( - MemInfo, "factory_dict", new_callable=PropertyMock, return_value=mock_dict + with ( + patch.object( + MemInfo, + "get_meminfo_var_from_tokens", + return_value=(mock_variable, mock_type), + ), + patch.object(MemInfo, "factory_dict", new_callable=PropertyMock, return_value=mock_dict), ): - # Call the method mem_info.add_meminfo_var_from_tokens(tokens) @@ -539,9 +483,7 @@ def test_add_meminfo_var_from_tokens_not_found(self): mem_info = MemInfo() # Mock get_meminfo_var_from_tokens to return None - with patch.object( - MemInfo, "get_meminfo_var_from_tokens", return_value=(None, None) - ): + with patch.object(MemInfo, "get_meminfo_var_from_tokens", return_value=(None, None)): # Verify exception is raised with self.assertRaises(RuntimeError): mem_info.add_meminfo_var_from_tokens(tokens) @@ -572,15 +514,14 @@ def mock_tokenize(line): return ([], "") # Mock methods - patch the function where it's imported in mem_info - with patch( - "assembler.memory_model.mem_info.tokenize_from_line", - side_effect=mock_tokenize, - ), patch.object( - MemInfo, "add_meminfo_var_from_tokens" - ) as mock_add_var, patch.object( - MemInfo, "validate" + with ( + patch( + "assembler.memory_model.mem_info.tokenize_from_line", + side_effect=mock_tokenize, + ), + patch.object(MemInfo, "add_meminfo_var_from_tokens") as mock_add_var, + patch.object(MemInfo, "validate"), ): - # Call the method MemInfo.from_file_iter(lines) @@ -597,16 +538,15 @@ def mock_tokenize(line): return (["invalid"], line) # Mock methods - with patch( - "assembler.instructions.tokenize_from_line", side_effect=mock_tokenize - ), patch.object( - MemInfo, - "add_meminfo_var_from_tokens", - side_effect=RuntimeError("Test error"), - ), patch.object( - MemInfo, "validate" + with ( + patch("assembler.instructions.tokenize_from_line", side_effect=mock_tokenize), + patch.object( + MemInfo, + "add_meminfo_var_from_tokens", + side_effect=RuntimeError("Test error"), + ), + patch.object(MemInfo, "validate"), ): - # Verify exception is raised with line number and content with self.assertRaises(RuntimeError) as context: MemInfo.from_file_iter(lines) @@ -623,12 +563,11 @@ def test_from_dinstrs_valid(self): ] # Mock methods - with patch.object( - MemInfo, "add_meminfo_var_from_tokens" - ) as mock_add_var, patch.object(MemInfo, "validate"), patch( - "builtins.print" + with ( + patch.object(MemInfo, "add_meminfo_var_from_tokens") as mock_add_var, + patch.object(MemInfo, "validate"), + patch("builtins.print"), ): # Mock print to avoid output - # Call the method MemInfo.from_dinstrs(dinstrs) @@ -648,14 +587,15 @@ def test_from_dinstrs_error(self): dinstrs = [MagicMock(tokens=["invalid"])] # Mock methods - with patch.object( - MemInfo, - "add_meminfo_var_from_tokens", - side_effect=RuntimeError("Test error"), - ), patch.object(MemInfo, "validate"), patch( - "builtins.print" + with ( + patch.object( + MemInfo, + "add_meminfo_var_from_tokens", + side_effect=RuntimeError("Test error"), + ), + patch.object(MemInfo, "validate"), + patch("builtins.print"), ): # Mock print to avoid output - # Verify exception is raised with instruction number with self.assertRaises(RuntimeError) as context: MemInfo.from_dinstrs(dinstrs) @@ -666,7 +606,6 @@ def test_as_dict(self): """@brief Test the as_dict method.""" # Create a MemInfo with test data with patch("assembler.memory_model.mem_info.MemInfo.validate"): - # dicts keygens_dict = {"var_name": "keygen_var", "seed_index": 1, "key_index": 2} inputs_dict = {"var_name": "input_var", "hbm_address": 42} @@ -710,9 +649,7 @@ def test_validate_valid(self): ones_dict = {"var_name": "ones_var", "hbm_address": 44} twiddle_dict = {"var_name": "twiddle_var", "hbm_address": 45} - twiddle_list = [ - twiddle_dict for _ in range(MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT) - ] + twiddle_list = [twiddle_dict for _ in range(MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT)] # Create metadata dictionary for initialization metadata_dict = { @@ -747,9 +684,7 @@ def test_validate_twiddle_mismatch(self): # Initialize without validation to set up the test MemInfo(metadata=metadata_dict) - self.assertIn( - "Expected 2 times as many twiddles as ones", str(context.exception) - ) + self.assertIn("Expected 2 times as many twiddles as ones", str(context.exception)) def test_validate_duplicate_var_name(self): """@brief Test validation with duplicate variable names but different HBM addresses.""" @@ -771,9 +706,7 @@ def test_validate_duplicate_var_name(self): with self.assertRaises(RuntimeError) as context: MemInfo(metadata=metadata_dict) - self.assertIn( - 'Variable "duplicate" already allocated', str(context.exception) - ) + self.assertIn('Variable "duplicate" already allocated', str(context.exception)) class TestUpdateMemoryModelWithMemInfo(unittest.TestCase): @@ -820,9 +753,7 @@ def setUp(self): def test_update_memory_model_inputs(self): """@brief Test updating memory model with input variables.""" # Call the function - with patch( - "assembler.memory_model.mem_info._allocateMemInfoVariable" - ) as mock_allocate: + with patch("assembler.memory_model.mem_info._allocateMemInfoVariable") as mock_allocate: updateMemoryModelWithMemInfo(self.mock_mem_model, self.mock_mem_info) # Verify input variables were allocated @@ -831,23 +762,17 @@ def test_update_memory_model_inputs(self): def test_update_memory_model_outputs(self): """@brief Test updating memory model with output variables.""" # Call the function - with patch( - "assembler.memory_model.mem_info._allocateMemInfoVariable" - ) as mock_allocate: + with patch("assembler.memory_model.mem_info._allocateMemInfoVariable") as mock_allocate: updateMemoryModelWithMemInfo(self.mock_mem_model, self.mock_mem_info) # Verify output variables were allocated and added to output_variables mock_allocate.assert_any_call(self.mock_mem_model, self.vars["output"]) - self.mock_mem_model.output_variables.push.assert_called_once_with( - "output_var", None - ) + self.mock_mem_model.output_variables.push.assert_called_once_with("output_var", None) def test_update_memory_model_metadata(self): """@brief Test updating memory model with metadata variables.""" # Call the function - with patch( - "assembler.memory_model.mem_info._allocateMemInfoVariable" - ) as mock_allocate: + with patch("assembler.memory_model.mem_info._allocateMemInfoVariable") as mock_allocate: updateMemoryModelWithMemInfo(self.mock_mem_model, self.mock_mem_info) # Verify metadata variables were retrieved, allocated and added to their respective lists @@ -882,12 +807,8 @@ def test_update_memory_model_metadata(self): self.assertEqual(self.mock_mem_model.meta_ntt_routing_table, "ntt_route") self.assertEqual(self.mock_mem_model.meta_intt_aux_table, "intt_aux") self.assertEqual(self.mock_mem_model.meta_intt_routing_table, "intt_route") - self.mock_mem_model.add_meta_twiddle_var.assert_called_once_with( - "twiddle_var" - ) - self.mock_mem_model.add_meta_keygen_seed_var.assert_called_once_with( - "keygen_seed" - ) + self.mock_mem_model.add_meta_twiddle_var.assert_called_once_with("twiddle_var") + self.mock_mem_model.add_meta_keygen_seed_var.assert_called_once_with("keygen_seed") class TestAllocateMemInfoVariable(unittest.TestCase): @@ -903,9 +824,7 @@ def test_allocate_mem_info_variable_success(self): mock_mem_model.variables = {"test_var": MagicMock(hbm_address=-1)} # Call the function - with patch( - "assembler.memory_model.mem_info._allocateMemInfoVariable" - ) as mock_function: + with patch("assembler.memory_model.mem_info._allocateMemInfoVariable") as mock_function: # Make it actually call the real function - simplified without lambda mock_function.original = _allocateMemInfoVariable mock_function.side_effect = mock_function.original @@ -913,9 +832,7 @@ def test_allocate_mem_info_variable_success(self): mock_function(mock_mem_model, mock_var_info) # Verify the variable was allocated - mock_mem_model.hbm.allocateForce.assert_called_once_with( - 42, mock_mem_model.variables["test_var"] - ) + mock_mem_model.hbm.allocateForce.assert_called_once_with(42, mock_mem_model.variables["test_var"]) def test_allocate_mem_info_variable_not_in_model(self): """@brief Test allocation when the variable is not in the memory model.""" @@ -927,9 +844,7 @@ def test_allocate_mem_info_variable_not_in_model(self): mock_mem_model.variables = {} # Call the function - with patch( - "assembler.memory_model.mem_info._allocateMemInfoVariable" - ) as mock_function: + with patch("assembler.memory_model.mem_info._allocateMemInfoVariable") as mock_function: # Make it actually call the real function - simplified without lambda mock_function.original = _allocateMemInfoVariable mock_function.side_effect = mock_function.original @@ -938,9 +853,7 @@ def test_allocate_mem_info_variable_not_in_model(self): with self.assertRaises(RuntimeError) as context: mock_function(mock_mem_model, mock_var_info) - self.assertIn( - "Variable missing_var not in memory model", str(context.exception) - ) + self.assertIn("Variable missing_var not in memory model", str(context.exception)) def test_allocate_mem_info_variable_mismatch(self): """@brief Test allocation when the variable has a different HBM address.""" @@ -952,9 +865,7 @@ def test_allocate_mem_info_variable_mismatch(self): mock_mem_model.variables = {"test_var": MagicMock(hbm_address=24)} # Call the function - with patch( - "assembler.memory_model.mem_info._allocateMemInfoVariable" - ) as mock_function: + with patch("assembler.memory_model.mem_info._allocateMemInfoVariable") as mock_function: # Make it actually call the real function - simplified without lambda mock_function.original = _allocateMemInfoVariable mock_function.side_effect = mock_function.original diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py index 07b7f73a..169480d1 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py @@ -8,12 +8,13 @@ @file test_he_link.py @brief Unit tests for the he_link module """ -import io + import argparse -from unittest.mock import patch, mock_open, MagicMock -import pytest +import io +from unittest.mock import MagicMock, mock_open, patch import he_link +import pytest from linker.kern_trace import KernelInfo @@ -68,18 +69,12 @@ def test_main(self, using_trace_file): "link_kernels": MagicMock(), "from_dinstrs": MagicMock(), "from_file_iter": MagicMock(), - "load_dinst": MagicMock( - return_value=[mock_dinstr1, mock_dinstr2] - ), # Return mock DInstructions + "load_dinst": MagicMock(return_value=[mock_dinstr1, mock_dinstr2]), # Return mock DInstructions "join_dinst": MagicMock(return_value=[]), "dump_instructions": MagicMock(), "remap_dinstrs_vars": MagicMock(return_value={"old_var": "new_var"}), - "update_input_prefixes": MagicMock( - return_value={"kernel1_pisa.tw": MagicMock()} - ), - "remap_vars": MagicMock( - return_value=([mock_dinstr1, mock_dinstr2], {"key": "value"}) - ), + "update_input_prefixes": MagicMock(return_value={"kernel1_pisa.tw": MagicMock()}), + "remap_vars": MagicMock(return_value=([mock_dinstr1, mock_dinstr2], {"key": "value"})), "initialize_memory_model": MagicMock(), # Return a kernel_op with expected_in_kern_file_name that will match our input file prefix "parse_kernel_ops": MagicMock( @@ -99,50 +94,45 @@ def test_main(self, using_trace_file): mock_config.trace_file = "mock_trace.txt" if using_trace_file else "" # Act - with patch( - "assembler.common.constants.convertBytes2Words", return_value=1024 - ), patch("he_link.prepare_output_files", mocks["prepare_output"]), patch( - "he_link.prepare_input_files", mocks["prepare_input"] - ), patch( - "assembler.common.counter.Counter.reset" - ), patch( - "he_link.Loader.load_dinst_kernel_from_file", mocks["load_dinst"] - ), patch( - "linker.instructions.BaseInstruction.dump_instructions_to_file", - mocks["dump_instructions"], - ), patch( - "linker.steps.program_linker.LinkedProgram.join_dinst_kernels", - mocks["join_dinst"], - ), patch( - "assembler.memory_model.mem_info.MemInfo.from_dinstrs", - mocks["from_dinstrs"], - ), patch( - "assembler.memory_model.mem_info.MemInfo.from_file_iter", - mocks["from_file_iter"], - ), patch( - "linker.MemoryModel" - ), patch( - "he_link.scan_variables", mocks["scan_variables"] - ), patch( - "he_link.check_unused_variables", mocks["check_unused_variables"] - ), patch( - "linker.kern_trace.TraceInfo.parse_kernel_ops", mocks["parse_kernel_ops"] - ), patch( - "os.path.isfile", - return_value=True, # Make all file existence checks return True - ), patch( - "linker.steps.program_linker.LinkedProgram.link_kernels_to_files", - mocks["link_kernels"], - ), patch( - "linker.kern_trace.remap_dinstrs_vars", mocks["remap_dinstrs_vars"] - ), patch( - "he_link.update_input_prefixes", mocks["update_input_prefixes"] - ), patch( - "he_link.remap_vars", mocks["remap_vars"] - ), patch( - "he_link.initialize_memory_model", mocks["initialize_memory_model"] + with ( + patch("assembler.common.constants.convertBytes2Words", return_value=1024), + patch("he_link.prepare_output_files", mocks["prepare_output"]), + patch("he_link.prepare_input_files", mocks["prepare_input"]), + patch("assembler.common.counter.Counter.reset"), + patch("he_link.Loader.load_dinst_kernel_from_file", mocks["load_dinst"]), + patch( + "linker.instructions.BaseInstruction.dump_instructions_to_file", + mocks["dump_instructions"], + ), + patch( + "linker.steps.program_linker.LinkedProgram.join_dinst_kernels", + mocks["join_dinst"], + ), + patch( + "assembler.memory_model.mem_info.MemInfo.from_dinstrs", + mocks["from_dinstrs"], + ), + patch( + "assembler.memory_model.mem_info.MemInfo.from_file_iter", + mocks["from_file_iter"], + ), + patch("linker.MemoryModel"), + patch("he_link.scan_variables", mocks["scan_variables"]), + patch("he_link.check_unused_variables", mocks["check_unused_variables"]), + patch("linker.kern_trace.TraceInfo.parse_kernel_ops", mocks["parse_kernel_ops"]), + patch( + "os.path.isfile", + return_value=True, # Make all file existence checks return True + ), + patch( + "linker.steps.program_linker.LinkedProgram.link_kernels_to_files", + mocks["link_kernels"], + ), + patch("linker.kern_trace.remap_dinstrs_vars", mocks["remap_dinstrs_vars"]), + patch("he_link.update_input_prefixes", mocks["update_input_prefixes"]), + patch("he_link.remap_vars", mocks["remap_vars"]), + patch("he_link.initialize_memory_model", mocks["initialize_memory_model"]), ): - # Run the main function with all patches in place he_link.main(mock_config, MagicMock()) @@ -179,24 +169,18 @@ def test_warning_on_use_xinstfetch(self): mock_config.input_mem_file = "input.mem" # Act & Assert - with patch("warnings.warn") as mock_warn, patch( - "assembler.common.constants.convertBytes2Words", return_value=1024 - ), patch("linker.he_link_utils.prepare_output_files"), patch( - "linker.he_link_utils.prepare_input_files" - ), patch( - "assembler.common.counter.Counter.reset" - ), patch( - "builtins.open", mock_open() - ), patch( - "assembler.memory_model.mem_info.MemInfo.from_file_iter" - ), patch( - "linker.MemoryModel" - ), patch( - "linker.steps.variable_discovery.scan_variables" - ), patch( - "linker.steps.variable_discovery.check_unused_variables" - ), patch( - "linker.steps.program_linker.LinkedProgram.link_kernels_to_files" + with ( + patch("warnings.warn") as mock_warn, + patch("assembler.common.constants.convertBytes2Words", return_value=1024), + patch("linker.he_link_utils.prepare_output_files"), + patch("linker.he_link_utils.prepare_input_files"), + patch("assembler.common.counter.Counter.reset"), + patch("builtins.open", mock_open()), + patch("assembler.memory_model.mem_info.MemInfo.from_file_iter"), + patch("linker.MemoryModel"), + patch("linker.steps.variable_discovery.scan_variables"), + patch("linker.steps.variable_discovery.check_unused_variables"), + patch("linker.steps.program_linker.LinkedProgram.link_kernels_to_files"), ): he_link.main(mock_config, None) mock_warn.assert_called_once() @@ -281,9 +265,7 @@ def test_trace_file_with_missing_output_prefix(self): error_output = io.StringIO() # Patch sys.argv and sys.stderr - with patch("sys.argv", mock_argv), patch("sys.stderr", error_output), patch( - "sys.exit" - ) as mock_exit: + with patch("sys.argv", mock_argv), patch("sys.stderr", error_output), patch("sys.exit") as mock_exit: # When required args are missing, argparse will call sys.exit() he_link.parse_args() @@ -300,24 +282,27 @@ def test_required_args_when_trace_file_not_set(self): @brief Test that input_mem_file and input_prefixes are required when trace_file is not set """ # Case 1: Missing input_mem_file - with patch( - "argparse.ArgumentParser.parse_args", - return_value=argparse.Namespace( - input_prefixes=["input_prefix"], - output_prefix="output_prefix", - input_mem_file="", # Empty input_mem_file - trace_file="", # No trace file - input_dir="", - output_dir="", - using_trace_file=False, - mem_spec_file="", - isa_spec_file="", - has_hbm=True, - hbm_size=None, - suppress_comments=False, - verbose=0, + with ( + patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="", # Empty input_mem_file + trace_file="", # No trace file + input_dir="", + output_dir="", + using_trace_file=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), ), - ), patch("argparse.ArgumentParser.error") as mock_error: + patch("argparse.ArgumentParser.error") as mock_error, + ): he_link.parse_args() # Verify error was called for missing input_mem_file mock_error.assert_called_once_with( @@ -325,24 +310,27 @@ def test_required_args_when_trace_file_not_set(self): ) # Case 2: Missing input_prefixes - with patch( - "argparse.ArgumentParser.parse_args", - return_value=argparse.Namespace( - input_prefixes=None, # Missing input_prefixes - output_prefix="output_prefix", - input_mem_file="input.mem", - trace_file="", # No trace file - input_dir="", - output_dir="", - using_trace_file=False, - mem_spec_file="", - isa_spec_file="", - has_hbm=True, - hbm_size=None, - suppress_comments=False, - verbose=0, + with ( + patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=None, # Missing input_prefixes + output_prefix="output_prefix", + input_mem_file="input.mem", + trace_file="", # No trace file + input_dir="", + output_dir="", + using_trace_file=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), ), - ), patch("argparse.ArgumentParser.error") as mock_error: + patch("argparse.ArgumentParser.error") as mock_error, + ): he_link.parse_args() # Verify error was called for missing input_prefixes mock_error.assert_called_once_with( @@ -354,24 +342,27 @@ def test_ignored_args_when_trace_file_set(self): @brief Test that input_mem_file and input_prefixes are ignored with warnings when trace_file is set """ # Both input_mem_file and input_prefixes are provided but should be ignored - with patch( - "argparse.ArgumentParser.parse_args", - return_value=argparse.Namespace( - input_prefixes=["input_prefix"], # Will be ignored - output_prefix="output_prefix", - input_mem_file="input.mem", # Will be ignored - trace_file="trace_file_path", # Trace file is provided - input_dir="", - output_dir="", - using_trace_file=None, # Will be computed - mem_spec_file="", - isa_spec_file="", - has_hbm=True, - hbm_size=None, - suppress_comments=False, - verbose=0, + with ( + patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], # Will be ignored + output_prefix="output_prefix", + input_mem_file="input.mem", # Will be ignored + trace_file="trace_file_path", # Trace file is provided + input_dir="", + output_dir="", + using_trace_file=None, # Will be computed + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), ), - ), patch("warnings.warn") as mock_warn: + patch("warnings.warn") as mock_warn, + ): args = he_link.parse_args() # Verify using_trace_file is set based on trace_file @@ -552,24 +543,27 @@ def test_input_dir_defaults_to_trace_file_directory(self): @brief Test that input_dir defaults to the directory of trace_file when not specified """ # Test with trace_file set but input_dir not set - with patch( - "argparse.ArgumentParser.parse_args", - return_value=argparse.Namespace( - input_prefixes=None, - output_prefix="output_prefix", - input_mem_file="", - input_dir="", # Not specified - trace_file="/path/to/trace_file.txt", # Trace file with a directory path - output_dir="", - using_trace_file=None, # Will be computed - mem_spec_file="", - isa_spec_file="", - has_hbm=True, - hbm_size=None, - suppress_comments=False, - verbose=0, + with ( + patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=None, + output_prefix="output_prefix", + input_mem_file="", + input_dir="", # Not specified + trace_file="/path/to/trace_file.txt", # Trace file with a directory path + output_dir="", + using_trace_file=None, # Will be computed + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), ), - ), patch("os.path.dirname", return_value="/path/to") as mock_dirname: + patch("os.path.dirname", return_value="/path/to") as mock_dirname, + ): args = he_link.parse_args() # Verify input_dir is set to the directory of trace_file @@ -577,24 +571,27 @@ def test_input_dir_defaults_to_trace_file_directory(self): assert args.input_dir == "/path/to" # Test with both trace_file and input_dir specified - input_dir should not be overwritten - with patch( - "argparse.ArgumentParser.parse_args", - return_value=argparse.Namespace( - input_prefixes=None, - output_prefix="output_prefix", - input_mem_file="", - input_dir="/custom/path", # Specified by user - trace_file="/path/to/trace_file.txt", - output_dir="", - using_trace_file=None, - mem_spec_file="", - isa_spec_file="", - has_hbm=True, - hbm_size=None, - suppress_comments=False, - verbose=0, + with ( + patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=None, + output_prefix="output_prefix", + input_mem_file="", + input_dir="/custom/path", # Specified by user + trace_file="/path/to/trace_file.txt", + output_dir="", + using_trace_file=None, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), ), - ), patch("os.path.dirname") as mock_dirname: + patch("os.path.dirname") as mock_dirname, + ): args = he_link.parse_args() # Verify dirname was not called since input_dir was already specified diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py index 48edc8d3..6dc56fe0 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py @@ -8,12 +8,13 @@ @brief Unit tests for he_prep module. """ -from unittest import mock import os -import sys import pathlib -import pytest +import sys +from unittest import mock + import he_prep +import pytest def test_main_assigns_and_saves(monkeypatch, tmp_path): @@ -37,9 +38,7 @@ def test_main_assigns_and_saves(monkeypatch, tmp_path): "preprocess_pisa_kernel_listing", mock.Mock(return_value=dummy_insts), ) - monkeypatch.setattr( - he_prep.preprocessor, "assign_register_banks_to_vars", mock.Mock() - ) + monkeypatch.setattr(he_prep.preprocessor, "assign_register_banks_to_vars", mock.Mock()) he_prep.main(str(output_file), str(input_file), b_verbose=False) # Output file should contain the instruction @@ -51,9 +50,7 @@ def test_main_no_input_file(): @brief Test that main raises an error when no input file is provided. """ with pytest.raises(FileNotFoundError): - he_prep.main( - "", "", b_verbose=False - ) # Should raise an error due to missing input file + he_prep.main("", "", b_verbose=False) # Should raise an error due to missing input file def test_main_no_output_file(): @@ -61,9 +58,7 @@ def test_main_no_output_file(): @brief Test that main raises an error when no output file is provided. """ with pytest.raises(FileNotFoundError): - he_prep.main( - "", "input.csv", b_verbose=False - ) # Should raise an error due to missing output file + he_prep.main("", "input.csv", b_verbose=False) # Should raise an error due to missing output file def test_main_no_instructions(monkeypatch): @@ -85,18 +80,13 @@ def test_main_no_instructions(monkeypatch): "preprocess_pisa_kernel_listing", mock.Mock(return_value=[]), ) - monkeypatch.setattr( - he_prep.preprocessor, "assign_register_banks_to_vars", mock.Mock() - ) + monkeypatch.setattr(he_prep.preprocessor, "assign_register_banks_to_vars", mock.Mock()) he_prep.main(output_file, input_file, b_verbose=False) # Output file should be empty output_file_path = pathlib.Path(output_file) - assert ( - not output_file_path.exists() - or output_file_path.read_text(encoding="utf-8").strip() == "" - ) + assert not output_file_path.exists() or output_file_path.read_text(encoding="utf-8").strip() == "" def test_main_invalid_input_file(tmp_path): @@ -107,9 +97,7 @@ def test_main_invalid_input_file(tmp_path): output_file = tmp_path / "output.csv" with pytest.raises(FileNotFoundError): - he_prep.main( - str(output_file), str(input_file), b_verbose=False - ) # Should raise an error due to missing input file + he_prep.main(str(output_file), str(input_file), b_verbose=False) # Should raise an error due to missing input file def test_main_invalid_output_file(tmp_path): @@ -127,9 +115,7 @@ def test_main_invalid_output_file(tmp_path): os.chmod(output_file, 0o444) # Read-only permissions with pytest.raises(PermissionError): - he_prep.main( - str(output_file), str(input_file), b_verbose=False - ) # Should raise an error due to permission issues + he_prep.main(str(output_file), str(input_file), b_verbose=False) # Should raise an error due to permission issues def test_parse_args(): diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_he_link_utils.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_he_link_utils.py index c46b889e..4d0c5b45 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_he_link_utils.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_he_link_utils.py @@ -8,18 +8,19 @@ @file test_he_link_utils.py @brief Unit tests for the he_link_utils module """ -from unittest.mock import patch, mock_open, MagicMock -import pytest +from unittest.mock import MagicMock, mock_open, patch + +import pytest +from assembler.common import constants from linker.he_link_utils import ( - prepare_output_files, + initialize_memory_model, prepare_input_files, - update_input_prefixes, + prepare_output_files, remap_vars, - initialize_memory_model, + update_input_prefixes, ) from linker.kern_trace.trace_info import KernelInfo -from assembler.common import constants class TestHelperFunctions: @@ -39,9 +40,11 @@ def test_prepare_output_files(self): mock_config.using_trace_file = False # Act - with patch("os.path.dirname", return_value="/tmp"), patch( - "pathlib.Path.mkdir" - ), patch("assembler.common.makeUniquePath", side_effect=lambda x: x): + with ( + patch("os.path.dirname", return_value="/tmp"), + patch("pathlib.Path.mkdir"), + patch("assembler.common.makeUniquePath", side_effect=lambda x: x), + ): result = prepare_output_files(mock_config) # Assert @@ -63,9 +66,11 @@ def test_prepare_output_files_with_mem(self): mock_config.using_trace_file = True # Act - with patch("os.path.dirname", return_value="/tmp"), patch( - "pathlib.Path.mkdir" - ), patch("assembler.common.makeUniquePath", side_effect=lambda x: x): + with ( + patch("os.path.dirname", return_value="/tmp"), + patch("pathlib.Path.mkdir"), + patch("assembler.common.makeUniquePath", side_effect=lambda x: x), + ): result = prepare_output_files(mock_config) # Assert @@ -97,9 +102,7 @@ def test_prepare_input_files(self): ) # Act - with patch("os.path.isfile", return_value=True), patch( - "assembler.common.makeUniquePath", side_effect=lambda x: x - ): + with patch("os.path.isfile", return_value=True), patch("assembler.common.makeUniquePath", side_effect=lambda x: x): result = prepare_input_files(mock_config, mock_output_files) # Assert @@ -134,9 +137,7 @@ def test_prepare_input_files_file_not_found(self): ) # Act & Assert - with patch("os.path.isfile", return_value=False), patch( - "assembler.common.makeUniquePath", side_effect=lambda x: x - ): + with patch("os.path.isfile", return_value=False), patch("assembler.common.makeUniquePath", side_effect=lambda x: x): with pytest.raises(FileNotFoundError): prepare_input_files(mock_config, mock_output_files) @@ -163,9 +164,7 @@ def test_prepare_input_files_output_conflict(self): ) # Act & Assert - with patch("os.path.isfile", return_value=True), patch( - "assembler.common.makeUniquePath", side_effect=lambda x: x - ): + with patch("os.path.isfile", return_value=True), patch("assembler.common.makeUniquePath", side_effect=lambda x: x): with pytest.raises(RuntimeError): prepare_input_files(mock_config, output_files) @@ -277,7 +276,6 @@ def test_remap_vars_with_multiple_kernels(self): # Act with patch("linker.he_link_utils.remap_dinstrs_vars") as mock_remap_vars: - # Configure mocks mock_remap_vars.side_effect = [ test_data["expected_dicts"], @@ -297,14 +295,10 @@ def test_remap_vars_with_multiple_kernels(self): assert mock_remap_vars.call_count == 2 # First call - mock_remap_vars.assert_any_call( - test_data["kernel_dinstrs"][0], test_data["kernel_ops"][0] - ) + mock_remap_vars.assert_any_call(test_data["kernel_dinstrs"][0], test_data["kernel_ops"][0]) # Second call - mock_remap_vars.assert_any_call( - test_data["kernel_dinstrs"][1], test_data["kernel_ops"][1] - ) + mock_remap_vars.assert_any_call(test_data["kernel_dinstrs"][1], test_data["kernel_ops"][1]) # Verify the remap_dict was set on each kernel file assert test_data["files"][0].remap_dict == test_data["expected_dicts"] @@ -384,15 +378,14 @@ def test_initialize_memory_model_with_kernel_dinstrs(self): mock_stream = MagicMock() # Act - with patch( - "assembler.common.constants.convertBytes2Words", return_value=1024 * 1024 - ) as mock_convert, patch( - "assembler.memory_model.mem_info.MemInfo.from_dinstrs", - return_value=mock_mem_info, - ) as mock_from_dinstrs, patch( - "linker.MemoryModel" - ) as mock_memory_model_class: - + with ( + patch("assembler.common.constants.convertBytes2Words", return_value=1024 * 1024) as mock_convert, + patch( + "assembler.memory_model.mem_info.MemInfo.from_dinstrs", + return_value=mock_mem_info, + ) as mock_from_dinstrs, + patch("linker.MemoryModel") as mock_memory_model_class, + ): # Configure mock memory model mock_memory_model = mock_memory_model_class.return_value mock_memory_model.hbm.capacity = 1024 * 1024 @@ -402,9 +395,7 @@ def test_initialize_memory_model_with_kernel_dinstrs(self): # Assert # Verify convertBytes2Words was called with correct parameters - mock_convert.assert_called_once_with( - mock_config.hbm_size * constants.Constants.KILOBYTE - ) + mock_convert.assert_called_once_with(mock_config.hbm_size * constants.Constants.KILOBYTE) # Verify from_dinstrs was called with kernel_dinstrs mock_from_dinstrs.assert_called_once_with(mock_dinstrs) @@ -432,15 +423,15 @@ def test_initialize_memory_model_with_input_mem_file(self): mock_stream = MagicMock() # Act - with patch( - "assembler.common.constants.convertBytes2Words", return_value=2048 * 1024 - ) as mock_convert, patch("builtins.open", mock_open()) as mock_open_file, patch( - "assembler.memory_model.mem_info.MemInfo.from_file_iter", - return_value=mock_mem_info, - ) as mock_from_file_iter, patch( - "linker.MemoryModel" - ) as mock_memory_model_class: - + with ( + patch("assembler.common.constants.convertBytes2Words", return_value=2048 * 1024) as mock_convert, + patch("builtins.open", mock_open()) as mock_open_file, + patch( + "assembler.memory_model.mem_info.MemInfo.from_file_iter", + return_value=mock_mem_info, + ) as mock_from_file_iter, + patch("linker.MemoryModel") as mock_memory_model_class, + ): # Configure mock memory model mock_memory_model = mock_memory_model_class.return_value mock_memory_model.hbm.capacity = 2048 * 1024 @@ -450,14 +441,10 @@ def test_initialize_memory_model_with_input_mem_file(self): # Assert # Verify convertBytes2Words was called with correct parameters - mock_convert.assert_called_once_with( - mock_config.hbm_size * constants.Constants.KILOBYTE - ) + mock_convert.assert_called_once_with(mock_config.hbm_size * constants.Constants.KILOBYTE) # Verify open was called with input_mem_file - mock_open_file.assert_called_once_with( - mock_config.input_mem_file, "r", encoding="utf-8" - ) + mock_open_file.assert_called_once_with(mock_config.input_mem_file, encoding="utf-8") # Verify from_file_iter was called assert mock_from_file_iter.called @@ -486,15 +473,14 @@ def test_initialize_memory_model_with_zero_hbm_size(self): mock_mem_info = MagicMock() # Act - with patch( - "assembler.common.constants.convertBytes2Words", return_value=0 - ) as mock_convert, patch( - "assembler.memory_model.mem_info.MemInfo.from_dinstrs", - return_value=mock_mem_info, - ), patch( - "linker.MemoryModel" - ) as mock_memory_model_class: - + with ( + patch("assembler.common.constants.convertBytes2Words", return_value=0) as mock_convert, + patch( + "assembler.memory_model.mem_info.MemInfo.from_dinstrs", + return_value=mock_mem_info, + ), + patch("linker.MemoryModel") as mock_memory_model_class, + ): # Call function under test result = initialize_memory_model(mock_config, mock_dinstrs) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py index 910ceb34..67a57aa3 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py @@ -13,7 +13,7 @@ from assembler.common.config import GlobalConfig from assembler.memory_model import mem_info -from linker import VariableInfo, HBM, MemoryModel +from linker import HBM, MemoryModel, VariableInfo class TestVariableInfo(unittest.TestCase): diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py index 933f8ee1..4f999b1e 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py @@ -9,7 +9,7 @@ """ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch from linker.instructions.dinst.dinstruction import DInstruction @@ -28,9 +28,7 @@ def setUp(self): self.mock_miv.as_dict.return_value = {"var_name": "var1", "hbm_address": 123} # Patch the MemInfo.get_meminfo_var_from_tokens method - self.mem_info_patcher = patch( - "linker.instructions.dinst.dinstruction.MemInfo.get_meminfo_var_from_tokens" - ) + self.mem_info_patcher = patch("linker.instructions.dinst.dinstruction.MemInfo.get_meminfo_var_from_tokens") self.mock_get_meminfo = self.mem_info_patcher.start() self.mock_get_meminfo.return_value = (self.mock_miv, 1) @@ -65,9 +63,7 @@ def test_get_name_token_index(self): @test Verifies the name token is at index 0 """ - self.assertEqual( - self.d_instruction_class.name_token_index, 0 - ) # Updated reference + self.assertEqual(self.d_instruction_class.name_token_index, 0) # Updated reference def test_num_tokens_property(self): """@brief Test num_tokens property returns expected value diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py index f9769657..5c9d6a66 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py @@ -28,9 +28,7 @@ def setUp(self): self.seed_idx = 1 self.key_idx = 2 self.var_name = "var1" - self.inst = Instruction( - [Instruction.name, self.seed_idx, self.key_idx, self.var_name] - ) + self.inst = Instruction([Instruction.name, self.seed_idx, self.key_idx, self.var_name]) def test_get_num_tokens(self): """@brief Test that _get_num_tokens returns 4 @@ -51,9 +49,7 @@ def test_initialization_valid_input(self): @test Verifies the instruction is properly initialized with valid tokens """ - inst = Instruction( - [MemInfo.Const.Keyword.KEYGEN, self.seed_idx, self.key_idx, self.var_name] - ) + inst = Instruction([MemInfo.Const.Keyword.KEYGEN, self.seed_idx, self.key_idx, self.var_name]) self.assertEqual(inst.name, MemInfo.Const.Keyword.KEYGEN) def test_initialization_invalid_name(self): @@ -84,7 +80,7 @@ def test_tokens_with_additional_data(self): @test Verifies extra tokens are preserved in the tokens property """ - additional_token = "extra" + additional_token = "extra" # noqa: S105 (allow hardcoded string) inst_with_extra = Instruction( [ Instruction.name, diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py index 601a4aae..9d746c14 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py @@ -48,9 +48,7 @@ def test_initialization_valid_input(self): @test Verifies the instruction is properly initialized with valid tokens """ - inst = Instruction( - [MemInfo.Const.Keyword.LOAD, self.type, str(self.address), self.var_name] - ) + inst = Instruction([MemInfo.Const.Keyword.LOAD, self.type, str(self.address), self.var_name]) self.assertEqual(inst.name, MemInfo.Const.Keyword.LOAD) @@ -83,9 +81,7 @@ def test_tokens_property(self): str(self.address), self.var_name, ] - inst = Instruction( - [Instruction.name, self.type, str(self.address), self.var_name] - ) + inst = Instruction([Instruction.name, self.type, str(self.address), self.var_name]) self.assertEqual(inst.tokens, expected_tokens) @@ -94,7 +90,7 @@ def test_tokens_with_additional_data(self): @test Verifies extra tokens are preserved in the tokens property """ - additional_token = "extra" + additional_token = "extra" # noqa: S105 (allow hardcoded string) inst_with_extra = Instruction( [ Instruction.name, diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py index 938a9db9..a4300e36 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py @@ -47,9 +47,7 @@ def test_initialization_valid_input(self): @test Verifies the instruction is properly initialized with valid tokens """ - inst = Instruction( - [MemInfo.Const.Keyword.STORE, self.var_name, str(self.address)] - ) + inst = Instruction([MemInfo.Const.Keyword.STORE, self.var_name, str(self.address)]) self.assertEqual(inst.name, MemInfo.Const.Keyword.STORE) @@ -84,7 +82,7 @@ def test_tokens_with_additional_data(self): @test Verifies extra tokens are preserved in the tokens property """ - additional_token = "extra" + additional_token = "extra" # noqa: S105 (allow hardcoded string) inst_with_extra = Instruction( [ Instruction.name, diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py index 26a9a799..83710ba6 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py @@ -9,10 +9,9 @@ """ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch -from linker.instructions.dinst import factory, create_from_mem_line -from linker.instructions.dinst import DLoad, DStore, DKeyGen +from linker.instructions.dinst import DKeyGen, DLoad, DStore, create_from_mem_line, factory class TestDInstModule(unittest.TestCase): diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py index a0e1f1f0..a28bfd43 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py @@ -12,7 +12,7 @@ """ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch from linker.instructions import create_from_str_line @@ -53,9 +53,7 @@ def test_create_from_str_line_success(self, mock_tokenize): mock_tokenize.return_value = (tokens, comment) # Call function - result = create_from_str_line( - "instruction, arg1, arg2 # Test comment", self.factory - ) + result = create_from_str_line("instruction, arg1, arg2 # Test comment", self.factory) # Verify mock_tokenize.assert_called_once_with("instruction, arg1, arg2 # Test comment") @@ -78,9 +76,7 @@ def test_create_from_str_line_failure(self, mock_tokenize): self.mock_class.side_effect = ValueError("Invalid instruction") # Call function - result = create_from_str_line( - "unknown, arg1, arg2 # Test comment", self.factory - ) + result = create_from_str_line("unknown, arg1, arg2 # Test comment", self.factory) # Verify self.assertIsNone(result) @@ -132,9 +128,7 @@ def test_create_from_str_line_exception_handling(self, mock_tokenize): self.mock_class.side_effect = ValueError("Invalid values") # Call function - should handle the exception and return None - result = create_from_str_line( - "instruction, arg1, arg2 # Test comment", self.factory - ) + result = create_from_str_line("instruction, arg1, arg2 # Test comment", self.factory) # Verify self.assertIsNone(result) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py index b5842745..1d28eee7 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py @@ -9,8 +9,8 @@ """ import os -import unittest import tempfile +import unittest from unittest.mock import patch from assembler.common.config import GlobalConfig @@ -157,7 +157,7 @@ def test_dump_instructions_to_file(self): try: BaseInstruction.dump_instructions_to_file(instructions, file_path) - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: lines = f.read().splitlines() self.assertEqual(len(lines), 2) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_remap.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_remap.py index 2bbc8f09..8f93fce9 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_remap.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_remap.py @@ -8,14 +8,15 @@ @file test_kern_remap.py @brief Unit tests for the kern_remap module """ + from unittest.mock import MagicMock -import pytest +import pytest +from linker.instructions.cinst import BLoad, BOnes, CLoad, CStore, NLoad +from linker.instructions.minst import MLoad, MStore from linker.kern_trace.kern_remap import remap_dinstrs_vars, remap_m_c_instrs_vars from linker.kern_trace.kern_var import KernVar from linker.kern_trace.kernel_op import KernelOp -from linker.instructions.minst import MLoad, MStore -from linker.instructions.cinst import BLoad, CLoad, BOnes, NLoad, CStore class TestRemapDinstrsVars: diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_var.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_var.py index b079ffb6..76aa256b 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_var.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_var.py @@ -8,8 +8,8 @@ @file test_kern_var.py @brief Unit tests for the KernVar class """ -import pytest +import pytest from linker.kern_trace.kern_var import KernVar @@ -96,9 +96,7 @@ def test_label_property_immutability(self): # Act & Assert with pytest.raises(AttributeError): - kern_var.label = ( - "new_label" # Should raise AttributeError for read-only property - ) + kern_var.label = "new_label" # Should raise AttributeError for read-only property def test_degree_property_immutability(self): """ diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kernel_op.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kernel_op.py index 9bebf623..771f0694 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kernel_op.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kernel_op.py @@ -8,12 +8,13 @@ @file test_kernel_op.py @brief Unit tests for the KernelOp class """ + from unittest.mock import patch -import pytest -from linker.kern_trace.kernel_op import KernelOp +import pytest from linker.kern_trace.context_config import ContextConfig from linker.kern_trace.kern_var import KernVar +from linker.kern_trace.kernel_op import KernelOp class TestKernelOp: @@ -73,9 +74,7 @@ def test_init_with_invalid_encryption_scheme(self): @brief Test initialization with invalid encryption scheme """ # Arrange - invalid_context = ContextConfig( - scheme="INVALID", poly_mod_degree=8192, keyrns_terms=2 - ) + invalid_context = ContextConfig(scheme="INVALID", poly_mod_degree=8192, keyrns_terms=2) kern_args = self._create_test_kern_args() # Act & Assert @@ -99,15 +98,11 @@ def test_get_kern_var_objs(self): @brief Test get_kern_var_objs method """ # Arrange - kernel_op = KernelOp( - "add", self._create_test_context_config(), self._create_test_kern_args() - ) + kernel_op = KernelOp("add", self._create_test_context_config(), self._create_test_kern_args()) test_var_strs = ["var1-1024-1", "var2-2048-2"] # Act - Using the private method for testing - with patch( - "linker.kern_trace.kern_var.KernVar.from_string" - ) as mock_from_string: + with patch("linker.kern_trace.kern_var.KernVar.from_string") as mock_from_string: mock_from_string.side_effect = [ KernVar("var1", 1024, 1), KernVar("var2", 2048, 2), @@ -126,9 +121,7 @@ def test_get_level(self): @brief Test get_level method """ # Arrange - kernel_op = KernelOp( - "add", self._create_test_context_config(), self._create_test_kern_args() - ) + kernel_op = KernelOp("add", self._create_test_context_config(), self._create_test_kern_args()) # Create test KernVar objects test_vars = [KernVar("var1", 1024, 1), KernVar("var2", 2048, 3)] @@ -144,9 +137,7 @@ def test_get_level_with_single_var(self): @brief Test get_level method with a single variable """ # Arrange - kernel_op = KernelOp( - "add", self._create_test_context_config(), self._create_test_kern_args() - ) + kernel_op = KernelOp("add", self._create_test_context_config(), self._create_test_kern_args()) # Create test KernVar objects test_vars = [KernVar("var1", 1024, 2)] @@ -162,9 +153,7 @@ def test_get_level_with_empty_vars(self): @brief Test get_level method with empty variables list """ # Arrange - kernel_op = KernelOp( - "add", self._create_test_context_config(), self._create_test_kern_args() - ) + kernel_op = KernelOp("add", self._create_test_context_config(), self._create_test_kern_args()) # Act & Assert with pytest.raises(ValueError, match="at least one variable"): @@ -175,9 +164,7 @@ def test_str_representation(self): @brief Test string representation of KernelOp """ # Arrange - kernel_op = KernelOp( - "add", self._create_test_context_config(), self._create_test_kern_args() - ) + kernel_op = KernelOp("add", self._create_test_context_config(), self._create_test_kern_args()) # Act result = str(kernel_op) @@ -191,9 +178,7 @@ def test_property_kern_vars(self): @brief Test kern_vars property """ # Arrange - kernel_op = KernelOp( - "add", self._create_test_context_config(), self._create_test_kern_args() - ) + kernel_op = KernelOp("add", self._create_test_context_config(), self._create_test_kern_args()) # Act result = kernel_op.kern_vars @@ -210,9 +195,7 @@ def test_property_name(self): @brief Test name property """ # Arrange - kernel_op = KernelOp( - "mul", self._create_test_context_config(), self._create_test_kern_args() - ) + kernel_op = KernelOp("mul", self._create_test_context_config(), self._create_test_kern_args()) # Act result = kernel_op.name @@ -267,9 +250,7 @@ def test_property_level(self): @brief Test level property """ # Arrange - kernel_op = KernelOp( - "add", self._create_test_context_config(), ["var1-8192-4", "var2-8192-4"] - ) + kernel_op = KernelOp("add", self._create_test_context_config(), ["var1-8192-4", "var2-8192-4"]) # Act result = kernel_op.level @@ -300,24 +281,18 @@ def test_case_insensitivity_of_operation_name(self): kern_args = self._create_test_kern_args() # Act - kernel_op = KernelOp( - "ADD", context_config, kern_args - ) # Uppercase operation name + kernel_op = KernelOp("ADD", context_config, kern_args) # Uppercase operation name # Assert assert kernel_op.name == "ADD" - assert ( - kernel_op.expected_in_kern_file_name == "ckks_add_8192_l2_m2" - ) # Note: lowercase in file name + assert kernel_op.expected_in_kern_file_name == "ckks_add_8192_l2_m2" # Note: lowercase in file name def test_case_insensitivity_of_scheme(self): """ @brief Test that scheme names are case-insensitive """ # Arrange - context = ContextConfig( - scheme="ckks", poly_mod_degree=8192, keyrns_terms=2 - ) # Lowercase scheme + context = ContextConfig(scheme="ckks", poly_mod_degree=8192, keyrns_terms=2) # Lowercase scheme kern_args = self._create_test_kern_args() # Act diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_trace_info.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_trace_info.py index de18a111..5eddc45a 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_trace_info.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_trace_info.py @@ -8,12 +8,13 @@ @file test_trace_info.py @brief Unit tests for the TraceInfo module and related classes """ -from unittest.mock import patch, mock_open -import pytest -from linker.kern_trace.trace_info import KernelInfo, TraceInfo +from unittest.mock import mock_open, patch + +import pytest from linker.kern_trace.context_config import ContextConfig from linker.kern_trace.kernel_op import KernelOp +from linker.kern_trace.trace_info import KernelInfo, TraceInfo class TestKernelInfo: @@ -145,9 +146,7 @@ def test_extract_context_and_args(self): } # Act - name, context_config, kern_args = trace_info.extract_context_and_args( - tokens, param_idxs, 1 - ) + name, context_config, kern_args = trace_info.extract_context_and_args(tokens, param_idxs, 1) # Assert assert name == "kernel1" @@ -210,10 +209,11 @@ def test_parse_kernel_ops_with_valid_trace(self): ) # Act - with patch("os.path.isfile", return_value=True), patch( - "builtins.open", mock_open(read_data=trace_content) - ), patch("linker.kern_trace.trace_info.tokenize_from_line") as mock_tokenize: - + with ( + patch("os.path.isfile", return_value=True), + patch("builtins.open", mock_open(read_data=trace_content)), + patch("linker.kern_trace.trace_info.tokenize_from_line") as mock_tokenize, + ): # Mock the tokenize_from_line function to return expected tokens mock_tokenize.side_effect = [ ( @@ -257,10 +257,7 @@ def test_parse_kernel_ops_with_empty_trace(self): trace_file = "/path/to/empty_trace.txt" # Act - with patch("os.path.isfile", return_value=True), patch( - "builtins.open", mock_open(read_data="") - ): - + with patch("os.path.isfile", return_value=True), patch("builtins.open", mock_open(read_data="")): trace_info = TraceInfo(trace_file) result = trace_info.parse_kernel_ops() @@ -296,10 +293,11 @@ def test_parse_kernel_ops_skip_empty_lines(self): ) # Act - with patch("os.path.isfile", return_value=True), patch( - "builtins.open", mock_open(read_data=trace_content) - ), patch("linker.kern_trace.trace_info.tokenize_from_line") as mock_tokenize: - + with ( + patch("os.path.isfile", return_value=True), + patch("builtins.open", mock_open(read_data=trace_content)), + patch("linker.kern_trace.trace_info.tokenize_from_line") as mock_tokenize, + ): # Mock the tokenize_from_line function to return expected tokens mock_tokenize.side_effect = [ ( diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_linker_run_config.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_linker_run_config.py index 26f126dd..f67f0d36 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_linker_run_config.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_linker_run_config.py @@ -8,13 +8,14 @@ @file test_linker_run_config.py @brief Unit tests for the LinkerRunConfig class """ + import os -from unittest.mock import patch, PropertyMock -import pytest +from unittest.mock import PropertyMock, patch -from linker.linker_run_config import LinkerRunConfig -from assembler.common.run_config import RunConfig +import pytest from assembler.common.config import GlobalConfig +from assembler.common.run_config import RunConfig +from linker.linker_run_config import LinkerRunConfig class TestLinkerRunConfig: @@ -144,16 +145,12 @@ def test_init_for_default_params(self): RunConfig.reset_class_state() # Act - with patch( - "assembler.common.makeUniquePath", side_effect=lambda x: x - ), patch.object( - RunConfig, "DEFAULT_HBM_SIZE_KB", new_callable=PropertyMock - ) as mock_hbm_size, patch.object( - GlobalConfig, "suppress_comments", new_callable=PropertyMock - ) as mock_suppress_comments, patch.object( - GlobalConfig, "useXInstFetch", new_callable=PropertyMock - ) as mock_use_xinstfetch: - + with ( + patch("assembler.common.makeUniquePath", side_effect=lambda x: x), + patch.object(RunConfig, "DEFAULT_HBM_SIZE_KB", new_callable=PropertyMock) as mock_hbm_size, + patch.object(GlobalConfig, "suppress_comments", new_callable=PropertyMock) as mock_suppress_comments, + patch.object(GlobalConfig, "useXInstFetch", new_callable=PropertyMock) as mock_use_xinstfetch, + ): # Mock the default HBM size mock_suppress_comments.return_value = False mock_use_xinstfetch.return_value = False diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py index 29f7d94b..3dff2ca6 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py @@ -9,7 +9,7 @@ """ import unittest -from unittest.mock import patch, mock_open, MagicMock, call +from unittest.mock import MagicMock, call, mock_open, patch from linker.loader import Loader @@ -72,9 +72,7 @@ def test_load_minst_kernel_failure(self, mock_factory, mock_create): with self.assertRaises(RuntimeError) as context: Loader.load_minst_kernel(self.minst_lines) - self.assertIn( - f"Error parsing line 1: {self.minst_lines[0]}", str(context.exception) - ) + self.assertIn(f"Error parsing line 1: {self.minst_lines[0]}", str(context.exception)) @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.Loader.load_minst_kernel") @@ -92,7 +90,7 @@ def test_load_minst_kernel_from_file_success(self, mock_load, mock_file): # Verify the results self.assertEqual(result, self.mock_minst) - mock_file.assert_called_once_with("test.minst", "r", encoding="utf-8") + mock_file.assert_called_once_with("test.minst", encoding="utf-8") mock_load.assert_called_once_with(self.minst_lines) @patch("builtins.open", new_callable=mock_open) @@ -110,9 +108,7 @@ def test_load_minst_kernel_from_file_failure(self, mock_load, mock_file): with self.assertRaises(RuntimeError) as context: Loader.load_minst_kernel_from_file("test.minst") - self.assertIn( - 'Error occurred loading file "test.minst"', str(context.exception) - ) + self.assertIn('Error occurred loading file "test.minst"', str(context.exception)) @patch("linker.instructions.create_from_str_line") @patch("linker.instructions.cinst.factory") @@ -155,9 +151,7 @@ def test_load_cinst_kernel_failure(self, mock_factory, mock_create): with self.assertRaises(RuntimeError) as context: Loader.load_cinst_kernel(self.cinst_lines) - self.assertIn( - f"Error parsing line 1: {self.cinst_lines[0]}", str(context.exception) - ) + self.assertIn(f"Error parsing line 1: {self.cinst_lines[0]}", str(context.exception)) @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.Loader.load_cinst_kernel") @@ -175,7 +169,7 @@ def test_load_cinst_kernel_from_file_success(self, mock_load, mock_file): # Verify the results self.assertEqual(result, self.mock_cinst) - mock_file.assert_called_once_with("test.cinst", "r", encoding="utf-8") + mock_file.assert_called_once_with("test.cinst", encoding="utf-8") mock_load.assert_called_once_with(self.cinst_lines) @patch("builtins.open", new_callable=mock_open) @@ -193,9 +187,7 @@ def test_load_cinst_kernel_from_file_failure(self, mock_load, mock_file): with self.assertRaises(RuntimeError) as context: Loader.load_cinst_kernel_from_file("test.cinst") - self.assertIn( - 'Error occurred loading file "test.cinst"', str(context.exception) - ) + self.assertIn('Error occurred loading file "test.cinst"', str(context.exception)) @patch("linker.instructions.create_from_str_line") @patch("linker.instructions.xinst.factory") @@ -238,9 +230,7 @@ def test_load_xinst_kernel_failure(self, mock_factory, mock_create): with self.assertRaises(RuntimeError) as context: Loader.load_xinst_kernel(self.xinst_lines) - self.assertIn( - f"Error parsing line 1: {self.xinst_lines[0]}", str(context.exception) - ) + self.assertIn(f"Error parsing line 1: {self.xinst_lines[0]}", str(context.exception)) @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.Loader.load_xinst_kernel") @@ -258,7 +248,7 @@ def test_load_xinst_kernel_from_file_success(self, mock_load, mock_file): # Verify the results self.assertEqual(result, self.mock_xinst) - mock_file.assert_called_once_with("test.xinst", "r", encoding="utf-8") + mock_file.assert_called_once_with("test.xinst", encoding="utf-8") mock_load.assert_called_once_with(self.xinst_lines) @patch("builtins.open", new_callable=mock_open) @@ -276,9 +266,7 @@ def test_load_xinst_kernel_from_file_failure(self, mock_load, mock_file): with self.assertRaises(RuntimeError) as context: Loader.load_xinst_kernel_from_file("test.xinst") - self.assertIn( - 'Error occurred loading file "test.xinst"', str(context.exception) - ) + self.assertIn('Error occurred loading file "test.xinst"', str(context.exception)) @patch("linker.instructions.dinst.create_from_mem_line") def test_load_dinst_kernel_success(self, mock_create): @@ -295,9 +283,7 @@ def test_load_dinst_kernel_success(self, mock_create): # Verify the results self.assertEqual(result, self.mock_dinst) self.assertEqual(mock_create.call_count, 2) - mock_create.assert_has_calls( - [call(self.dinst_lines[0]), call(self.dinst_lines[1])] - ) + mock_create.assert_has_calls([call(self.dinst_lines[0]), call(self.dinst_lines[1])]) @patch("linker.instructions.dinst.create_from_mem_line") def test_load_dinst_kernel_failure(self, mock_create): @@ -312,9 +298,7 @@ def test_load_dinst_kernel_failure(self, mock_create): with self.assertRaises(RuntimeError) as context: Loader.load_dinst_kernel(self.dinst_lines) - self.assertIn( - f"Error parsing line 1: {self.dinst_lines[0]}", str(context.exception) - ) + self.assertIn(f"Error parsing line 1: {self.dinst_lines[0]}", str(context.exception)) @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.Loader.load_dinst_kernel") @@ -332,7 +316,7 @@ def test_load_dinst_kernel_from_file_success(self, mock_load, mock_file): # Verify the results self.assertEqual(result, self.mock_dinst) - mock_file.assert_called_once_with("test.dinst", "r", encoding="utf-8") + mock_file.assert_called_once_with("test.dinst", encoding="utf-8") mock_load.assert_called_once_with(self.dinst_lines) @patch("builtins.open", new_callable=mock_open) @@ -350,9 +334,7 @@ def test_load_dinst_kernel_from_file_failure(self, mock_load, mock_file): with self.assertRaises(RuntimeError) as context: Loader.load_dinst_kernel_from_file("test.dinst") - self.assertIn( - 'Error occurred loading file "test.dinst"', str(context.exception) - ) + self.assertIn('Error occurred loading file "test.dinst"', str(context.exception)) if __name__ == "__main__": diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py index b306daf8..c73cd794 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py @@ -10,12 +10,12 @@ import io import unittest -from unittest.mock import patch, MagicMock, call, mock_open from collections import namedtuple +from unittest.mock import MagicMock, call, mock_open, patch from assembler.common.config import GlobalConfig from linker import MemoryModel -from linker.instructions import minst, cinst, dinst +from linker.instructions import cinst, dinst, minst from linker.steps.program_linker import LinkedProgram @@ -38,9 +38,7 @@ def setUp(self): self.mock_has_hbm = self.has_hbm_patcher.start() # Mock the suppress_comments property to return False by default - self.suppress_comments_patcher = patch.object( - GlobalConfig, "suppress_comments", False - ) + self.suppress_comments_patcher = patch.object(GlobalConfig, "suppress_comments", False) self.mock_suppress_comments = self.suppress_comments_patcher.start() self.program = LinkedProgram( @@ -149,16 +147,12 @@ def test_update_minsts(self): self.assertEqual(mock_msyncc.target, 15) # 5 + 10 self.assertEqual(mock_mload.source, "10") # Replaced with HBM address self.assertIn("input_var", mock_mload.comment) # Comment updated - self.assertIn( - "original comment", mock_mload.comment - ) # Original comment preserved + self.assertIn("original comment", mock_mload.comment) # Original comment preserved self.assertEqual(mock_mstore.dest, "20") # Replaced with HBM address # Verify the memory model was used correctly - self.mem_model.use_variable.assert_has_calls( - [call("input_var", 1), call("output_var", 1)] - ) + self.mem_model.use_variable.assert_has_calls([call("input_var", 1), call("output_var", 1)]) def test_remove_and_merge_csyncm_cnop(self): """@brief Test removing CSyncm instructions and merging CNop instructions. @@ -185,9 +179,7 @@ def test_remove_and_merge_csyncm_cnop(self): mock_cnop2.tokens = [0] # Set up ISACInst.CSyncm.get_throughput - with patch( - "assembler.instructions.cinst.CSyncm.get_throughput", return_value=2 - ): + with patch("assembler.instructions.cinst.CSyncm.get_throughput", return_value=2): # Execute the method kernel_cinstrs = [ mock_ifetch, @@ -263,9 +255,7 @@ def test_update_cinsts_addresses_and_offsets(self): self.assertEqual(mock_cstore.dest, "40") # Verify the memory model was used correctly - self.mem_model.use_variable.assert_has_calls( - [call("var1", 2), call("var2", 2)] - ) + self.mem_model.use_variable.assert_has_calls([call("var1", 2), call("var2", 2)]) # Test that XInstFetch raises NotImplementedError with self.assertRaises(NotImplementedError): @@ -277,12 +267,10 @@ def test_update_cinsts(self): @test Verifies that the correct update methods are called based on HBM configuration """ # Create a mock for _remove_and_merge_csyncm_cnop and _update_cinsts_addresses_and_offsets - with patch.object( - LinkedProgram, "_remove_and_merge_csyncm_cnop" - ) as mock_remove, patch.object( - LinkedProgram, "_update_cinsts_addresses_and_offsets" - ) as mock_update: - + with ( + patch.object(LinkedProgram, "_remove_and_merge_csyncm_cnop") as mock_remove, + patch.object(LinkedProgram, "_update_cinsts_addresses_and_offsets") as mock_update, + ): # Execute the method with HBM enabled kernel_cinstrs = ["cinst1", "cinst2"] self.program._update_cinsts(kernel_cinstrs) @@ -342,14 +330,11 @@ def test_link_kernel(self): @test Verifies that a kernel is correctly linked with updated instructions """ # Create mocks for the update methods - with patch.object( - LinkedProgram, "_update_minsts" - ) as mock_update_minsts, patch.object( - LinkedProgram, "_update_cinsts" - ) as mock_update_cinsts, patch.object( - LinkedProgram, "_update_xinsts" - ) as mock_update_xinsts: - + with ( + patch.object(LinkedProgram, "_update_minsts") as mock_update_minsts, + patch.object(LinkedProgram, "_update_cinsts") as mock_update_cinsts, + patch.object(LinkedProgram, "_update_xinsts") as mock_update_xinsts, + ): # Setup mock_update_xinsts to return a bundle offset mock_update_xinsts.return_value = 5 @@ -363,15 +348,11 @@ def test_link_kernel(self): xinstr.to_line.return_value = f"xinst{i}" xinstr.comment = f"xinst_comment{i}" if i % 2 == 0 else None - for i, cinstr in enumerate( - kernel_cinstrs[:-1] - ): # Skip the last one (cexit) + for i, cinstr in enumerate(kernel_cinstrs[:-1]): # Skip the last one (cexit) cinstr.to_line.return_value = f"cinst{i}" cinstr.comment = f"cinst_comment{i}" if i % 2 == 0 else None - for i, minstr in enumerate( - kernel_minstrs[:-1] - ): # Skip the last one (msyncc) + for i, minstr in enumerate(kernel_minstrs[:-1]): # Skip the last one (msyncc) minstr.to_line.return_value = f"minst{i}" minstr.comment = f"minst_comment{i}" if i % 2 == 0 else None @@ -387,12 +368,8 @@ def test_link_kernel(self): self.assertEqual(self.program._bundle_offset, 6) # 5 + 1 # Verify line offsets were updated - self.assertEqual( - self.program._minst_line_offset, 1 - ) # len(kernel_minstrs) - 1 - self.assertEqual( - self.program._cinst_line_offset, 1 - ) # len(kernel_cinstrs) - 1 + self.assertEqual(self.program._minst_line_offset, 1) # len(kernel_minstrs) - 1 + self.assertEqual(self.program._cinst_line_offset, 1) # len(kernel_cinstrs) - 1 # Verify kernel count was incremented self.assertEqual(self.program._kernel_count, 1) @@ -419,12 +396,10 @@ def test_link_kernel_with_no_hbm(self): """ with patch.object(GlobalConfig, "hasHBM", False): # Create mocks for the update methods - with patch.object( - LinkedProgram, "_update_cinsts" - ) as mock_update_cinsts, patch.object( - LinkedProgram, "_update_xinsts" - ) as mock_update_xinsts: - + with ( + patch.object(LinkedProgram, "_update_cinsts") as mock_update_cinsts, + patch.object(LinkedProgram, "_update_xinsts") as mock_update_xinsts, + ): # Setup mock_update_xinsts to return a bundle offset mock_update_xinsts.return_value = 5 @@ -476,10 +451,11 @@ def test_link_kernel_with_suppress_comments(self): """ with patch.object(GlobalConfig, "suppress_comments", True): # Create mocks for the update methods - with patch.object(LinkedProgram, "_update_minsts"), patch.object( - LinkedProgram, "_update_cinsts" - ), patch.object(LinkedProgram, "_update_xinsts"): - + with ( + patch.object(LinkedProgram, "_update_minsts"), + patch.object(LinkedProgram, "_update_cinsts"), + patch.object(LinkedProgram, "_update_xinsts"), + ): # Create mock instruction lists with comments kernel_minstrs = [MagicMock(), MagicMock()] kernel_cinstrs = [MagicMock(), MagicMock()] @@ -514,9 +490,7 @@ def test_link_kernels_to_files(self): @test Verifies that kernels are correctly linked and written to output files """ # Create a namedtuple similar to KernelInfo for testing - KernelInfo = namedtuple( - "KernelInfo", ["prefix", "minst", "cinst", "xinst", "mem", "remap_dict"] - ) + KernelInfo = namedtuple("KernelInfo", ["prefix", "minst", "cinst", "xinst", "mem", "remap_dict"]) # Arrange input_files = [ @@ -543,26 +517,25 @@ def test_link_kernels_to_files(self): mock_verbose = MagicMock() # Act - with patch("builtins.open", mock_open()), patch( - "linker.steps.program_linker.Loader.load_minst_kernel_from_file", - return_value=[], - ), patch( - "linker.steps.program_linker.Loader.load_cinst_kernel_from_file", - return_value=[], - ), patch( - "linker.steps.program_linker.Loader.load_xinst_kernel_from_file", - return_value=[], - ), patch.object( - LinkedProgram, "__init__", return_value=None - ) as mock_init, patch.object( - LinkedProgram, "link_kernel" - ) as mock_link_kernel, patch.object( - LinkedProgram, "close" - ) as mock_close: - - LinkedProgram.link_kernels_to_files( - input_files, output_files, mock_mem_model, mock_verbose - ) + with ( + patch("builtins.open", mock_open()), + patch( + "linker.steps.program_linker.Loader.load_minst_kernel_from_file", + return_value=[], + ), + patch( + "linker.steps.program_linker.Loader.load_cinst_kernel_from_file", + return_value=[], + ), + patch( + "linker.steps.program_linker.Loader.load_xinst_kernel_from_file", + return_value=[], + ), + patch.object(LinkedProgram, "__init__", return_value=None) as mock_init, + patch.object(LinkedProgram, "link_kernel") as mock_link_kernel, + patch.object(LinkedProgram, "close") as mock_close, + ): + LinkedProgram.link_kernels_to_files(input_files, output_files, mock_mem_model, mock_verbose) # Assert mock_init.assert_called_once() @@ -588,9 +561,7 @@ def setUp(self): self.mock_has_hbm = self.has_hbm_patcher.start() # Mock the suppress_comments property to return False by default - self.suppress_comments_patcher = patch.object( - GlobalConfig, "suppress_comments", False - ) + self.suppress_comments_patcher = patch.object(GlobalConfig, "suppress_comments", False) self.mock_suppress_comments = self.suppress_comments_patcher.start() self.program = LinkedProgram( @@ -731,9 +702,7 @@ def test_join_dinst_kernels_multiple_kernels(self): mock_dstore2.var = "var4" # Execute the method - result = LinkedProgram.join_dinst_kernels( - [[mock_dload1, mock_dstore1], [mock_dload2, mock_dkeygen, mock_dstore2]] - ) + result = LinkedProgram.join_dinst_kernels([[mock_dload1, mock_dstore1], [mock_dload2, mock_dkeygen, mock_dstore2]]) # Verify result - should contain load1, store1 (output), keygen, store2 (output) # dload2 should be skipped since it loads var2 which is already an output from kernel1 @@ -768,9 +737,7 @@ def test_join_dinst_kernels_with_carry_over_vars(self): mock_dstore2.var = "var2" # Same variable is also an output # Execute the method - result = LinkedProgram.join_dinst_kernels( - [[mock_dload1, mock_dstore1], [mock_dload2, mock_dstore2]] - ) + result = LinkedProgram.join_dinst_kernels([[mock_dload1, mock_dstore1], [mock_dload2, mock_dstore2]]) # Verify result - should contain load1, store2 # Both dload2 and dstore1 should be skipped since var2 is carried over diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py index 241bb3fe..7e3e121c 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py @@ -9,16 +9,16 @@ """ import unittest -from unittest.mock import patch, MagicMock from collections import namedtuple -import pytest +from unittest.mock import MagicMock, patch +import pytest from assembler.common.config import GlobalConfig from linker.steps.variable_discovery import ( + check_unused_variables, discover_variables, discover_variables_spad, scan_variables, - check_unused_variables, ) @@ -47,9 +47,7 @@ def setUp(self): @patch("linker.steps.variable_discovery.minst") @patch("linker.steps.variable_discovery.MInstruction") @patch("assembler.memory_model.variable.Variable.validateName") - def test_discover_variables_valid( - self, mock_validate, mock_minst_class, mock_minst - ): + def test_discover_variables_valid(self, mock_validate, mock_minst_class, mock_minst): """@brief Test discovering variables from valid MInstructions. @test Verifies that variables are correctly discovered from MLoad and MStore instructions @@ -81,7 +79,6 @@ def is_mstore_side_effect(obj): mock_minst.MStore: is_mstore_side_effect(obj), }.get(cls, False), ): - # Configure validateName to return True mock_validate.return_value = True @@ -154,9 +151,7 @@ def test_discover_variables_invalid_variable_name(self, mock_validate, mock_mins @patch("linker.steps.variable_discovery.cinst") @patch("linker.steps.variable_discovery.CInstruction") @patch("assembler.memory_model.variable.Variable.validateName") - def test_discover_variables_spad_valid( - self, mock_validate, mock_cinst_class, mock_cinst - ): + def test_discover_variables_spad_valid(self, mock_validate, mock_cinst_class, mock_cinst): """@brief Test discovering variables from valid CInstructions. @test Verifies that variables are correctly discovered from all relevant CInstruction types @@ -201,9 +196,7 @@ def mock_isinstance(obj, cls): return class_checks.get(cls, lambda: False)() # Patch the isinstance calls at the module level - with patch( - "linker.steps.variable_discovery.isinstance", side_effect=mock_isinstance - ): + with patch("linker.steps.variable_discovery.isinstance", side_effect=mock_isinstance): # Call the function result = list(discover_variables_spad(cinstrs)) @@ -245,9 +238,7 @@ def test_discover_variables_spad_invalid_type(self): @patch("linker.steps.variable_discovery.cinst") @patch("linker.steps.variable_discovery.CInstruction") @patch("assembler.memory_model.variable.Variable.validateName") - def test_discover_variables_spad_invalid_variable_name( - self, mock_validate, mock_cinst_class, mock_cinst - ): + def test_discover_variables_spad_invalid_variable_name(self, mock_validate, mock_cinst_class, mock_cinst): """@brief Test discovering variables with an invalid variable name. @test Verifies that a RuntimeError is raised when a variable name is invalid @@ -311,18 +302,23 @@ def test_scan_variables(self): mock_verbose = MagicMock() # Act - with patch( - "linker.steps.variable_discovery.Loader.load_minst_kernel_from_file", - return_value=[], - ), patch( - "linker.steps.variable_discovery.Loader.load_cinst_kernel_from_file", - return_value=[], - ), patch( - "linker.steps.variable_discovery.discover_variables", - return_value=["var1", "var2"], - ), patch( - "linker.steps.variable_discovery.discover_variables_spad", - return_value=["var1", "var2"], + with ( + patch( + "linker.steps.variable_discovery.Loader.load_minst_kernel_from_file", + return_value=[], + ), + patch( + "linker.steps.variable_discovery.Loader.load_cinst_kernel_from_file", + return_value=[], + ), + patch( + "linker.steps.variable_discovery.discover_variables", + return_value=["var1", "var2"], + ), + patch( + "linker.steps.variable_discovery.discover_variables_spad", + return_value=["var1", "var2"], + ), ): scan_variables(input_files, mock_mem_model, mock_verbose) diff --git a/docs/Encrypted-Computing-SDK/empty_file.txt b/docs/Encrypted-Computing-SDK/empty_file.txt index 8b137891..e69de29b 100644 --- a/docs/Encrypted-Computing-SDK/empty_file.txt +++ b/docs/Encrypted-Computing-SDK/empty_file.txt @@ -1 +0,0 @@ - diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index a20fbe73..00000000 --- a/mypy.ini +++ /dev/null @@ -1,2 +0,0 @@ -[mypy] -mypy_path=p-isa_tools/kerngen diff --git a/p-isa_tools/CPPLINT.cfg b/p-isa_tools/CPPLINT.cfg index 597aa8ae..89479da9 100644 --- a/p-isa_tools/CPPLINT.cfg +++ b/p-isa_tools/CPPLINT.cfg @@ -7,5 +7,6 @@ filter=-readability/todo filter=-runtime/references filter=-runtime/explicit filter=-build/c++11 +filter=-build/c++17 # is C++17 only filter=-build/namespaces filter=-build/include diff --git a/p-isa_tools/common/config.h.in b/p-isa_tools/common/config.h.in index dac9b916..d5afef16 100644 --- a/p-isa_tools/common/config.h.in +++ b/p-isa_tools/common/config.h.in @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #cmakedefine01 ENABLE_DATA_FORMATS diff --git a/p-isa_tools/common/graph/graph.h b/p-isa_tools/common/graph/graph.h index cf42543e..2fd24494 100644 --- a/p-isa_tools/common/graph/graph.h +++ b/p-isa_tools/common/graph/graph.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/common/p_isa/isa_instruction.h b/p-isa_tools/common/p_isa/isa_instruction.h index c1d3a742..736f6d7e 100644 --- a/p-isa_tools/common/p_isa/isa_instruction.h +++ b/p-isa_tools/common/p_isa/isa_instruction.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/common/p_isa/p_isa.h b/p-isa_tools/common/p_isa/p_isa.h index f940e8d6..c56b7034 100644 --- a/p-isa_tools/common/p_isa/p_isa.h +++ b/p-isa_tools/common/p_isa/p_isa.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/common/p_isa/p_isa_hardware_models.h b/p-isa_tools/common/p_isa/p_isa_hardware_models.h index 2c07f5b2..0d7a03c7 100644 --- a/p-isa_tools/common/p_isa/p_isa_hardware_models.h +++ b/p-isa_tools/common/p_isa/p_isa_hardware_models.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once @@ -57,9 +57,9 @@ class ExampleHardware : public PISAHardwareModel }; MemorySizesMap = { - { "MEMORY", uint64_t(1572000) }, - { "CACHE", uint64_t(1572000) }, - { "REGISTER", uint64_t(1572000) }, + { "MEMORY", static_cast(1572000) }, + { "CACHE", static_cast(1572000) }, + { "REGISTER", static_cast(1572000) }, }; } @@ -91,9 +91,9 @@ class Model1 : public PISAHardwareModel }; MemorySizesMap = { - { "MEMORY", uint64_t(1572000) }, - { "CACHE", uint64_t(1572000) }, - { "REGISTER", uint64_t(1572000) }, + { "MEMORY", static_cast(1572000) }, + { "CACHE", static_cast(1572000) }, + { "REGISTER", static_cast(1572000) }, }; } @@ -125,9 +125,9 @@ class Model2 : public PISAHardwareModel }; MemorySizesMap = { - { "MEMORY", uint64_t(1572000) }, - { "CACHE", uint64_t(2048) }, - { "REGISTER", uint64_t(256) }, + { "MEMORY", static_cast(1572000) }, + { "CACHE", static_cast(2048) }, + { "REGISTER", static_cast(256) }, }; } diff --git a/p-isa_tools/common/p_isa/p_isa_instruction.cpp b/p-isa_tools/common/p_isa/p_isa_instruction.cpp index 6bff8f40..4fb6004a 100644 --- a/p-isa_tools/common/p_isa/p_isa_instruction.cpp +++ b/p-isa_tools/common/p_isa/p_isa_instruction.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #include "p_isa_instruction.h" diff --git a/p-isa_tools/common/p_isa/p_isa_instruction.h b/p-isa_tools/common/p_isa/p_isa_instruction.h index e11d400e..9b87ffb4 100644 --- a/p-isa_tools/common/p_isa/p_isa_instruction.h +++ b/p-isa_tools/common/p_isa/p_isa_instruction.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/common/p_isa/p_isa_instructions.h b/p-isa_tools/common/p_isa/p_isa_instructions.h index 306acd62..a8e656cb 100644 --- a/p-isa_tools/common/p_isa/p_isa_instructions.h +++ b/p-isa_tools/common/p_isa/p_isa_instructions.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/common/p_isa/p_isa_performance_modeler.cpp b/p-isa_tools/common/p_isa/p_isa_performance_modeler.cpp index 1cce6f6d..0e25aba4 100644 --- a/p-isa_tools/common/p_isa/p_isa_performance_modeler.cpp +++ b/p-isa_tools/common/p_isa/p_isa_performance_modeler.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #include "p_isa_performance_modeler.h" @@ -27,8 +27,8 @@ void PISAPerformanceModeler::addGraphAnalysis(PerformanceReport &report) { depth++; auto input_nodes = p_isa_graph_instructions.getInputNodes(true, true, true); - report.graph_min_width = std::min(report.graph_min_width, (int64_t)input_nodes.size()); - report.graph_max_width = std::max(report.graph_max_width, (int64_t)input_nodes.size()); + report.graph_min_width = std::min(report.graph_min_width, static_cast(input_nodes.size())); + report.graph_max_width = std::max(report.graph_max_width, static_cast(input_nodes.size())); report.graph_average_width += input_nodes.size(); for (auto &input : input_nodes) { diff --git a/p-isa_tools/common/p_isa/p_isa_performance_modeler.h b/p-isa_tools/common/p_isa/p_isa_performance_modeler.h index 4fd0dac1..b4d1f44a 100644 --- a/p-isa_tools/common/p_isa/p_isa_performance_modeler.h +++ b/p-isa_tools/common/p_isa/p_isa_performance_modeler.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/common/p_isa/parser/p_isa_parser.cpp b/p-isa_tools/common/p_isa/parser/p_isa_parser.cpp index a133ed64..4058ceda 100644 --- a/p-isa_tools/common/p_isa/parser/p_isa_parser.cpp +++ b/p-isa_tools/common/p_isa/parser/p_isa_parser.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #include diff --git a/p-isa_tools/common/p_isa/parser/p_isa_parser.h b/p-isa_tools/common/p_isa/parser/p_isa_parser.h index 9aa5ab6f..3c3c7c35 100644 --- a/p-isa_tools/common/p_isa/parser/p_isa_parser.h +++ b/p-isa_tools/common/p_isa/parser/p_isa_parser.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/common/string.h b/p-isa_tools/common/string.h index 914080cb..8ebeef41 100644 --- a/p-isa_tools/common/string.h +++ b/p-isa_tools/common/string.h @@ -1,5 +1,5 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/common/timer/timer.h b/p-isa_tools/common/timer/timer.h index dab5f8da..31baeb95 100644 --- a/p-isa_tools/common/timer/timer.h +++ b/p-isa_tools/common/timer/timer.h @@ -1,6 +1,3 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - // Copyright (C) 2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 diff --git a/p-isa_tools/functional_modeler/data_handlers/hec_dataformats_handler.h b/p-isa_tools/functional_modeler/data_handlers/hec_dataformats_handler.h index 78f50ca9..f7db97f1 100644 --- a/p-isa_tools/functional_modeler/data_handlers/hec_dataformats_handler.h +++ b/p-isa_tools/functional_modeler/data_handlers/hec_dataformats_handler.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/functional_modeler/data_handlers/json_data_handler.h b/p-isa_tools/functional_modeler/data_handlers/json_data_handler.h index 5417514d..1652521c 100644 --- a/p-isa_tools/functional_modeler/data_handlers/json_data_handler.h +++ b/p-isa_tools/functional_modeler/data_handlers/json_data_handler.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/functional_modeler/functional_models/multiregister.h b/p-isa_tools/functional_modeler/functional_models/multiregister.h index a592a3a5..f3f1adb1 100644 --- a/p-isa_tools/functional_modeler/functional_models/multiregister.h +++ b/p-isa_tools/functional_modeler/functional_models/multiregister.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/functional_modeler/functional_models/p_isa_functional_model.h b/p-isa_tools/functional_modeler/functional_models/p_isa_functional_model.h index 0a5b1a71..51666b5a 100644 --- a/p-isa_tools/functional_modeler/functional_models/p_isa_functional_model.h +++ b/p-isa_tools/functional_modeler/functional_models/p_isa_functional_model.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/functional_modeler/functional_models/p_isa_memory_model.h b/p-isa_tools/functional_modeler/functional_models/p_isa_memory_model.h index 7d15f34a..3497b4f6 100644 --- a/p-isa_tools/functional_modeler/functional_models/p_isa_memory_model.h +++ b/p-isa_tools/functional_modeler/functional_models/p_isa_memory_model.h @@ -1,5 +1,5 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/functional_modeler/functional_models/utility_functions.h b/p-isa_tools/functional_modeler/functional_models/utility_functions.h index e5be9d35..245db1ca 100644 --- a/p-isa_tools/functional_modeler/functional_models/utility_functions.h +++ b/p-isa_tools/functional_modeler/functional_models/utility_functions.h @@ -1,5 +1,5 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/functional_modeler/main.cpp b/p-isa_tools/functional_modeler/main.cpp index fe1121c8..bd28e614 100644 --- a/p-isa_tools/functional_modeler/main.cpp +++ b/p-isa_tools/functional_modeler/main.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #include diff --git a/p-isa_tools/functional_modeler/pisa_runtime/p_isa_instruction_trace.h b/p-isa_tools/functional_modeler/pisa_runtime/p_isa_instruction_trace.h index dde3db14..cffe5487 100644 --- a/p-isa_tools/functional_modeler/pisa_runtime/p_isa_instruction_trace.h +++ b/p-isa_tools/functional_modeler/pisa_runtime/p_isa_instruction_trace.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/functional_modeler/pisa_runtime/pisaprogramruntime.h b/p-isa_tools/functional_modeler/pisa_runtime/pisaprogramruntime.h index 4e50a721..72e891b9 100644 --- a/p-isa_tools/functional_modeler/pisa_runtime/pisaprogramruntime.h +++ b/p-isa_tools/functional_modeler/pisa_runtime/pisaprogramruntime.h @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once diff --git a/p-isa_tools/kerngen/.gitignore b/p-isa_tools/kerngen/.gitignore index 14b0c1e6..a01021cc 100644 --- a/p-isa_tools/kerngen/.gitignore +++ b/p-isa_tools/kerngen/.gitignore @@ -12,4 +12,4 @@ data/ zz_playground/ # vs code local config -.vscode \ No newline at end of file +.vscode diff --git a/p-isa_tools/kerngen/README.md b/p-isa_tools/kerngen/README.md index 1f76ca50..e3433ad7 100644 --- a/p-isa_tools/kerngen/README.md +++ b/p-isa_tools/kerngen/README.md @@ -121,7 +121,7 @@ kernel instructions given in the manifest file, see - second field defines the polynomial size for the `DATA`. This is required by the generating kernels to define how many units (multiples of the native polynomial size, 8192 in HERACLES silicon case) are required and handled. -- third field defines the key RNS, i.e. the total max RNS of the relinearization key, typically the global max number of how many 32 bit prime number moduli (HERACLES silicon case) are in the modulus chain that the kernels can have or need to handle + 1. +- third field defines the key RNS, i.e. the total max RNS of the relinearization key, typically the global max number of how many 32 bit prime number moduli (HERACLES silicon case) are in the modulus chain that the kernels can have or need to handle + 1. - fourth field defines the number of RNS terms in the current polynomial. ## DATA diff --git a/p-isa_tools/kerngen/__init__.py b/p-isa_tools/kerngen/__init__.py index b5309c91..4057dc01 100644 --- a/p-isa_tools/kerngen/__init__.py +++ b/p-isa_tools/kerngen/__init__.py @@ -1 +1,2 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/p-isa_tools/kerngen/const/options.py b/p-isa_tools/kerngen/const/options.py index 16fdfc14..c712fb17 100644 --- a/p-isa_tools/kerngen/const/options.py +++ b/p-isa_tools/kerngen/const/options.py @@ -4,6 +4,7 @@ # generative artificial intelligence solutions. """Module for defining constants and enums used in the kernel generator""" + from enum import Enum diff --git a/p-isa_tools/kerngen/high_parser/__init__.py b/p-isa_tools/kerngen/high_parser/__init__.py index 9afad87e..3504aba4 100644 --- a/p-isa_tools/kerngen/high_parser/__init__.py +++ b/p-isa_tools/kerngen/high_parser/__init__.py @@ -1,13 +1,24 @@ # Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 """Imports commonly used when using this package""" from high_parser.types import ( Context, - Immediate, HighOp, - expand_ios, - Polys, + Immediate, KernelContext, KeyPolys, + Polys, + expand_ios, ) + +__all__ = [ + "Context", + "HighOp", + "Immediate", + "KernelContext", + "KeyPolys", + "Polys", + "expand_ios", +] diff --git a/p-isa_tools/kerngen/high_parser/generators.py b/p-isa_tools/kerngen/high_parser/generators.py index 0b499cef..4c2c3dd9 100644 --- a/p-isa_tools/kerngen/high_parser/generators.py +++ b/p-isa_tools/kerngen/high_parser/generators.py @@ -33,9 +33,7 @@ def from_manifest(cls, filepath: str, scheme: str): try: return cls(dirpath, manifest[scheme.upper()]) except KeyError as e: - raise GeneratorError( - f"Scheme `{scheme.upper()}` not found in manifest file" - ) from e + raise GeneratorError(f"Scheme `{scheme.upper()}` not found in manifest file") from e def available_kernels(self) -> str: """Returns a list of available pisa ops.""" @@ -55,8 +53,6 @@ def get_kernel(self, opname: str): except KeyError as e: raise GeneratorError(f"Op not found in available pisa ops: {opname}") from e except AttributeError as e: - raise GeneratorError( - f"Class for op `{opname}` name not found: {class_name}" - ) from e + raise GeneratorError(f"Class for op `{opname}` name not found: {class_name}") from e except ImportError as e: raise GeneratorError(f"Unable to import module: {module_path}") from e diff --git a/p-isa_tools/kerngen/high_parser/options_handler.py b/p-isa_tools/kerngen/high_parser/options_handler.py index fe5e27b1..dbc4b306 100644 --- a/p-isa_tools/kerngen/high_parser/options_handler.py +++ b/p-isa_tools/kerngen/high_parser/options_handler.py @@ -4,13 +4,14 @@ """A module to process optional key/value dictionary parameters""" from abc import ABC, abstractmethod +from typing import Any class OptionsDict(ABC): """Abstract class to hold the options key/value pairs""" op_name: str = "" - op_value = None + op_value: Any = None @abstractmethod def validate(self, value): @@ -42,9 +43,7 @@ def op_value(self, value: int): if self.validate(value): self._op_value = int(value) else: - raise ValueError( - "{self.op_name} must be in range ({self.min_val}, {self.max_val}): {self.op_name}={self.op_value}" - ) + raise ValueError("{self.op_name} must be in range ({self.min_val}, {self.max_val}): {self.op_name}={self.op_value}") class OptionsIntBounds: @@ -127,11 +126,7 @@ def parse(options: list[str]): for option in options: try: key, value = option.split("=") - output_dict[key] = OptionsDictFactoryDispatcher.create( - key, value - ).op_value + output_dict[key] = OptionsDictFactoryDispatcher.create(key, value).op_value except ValueError as err: - raise ValueError( - f"Options must be key/value pairs (e.g. num_digits=3): '{option}'" - ) from err + raise ValueError(f"Options must be key/value pairs (e.g. num_digits=3): '{option}'") from err return output_dict diff --git a/p-isa_tools/kerngen/high_parser/parser.py b/p-isa_tools/kerngen/high_parser/parser.py index c4d2e5b6..3c0b43dc 100644 --- a/p-isa_tools/kerngen/high_parser/parser.py +++ b/p-isa_tools/kerngen/high_parser/parser.py @@ -3,26 +3,24 @@ """Module for parsing isa commands""" +from collections.abc import Iterator from pathlib import Path -from typing import Iterator from .config import Config from .generators import Generators from .pisa_operations import PIsaOp from .types import ( - Context, - KernelContext, Comment, - EmptyLine, + Context, Data, + EmptyLine, + HighOp, Immediate, + KernelContext, Polys, - HighOp, ) -MANIFEST_FILE = str( - Path(__file__).parent.parent.absolute() / "pisa_generators/manifest.json" -) +MANIFEST_FILE = str(Path(__file__).parent.parent.absolute() / "pisa_generators/manifest.json") Symbol = str @@ -41,9 +39,7 @@ def _get_context_from_commands_list(commands): if not context_list: raise LookupError("No Context found for commands list for ParseResults") if len(context_list) > 1: - raise LookupError( - "Multiple Context found in commands list for ParseResults" - ) + raise LookupError("Multiple Context found in commands list for ParseResults") return context_list[0] @property @@ -70,10 +66,7 @@ def get_pisa_ops(self) -> Iterator[list[PIsaOp] | None]: if isinstance(command, HighOp) and hasattr(command, "context"): command.context.label = str(self.context.ntt_stages) - return ( - command.to_pisa() if isinstance(command, HighOp) else None - for command in commands - ) + return (command.to_pisa() if isinstance(command, HighOp) else None for command in commands) class Parser: @@ -132,18 +125,14 @@ def _delegate(self, command_str: str, context_seen: list[Context], symbols_map): case _: # If context has not been given yet - FAIL if len(context_seen) == 0: - raise RuntimeError( - f"No `CONTEXT` provided before `{command_str.rstrip()}`" - ) + raise RuntimeError(f"No `CONTEXT` provided before `{command_str.rstrip()}`") # Look up commands defined in manifest if self.generators is None: raise ValueError("Generator not set") cls = self.generators.get_kernel(command) - kernel_context = KernelContext.from_context( - context_seen[0], label=label - ) + kernel_context = KernelContext.from_context(context_seen[0], label=label) return cls.from_string(kernel_context, symbols_map, rest) def parse_inputs(self, lines: list[str]) -> ParseResults: diff --git a/p-isa_tools/kerngen/high_parser/pisa_operations.py b/p-isa_tools/kerngen/high_parser/pisa_operations.py index a6dc8959..1f320dec 100644 --- a/p-isa_tools/kerngen/high_parser/pisa_operations.py +++ b/p-isa_tools/kerngen/high_parser/pisa_operations.py @@ -3,8 +3,8 @@ """Module containing the low level p-isa operations""" -from dataclasses import dataclass from abc import ABC, abstractmethod +from dataclasses import dataclass from .config import Config @@ -45,9 +45,7 @@ class BinaryOp: def _op_str(self, op: str) -> str: """Return the p-isa instructions of operation `op`""" - return ( - f"{self.label}, {op}, {self.output}, {self.input0}, {self.input1}, {self.q}" - ) + return f"{self.label}, {op}, {self.output}, {self.input0}, {self.input1}, {self.q}" @dataclass @@ -146,10 +144,7 @@ def _op_str(self, op: str) -> str: f"{self.label}, {op}, {self.output0}, {self.output1}, " f"{self.input0}, {self.input1}, w_{self.q}_{self.stage}_{self.unit}, {self.q}" ) - return ( - f"{self.label}, {op}, {self.output0}, {self.output1}, " - f"{self.input0}, {self.input1}, {self.stage}, {self.unit}, {self.q}" - ) + return f"{self.label}, {op}, {self.output0}, {self.output1}, " f"{self.input0}, {self.input1}, {self.stage}, {self.unit}, {self.q}" class NTT(Butterfly, PIsaOp): diff --git a/p-isa_tools/kerngen/high_parser/types.py b/p-isa_tools/kerngen/high_parser/types.py index f6a6ead1..fac22f85 100644 --- a/p-isa_tools/kerngen/high_parser/types.py +++ b/p-isa_tools/kerngen/high_parser/types.py @@ -3,17 +3,15 @@ """Module for parsing isa commands""" -import math import itertools as it +import math from abc import ABC, abstractmethod from dataclasses import dataclass - from pydantic import BaseModel -from .pisa_operations import PIsaOp - from .options_handler import OptionsDictParser +from .pisa_operations import PIsaOp class PolyOutOfBoundsError(Exception): @@ -36,9 +34,7 @@ def expand(self, *args) -> str: part, q, unit = args # Sanity bounds checks if self.start_parts > part >= self.parts or self.start_rns > q >= self.rns: - raise PolyOutOfBoundsError( - f"part `{part}` or q `{q}` are not within the poly's range `{self!r}`" - ) + raise PolyOutOfBoundsError(f"part `{part}` or q `{q}` are not within the poly's range `{self!r}`") return f"{self.name}_{part}_{q}_{unit}" def __call__(self, *args) -> str: @@ -84,14 +80,8 @@ def expand(self, *args) -> str: """Returns a string of the expanded symbol w.r.t. digit, rns, part, and unit""" digit, part, q, unit = args # Sanity bounds checks - if ( - self.start_parts > part >= self.parts - or self.start_rns > q >= self.rns - or digit > self.digits - ): - raise PolyOutOfBoundsError( - f"part `{digit}` or `{part}` or q `{q}` are not within the key poly's range `{self!r}`" - ) + if self.start_parts > part >= self.parts or self.start_rns > q >= self.rns or digit > self.digits: + raise PolyOutOfBoundsError(f"part `{digit}` or `{part}` or q `{q}` are not within the key poly's range `{self!r}`") return f"{self.name}_{part}_{digit}_{q}_{unit}" @@ -154,7 +144,7 @@ class Context(BaseModel): key_rns: int current_rns: int # optional vars for context - num_digits: int | None + num_digits: int | None = None # calculated based on required params max_rns: int @@ -165,23 +155,15 @@ def from_string(cls, line: str): scheme, poly_order, key_rns, current_rns, *optionals = line.split() optional_dict = OptionsDictParser.parse(optionals) int_poly_order = int(poly_order) - if ( - int_poly_order < MIN_POLY_SIZE - or int_poly_order > MAX_POLY_SIZE - or not math.log2(int_poly_order).is_integer() - ): - raise ValueError( - f"Poly order `{int_poly_order}` must be power of two >= {MIN_POLY_SIZE} and < {MAX_POLY_SIZE}" - ) + if int_poly_order < MIN_POLY_SIZE or int_poly_order > MAX_POLY_SIZE or not math.log2(int_poly_order).is_integer(): + raise ValueError(f"Poly order `{int_poly_order}` must be power of two >= {MIN_POLY_SIZE} and < {MAX_POLY_SIZE}") int_key_rns = int(key_rns) int_current_rns = int(current_rns) int_max_rns = int_key_rns - 1 if int_key_rns <= int_current_rns: - raise ValueError( - f"Current RNS must be less than Key RNS: current_rns={current_rns}, key_rns={key_rns}" - ) + raise ValueError(f"Current RNS must be less than Key RNS: current_rns={current_rns}, key_rns={key_rns}") return cls( scheme=scheme.upper(), @@ -211,7 +193,7 @@ class KernelContext(Context): @classmethod def from_context(cls, context: Context, label: str = "0") -> "KernelContext": - """Create a kernel context froma context (and optionally a label)""" + """Create a kernel context from context (and optionally a label)""" return cls(label=label, **vars(context)) @@ -243,9 +225,7 @@ def __call__(self, *args, **kwargs): # Sanity bounds checks q = args[1] if q > self.rns: - raise PolyOutOfBoundsError( - f"q `{q}` is more than the immediate with RNS `{self!r}`" - ) + raise PolyOutOfBoundsError(f"q `{q}` is more than the immediate with RNS `{self!r}`") return f"{self.name}_{q}" @classmethod diff --git a/p-isa_tools/kerngen/kernel_optimization/loops.py b/p-isa_tools/kerngen/kernel_optimization/loops.py index 0f09806a..9c5907f8 100644 --- a/p-isa_tools/kerngen/kernel_optimization/loops.py +++ b/p-isa_tools/kerngen/kernel_optimization/loops.py @@ -6,8 +6,9 @@ """Module for loop interchange optimization in P-ISA operations""" import re + from const.options import LoopKey -from high_parser.pisa_operations import PIsaOp, Comment +from high_parser.pisa_operations import Comment, PIsaOp class PIsaOpGroup: @@ -66,7 +67,8 @@ def split_by_reorderable(pisa_list: list[PIsaOp]) -> list[PIsaOpGroup]: no_reoderable_group = True for pisa in pisa_list: - # if the pisa is a comment and it contains tag, treat the following pisa as reorderable until a tag is found. + # if the pisa is a comment and it contains tag, + # treat the following pisa as reorderable until a tag is found. if isinstance(pisa, Comment): if "" in pisa.line: # If current group has instructions, append it to groups first diff --git a/p-isa_tools/kerngen/kernel_parser/parser.py b/p-isa_tools/kerngen/kernel_parser/parser.py index a99c3d46..d48b56b0 100644 --- a/p-isa_tools/kerngen/kernel_parser/parser.py +++ b/p-isa_tools/kerngen/kernel_parser/parser.py @@ -6,14 +6,15 @@ """Module for parsing kernel commands from Kerngen""" import re -from high_parser.types import Immediate, KernelContext, Polys, Context -from pisa_generators.basic import Copy, HighOp, Add, Sub, Mul, Muli -from pisa_generators.ntt import NTT, INTT -from pisa_generators.square import Square -from pisa_generators.relin import Relin -from pisa_generators.rotate import Rotate + +from high_parser.types import Context, Immediate, KernelContext, Polys +from pisa_generators.basic import Add, Copy, HighOp, Mul, Muli, Sub from pisa_generators.mod import Mod, ModUp +from pisa_generators.ntt import INTT, NTT +from pisa_generators.relin import Relin from pisa_generators.rescale import Rescale +from pisa_generators.rotate import Rotate +from pisa_generators.square import Square class KernelParser: @@ -39,8 +40,7 @@ class KernelParser: def parse_context(context_str: str) -> KernelContext: """Parse the context string and return a KernelContext object.""" context_match = re.search( - r"KernelContext\(scheme='(?P\w+)', " - + r"poly_order=(?P\w+), key_rns=(?P\w+), " + r"KernelContext\(scheme='(?P\w+)', " + r"poly_order=(?P\w+), key_rns=(?P\w+), " r"current_rns=(?P\w+), .*? label='(?P