Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
58c12f8
WIP
tbittar Jun 21, 2024
0f285c6
Operator simplification on construction
tbittar Jun 21, 2024
ef9c8b6
Improve common operations on linear expr
tbittar Jun 25, 2024
4f0542c
Handle constraints
tbittar Jun 25, 2024
946c366
Equality comparator for linear expression
tbittar Jun 26, 2024
e1ebb29
WIP
tbittar Jun 26, 2024
10c2066
Sum shift
tbittar Jun 27, 2024
d13a649
WIP
tbittar Jul 11, 2024
75a5881
Shift and eval implementation in progress
tbittar Jul 12, 2024
2860242
TimeShift hashing
tbittar Jul 12, 2024
b640b22
Implement shift, eval and time sum of linear expressions
tbittar Jul 15, 2024
5f2823f
Test sum of linear expressions
tbittar Jul 15, 2024
e90254e
Fix printing term tests
tbittar Jul 16, 2024
c5214eb
More online simplifications, start handling constraints and ports
tbittar Jul 16, 2024
9ca614a
Make param and literal return expression node to later handle Express…
tbittar Jul 22, 2024
3d03d86
Fix model test
tbittar Jul 22, 2024
3251495
Fix test imports and port definition validation
tbittar Jul 23, 2024
6f85db0
Design problem between time operator / expression node
tbittar Jul 23, 2024
13ea7de
Fix circular imports
tbittar Jul 25, 2024
d092b1b
Fix syntax
tbittar Jul 26, 2024
48c96ff
Resolve linear expression
tbittar Jul 26, 2024
9de9cae
Fix circular imports
tbittar Aug 14, 2024
3a4004c
Parameter evaluation visitor implemented
tbittar Aug 14, 2024
8007306
Be able to create variables
tbittar Aug 14, 2024
c03705c
Merge branch 'main' into feature/better_expression_linearization
tbittar Aug 16, 2024
1e1ac18
Test resolve coefficients, update test evaluation context
tbittar Aug 16, 2024
1be9d84
Improve resolve coefficient API
tbittar Aug 16, 2024
b532a41
Start resolve variables, separate optimization context from optimizat…
tbittar Aug 19, 2024
b429782
Set objective
tbittar Aug 19, 2024
bda088c
Fix shift distribution and add component context for time operators
tbittar Aug 20, 2024
2c723cd
Temporary API for single shift over ExpressionNodeEfficient
tbittar Aug 21, 2024
27f8ad8
Fix variable get structure
tbittar Aug 21, 2024
bcafb95
Fix expectation computation
tbittar Aug 21, 2024
b0197da
Feature/update yaml parsing (#51)
tbittar Aug 21, 2024
1911e29
Fix interaction resolve ports / add component context
tbittar Aug 21, 2024
a904d9d
Uniformize imports
tbittar Aug 21, 2024
8b7a244
Fix some type checking issues, remove useless code
tbittar Aug 21, 2024
85d1628
Remove useless commented code
tbittar Aug 21, 2024
b5c3328
Remove useless commented code
tbittar Aug 21, 2024
344ac65
Improve type checking for constraint
tbittar Aug 22, 2024
600d1bc
Improve type checking for port field def
tbittar Aug 22, 2024
d3317ef
Improve type checking for **kwargs
tbittar Aug 22, 2024
7c9d427
Type checking and reformatting
tbittar Aug 22, 2024
6022781
Remove useless code
tbittar Aug 23, 2024
93eda7e
Rename files
tbittar Aug 27, 2024
952bce4
Rename file
tbittar Aug 27, 2024
85f3953
Remove 'efficient' suffix
tbittar Aug 27, 2024
f50f57f
Rename test files
tbittar Aug 27, 2024
abd1eed
Remove useless comments
tbittar Aug 27, 2024
f6a03c6
Comment and reformatting
tbittar Aug 27, 2024
9431fe7
WIP for parsing
tbittar Aug 27, 2024
02f8a39
WIP for yaml parsing
tbittar Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Operator simplification on construction
  • Loading branch information
tbittar committed Jun 21, 2024
commit 0f285c6513b753efd7b4e0af7a6ffcf266f8c161
97 changes: 88 additions & 9 deletions src/andromede/expression/expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import enum
import inspect
from dataclasses import dataclass, field
import math
from typing import Any, Callable, List, Optional, Sequence, Union

import andromede.expression.port_operator
Expand Down Expand Up @@ -43,31 +44,31 @@ class ExpressionNodeEfficient:
instances: Instances = field(init=False, default=Instances.SIMPLE)

def __neg__(self) -> "ExpressionNodeEfficient":
return NegationNode(self)
return _negate_node(self)

def __add__(self, rhs: Any) -> "ExpressionNodeEfficient":
return _apply_if_node(rhs, lambda x: AdditionNode(self, x))
return _apply_if_node(rhs, lambda x: _add_node(self, x))

def __radd__(self, lhs: Any) -> "ExpressionNodeEfficient":
return _apply_if_node(lhs, lambda x: AdditionNode(x, self))
return _apply_if_node(lhs, lambda x: _add_node(x, self))

def __sub__(self, rhs: Any) -> "ExpressionNodeEfficient":
return _apply_if_node(rhs, lambda x: SubstractionNode(self, x))
return _apply_if_node(rhs, lambda x: _substract_node(self, x))

def __rsub__(self, lhs: Any) -> "ExpressionNodeEfficient":
return _apply_if_node(lhs, lambda x: SubstractionNode(x, self))
return _apply_if_node(lhs, lambda x: _substract_node(x, self))

def __mul__(self, rhs: Any) -> "ExpressionNodeEfficient":
return _apply_if_node(rhs, lambda x: MultiplicationNode(self, x))
return _apply_if_node(rhs, lambda x: _multiply_node(self, x))

def __rmul__(self, lhs: Any) -> "ExpressionNodeEfficient":
return _apply_if_node(lhs, lambda x: MultiplicationNode(x, self))
return _apply_if_node(lhs, lambda x: _multiply_node(x, self))

def __truediv__(self, rhs: Any) -> "ExpressionNodeEfficient":
return _apply_if_node(rhs, lambda x: DivisionNode(self, x))
return _apply_if_node(rhs, lambda x: _divide_node(self, x))

def __rtruediv__(self, lhs: Any) -> "ExpressionNodeEfficient":
return _apply_if_node(lhs, lambda x: DivisionNode(x, self))
return _apply_if_node(lhs, lambda x: _divide_node(x, self))

def __le__(self, rhs: Any) -> "ExpressionNodeEfficient":
return _apply_if_node(
Expand Down Expand Up @@ -150,6 +151,84 @@ def _apply_if_node(
else:
return NotImplemented

def is_zero(node: ExpressionNodeEfficient) -> bool:
# Could we use expressions equal
# TODO: Change hard coded 1e-16 abs_tol
return isinstance(node, LiteralNode) and math.isclose(node.value, 0, abs_tol=1e-16)

def is_one(node: ExpressionNodeEfficient) -> bool:
# Could we use expressions equal
return isinstance(node, LiteralNode) and math.isclose(node.value, 1)

def is_minus_one(node: ExpressionNodeEfficient) -> bool:
# Could we use expressions equal
return isinstance(node, LiteralNode) and math.isclose(node.value, -1)

def _negate_node(node: ExpressionNodeEfficient):
if isinstance(node, LiteralNode):
return LiteralNode(-node.value)
else:
return NegationNode(node)

def _add_node(
lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient
) -> ExpressionNodeEfficient:
if is_zero(lhs):
return rhs
if is_zero(rhs):
return lhs
if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode):
return LiteralNode(lhs.value + rhs.value)
# TODO : Si noeuds gauche droite même clé de param -> 2 * param
else:
return AdditionNode(lhs, rhs)


def _substract_node(
lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient
) -> ExpressionNodeEfficient:
if is_zero(lhs):
return -rhs
if is_zero(rhs):
return lhs
if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode):
return LiteralNode(lhs.value - rhs.value)
# TODO : Si noeuds gauche droite même clé de param -> 0
else:
return SubstractionNode(lhs, rhs)


def _multiply_node(
lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient
) -> ExpressionNodeEfficient:
if is_one(lhs):
return rhs
if is_one(rhs):
return lhs
if is_minus_one(lhs):
return -rhs
if is_minus_one(rhs):
return -lhs
if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode):
return LiteralNode(lhs.value * rhs.value)
else:
return MultiplicationNode(lhs, rhs)


def _divide_node(
lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient
) -> ExpressionNodeEfficient:
if is_one(rhs):
return lhs
if is_minus_one(rhs):
return -lhs
if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode):
# This could raise division by 0 error
return LiteralNode(lhs.value / rhs.value)

else:
return DivisionNode(lhs, rhs)


@dataclass(frozen=True, eq=False)
class PortFieldNode(ExpressionNodeEfficient):
Expand Down
66 changes: 43 additions & 23 deletions src/andromede/expression/linear_expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
with only variables and literal coefficients.
"""
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

from andromede.expression.equality import expressions_equal
from andromede.expression.evaluate import ValueProvider, evaluate
Expand All @@ -35,8 +35,8 @@
EPS = 10 ** (-16)


def is_close_abs(value: float, other_value: float, eps: float) -> bool:
return abs(value - other_value) < eps
# def is_close_abs(value: float, other_value: float, eps: float) -> bool:
# return abs(value - other_value) < eps


def is_zero(value: ExpressionNodeEfficient) -> bool:
Expand Down Expand Up @@ -296,36 +296,37 @@ def __str__(self) -> str:
def __eq__(self, rhs: object) -> bool:
return (
isinstance(rhs, LinearExpressionEfficient)
and is_close_abs(self.constant, rhs.constant, EPS)
and expressions_equal(self.constant, rhs.constant)
and self.terms
== rhs.terms # /!\ There may be float equality comparison in the terms values
)

def __iadd__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient":
if not isinstance(rhs, LinearExpressionEfficient):
return NotImplemented
def __iadd__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient":
rhs = _wrap_in_linear_expr(rhs)
self.constant += rhs.constant
aggregated_terms = _merge_dicts(self.terms, rhs.terms, _add_terms, 0)
self.terms = aggregated_terms
self.remove_zeros_from_terms()
return self

def __add__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient":
def __add__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient":
result = LinearExpressionEfficient()
result += self
result += rhs
return result

def __isub__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient":
if not isinstance(rhs, LinearExpressionEfficient):
return NotImplemented
def __radd__(self, rhs: int) -> "LinearExpressionEfficient":
return self.__add__(rhs)

def __isub__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient":
rhs = _wrap_in_linear_expr(rhs)
self.constant -= rhs.constant
aggregated_terms = _merge_dicts(self.terms, rhs.terms, _substract_terms, 0)
self.terms = aggregated_terms
self.remove_zeros_from_terms()
return self

def __sub__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient":
def __sub__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient":
result = LinearExpressionEfficient()
result += self
result -= rhs
Expand All @@ -336,9 +337,8 @@ def __neg__(self) -> "LinearExpressionEfficient":
result -= self
return result

def __imul__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient":
if not isinstance(rhs, LinearExpressionEfficient):
return NotImplemented
def __imul__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient":
rhs = _wrap_in_linear_expr(rhs)

if self.terms and rhs.terms:
raise ValueError("Cannot multiply two non constant expression")
Expand All @@ -350,9 +350,9 @@ def __imul__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficie
# It is possible that both expr are constant
left_expr = rhs
const_expr = self
if expressions_equal(const_expr.constant, LiteralNode(0), EPS):
if is_zero(const_expr.constant):
return LinearExpressionEfficient()
elif expressions_equal(const_expr.constant, LiteralNode(1), EPS):
elif is_one(const_expr.constant):
_copy_expression(left_expr, self)
else:
left_expr.constant *= const_expr.constant
Expand All @@ -369,17 +369,19 @@ def __imul__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficie
_copy_expression(left_expr, self)
return self

def __mul__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient":
def __mul__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient":
result = LinearExpressionEfficient()
result += self
result *= rhs
return result

def __rmul__(self, rhs: int) -> "LinearExpressionEfficient":
return self.__mul__(rhs)

def __itruediv__(
self, rhs: "LinearExpressionEfficient"
self, rhs: Union["LinearExpressionEfficient", int, float]
) -> "LinearExpressionEfficient":
if not isinstance(rhs, LinearExpressionEfficient):
return NotImplemented
rhs = _wrap_in_linear_expr(rhs)

if rhs.terms:
raise ValueError("Cannot divide by a non constant expression")
Expand All @@ -403,18 +405,21 @@ def __itruediv__(
return self

def __truediv__(
self, rhs: "LinearExpressionEfficient"
self, rhs: Union["LinearExpressionEfficient", int, float]
) -> "LinearExpressionEfficient":
result = LinearExpressionEfficient()
result += self
result /= rhs

return result

def __rtruediv__(self, rhs: Union[int, float]) -> "LinearExpressionEfficient":
return self.__truediv__(rhs)

def remove_zeros_from_terms(self) -> None:
# TODO: Not optimized, checks could be done directly when doing operations on self.linear_term to avoid copies
for term_key, term in self.terms.copy().items():
if is_close_abs(term.coefficient, 0, EPS):
if is_zero(term.coefficient):
del self.terms[term_key]

def is_valid(self) -> bool:
Expand Down Expand Up @@ -446,6 +451,21 @@ def is_constant(self) -> bool:
# Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree...
return not self.terms

def _wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient:
if isinstance(obj, LinearExpressionEfficient):
return obj
elif isinstance(obj, float) or isinstance(obj, int):
return LinearExpressionEfficient([], LiteralNode(float(obj)))
raise TypeError(f"Unable to wrap {obj} into a linear expression")

def _apply_if_node(
obj: Any, func: Callable[[LinearExpressionEfficient], LinearExpressionEfficient]
) -> LinearExpressionEfficient:
if as_linear_expr := _wrap_in_linear_expr(obj):
return func(as_linear_expr)
else:
return NotImplemented


def _copy_expression(
src: LinearExpressionEfficient, dst: LinearExpressionEfficient
Expand Down
6 changes: 3 additions & 3 deletions src/andromede/expression/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ def addition(self, node: AdditionNode) -> T_op:
return left_value + right_value

def substraction(self, node: SubstractionNode) -> T_op:
left_value = visit(node.left, self)
right_value = visit(node.right, self)
return left_value - right_value
left_value = visit(node.left, self)
right_value = visit(node.right, self)
return left_value - right_value

def multiplication(self, node: MultiplicationNode) -> T_op:
left_value = visit(node.left, self)
Expand Down