Skip to content

Commit f37d338

Browse files
lkdvosJutho
andauthored
[Fix] include alpha when determining the scalartype of a linear combination (#225)
* Include alpha for determining scalartype of linear combination * Bump v5.3.1 * some unification of scaltype determination * Some more consistency changes --------- Co-authored-by: Jutho <[email protected]>
1 parent cd34d31 commit f37d338

File tree

3 files changed

+29
-24
lines changed

3 files changed

+29
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorOperations"
22
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
33
authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"]
4-
version = "5.3.0"
4+
version = "5.3.1"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/indexnotation/instantiators.jl

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,15 @@ function instantiate_generaltensor(
122122
β = βsym
123123
end
124124
if alloc (NewTensor, TemporaryTensor)
125-
TC = gensym("T_" * string(dst))
125+
TCsym = gensym("T_" * string(dst))
126126
istemporary = Val(alloc === TemporaryTensor)
127-
if scaltype === nothing
128-
TCval = α === One() ? instantiate_scalartype(src) :
129-
instantiate_scalartype(Expr(:call, :*, α, src))
130-
else
131-
TCval = scaltype
132-
end
133-
push!(out.args, Expr(:(=), TC, TCval))
127+
TCval = @something(
128+
scaltype, instantiate_scalartype=== One() ? src : Expr(:call, :*, α, src))
129+
)
130+
push!(out.args, Expr(:(=), TCsym, TCval))
134131
push!(
135132
out.args,
136-
Expr(:(=), dst, :(tensoralloc_add($TC, $src, $p, $conj, $istemporary)))
133+
Expr(:(=), dst, :(tensoralloc_add($TCsym, $src, $p, $conj, $istemporary)))
137134
)
138135
end
139136

@@ -167,9 +164,9 @@ function instantiate_linearcombination(
167164
)
168165
out = Expr(:block)
169166
if alloc (NewTensor, TemporaryTensor)
170-
if scaltype === nothing
171-
scaltype = instantiate_scalartype(ex)
172-
end
167+
scaltype = @something(
168+
scaltype, instantiate_scalartype=== One() ? ex : Expr(:call, :*, α, ex))
169+
)
173170
push!(
174171
out.args,
175172
instantiate(dst, β, ex.args[2], α, leftind, rightind, alloc, scaltype)
@@ -275,18 +272,15 @@ function instantiate_contraction(
275272
end
276273
if alloc (NewTensor, TemporaryTensor)
277274
TCsym = gensym("T_" * string(dst))
278-
if scaltype === nothing
279-
Atype = instantiate_scalartype(A)
280-
Btype = instantiate_scalartype(B)
281-
TCval = Expr(:call, :promote_contract, Atype, Btype)
282-
if α !== One()
283-
TCval = Expr(
284-
:call, :(Base.promote_op), :*, instantiate_scalartype(α), TCval
285-
)
275+
TCval = @something(
276+
scaltype,
277+
begin
278+
TA = instantiate_scalartype(A)
279+
TB = instantiate_scalartype(B)
280+
TAB = :(promote_contract($TA, $TB))
281+
α === One() ? TAB : :(Base.promote_op(*, $(instantiate_scalartype(α)), $TAB))
286282
end
287-
else
288-
TCval = scaltype
289-
end
283+
)
290284
istemporary = Val(alloc === TemporaryTensor)
291285
initC = Expr(
292286
:block, Expr(:(=), TCsym, TCval),

test/tensor.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,4 +581,15 @@ end
581581
@test isblascontractable(pA, p)
582582
@test isblascontractable(conj(pA), p)
583583
end
584+
585+
@testset "Issue 220" begin
586+
A = rand(2, 2)
587+
B = rand(2, 2)
588+
C = rand(2, 2)
589+
D = rand(2, 2)
590+
c = 1im
591+
@tensor E[a; c] := c * (A[a b] * B[b c] + C[a b] * D[b c])
592+
@test scalartype(E) == ComplexF64
593+
@test E c * (A * B + C * D)
594+
end
584595
end

0 commit comments

Comments
 (0)