diff --git a/lib/bindings/python/rust/llm/block_manager.rs b/lib/bindings/python/rust/llm/block_manager.rs index e877ec021c..4e266ab191 100644 --- a/lib/bindings/python/rust/llm/block_manager.rs +++ b/lib/bindings/python/rust/llm/block_manager.rs @@ -14,18 +14,18 @@ // limitations under the License. #![cfg(feature = "block-manager")] -// Silence warnings about deprecated features (like pyo3::IntoPy::into_py) -#![allow(deprecated)] use super::*; use pyo3::PyResult; -use tokio; mod block; mod block_list; +mod dlpack; +mod layer; /// Add bingings from this crate to the provided module pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -34,9 +34,6 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> { #[pyclass] pub struct BlockManager { - // TODO: Can this be implicitly created and referenced? - tokio_runtime: tokio::runtime::Runtime, - // Block manager inner: Arc, // TODO: Metadata should be stored in the block manager? dtype: dynamo_llm::common::dtype::DType, @@ -62,7 +59,7 @@ impl BlockManager { dynamo_llm::block_manager::KvManagerRuntimeConfig::builder() .worker_id(worker_id) .build() - .unwrap(), + .map_err(to_pyerr)?, ); let mut model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder() .num_layers(num_layer) @@ -93,14 +90,17 @@ impl BlockManager { }; } model_config = model_config.dtype(dtype_.clone()); - config = config.model(model_config.build().unwrap()); + config = config.model(model_config.build().map_err(to_pyerr)?); if let Some(host_num_blocks) = host_num_blocks { config = config.host_layout( dynamo_llm::block_manager::KvManagerLayoutConfig::builder() .num_blocks(host_num_blocks) - .allocator(dynamo_llm::block_manager::storage::PinnedAllocator::new().unwrap()) + .allocator( + dynamo_llm::block_manager::storage::PinnedAllocator::new() + .map_err(to_pyerr)?, + ) .build() - .unwrap(), + .map_err(to_pyerr)?, ); } if let Some(device_num_blocks) = device_num_blocks { @@ -109,23 +109,22 @@ impl BlockManager { .num_blocks(device_num_blocks) .allocator( dynamo_llm::block_manager::storage::DeviceAllocator::new(device_id) - .unwrap(), + .map_err(to_pyerr)?, ) .build() - .unwrap(), + .map_err(to_pyerr)?, ); } - let config = config.build().unwrap(); - let tokio_runtime = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); - let block_manager = tokio_runtime.block_on(async { - dynamo_llm::block_manager::ReferenceBlockManager::new(config).unwrap() - }); + let config = config.build().map_err(to_pyerr)?; + let tokio_runtime = pyo3_async_runtimes::tokio::get_runtime(); Ok(BlockManager { - tokio_runtime: tokio_runtime, - inner: Arc::from(block_manager), + inner: Arc::from( + tokio_runtime + .block_on(async { + dynamo_llm::block_manager::ReferenceBlockManager::new(config) + }) + .map_err(to_pyerr)?, + ), dtype: dtype_, device_id: device_id, }) @@ -135,9 +134,11 @@ impl BlockManager { let blocks = self .inner .host() - .unwrap() + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Host allocator not available") + })? .allocate_blocks_blocking(count) - .unwrap(); + .map_err(to_pyerr)?; // Wrap each block in an enum accounting for Pinned & Device block let blocks = blocks .into_iter() @@ -150,13 +151,42 @@ impl BlockManager { )) } + #[pyo3(signature = (count))] + fn allocate_host_blocks<'py>( + &self, + py: Python<'py>, + count: usize, + ) -> PyResult> { + let inner = self.inner.clone(); + let dtype = self.dtype.clone(); + let device_id = self.device_id; + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let blocks = inner + .host() + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Host allocator not available") + })? + .allocate_blocks(count) + .await + .map_err(to_pyerr)?; + // Wrap each block in an enum accounting for Pinned & Device block + let blocks = blocks + .into_iter() + .map(|b| block::BlockType::Pinned(b)) + .collect(); + Ok(block_list::BlockList::from_rust(blocks, dtype, device_id)) + }) + } + fn allocate_device_blocks_blocking(&self, count: usize) -> PyResult { let blocks = self .inner .device() - .unwrap() + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Device allocator not available") + })? .allocate_blocks_blocking(count) - .unwrap(); + .map_err(to_pyerr)?; // Wrap each block in an enum accounting for Pinned & Device block let blocks = blocks .into_iter() @@ -168,4 +198,31 @@ impl BlockManager { self.device_id, )) } + + #[pyo3(signature = (count))] + fn allocate_device_blocks<'py>( + &self, + py: Python<'py>, + count: usize, + ) -> PyResult> { + let inner = self.inner.clone(); + let dtype = self.dtype.clone(); + let device_id = self.device_id; + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let blocks = inner + .device() + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Device allocator not available") + })? + .allocate_blocks(count) + .await + .map_err(to_pyerr)?; + // Wrap each block in an enum accounting for Pinned & Device block + let blocks = blocks + .into_iter() + .map(|b| block::BlockType::Device(b)) + .collect(); + Ok(block_list::BlockList::from_rust(blocks, dtype, device_id)) + }) + } } diff --git a/lib/bindings/python/rust/llm/block_manager/block.rs b/lib/bindings/python/rust/llm/block_manager/block.rs index f89e26159b..25e8874bf6 100644 --- a/lib/bindings/python/rust/llm/block_manager/block.rs +++ b/lib/bindings/python/rust/llm/block_manager/block.rs @@ -14,16 +14,14 @@ // limitations under the License. #![cfg(feature = "block-manager")] -// Silence warnings about deprecated features (like pyo3::IntoPy::into_py) -#![allow(deprecated)] use super::*; - -use dlpark::prelude::{DataType, Device, ManagerCtx, ShapeAndStrides, ToTensor}; -use pyo3::{ffi::c_str, prelude::IntoPy, types::PyTuple, PyObject, PyResult, Python}; -use std::sync::{Arc, Mutex}; - use dynamo_llm::block_manager::block::BlockDataExt; +use pyo3::{ + types::{PyList, PyTuple}, + PyObject, PyResult, Python, +}; +use std::sync::{Arc, Mutex}; pub enum BlockType { Pinned( @@ -40,111 +38,14 @@ pub enum BlockType { ), } -struct DlPackTensor { - block: Arc>, - // TODO: Metadata should be stored in the block manager? - dtype: dynamo_llm::common::dtype::DType, - device_id: usize, -} - -impl ToTensor for DlPackTensor { - fn data_ptr(&self) -> *mut std::ffi::c_void { - let mut mutable_block = self.block.lock().unwrap(); - let ptr = match &mut *mutable_block { - BlockType::Pinned(block) => { - let mut block_view_mut = block - .block_view_mut() - .expect("Failed to get mutable Pinned block view"); - unsafe { block_view_mut.as_mut_ptr() } - } - BlockType::Device(block) => { - let mut block_view_mut = block - .block_view_mut() - .expect("Failed to get mutable Device block view"); - unsafe { block_view_mut.as_mut_ptr() } - } - }; - ptr as *mut std::ffi::c_void - } - - fn byte_offset(&self) -> u64 { - 0 - } - - fn device(&self) -> Device { - let mutable_block = self.block.lock().unwrap(); - match &*mutable_block { - BlockType::Pinned(_) => { - // TODO: Why torch does not support CPU_PINNED here? - /*Device { - device_type: DeviceType::CudaHost, - device_id: 0, - }*/ - Device::CPU - } - BlockType::Device(_) => Device::cuda(self.device_id), - } - } - - fn dtype(&self) -> DataType { - // Map from dynamo_llm::common::dtype::DType to dlpark::prelude::DataType - match self.dtype { - dynamo_llm::common::dtype::DType::FP8 => { - // No direct FP8 equivalent, use U8 as closest alternative - DataType::U8 - } - dynamo_llm::common::dtype::DType::FP16 => DataType::F16, - dynamo_llm::common::dtype::DType::BF16 => DataType::BF16, - dynamo_llm::common::dtype::DType::FP32 => DataType::F32, - dynamo_llm::common::dtype::DType::U8 => DataType::U8, - dynamo_llm::common::dtype::DType::U16 => DataType::U16, - dynamo_llm::common::dtype::DType::U32 => DataType::U32, - dynamo_llm::common::dtype::DType::U64 => DataType::U64, - dynamo_llm::common::dtype::DType::I8 => DataType::I8, - dynamo_llm::common::dtype::DType::I16 => DataType::I16, - dynamo_llm::common::dtype::DType::I32 => DataType::I32, - dynamo_llm::common::dtype::DType::I64 => DataType::I64, - } - } - - fn shape_and_strides(&self) -> ShapeAndStrides { - let mutable_block = self.block.lock().unwrap(); - let (num_blocks, num_layers, page_size, inner_dim) = match &*mutable_block { - BlockType::Pinned(block) => ( - block.num_blocks(), - block.num_layers(), - block.page_size(), - block.inner_dim(), - ), - BlockType::Device(block) => ( - block.num_blocks(), - block.num_layers(), - block.page_size(), - block.inner_dim(), - ), - }; - let shape_i64: Vec = vec![ - num_blocks as i64, - num_layers as i64, - page_size as i64, - inner_dim as i64, - ]; - ShapeAndStrides::new_contiguous(&shape_i64) - } -} - -/*impl Drop for DlPackTensor { - fn drop(&mut self) { - println!("Dropping DlPackTensor"); - } -}*/ - #[pyclass] pub struct Block { inner: Arc>, // TODO: Metadata should be stored in the block manager? dtype: dynamo_llm::common::dtype::DType, device_id: usize, + // Python iterator state + py_itr_idx: usize, } impl Block { @@ -157,69 +58,161 @@ impl Block { inner: block, dtype: dtype, device_id: device_id, + py_itr_idx: 0, + } + } + + fn num_layers(&self) -> usize { + let mutable_block = self.inner.lock().unwrap(); + match &*mutable_block { + BlockType::Pinned(block) => block.num_layers(), + BlockType::Device(block) => block.num_layers(), } } } #[pymethods] impl Block { + #[pyo3(signature = ())] + fn to_list<'py>(&self, py: Python<'py>) -> PyResult> { + let layers: Vec = (0..self.num_layers()) + .map(|layer_idx| { + layer::Layer::from_rust( + self.inner.clone(), + layer_idx, + self.dtype.clone(), + self.device_id, + ) + }) + .collect(); + PyList::new(py, layers) + } + + fn __len__(&self) -> PyResult { + Ok(self.num_layers()) + } + + fn __getitem__(&self, index: usize) -> PyResult { + let num_layers = self.num_layers(); + if index >= num_layers { + return Err(pyo3::exceptions::PyIndexError::new_err(format!( + "Index {} out of range for Block with {} layers", + index, num_layers + ))); + } + let layer = layer::Layer::from_rust( + self.inner.clone(), + index, + self.dtype.clone(), + self.device_id, + ); + Ok(layer) + } + + fn __iter__(mut slf: PyRefMut<'_, Self>) -> PyResult> { + // Reset iterator index at the beginning of each iteration + // Use to_list() for iterating concurrently + slf.py_itr_idx = 0; + Ok(slf) + } + + fn __next__(&mut self) -> PyResult { + if self.py_itr_idx >= self.num_layers() { + return Err(pyo3::exceptions::PyStopIteration::new_err( + "No more items in Block", + )); + } + let layer = layer::Layer::from_rust( + self.inner.clone(), + self.py_itr_idx, + self.dtype.clone(), + self.device_id, + ); + self.py_itr_idx += 1; + Ok(layer) + } + #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))] - fn __dlpack__( + fn __dlpack__<'py>( &self, + py: Python<'py>, stream: Option, max_version: Option, dl_device: Option, copy: Option, ) -> PyResult { - // Panic if any arguments are provided + // Return error if any arguments are provided if stream.is_some() { - panic!("stream argument is not supported"); + return Err(pyo3::exceptions::PyNotImplementedError::new_err( + "stream argument is not supported", + )); } if max_version.is_some() { - panic!("max_version argument is not supported"); + return Err(pyo3::exceptions::PyNotImplementedError::new_err( + "max_version argument is not supported", + )); } if dl_device.is_some() { - panic!("dl_device argument is not supported"); + return Err(pyo3::exceptions::PyNotImplementedError::new_err( + "dl_device argument is not supported", + )); } if copy.is_some() { - panic!("copy argument is not supported"); + return Err(pyo3::exceptions::PyNotImplementedError::new_err( + "copy argument is not supported", + )); } - // Create DLPack PyCapsule - let manager_ctx = ManagerCtx::new(DlPackTensor { - block: self.inner.clone(), - dtype: self.dtype.clone(), - device_id: self.device_id, - }); - let py_capsule = Python::with_gil(|py| manager_ctx.into_py(py)); - Ok(py_capsule) - } - - fn __dlpack_device__(&self) -> PyResult> { - let dlpack_device = Python::with_gil(|py| { - let device_type_list = py.eval(c_str!("[('CPU', 1), ('CUDA', 2), ('CPU_PINNED', 3), ('OPENCL', 4), ('VULKAN', 7), ('METAL', 8), ('VPI', 9), ('ROCM', 10)]"), None, None).unwrap(); - let device_type_enum = py - .import("enum") - .unwrap() - .getattr("Enum") - .unwrap() - .call1(("DLDeviceType", device_type_list)) - .unwrap(); - let block = self.inner.lock().unwrap(); - let device_type = match &*block { - BlockType::Pinned(_) => device_type_enum.getattr("CPU_PINNED").unwrap(), - BlockType::Device(_) => device_type_enum.getattr("CUDA").unwrap(), + // Extract all necessary data for dlpack + let ptr: *mut std::ffi::c_void; + let num_blocks: i64; + let num_layers: i64; + let num_outer_dims: i64; + let page_size: i64; + let inner_dim: i64; + { + let mut mutable_block = self.inner.lock().unwrap(); + ptr = match &mut *mutable_block { + BlockType::Pinned(block) => { + let mut block_view_mut = block.block_view_mut().map_err(to_pyerr)?; + (unsafe { block_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void + } + BlockType::Device(block) => { + let mut block_view_mut = block.block_view_mut().map_err(to_pyerr)?; + (unsafe { block_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void + } + }; + (num_blocks, num_layers, num_outer_dims, page_size, inner_dim) = match &*mutable_block { + BlockType::Pinned(block) => ( + block.num_blocks() as i64, + block.num_layers() as i64, + block.num_outer_dims() as i64, + block.page_size() as i64, + block.inner_dim() as i64, + ), + BlockType::Device(block) => ( + block.num_blocks() as i64, + block.num_layers() as i64, + block.num_outer_dims() as i64, + block.page_size() as i64, + block.inner_dim() as i64, + ), }; - let device_id = self.device_id.into_py(py).into_bound(py); - let device = vec![device_type, device_id]; - PyTuple::new(py, device).unwrap().unbind() - }); - Ok(dlpack_device) + } + + // Create the DLPack tensor + dlpack::dlpack( + py, + self.inner.clone(), + ptr, + vec![num_blocks, num_layers, num_outer_dims, page_size, inner_dim], + self.dtype.clone(), + self.device_id, + ) } -} -/*impl Drop for Block { - fn drop(&mut self) { - println!("Dropping Block"); + #[pyo3(signature = ())] + fn __dlpack_device__<'py>(&self, py: Python<'py>) -> PyResult> { + dlpack::dlpack_device(py, self.inner.clone(), self.device_id) } -}*/ +} diff --git a/lib/bindings/python/rust/llm/block_manager/block_list.rs b/lib/bindings/python/rust/llm/block_manager/block_list.rs index 06378f1524..d0a5a2d848 100644 --- a/lib/bindings/python/rust/llm/block_manager/block_list.rs +++ b/lib/bindings/python/rust/llm/block_manager/block_list.rs @@ -14,11 +14,8 @@ // limitations under the License. #![cfg(feature = "block-manager")] -// Silence warnings about deprecated features (like pyo3::IntoPy::into_py) -#![allow(deprecated)] use super::*; - use pyo3::{types::PyList, PyResult, Python}; use std::sync::{Arc, Mutex}; @@ -52,16 +49,14 @@ impl BlockList { #[pymethods] impl BlockList { - fn to_list(&self) -> PyResult> { - let py_list = Python::with_gil(|py| { - let blocks: Vec = self - .inner - .iter() - .map(|b| block::Block::from_rust(b.clone(), self.dtype.clone(), self.device_id)) - .collect(); - PyList::new(py, blocks).unwrap().unbind() - }); - Ok(py_list) + #[pyo3(signature = ())] + fn to_list<'py>(&self, py: Python<'py>) -> PyResult> { + let blocks: Vec = self + .inner + .iter() + .map(|b| block::Block::from_rust(b.clone(), self.dtype.clone(), self.device_id)) + .collect(); + PyList::new(py, blocks) } fn __len__(&self) -> PyResult { @@ -84,13 +79,10 @@ impl BlockList { Ok(block) } - fn __iter__(slf: Py) -> PyResult> { - Python::with_gil(|py| { - let mut slf = slf.borrow_mut(py); - // Reset iterator index at the beginning of each iteration - // Use to_list() for iterating concurrently - slf.py_itr_idx = 0; - }); + fn __iter__(mut slf: PyRefMut<'_, Self>) -> PyResult> { + // Reset iterator index at the beginning of each iteration + // Use to_list() for iterating concurrently + slf.py_itr_idx = 0; Ok(slf) } diff --git a/lib/bindings/python/rust/llm/block_manager/dlpack.rs b/lib/bindings/python/rust/llm/block_manager/dlpack.rs new file mode 100644 index 0000000000..41c7b23fb6 --- /dev/null +++ b/lib/bindings/python/rust/llm/block_manager/dlpack.rs @@ -0,0 +1,129 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![cfg(feature = "block-manager")] +// Silence warnings about deprecated features (like pyo3::IntoPy::into_py) +#![allow(deprecated)] + +use super::*; +use dlpark::prelude::{DataType, Device, ManagerCtx, ShapeAndStrides, ToTensor}; +use pyo3::{ffi::c_str, prelude::IntoPy, types::PyTuple, PyObject, PyResult, Python}; +use std::sync::{Arc, Mutex}; + +struct DlPackTensor { + block: Arc>, + ptr: *mut std::ffi::c_void, + shape: Vec, + // TODO: Metadata should be stored in the block? + dtype: dynamo_llm::common::dtype::DType, + device_id: usize, +} + +impl ToTensor for DlPackTensor { + fn data_ptr(&self) -> *mut std::ffi::c_void { + self.ptr + } + + fn byte_offset(&self) -> u64 { + 0 + } + + fn device(&self) -> Device { + let mutable_block = self.block.lock().unwrap(); + match &*mutable_block { + block::BlockType::Pinned(_) => { + // TODO: Why torch does not support CPU_PINNED here? + /*Device { + device_type: DeviceType::CudaHost, + device_id: 0, + }*/ + Device::CPU + } + block::BlockType::Device(_) => Device::cuda(self.device_id), + } + } + + fn dtype(&self) -> DataType { + // Map from dynamo_llm::common::dtype::DType to dlpark::prelude::DataType + match self.dtype { + dynamo_llm::common::dtype::DType::FP8 => { + // No direct FP8 equivalent, use U8 as closest alternative + DataType::U8 + } + dynamo_llm::common::dtype::DType::FP16 => DataType::F16, + dynamo_llm::common::dtype::DType::BF16 => DataType::BF16, + dynamo_llm::common::dtype::DType::FP32 => DataType::F32, + dynamo_llm::common::dtype::DType::U8 => DataType::U8, + dynamo_llm::common::dtype::DType::U16 => DataType::U16, + dynamo_llm::common::dtype::DType::U32 => DataType::U32, + dynamo_llm::common::dtype::DType::U64 => DataType::U64, + dynamo_llm::common::dtype::DType::I8 => DataType::I8, + dynamo_llm::common::dtype::DType::I16 => DataType::I16, + dynamo_llm::common::dtype::DType::I32 => DataType::I32, + dynamo_llm::common::dtype::DType::I64 => DataType::I64, + } + } + + fn shape_and_strides(&self) -> ShapeAndStrides { + ShapeAndStrides::new_contiguous(&self.shape) + } +} + +/*impl Drop for DlPackTensor { + fn drop(&mut self) { + println!("Dropping DlPackTensor"); + } +}*/ + +pub fn dlpack<'py>( + py: Python<'py>, + block: Arc>, + ptr: *mut std::ffi::c_void, + shape: Vec, + dtype: dynamo_llm::common::dtype::DType, + device_id: usize, +) -> PyResult { + let manager_ctx = ManagerCtx::new(DlPackTensor { + block: block, + ptr: ptr, + shape: shape, + dtype: dtype, + device_id: device_id, + }); + let py_capsule = manager_ctx.into_py(py); + Ok(py_capsule) +} + +pub fn dlpack_device<'py>( + py: Python<'py>, + block: Arc>, + device_id: usize, +) -> PyResult> { + let dev_type_list = py.eval(c_str!("[('CPU', 1), ('CUDA', 2), ('CPU_PINNED', 3), ('OPENCL', 4), ('VULKAN', 7), ('METAL', 8), ('VPI', 9), ('ROCM', 10)]"), None, None).unwrap(); + let dev_type_enum = py + .import("enum") + .unwrap() + .getattr("Enum") + .unwrap() + .call1(("DLDeviceType", dev_type_list)) + .unwrap(); + let dev_type = match &*block.lock().unwrap() { + block::BlockType::Pinned(_) => dev_type_enum.getattr("CPU_PINNED").unwrap(), + block::BlockType::Device(_) => dev_type_enum.getattr("CUDA").unwrap(), + }; + let dev_id = device_id.into_py(py).into_bound(py); + let dev = vec![dev_type, dev_id]; + PyTuple::new(py, dev) +} diff --git a/lib/bindings/python/rust/llm/block_manager/layer.rs b/lib/bindings/python/rust/llm/block_manager/layer.rs new file mode 100644 index 0000000000..8a1475900d --- /dev/null +++ b/lib/bindings/python/rust/llm/block_manager/layer.rs @@ -0,0 +1,129 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![cfg(feature = "block-manager")] + +use super::*; +use dynamo_llm::block_manager::block::BlockDataExt; +use pyo3::{types::PyTuple, PyObject, PyResult, Python}; +use std::sync::{Arc, Mutex}; + +// Layer struct that represents a layer within a block +#[pyclass] +pub struct Layer { + inner: Arc>, + layer_idx: usize, + dtype: dynamo_llm::common::dtype::DType, + device_id: usize, +} + +impl Layer { + pub fn from_rust( + block: Arc>, + layer_idx: usize, + dtype: dynamo_llm::common::dtype::DType, + device_id: usize, + ) -> Self { + Self { + inner: block, + layer_idx, + dtype, + device_id, + } + } +} + +#[pymethods] +impl Layer { + #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))] + fn __dlpack__<'py>( + &self, + py: Python<'py>, + stream: Option, + max_version: Option, + dl_device: Option, + copy: Option, + ) -> PyResult { + // Return error if any arguments are provided + if stream.is_some() { + return Err(pyo3::exceptions::PyNotImplementedError::new_err( + "stream argument is not supported", + )); + } + if max_version.is_some() { + return Err(pyo3::exceptions::PyNotImplementedError::new_err( + "max_version argument is not supported", + )); + } + if dl_device.is_some() { + return Err(pyo3::exceptions::PyNotImplementedError::new_err( + "dl_device argument is not supported", + )); + } + if copy.is_some() { + return Err(pyo3::exceptions::PyNotImplementedError::new_err( + "copy argument is not supported", + )); + } + + // Extract all necessary data for dlpack + let ptr: *mut std::ffi::c_void; + let num_outer_dims: i64; + let page_size: i64; + let inner_dim: i64; + { + let mut mutable_block = self.inner.lock().unwrap(); + ptr = match &mut *mutable_block { + block::BlockType::Pinned(block) => { + let mut layer_view_mut = + block.layer_view_mut(self.layer_idx, 0).map_err(to_pyerr)?; + (unsafe { layer_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void + } + block::BlockType::Device(block) => { + let mut layer_view_mut = + block.layer_view_mut(self.layer_idx, 0).map_err(to_pyerr)?; + (unsafe { layer_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void + } + }; + (num_outer_dims, page_size, inner_dim) = match &*mutable_block { + block::BlockType::Pinned(block) => ( + block.num_outer_dims() as i64, + block.page_size() as i64, + block.inner_dim() as i64, + ), + block::BlockType::Device(block) => ( + block.num_outer_dims() as i64, + block.page_size() as i64, + block.inner_dim() as i64, + ), + }; + } + + // Create the DLPack tensor + dlpack::dlpack( + py, + self.inner.clone(), + ptr, + vec![1, 1, num_outer_dims, page_size, inner_dim], + self.dtype.clone(), + self.device_id, + ) + } + + #[pyo3(signature = ())] + fn __dlpack_device__<'py>(&self, py: Python<'py>) -> PyResult> { + dlpack::dlpack_device(py, self.inner.clone(), self.device_id) + } +} diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index 1d7f52799b..6bac5c97be 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -672,6 +672,25 @@ class NatsQueue: """ ... +class Layer: + """ + A KV cache block layer + """ + + ... + + def __dlpack__(self, stream: Optional[Any] = None, max_version: Optional[Any] = None, dl_device: Optional[Any] = None, copy: Optional[bool] = None) -> Any: + """ + Get a dlpack capsule of the layer + """ + ... + + def __dlpack_device__(self) -> Any: + """ + Get the dlpack device of the layer + """ + ... + class Block: """ A KV cache block @@ -679,9 +698,40 @@ class Block: ... + def __len__(self) -> int: + """ + Get the number of layers in the list + """ + ... + + def __getitem__(self, index: int) -> Layer: + """ + Get a layer by index + """ + ... + + def __iter__(self) -> 'Block': + """ + Get an iterator over the layers + """ + ... + + def __next__(self) -> Block: + """ + Get the next layer in the iterator + """ + ... + + def to_list(self) -> List[Layer]: + """ + Get a list of layers + """ + ... + def __dlpack__(self, stream: Optional[Any] = None, max_version: Optional[Any] = None, dl_device: Optional[Any] = None, copy: Optional[bool] = None) -> Any: """ - Get a dlpack capsule from the block + Get a dlpack capsule of the block + Exception raised if the block is not contiguous """ ... @@ -784,6 +834,22 @@ class BlockManager: """ ... + async def allocate_host_blocks(self, count: int) -> BlockList: + """ + Allocate a list of host blocks + + Parameters: + ----------- + count: int + Number of blocks to allocate + + Returns: + -------- + BlockList + List of allocated blocks + """ + ... + def allocate_device_blocks_blocking(self, count: int) -> BlockList: """ Allocate a list of device blocks (blocking call) @@ -799,3 +865,19 @@ class BlockManager: List of allocated blocks """ ... + + async def allocate_device_blocks(self, count: int) -> BlockList: + """ + Allocate a list of device blocks + + Parameters: + ----------- + count: int + Number of blocks to allocate + + Returns: + -------- + BlockList + List of allocated blocks + """ + ... diff --git a/lib/bindings/python/tests/test_block_manager.py b/lib/bindings/python/tests/test_block_manager.py index dc7c803f04..94c7b455db 100644 --- a/lib/bindings/python/tests/test_block_manager.py +++ b/lib/bindings/python/tests/test_block_manager.py @@ -35,9 +35,7 @@ DEVICE_ID = 0 -@pytest.fixture -def block_manager(): - """Pytest fixture for creating a BlockManager instance.""" +def new_block_manager(): return BlockManager( WORKER_ID, NUM_LAYER, @@ -51,6 +49,11 @@ def block_manager(): ) +@pytest.fixture +def block_manager(): + return new_block_manager() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") async def test_block_manager_initialization(): # Python should drop the BlockManager instance as soon as it goes out of scope, but @@ -106,22 +109,22 @@ async def test_block_manager_initialization(): async def test_cpu_block_access(block_manager: BlockManager): block_count = 2 block_list = block_manager.allocate_host_blocks_blocking(block_count) - py_blocks = block_list.to_list() - assert len(py_blocks) == block_count - tensors = [torch.from_dlpack(b) for b in py_blocks] + blocks = block_list.to_list() + assert len(blocks) == block_count + tensors = [torch.from_dlpack(b) for b in blocks] for tensor in tensors: assert tensor.get_device() == -1 # CPU - assert tensor.shape == (1, NUM_LAYER, PAGE_SIZE, INNER_DIM) + assert tensor.shape == (1, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM) assert tensor.dtype == TORCH_DTYPE # print(tensors) for tensor in tensors: - tensor[0][0][0][0] = 1.0 - tensor[0][NUM_LAYER - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 + tensor[0][0][0][0][0] = 1.0 + tensor[0][NUM_LAYER - 1][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 # print(tensors) - py_blocks_ = block_list.to_list() - assert py_blocks is not py_blocks_ - assert len(py_blocks) == len(py_blocks_) - tensors_ = [torch.from_dlpack(b) for b in py_blocks_] + blocks_ = block_list.to_list() + assert blocks is not blocks_ + assert len(blocks) == len(blocks_) + tensors_ = [torch.from_dlpack(b) for b in blocks_] for tensor, tensor_ in zip(tensors, tensors_): assert tensor is not tensor_ assert tensor.shape == tensor_.shape @@ -133,22 +136,22 @@ async def test_cpu_block_access(block_manager: BlockManager): async def test_gpu_block_access(block_manager: BlockManager): block_count = 6 block_list = block_manager.allocate_device_blocks_blocking(block_count) - py_blocks = block_list.to_list() - assert len(py_blocks) == block_count - tensors = [torch.from_dlpack(b) for b in py_blocks] + blocks = block_list.to_list() + assert len(blocks) == block_count + tensors = [torch.from_dlpack(b) for b in blocks] for tensor in tensors: assert tensor.get_device() == DEVICE_ID # GPU - assert tensor.shape == (1, NUM_LAYER, PAGE_SIZE, INNER_DIM) + assert tensor.shape == (1, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM) assert tensor.dtype == TORCH_DTYPE # print(tensors) for tensor in tensors: - tensor[0][0][0][0] = 1.0 - tensor[0][NUM_LAYER - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 + tensor[0][0][0][0][0] = 1.0 + tensor[0][NUM_LAYER - 1][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 # print(tensors) - py_blocks_ = block_list.to_list() - assert py_blocks is not py_blocks_ - assert len(py_blocks) == len(py_blocks_) - tensors_ = [torch.from_dlpack(b) for b in py_blocks_] + blocks_ = block_list.to_list() + assert blocks is not blocks_ + assert len(blocks) == len(blocks_) + tensors_ = [torch.from_dlpack(b) for b in blocks_] for tensor, tensor_ in zip(tensors, tensors_): assert tensor is not tensor_ assert tensor.shape == tensor_.shape @@ -159,27 +162,27 @@ async def test_gpu_block_access(block_manager: BlockManager): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") async def test_block_list_iteration(block_manager: BlockManager): block_count = 4 - block_list = block_manager.allocate_host_blocks_blocking(block_count) + block_list = await block_manager.allocate_host_blocks(block_count) # Test __len__() assert len(block_list) == block_count # Test __getitem__() for i in range(block_count): block = block_list[i] tensor = torch.from_dlpack(block) - tensor[0][0][0][0] = 1.0 + i + tensor[0][0][0][0][0] = 1.0 + i # Test __iter__() and __next__() idx = 1.0 for block in block_list: tensor = torch.from_dlpack(block) - assert tensor[0][0][0][0] == idx - tensor[0][0][0][0] += 0.5 + assert tensor[0][0][0][0][0] == idx + tensor[0][0][0][0][0] += 0.5 idx += 1.0 assert idx == 1.0 + block_count # Test __iter__() should reset current index idx = 1.0 for block in block_list: tensor = torch.from_dlpack(block) - assert tensor[0][0][0][0] == idx + 0.5 + assert tensor[0][0][0][0][0] == idx + 0.5 idx += 1.0 assert idx == 1.0 + block_count @@ -187,27 +190,37 @@ async def test_block_list_iteration(block_manager: BlockManager): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") async def test_block_copy_g1_g2(block_manager: BlockManager): # Allocate device (G1) and host (G2) block - host_block_list = block_manager.allocate_host_blocks_blocking(1) - device_block_list = block_manager.allocate_device_blocks_blocking(1) + host_block_list = await block_manager.allocate_host_blocks(1) + device_block_list = await block_manager.allocate_device_blocks(1) # Populate host block with unique values host_tensor = torch.from_dlpack(host_block_list[0]) for i in range(NUM_LAYER): - for j in range(PAGE_SIZE): - for k in range(INNER_DIM): - host_tensor[0][i][j][k] = i * PAGE_SIZE * INNER_DIM + j * INNER_DIM + k + for j in range(OUTER_DIM): + for k in range(PAGE_SIZE): + for w in range(INNER_DIM): + host_tensor[0][i][j][k][w] = ( + i * OUTER_DIM * PAGE_SIZE * INNER_DIM + + j * PAGE_SIZE * INNER_DIM + + k * INNER_DIM + + w + ) # Copy host block to device block after permuting - permute_dims = (0, 2, 3, 1) + permute_dims = (0, 2, 4, 3, 1) device_tensor_ = torch.from_dlpack(device_block_list[0]).permute(*permute_dims) device_tensor_.copy_(host_tensor.permute(*permute_dims)) # Assert device block is contiguous and updated in block manager device_tensor = torch.from_dlpack(device_block_list[0]) for i in range(NUM_LAYER): - for j in range(PAGE_SIZE): - for k in range(INNER_DIM): - assert ( - device_tensor[0][i][j][k] - == i * PAGE_SIZE * INNER_DIM + j * INNER_DIM + k - ) + for j in range(OUTER_DIM): + for k in range(PAGE_SIZE): + for w in range(INNER_DIM): + assert ( + device_tensor[0][i][j][k][w] + == i * OUTER_DIM * PAGE_SIZE * INNER_DIM + + j * PAGE_SIZE * INNER_DIM + + k * INNER_DIM + + w + ) # Set host block to zero and assert updated in block manager host_tensor_ = torch.from_dlpack(host_block_list[0]).permute(*permute_dims) host_tensor_.zero_() @@ -216,22 +229,166 @@ async def test_block_copy_g1_g2(block_manager: BlockManager): host_tensor_.copy_(device_tensor_) # Assert host block is updated in block manager for i in range(NUM_LAYER): - for j in range(PAGE_SIZE): - for k in range(INNER_DIM): - assert ( - host_tensor[0][i][j][k] - == i * PAGE_SIZE * INNER_DIM + j * INNER_DIM + k - ) + for j in range(OUTER_DIM): + for k in range(PAGE_SIZE): + for w in range(INNER_DIM): + assert ( + host_tensor[0][i][j][k][w] + == i * OUTER_DIM * PAGE_SIZE * INNER_DIM + + j * PAGE_SIZE * INNER_DIM + + k * INNER_DIM + + w + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") +async def test_cpu_layer_access(block_manager: BlockManager): + block_list = block_manager.allocate_host_blocks_blocking(1) + block = block_list[0] + layers = block.to_list() + assert len(layers) == NUM_LAYER + tensors = [torch.from_dlpack(bl) for bl in layers] + for tensor in tensors: + assert tensor.get_device() == -1 # CPU + assert tensor.shape == (1, 1, OUTER_DIM, PAGE_SIZE, INNER_DIM) + assert tensor.dtype == TORCH_DTYPE + # print(tensors) + for tensor in tensors: + tensor[0][0][0][0][0] = 1.0 + tensor[0][0][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 + # print(tensors) + layers_ = block.to_list() + assert layers is not layers_ + assert len(layers) == len(layers_) + tensors_ = [torch.from_dlpack(bl) for bl in layers_] + for tensor, tensor_ in zip(tensors, tensors_): + assert tensor is not tensor_ + assert tensor.shape == tensor_.shape + assert tensor.dtype == tensor_.dtype + assert torch.allclose(tensor, tensor_) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") +async def test_gpu_layer_access(block_manager: BlockManager): + block_list = block_manager.allocate_device_blocks_blocking(1) + block = block_list[0] + layers = block.to_list() + assert len(layers) == NUM_LAYER + tensors = [torch.from_dlpack(bl) for bl in layers] + for tensor in tensors: + assert tensor.get_device() == DEVICE_ID # GPU + assert tensor.shape == (1, 1, OUTER_DIM, PAGE_SIZE, INNER_DIM) + assert tensor.dtype == TORCH_DTYPE + # print(tensors) + for tensor in tensors: + tensor[0][0][0][0][0] = 1.0 + tensor[0][0][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 + # print(tensors) + layers_ = block.to_list() + assert layers is not layers_ + assert len(layers) == len(layers_) + tensors_ = [torch.from_dlpack(bl) for bl in layers_] + for tensor, tensor_ in zip(tensors, tensors_): + assert tensor is not tensor_ + assert tensor.shape == tensor_.shape + assert tensor.dtype == tensor_.dtype + assert torch.allclose(tensor, tensor_) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") +async def test_block_iteration(block_manager: BlockManager): + block = (await block_manager.allocate_host_blocks(1))[0] + # Test __len__() + assert len(block) == NUM_LAYER + # Test __getitem__() + for i in range(NUM_LAYER): + layer = block[i] + tensor = torch.from_dlpack(layer) + tensor[0][0][0][0][0] = 1.0 + i + # Test __iter__() and __next__() + idx = 1.0 + for layer in block: + tensor = torch.from_dlpack(layer) + assert tensor[0][0][0][0][0] == idx + tensor[0][0][0][0][0] += 0.5 + idx += 1.0 + assert idx == 1.0 + NUM_LAYER + # Test __iter__() should reset current index + idx = 1.0 + for layer in block: + tensor = torch.from_dlpack(layer) + assert tensor[0][0][0][0][0] == idx + 0.5 + idx += 1.0 + assert idx == 1.0 + NUM_LAYER + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") +async def test_block_layer_copy_g1_g2(block_manager: BlockManager): + # Allocate device (G1) and host (G2) block + host_block = (await block_manager.allocate_host_blocks(1))[0] + device_block = (await block_manager.allocate_device_blocks(1))[0] + # Populate host block at layer level with unique values + host_layer_tensors = [torch.from_dlpack(bl) for bl in host_block] + for i in range(NUM_LAYER): + host_layer_tensor = host_layer_tensors[i] + for j in range(OUTER_DIM): + for k in range(PAGE_SIZE): + for w in range(INNER_DIM): + host_layer_tensor[0][0][j][k][w] = ( + i * OUTER_DIM * PAGE_SIZE * INNER_DIM + + j * PAGE_SIZE * INNER_DIM + + k * INNER_DIM + + w + ) + # Copy host block to device block after permuting + permute_dims = (0, 2, 4, 3, 1) + host_block_tensor_ = torch.from_dlpack(host_block).permute(*permute_dims) + device_block_tensor_ = torch.from_dlpack(device_block).permute(*permute_dims) + device_block_tensor_.copy_(host_block_tensor_) + # Assert device block is contiguous and updated in block manager at layer level + device_layer_tensors = [torch.from_dlpack(bl) for bl in device_block] + for i in range(NUM_LAYER): + device_layer_tensor = device_layer_tensors[i] + for j in range(OUTER_DIM): + for k in range(PAGE_SIZE): + for w in range(INNER_DIM): + assert ( + device_layer_tensor[0][0][j][k][w] + == i * OUTER_DIM * PAGE_SIZE * INNER_DIM + + j * PAGE_SIZE * INNER_DIM + + k * INNER_DIM + + w + ) + # Set host block to zero and assert updated in block manager + host_block_tensor = torch.from_dlpack(host_block) + host_block_tensor.zero_() + assert torch.all(host_block_tensor_ == 0) + # Copy device block back to host block + host_block_tensor_.copy_(device_block_tensor_) + # Assert host block is updated in block manager + for i in range(NUM_LAYER): + for j in range(OUTER_DIM): + for k in range(PAGE_SIZE): + for w in range(INNER_DIM): + assert ( + host_block_tensor[0][i][j][k][w] + == i * OUTER_DIM * PAGE_SIZE * INNER_DIM + + j * PAGE_SIZE * INNER_DIM + + k * INNER_DIM + + w + ) async def main(): await test_block_manager_initialization() - - # todo: revise these tests to index into the block via block_id, layer_id, outer_id (k/v) - # await test_cpu_block_access() - # await test_gpu_block_access() - # await test_block_list_iteration() - # await test_block_copy_g1_g2() + await test_cpu_block_access(new_block_manager()) + await test_gpu_block_access(new_block_manager()) + await test_block_list_iteration(new_block_manager()) + await test_block_copy_g1_g2(new_block_manager()) + await test_cpu_layer_access(new_block_manager()) + await test_gpu_layer_access(new_block_manager()) + await test_block_iteration(new_block_manager()) + await test_block_layer_copy_g1_g2(new_block_manager()) if __name__ == "__main__":