Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions src/mxarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,42 @@ function mxsparse(ty::Type{Float64}, m::Integer, n::Integer, nzmax::Integer)
MxArray(pm)
end

function mxsparse(ty::Type{ComplexF64}, m::Integer, n::Integer, nzmax::Integer)
pm = ccall(mx_create_sparse[], Ptr{Cvoid},
(mwSize, mwSize, mwSize, mxComplexity), m, n, nzmax, mxCOMPLEX)
MxArray(pm)
end

function mxsparse(ty::Type{Bool}, m::Integer, n::Integer, nzmax::Integer)
pm = ccall(mx_create_sparse_logical[], Ptr{Cvoid},
(mwSize, mwSize, mwSize), m, n, nzmax)
MxArray(pm)
end

function _copy_sparse_mat(a::SparseMatrixCSC{V,I}, ir_p::Ptr{mwIndex}, jc_p::Ptr{mwIndex}, pr_p::Ptr{Float64}, pi_p::Ptr{Float64}) where {V<:ComplexF64,I}
colptr::Vector{I} = a.colptr
rinds::Vector{I} = a.rowval
vr::Vector{Float64} = real(a.nzval)
vi::Vector{Float64} = imag(a.nzval)
n::Int = a.n
nnz::Int = length(vr)

# Note: ir and jc contain zero-based indices

ir = unsafe_wrap(Array, ir_p, (nnz,))
for i = 1:nnz
ir[i] = rinds[i] - 1
end

jc = unsafe_wrap(Array, jc_p, (n+1,))
for i = 1:n+1
jc[i] = colptr[i] - 1
end

copyto!(unsafe_wrap(Array, pr_p, (nnz,)), vr)
copyto!(unsafe_wrap(Array, pi_p, (nnz,)), vi)
end

function _copy_sparse_mat(a::SparseMatrixCSC{V,I}, ir_p::Ptr{mwIndex}, jc_p::Ptr{mwIndex}, pr_p::Ptr{V}) where {V,I}
colptr::Vector{I} = a.colptr
rinds::Vector{I} = a.rowval
Expand All @@ -315,19 +345,24 @@ function _copy_sparse_mat(a::SparseMatrixCSC{V,I}, ir_p::Ptr{mwIndex}, jc_p::Ptr
copyto!(unsafe_wrap(Array, pr_p, (nnz,)), v)
end

function mxarray(a::SparseMatrixCSC{V,I}) where {V<:Union{Float64,Bool},I}
function mxarray(a::SparseMatrixCSC{V,I}) where {V<:Union{Float64,ComplexF64,Bool},I}
m::Int = a.m
n::Int = a.n
nnz = length(a.nzval)
@assert nnz == a.colptr[n+1]-1

mx = mxsparse(V, m, n, nnz)

ir_p = ccall(mx_get_ir[], Ptr{mwIndex}, (Ptr{Cvoid},), mx)
jc_p = ccall(mx_get_jc[], Ptr{mwIndex}, (Ptr{Cvoid},), mx)
pr_p = ccall(mx_get_pr[], Ptr{V}, (Ptr{Cvoid},), mx)

_copy_sparse_mat(a, ir_p, jc_p, pr_p)
if V <: ComplexF64
pr_p = ccall(mx_get_pr[], Ptr{Float64}, (Ptr{Cvoid},), mx)
pi_p = ccall(mx_get_pi[], Ptr{Float64}, (Ptr{Cvoid},), mx)
_copy_sparse_mat(a, ir_p, jc_p, pr_p, pi_p)
else
pr_p = ccall(mx_get_pr[], Ptr{V}, (Ptr{Cvoid},), mx)
_copy_sparse_mat(a, ir_p, jc_p, pr_p)
end
return mx
end

Expand Down Expand Up @@ -537,7 +572,6 @@ function _jsparse(ty::Type{T}, mx::MxArray) where T<:MxRealNum
n = ncols(mx)
ir_ptr = ccall(mx_get_ir[], Ptr{mwIndex}, (Ptr{Cvoid},), mx)
jc_ptr = ccall(mx_get_jc[], Ptr{mwIndex}, (Ptr{Cvoid},), mx)
pr_ptr = ccall(mx_get_pr[], Ptr{T}, (Ptr{Cvoid},), mx)

jc_a::Vector{mwIndex} = unsafe_wrap(Array, jc_ptr, (n+1,))
nnz = jc_a[n+1]
Expand All @@ -555,8 +589,15 @@ function _jsparse(ty::Type{T}, mx::MxArray) where T<:MxRealNum
jc[i] = jc_x[i] + 1
end

pr_ptr = ccall(mx_get_pr[], Ptr{T}, (Ptr{Cvoid},), mx)
pr::Vector{T} = copy(unsafe_wrap(Array, pr_ptr, (nnz,)))
return SparseMatrixCSC(m, n, jc, ir, pr)
if is_complex(mx)
pi_ptr = ccall(mx_get_pi[], Ptr{T}, (Ptr{Cvoid},), mx)
pi::Vector{T} = copy(unsafe_wrap(Array, pi_ptr, (nnz,)))
return SparseMatrixCSC(m, n, jc, ir, pr + im.*pi)
else
return SparseMatrixCSC(m, n, jc, ir, pr)
end
end

function jsparse(mx::MxArray)
Expand Down
16 changes: 16 additions & 0 deletions test/mxarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,15 @@ a2 = jsparse(a_mx)
@test isequal(a2, a)
delete(a_mx)

a = sparse([1.0 1.0im])
a_mx = mxarray(a)
@test is_sparse(a_mx)
@test is_double(a_mx)
@test is_complex(a_mx)
@test nrows(a_mx) == 1
@test ncols(a_mx) == 2
delete(a_mx)

# strings

s = "MATLAB.jl"
Expand Down Expand Up @@ -345,6 +354,13 @@ delete(x)
@test isa(y, Array{Float64,3})
@test isequal(y, a)

a = sparse([1.0 2.0im; 0 -1.0im])
a_mx = mxarray(a)
a_jl = jvalue(a_mx)
delete(a_mx)
@test a == a_jl
@test isa(a_jl, SparseMatrixCSC{Complex{Float64}})

a = "MATLAB"
x = mxarray(a)
y = jvalue(x)
Expand Down