Skip to content

Commit fb00997

Browse files
author
quxiaofeng
committed
Removed the NE package from src; And run through the src
Signed-off-by: quxiaofeng <[email protected]>
1 parent 0c114e6 commit fb00997

File tree

7 files changed

+23
-16
lines changed

7 files changed

+23
-16
lines changed

src/ClassificationDPL.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ function ClassificationDPL(TtData::Matrix{Float64}, DictMat::Array{Any,1}, Encod
88
# Class-specific reconstruction error calculation
99
for i=1:ClassNum
1010
@inbounds reconstructedTtData::Matrix{Float64} = DictMat[i] * PredictCoef[(i-1)*DictSize+1:i*DictSize, :]
11-
subtract!(reconstructedTtData, TtData)
12-
@inbounds Error[i,:] = sumsq(reconstructedTtData, 1)
11+
reconstructedTtData -= TtData
12+
@inbounds Error[i,:] = sum(abs2(reconstructedTtData), 1)
1313
end
1414
Distance::Matrix{Float64}, PredictInd::Matrix{Int64} = findmin(Error, 1)
1515
PredictLabel = [ind2sub(size(Error), PredictInd[i])[1] for i = 1:size(PredictInd, 2)]

src/ProjectiveDictionaryPairLearning.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ using MAT # to load and save data
3333

3434
include("TrainDPL.jl")
3535
include("ClassificationDPL.jl")
36+
include("normcol_equal")
3637

37-
export dpldemo, TrainDPL, ClassificationDPL, updateA!, updateD!, updateP!, initialization, normcol_lessequal
38+
export dpldemo, TrainDPL, ClassificationDPL, updateA!, updateD!, updateP!, initialization, normcol_lessequal, normcol_equal
3839

3940
function dpldemo()
4041
# Load training and testing data
@@ -44,11 +45,11 @@ function dpldemo()
4445
TtData, TtLabel = data["TtData"], data["TtLabel"]
4546

4647
# Column normalization
47-
normalize!(TrData, 2, 1)
48-
normalize!(TtData, 2, 1)
48+
TrData = normcol_equal(TrData)
49+
TtData = normcol_equal(TtData)
4950

50-
TrLabel = int(TrLabel)
51-
TtLabel = int(TtLabel)
51+
TrLabel = round(Int64, TrLabel)
52+
TtLabel = round(Int64, TtLabel)
5253

5354
# Parameter setting
5455
DictSize = 30

src/initialization.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
include(joinpath(dirname(@__FILE__), "updateA!.jl"))
1+
include("updateA!.jl")
2+
include("normcol_equal.jl")
23

34
# return inv(τ*A*A' + λ*B*B' + γ*I)
45
function getinv::Float64, λ::Float64, γ::Float64, A::Matrix{Float64}, B::Matrix{Float64})
@@ -44,12 +45,12 @@ function initialization(Data::Matrix{Float64}, Label::Matrix{Int64}, DictSize::I
4445

4546
DEMO && srand(i)
4647
TempRand = randn(Dim, DictSize)
47-
normalize!(TempRand, 2, 1)
48+
TempRand = normcol_equal(TempRand)
4849
@inbounds DictMat[i] = TempRand
4950

5051
DEMO && srand(2i)
5152
TempRand = randn(Dim, DictSize)
52-
normalize!(TempRand, 2, 1)
53+
TempRand = normcol_equal(TempRand)
5354
@inbounds P[i] = TempRand'
5455

5556
@inbounds TempDataC = Data[:, find(Label .!=i)]

src/normcol_equal.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
function normcol_equal(matin)
2+
# solve the proximal problem
3+
# matout = argmin||matout-matin||_F^2, s.t. matout(:,i)=1
4+
matin ./ repmat(sqrt(sum(matin .^ 2, 1) + eps()), size(matin, 1), 1)
5+
end

src/normcol_lessequal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# solve the proximal problem
22
# matout = argmin||matout-matin||_F^2, s.t. matout(:,i)<=1
33
function normcol_lessequal(matin::Matrix{Float64})
4-
broadcast(/, matin, max(1.0, sqrt(sumsq(matin,1))))
4+
broadcast(/, matin, max(1.0, sqrt(sum(abs2(matin), 2))))
55
end

src/updateA!.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function updateA!(A::Array{Any,1}, D::Array{Any,1}, DataMat::Array{Any,1}, P::Ar
1010
tempDictDataCoef::Matrix{Float64} = TempDict' * TempData
1111
@inbounds C::Matrix{Float64} = P[i] * TempData
1212
diagadd!(tempDictCoef, τ)
13-
fma!(tempDictDataCoef, C, τ)
13+
tempDictDataCoef += C .* τ
1414
@inbounds A[i] = tempDictCoef \ tempDictDataCoef
1515
end
16-
end
16+
end

src/updateD!.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,18 @@ function updateD!(D::Array{Any,1}, A::Array{Any,1}, DataMat::Array{Any,1})
2424
while ERROR > 1e-8 && Iter < 100
2525

2626
tempMat::Matrix{Float64} = TempData*TempCoef'
27-
fma!(tempMat, TempS-TempT, ρ) # tempMat <- tempMat + ρ(TempS - TempT)
27+
tempMat += (TempS-TempT) .* ρ # tempMat <- tempMat + ρ(TempS - TempT)
2828
tempMatCoef::Matrix{Float64} = TempCoef*TempCoef'
2929
diagadd!(tempMatCoef, ρ)
3030
TempD::Matrix{Float64} = tempMat/tempMatCoef
3131

3232
TempS = normcol_lessequal(TempD+TempT)
3333
add_sub!(TempT, TempD, TempS) # TemP <- TemP + (TempD-TempS)
3434
ρ *= rate_ρ
35-
ERROR = meansq(preD-TempD)
35+
ERROR = mean(abs2(preD-TempD))
3636
preD = TempD
3737
Iter += 1
3838
end
3939
@inbounds D[i] = preD
4040
end
41-
end
41+
end

0 commit comments

Comments
 (0)