Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update docs, name some magic values
  • Loading branch information
mateiz committed Apr 15, 2014
commit 154f45db4299aff8ae4fe085950d0911cfbb187d
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ import org.apache.spark.rdd.RDD
*/
@DeveloperApi
class PythonMLLibAPI extends Serializable {
private val DENSE_VECTOR_MAGIC = 1
private val SPARSE_VECTOR_MAGIC = 2
private val DENSE_MATRIX_MAGIC = 3

private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = {
val packetLength = bytes.length
if (packetLength < 5) {
Expand All @@ -44,7 +48,7 @@ class PythonMLLibAPI extends Serializable {
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
val magic = bb.get()
if (magic != 1) {
if (magic != DENSE_VECTOR_MAGIC) {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
val length = bb.getInt()
Expand Down Expand Up @@ -77,7 +81,7 @@ class PythonMLLibAPI extends Serializable {
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
val magic = bb.get()
if (magic != 2) {
if (magic != DENSE_MATRIX_MAGIC) {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
val rows = bb.getInt()
Expand Down
32 changes: 19 additions & 13 deletions python/pyspark/mllib/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,33 @@

# Dense double vector format:
#
# [8-byte 1] [8-byte length] [length*8 bytes of data]
# [1-byte 1] [4-byte length] [length*8 bytes of data]
#
# Sparse double vector format:
#
# [8-byte 2] [8-byte size] [8-byte entries] [entries*4 bytes of indices] [entries*8 bytes of values]
# [1-byte 2] [4-byte length] [4-byte nonzeros] [nonzeros*4 bytes of indices] [nonzeros*8 bytes of values]
#
# Double matrix format:
#
# [8-byte 3] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data]
# [1-byte 3] [4-byte rows] [4-byte cols] [rows*cols*8 bytes of data]
#
# This is all in machine-endian. That means that the Java interpreter and the
# Python interpreter must agree on what endian the machine is.

def _deserialize_byte_array(shape, ba, offset):
"""Wrapper around ndarray aliasing hack.
DENSE_VECTOR_MAGIC = 1
SPARSE_VECTOR_MAGIC = 2
DENSE_MATRIX_MAGIC = 3

def _deserialize_numpy_array(shape, ba, offset):
"""
Deserialize a numpy array of float64s from a given offset in
bytearray ba, assigning it the given shape.

>>> x = array([1.0, 2.0, 3.0, 4.0, 5.0])
>>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
>>> array_equal(x, _deserialize_numpy_array(x.shape, x.data, 0))
True
>>> x = array([1.0, 2.0, 3.0, 4.0]).reshape(2,2)
>>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
>>> array_equal(x, _deserialize_numpy_array(x.shape, x.data, 0))
True
"""
ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", order='C')
Expand All @@ -71,7 +77,7 @@ def _serialize_double_vector(v):
v = v.astype('float64')
length = v.shape[0]
ba = bytearray(5 + 8 * length)
ba[0] = 1
ba[0] = DENSE_VECTOR_MAGIC
length_bytes = ndarray(shape=[1], buffer=ba, offset=1, dtype="int32")
length_bytes[0] = length
arr_mid = ndarray(shape=[length], buffer=ba, offset=5, dtype="float64")
Expand All @@ -91,14 +97,14 @@ def _deserialize_double_vector(ba):
if len(ba) < 5:
raise TypeError("_deserialize_double_vector called on a %d-byte array, "
"which is too short" % len(ba))
if ba[0] != 1:
if ba[0] != DENSE_VECTOR_MAGIC:
raise TypeError("_deserialize_double_vector called on bytearray "
"with wrong magic")
length = ndarray(shape=[1], buffer=ba, offset=1, dtype="int32")[0]
if len(ba) != 8*length + 5:
raise TypeError("_deserialize_double_vector called on bytearray "
"with wrong length")
return _deserialize_byte_array([length], ba, 5)
return _deserialize_numpy_array([length], ba, 5)

def _serialize_double_matrix(m):
"""Serialize a double matrix into a mutually understood format."""
Expand All @@ -111,7 +117,7 @@ def _serialize_double_matrix(m):
rows = m.shape[0]
cols = m.shape[1]
ba = bytearray(9 + 8 * rows * cols)
ba[0] = 2
ba[0] = DENSE_MATRIX_MAGIC
lengths = ndarray(shape=[3], buffer=ba, offset=1, dtype="int32")
lengths[0] = rows
lengths[1] = cols
Expand All @@ -130,7 +136,7 @@ def _deserialize_double_matrix(ba):
if len(ba) < 9:
raise TypeError("_deserialize_double_matrix called on a %d-byte array, "
"which is too short" % len(ba))
if ba[0] != 2:
if ba[0] != DENSE_MATRIX_MAGIC:
raise TypeError("_deserialize_double_matrix called on bytearray "
"with wrong magic")
lengths = ndarray(shape=[2], buffer=ba, offset=1, dtype="int32")
Expand All @@ -139,7 +145,7 @@ def _deserialize_double_matrix(ba):
if (len(ba) != 8 * rows * cols + 9):
raise TypeError("_deserialize_double_matrix called on bytearray "
"with wrong length")
return _deserialize_byte_array([rows, cols], ba, 9)
return _deserialize_numpy_array([rows, cols], ba, 9)

def _linear_predictor_typecheck(x, coeffs):
"""Check that x is a one-dimensional vector of the right shape.
Expand Down