diff --git a/esp-hal-embassy/Cargo.toml b/esp-hal-embassy/Cargo.toml index 83ceebff460..6c6b67723c3 100644 --- a/esp-hal-embassy/Cargo.toml +++ b/esp-hal-embassy/Cargo.toml @@ -17,7 +17,7 @@ defmt = { version = "0.3.8", optional = true } document-features = "0.2.10" embassy-executor = { version = "0.6.0", optional = true } embassy-time-driver = { version = "0.1.0", features = [ "tick-hz-1_000_000" ] } -esp-hal = { version = "0.19.0", path = "../esp-hal" } +esp-hal = { version = "0.19.0", path = "../esp-hal", features = ["__esp_hal_embassy"] } log = { version = "0.4.22", optional = true } macros = { version = "0.12.0", features = ["embassy"], package = "esp-hal-procmacros", path = "../esp-hal-procmacros" } portable-atomic = "1.6.0" diff --git a/esp-hal/CHANGELOG.md b/esp-hal/CHANGELOG.md index c3e38928585..19ca541783f 100644 --- a/esp-hal/CHANGELOG.md +++ b/esp-hal/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Introduce DMA buffer objects (#1856) +- Introduce DMA buffer objects (#1856, #1985) - Added new `Io::new_no_bind_interrupt` constructor (#1861) - Added touch pad support for esp32 (#1873, #1956) - Allow configuration of period updating method for MCPWM timers (#1898) @@ -23,7 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Peripheral driver constructors don't take `InterruptHandler`s anymore. Use `set_interrupt_handler` to explicitly set the interrupt handler now. (#1819) -- Migrate SPI driver to use DMA buffer objects (#1856) +- Migrate SPI driver to use DMA buffer objects (#1856, #1985) - Use the peripheral ref pattern for `OneShotTimer` and `PeriodicTimer` (#1855) - Improve SYSTIMER API (#1871) - DMA buffers now don't require a static lifetime. Make sure to never `mem::forget` an in-progress DMA transfer (consider using `#[deny(clippy::mem_forget)]`) (#1837) @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Remove `fn free(self)` in HMAC which goes against esp-hal API guidelines (#1972) - PARL_IO use ReadBuffer and WriteBuffer for Async DMA (#1996) - `AnyPin`, `AnyInputOnyPin` and `DummyPin` are now accessible from `gpio` module (#1918) +- Changed the RSA modular multiplication API to be consistent across devices (#2002) ### Fixed @@ -44,6 +45,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - We should no longer generate 1GB .elf files for ESP32C2 and ESP32C3 (#1962) - Reset peripherals in driver constructors where missing (#1893, #1961) - Fixed ESP32-S2 systimer interrupts (#1979) +- Software interrupt 3 is no longer available when it is required by `esp-hal-embassy`. (#2011) +- ESP32: Fixed async RSA (#2002) ### Removed @@ -51,6 +54,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - The `AesFlavour` trait no longer has the `ENCRYPT_MODE`/`DECRYPT_MODE` associated constants (#1849) - Removed `FlashSafeDma` (#1856) - Remove redundant WithDmaSpi traits (#1975) +- `IsFullDuplex` and `IsHalfDuplex` traits (#1985) ## [0.19.0] - 2024-07-15 diff --git a/esp-hal/Cargo.toml b/esp-hal/Cargo.toml index a20f8d0fa5d..fe7cde90dac 100644 --- a/esp-hal/Cargo.toml +++ b/esp-hal/Cargo.toml @@ -85,6 +85,8 @@ bluetooth = [] usb-otg = ["esp-synopsys-usb-otg", "usb-device"] +__esp_hal_embassy = [] + ## Enable debug features in the HAL (used for development). debug = [ "esp32?/impl-register-debug", diff --git a/esp-hal/src/dma/mod.rs b/esp-hal/src/dma/mod.rs index 303b7f46063..82ddb07730f 100644 --- a/esp-hal/src/dma/mod.rs +++ b/esp-hal/src/dma/mod.rs @@ -19,7 +19,7 @@ #![doc = crate::before_snippet!()] //! # use esp_hal::dma_buffers; //! # use esp_hal::gpio::Io; -//! # use esp_hal::spi::{master::{Spi, prelude::*}, SpiMode}; +//! # use esp_hal::spi::{master::Spi, SpiMode}; //! # use esp_hal::dma::{Dma, DmaPriority}; //! # use crate::esp_hal::prelude::_fugit_RateExtU32; //! let dma = Dma::new(peripherals.DMA); diff --git a/esp-hal/src/rsa/esp32.rs b/esp-hal/src/rsa/esp32.rs index bb9a3796dc8..c5b9938d824 100644 --- a/esp-hal/src/rsa/esp32.rs +++ b/esp-hal/src/rsa/esp32.rs @@ -1,8 +1,4 @@ -use core::{ - convert::Infallible, - marker::PhantomData, - ptr::{copy_nonoverlapping, write_bytes}, -}; +use core::convert::Infallible; use crate::rsa::{ implement_op, @@ -37,35 +33,30 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } /// Starts the modular exponentiation operation. - pub(super) fn write_modexp_start(&mut self) { + pub(super) fn write_modexp_start(&self) { self.rsa .modexp_start() .write(|w| w.modexp_start().set_bit()); } /// Starts the multiplication operation. - pub(super) fn write_multi_start(&mut self) { + pub(super) fn write_multi_start(&self) { self.rsa.mult_start().write(|w| w.mult_start().set_bit()); } + /// Starts the modular multiplication operation. + pub(super) fn write_modmulti_start(&self) { + self.write_multi_start(); + } + /// Clears the RSA interrupt flag. pub(super) fn clear_interrupt(&mut self) { self.rsa.interrupt().write(|w| w.interrupt().set_bit()); } /// Checks if the RSA peripheral is idle. - pub(super) fn is_idle(&mut self) -> bool { - self.rsa.interrupt().read().bits() == 1 - } - - unsafe fn write_multi_operand_a(&mut self, operand_a: &[u32; N]) { - copy_nonoverlapping(operand_a.as_ptr(), self.rsa.x_mem(0).as_ptr(), N); - write_bytes(self.rsa.x_mem(0).as_ptr().add(N), 0, N); - } - - unsafe fn write_multi_operand_b(&mut self, operand_b: &[u32; N]) { - write_bytes(self.rsa.z_mem(0).as_ptr(), 0, N); - copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N); + pub(super) fn is_idle(&self) -> bool { + self.rsa.interrupt().read().interrupt().bit_is_set() } } @@ -92,59 +83,18 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplicati where T: RsaMode, { - /// Creates an instance of `RsaMultiplication`. - /// - /// `m_prime` can be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 24.3.2 of . - pub fn new(rsa: &'a mut Rsa<'d, DM>, modulus: &T::InputType, m_prime: u32) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_modulus(modulus); - } - rsa.write_mprime(m_prime); - - Self { - rsa, - phantom: PhantomData, - } - } - - fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_multi_mode((N / 16 - 1) as u32) } - /// Starts the first step of modular multiplication operation. - /// - /// `r` can be calculated using `2 ^ ( bitlength * 2 ) mod modulus`. + /// Starts the modular multiplication operation. /// /// For more information refer to 24.3.2 of . - pub fn start_step1(&mut self, operand_a: &T::InputType, r: &T::InputType) { - unsafe { - self.rsa.write_operand_a(operand_a); - self.rsa.write_r(r); - } - self.start(); - } - - /// Starts the second step of modular multiplication operation. - /// - /// This is a non blocking function that returns without an error if - /// operation is completed successfully. `start_step1` must be called - /// before calling this function. - pub fn start_step2(&mut self, operand_b: &T::InputType) { - while !self.rsa.is_idle() {} - - self.rsa.clear_interrupt(); - unsafe { - self.rsa.write_operand_a(operand_b); - } - self.start(); - } - - fn start(&mut self) { + pub(super) fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) { self.rsa.write_multi_start(); + self.rsa.wait_for_idle(); + + self.rsa.write_operand_a(operand_b); } } @@ -152,70 +102,22 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati where T: RsaMode, { - /// Creates an instance of `RsaModularExponentiation`. - /// - /// `m_prime` can be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 24.3.2 of . - pub fn new( - rsa: &'a mut Rsa<'d, DM>, - exponent: &T::InputType, - modulus: &T::InputType, - m_prime: u32, - ) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_operand_b(exponent); - rsa.write_modulus(modulus); - } - rsa.write_mprime(m_prime); - Self { - rsa, - phantom: PhantomData, - } - } - /// Sets the modular exponentiation mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_modexp_mode((N / 16 - 1) as u32) } - - /// Starts the modular exponentiation operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_modexp_start(); - } } impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplication<'a, 'd, T, DM> where T: RsaMode, { - /// Creates an instance of `RsaMultiplication`. - pub fn new(rsa: &'a mut Rsa<'d, DM>) -> Self { - Self::set_mode(rsa); - Self { - rsa, - phantom: PhantomData, - } - } - - /// Starts the multiplication operation. - pub fn start_multiplication(&mut self, operand_a: &T::InputType, operand_b: &T::InputType) { - unsafe { - self.rsa.write_multi_operand_a(operand_a); - self.rsa.write_multi_operand_b(operand_b); - } - self.start(); - } - /// Sets the multiplication mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_multi_mode(((N * 2) / 16 + 7) as u32) } - /// Starts the multiplication operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_multi_start(); + pub(super) fn set_up_multiplication(&mut self, operand_b: &T::InputType) { + self.rsa.write_multi_operand_b(operand_b); } } diff --git a/esp-hal/src/rsa/esp32cX.rs b/esp-hal/src/rsa/esp32cX.rs index c09fa73902e..ea6d79d44c2 100644 --- a/esp-hal/src/rsa/esp32cX.rs +++ b/esp-hal/src/rsa/esp32cX.rs @@ -1,4 +1,4 @@ -use core::{convert::Infallible, marker::PhantomData, ptr::copy_nonoverlapping}; +use core::convert::Infallible; use crate::rsa::{ implement_op, @@ -94,21 +94,21 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } /// Starts the modular exponentiation operation. - pub(super) fn write_modexp_start(&mut self) { + pub(super) fn write_modexp_start(&self) { self.rsa .set_start_modexp() .write(|w| w.set_start_modexp().set_bit()); } /// Starts the multiplication operation. - pub(super) fn write_multi_start(&mut self) { + pub(super) fn write_multi_start(&self) { self.rsa .set_start_mult() .write(|w| w.set_start_mult().set_bit()); } /// Starts the modular multiplication operation. - fn write_modmulti_start(&mut self) { + pub(super) fn write_modmulti_start(&self) { self.rsa .set_start_modmult() .write(|w| w.set_start_modmult().set_bit()); @@ -120,13 +120,9 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } /// Checks if the RSA peripheral is idle. - pub(super) fn is_idle(&mut self) -> bool { + pub(super) fn is_idle(&self) -> bool { self.rsa.query_idle().read().query_idle().bit_is_set() } - - unsafe fn write_multi_operand_b(&mut self, operand_b: &[u32; N]) { - copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N); - } } /// Module defining marker types for various RSA operand sizes. @@ -240,34 +236,7 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati where T: RsaMode, { - /// Creates an instance of `RsaModularExponentiation`. - /// - /// `m_prime` could be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 19.3.1 of . - pub fn new( - rsa: &'a mut Rsa<'d, DM>, - exponent: &T::InputType, - modulus: &T::InputType, - m_prime: u32, - ) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_operand_b(exponent); - rsa.write_modulus(modulus); - } - rsa.write_mprime(m_prime); - if rsa.is_search_enabled() { - rsa.write_search_position(Self::find_search_pos(exponent)); - } - Self { - rsa, - phantom: PhantomData, - } - } - - fn find_search_pos(exponent: &T::InputType) -> u32 { + pub(super) fn find_search_pos(exponent: &T::InputType) -> u32 { for (i, byte) in exponent.iter().rev().enumerate() { if *byte == 0 { continue; @@ -278,64 +247,21 @@ where } /// Sets the modular exponentiation mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N - 1) as u32) } - - /// Starts the modular exponentiation operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_modexp_start(); - } } impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplication<'a, 'd, T, DM> where T: RsaMode, { - fn write_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N - 1) as u32) } - /// Creates an instance of `RsaModularMultiplication`. - /// - /// `m_prime` can be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 19.3.1 of . - pub fn new( - rsa: &'a mut Rsa<'d, DM>, - operand_a: &T::InputType, - operand_b: &T::InputType, - modulus: &T::InputType, - m_prime: u32, - ) -> Self { - Self::write_mode(rsa); - rsa.write_mprime(m_prime); - unsafe { - rsa.write_modulus(modulus); - rsa.write_operand_a(operand_a); - rsa.write_operand_b(operand_b); - } - Self { - rsa, - phantom: PhantomData, - } - } - - /// Starts the modular multiplication operation. - /// - /// `r` could be calculated using `2 ^ ( bitlength * 2 ) mod modulus`. - /// - /// For more information refer to 19.3.1 of . - pub fn start_modular_multiplication(&mut self, r: &T::InputType) { - unsafe { - self.rsa.write_r(r); - } - self.start(); - } - - fn start(&mut self) { - self.rsa.write_modmulti_start(); + pub(super) fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) { + self.rsa.write_operand_b(operand_b); } } @@ -343,33 +269,12 @@ impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplicat where T: RsaMode, { - /// Creates an instance of `RsaMultiplication`. - pub fn new(rsa: &'a mut Rsa<'d, DM>, operand_a: &T::InputType) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_operand_a(operand_a); - } - Self { - rsa, - phantom: PhantomData, - } - } - - /// Starts the multiplication operation. - pub fn start_multiplication(&mut self, operand_b: &T::InputType) { - unsafe { - self.rsa.write_multi_operand_b(operand_b); - } - self.start(); + pub(super) fn set_up_multiplication(&mut self, operand_b: &T::InputType) { + self.rsa.write_multi_operand_b(operand_b); } /// Sets the multiplication mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N * 2 - 1) as u32) } - - /// Starts the multiplication operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_multi_start(); - } } diff --git a/esp-hal/src/rsa/esp32sX.rs b/esp-hal/src/rsa/esp32sX.rs index aefab89a668..956ec3fe4e8 100644 --- a/esp-hal/src/rsa/esp32sX.rs +++ b/esp-hal/src/rsa/esp32sX.rs @@ -1,4 +1,4 @@ -use core::{convert::Infallible, marker::PhantomData, ptr::copy_nonoverlapping}; +use core::convert::Infallible; use crate::rsa::{ implement_op, @@ -101,19 +101,19 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } /// Starts the modular exponentiation operation. - pub(super) fn write_modexp_start(&mut self) { + pub(super) fn write_modexp_start(&self) { self.rsa .modexp_start() .write(|w| w.modexp_start().set_bit()); } /// Starts the multiplication operation. - pub(super) fn write_multi_start(&mut self) { + pub(super) fn write_multi_start(&self) { self.rsa.mult_start().write(|w| w.mult_start().set_bit()); } /// Starts the modular multiplication operation. - fn write_modmulti_start(&mut self) { + pub(super) fn write_modmulti_start(&self) { self.rsa .modmult_start() .write(|w| w.modmult_start().set_bit()); @@ -127,13 +127,9 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } /// Checks if the RSA peripheral is idle. - pub(super) fn is_idle(&mut self) -> bool { + pub(super) fn is_idle(&self) -> bool { self.rsa.idle().read().idle().bit_is_set() } - - unsafe fn write_multi_operand_b(&mut self, operand_b: &[u32; N]) { - copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N); - } } pub mod operand_sizes { @@ -281,34 +277,7 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati where T: RsaMode, { - /// Creates an instance of `RsaModularExponentiation`. - /// - /// `m_prime` can be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 20.3.1 of . - pub fn new( - rsa: &'a mut Rsa<'d, DM>, - exponent: &T::InputType, - modulus: &T::InputType, - m_prime: u32, - ) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_operand_b(exponent); - rsa.write_modulus(modulus); - } - rsa.write_mprime(m_prime); - if rsa.is_search_enabled() { - rsa.write_search_position(Self::find_search_pos(exponent)); - } - Self { - rsa, - phantom: PhantomData, - } - } - - fn find_search_pos(exponent: &T::InputType) -> u32 { + pub(super) fn find_search_pos(exponent: &T::InputType) -> u32 { for (i, byte) in exponent.iter().rev().enumerate() { if *byte == 0 { continue; @@ -319,64 +288,21 @@ where } /// Sets the modular exponentiation mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N - 1) as u32) } - - /// Starts the modular exponentiation operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_modexp_start(); - } } impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplication<'a, 'd, T, DM> where T: RsaMode, { - /// Creates an instance of `RsaModularMultiplication`. - /// - /// `m_prime` could be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 20.3.1 of . - pub fn new( - rsa: &'a mut Rsa<'d, DM>, - operand_a: &T::InputType, - operand_b: &T::InputType, - modulus: &T::InputType, - m_prime: u32, - ) -> Self { - Self::write_mode(rsa); - rsa.write_mprime(m_prime); - unsafe { - rsa.write_modulus(modulus); - rsa.write_operand_a(operand_a); - rsa.write_operand_b(operand_b); - } - Self { - rsa, - phantom: PhantomData, - } - } - - fn write_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N - 1) as u32) } - /// Starts the modular multiplication operation. - /// - /// `r` could be calculated using `2 ^ ( bitlength * 2 ) mod modulus`. - /// - /// For more information refer to 19.3.1 of . - pub fn start_modular_multiplication(&mut self, r: &T::InputType) { - unsafe { - self.rsa.write_r(r); - } - self.start(); - } - - fn start(&mut self) { - self.rsa.write_modmulti_start(); + pub(super) fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) { + self.rsa.write_operand_b(operand_b); } } @@ -384,33 +310,12 @@ impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplicat where T: RsaMode, { - /// Creates an instance of `RsaMultiplication`. - pub fn new(rsa: &'a mut Rsa<'d, DM>, operand_a: &T::InputType) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_operand_a(operand_a); - } - Self { - rsa, - phantom: PhantomData, - } - } - - /// Starts the multiplication operation. - pub fn start_multiplication(&mut self, operand_b: &T::InputType) { - unsafe { - self.rsa.write_multi_operand_b(operand_b); - } - self.start(); - } - /// Sets the multiplication mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N * 2 - 1) as u32) } - /// Starts the multiplication operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_multi_start(); + pub(super) fn set_up_multiplication(&mut self, operand_b: &T::InputType) { + self.rsa.write_multi_operand_b(operand_b); } } diff --git a/esp-hal/src/rsa/mod.rs b/esp-hal/src/rsa/mod.rs index e8c54c550b2..8bf11d0eafa 100644 --- a/esp-hal/src/rsa/mod.rs +++ b/esp-hal/src/rsa/mod.rs @@ -16,16 +16,10 @@ //! ## Examples //! //! ### Modular Exponentiation, Modular Multiplication, and Multiplication -//! Visit the [RSA test] for an example of using the peripheral. -//! -//! ## Implementation State -//! -//! - The [nb] crate is used to handle non-blocking operations. -//! - This peripheral supports `async` on every available chip except of `esp32` -//! (to be solved). +//! Visit the [RSA test suite] for an example of using the peripheral. //! //! [nb]: https://docs.rs/nb/1.1.0/nb/ -//! [RSA test]: https://github.com/esp-rs/esp-hal/blob/main/hil-test/tests/rsa.rs +//! [RSA test suite]: https://github.com/esp-rs/esp-hal/blob/main/hil-test/tests/rsa.rs use core::{marker::PhantomData, ptr::copy_nonoverlapping}; @@ -53,24 +47,6 @@ pub struct Rsa<'d, DM: crate::Mode> { phantom: PhantomData, } -impl<'d, DM: crate::Mode> Rsa<'d, DM> { - fn internal_set_interrupt_handler(&mut self, handler: InterruptHandler) { - unsafe { - crate::interrupt::bind_interrupt(crate::peripherals::Interrupt::RSA, handler.handler()); - crate::interrupt::enable(crate::peripherals::Interrupt::RSA, handler.priority()) - .unwrap(); - } - } - - fn read_results(&mut self, outbuf: &mut [u32; N]) { - while !self.is_idle() {} - unsafe { - self.read_out(outbuf); - } - self.clear_interrupt(); - } -} - impl<'d> Rsa<'d, crate::Blocking> { /// Create a new instance in [crate::Blocking] mode. /// @@ -111,32 +87,66 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } } - unsafe fn write_operand_b(&mut self, operand_b: &[u32; N]) { - copy_nonoverlapping(operand_b.as_ptr(), self.rsa.y_mem(0).as_ptr(), N); + fn write_operand_b(&mut self, operand_b: &[u32; N]) { + unsafe { + copy_nonoverlapping(operand_b.as_ptr(), self.rsa.y_mem(0).as_ptr(), N); + } } - unsafe fn write_modulus(&mut self, modulus: &[u32; N]) { - copy_nonoverlapping(modulus.as_ptr(), self.rsa.m_mem(0).as_ptr(), N); + fn write_modulus(&mut self, modulus: &[u32; N]) { + unsafe { + copy_nonoverlapping(modulus.as_ptr(), self.rsa.m_mem(0).as_ptr(), N); + } } fn write_mprime(&mut self, m_prime: u32) { self.rsa.m_prime().write(|w| unsafe { w.bits(m_prime) }); } - unsafe fn write_operand_a(&mut self, operand_a: &[u32; N]) { - copy_nonoverlapping(operand_a.as_ptr(), self.rsa.x_mem(0).as_ptr(), N); + fn write_operand_a(&mut self, operand_a: &[u32; N]) { + unsafe { + copy_nonoverlapping(operand_a.as_ptr(), self.rsa.x_mem(0).as_ptr(), N); + } } - unsafe fn write_r(&mut self, r: &[u32; N]) { - copy_nonoverlapping(r.as_ptr(), self.rsa.z_mem(0).as_ptr(), N); + fn write_multi_operand_b(&mut self, operand_b: &[u32; N]) { + unsafe { + copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N); + } } - unsafe fn read_out(&mut self, outbuf: &mut [u32; N]) { - copy_nonoverlapping( - self.rsa.z_mem(0).as_ptr() as *const u32, - outbuf.as_ptr() as *mut u32, - N, - ); + fn write_r(&mut self, r: &[u32; N]) { + unsafe { + copy_nonoverlapping(r.as_ptr(), self.rsa.z_mem(0).as_ptr(), N); + } + } + + fn read_out(&self, outbuf: &mut [u32; N]) { + unsafe { + copy_nonoverlapping( + self.rsa.z_mem(0).as_ptr() as *const u32, + outbuf.as_ptr() as *mut u32, + N, + ); + } + } + + fn internal_set_interrupt_handler(&mut self, handler: InterruptHandler) { + unsafe { + crate::interrupt::bind_interrupt(crate::peripherals::Interrupt::RSA, handler.handler()); + crate::interrupt::enable(crate::peripherals::Interrupt::RSA, handler.priority()) + .unwrap(); + } + } + + fn wait_for_idle(&mut self) { + while !self.is_idle() {} + self.clear_interrupt(); + } + + fn read_results(&mut self, outbuf: &mut [u32; N]) { + self.wait_for_idle(); + self.read_out(outbuf); } } @@ -155,7 +165,7 @@ pub trait Multi: RsaMode { macro_rules! implement_op { (($x:literal, multi)) => { paste! { - /// Represents an RSA operation for the given bit size with multi-output. + #[doc = concat!($x, "-bit RSA operation.")] pub struct []; impl Multi for [] { @@ -204,17 +214,47 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati where T: RsaMode, { + /// Creates an instance of `RsaModularExponentiation`. + /// + /// `m_prime` could be calculated using `-(modular multiplicative inverse of + /// modulus) mod 2^32`. + /// + /// For more information refer to 24.3.2 of . + pub fn new( + rsa: &'a mut Rsa<'d, DM>, + exponent: &T::InputType, + modulus: &T::InputType, + m_prime: u32, + ) -> Self { + Self::write_mode(rsa); + rsa.write_operand_b(exponent); + rsa.write_modulus(modulus); + rsa.write_mprime(m_prime); + + #[cfg(not(esp32))] + if rsa.is_search_enabled() { + rsa.write_search_position(Self::find_search_pos(exponent)); + } + + Self { + rsa, + phantom: PhantomData, + } + } + + fn set_up_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) { + self.rsa.write_operand_a(base); + self.rsa.write_r(r); + } + /// Starts the modular exponentiation operation. /// /// `r` can be calculated using `2 ^ ( bitlength * 2 ) mod modulus`. /// /// For more information refer to 24.3.2 of . pub fn start_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) { - unsafe { - self.rsa.write_operand_a(base); - self.rsa.write_r(r); - } - self.start(); + self.set_up_exponentiation(base, r); + self.rsa.write_modexp_start(); } /// Reads the result to the given buffer. @@ -240,6 +280,40 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplicati where T: RsaMode, { + /// Creates an instance of `RsaModularMultiplication`. + /// + /// - `r` can be calculated using `2 ^ ( bitlength * 2 ) mod modulus`. + /// - `m_prime` can be calculated using `-(modular multiplicative inverse of + /// modulus) mod 2^32`. + /// + /// For more information refer to 20.3.1 of . + pub fn new( + rsa: &'a mut Rsa<'d, DM>, + operand_a: &T::InputType, + modulus: &T::InputType, + r: &T::InputType, + m_prime: u32, + ) -> Self { + Self::write_mode(rsa); + rsa.write_mprime(m_prime); + rsa.write_modulus(modulus); + rsa.write_operand_a(operand_a); + rsa.write_r(r); + + Self { + rsa, + phantom: PhantomData, + } + } + + /// Starts the modular multiplication operation. + /// + /// For more information refer to 19.3.1 of . + pub fn start_modular_multiplication(&mut self, operand_b: &T::InputType) { + self.set_up_modular_multiplication(operand_b); + self.rsa.write_modmulti_start(); + } + /// Reads the result to the given buffer. /// This is a non blocking function that returns without an error if /// operation is completed successfully. @@ -261,6 +335,23 @@ impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplicat where T: RsaMode, { + /// Creates an instance of `RsaMultiplication`. + pub fn new(rsa: &'a mut Rsa<'d, DM>, operand_a: &T::InputType) -> Self { + Self::write_mode(rsa); + rsa.write_operand_a(operand_a); + + Self { + rsa, + phantom: PhantomData, + } + } + + /// Starts the multiplication operation. + pub fn start_multiplication(&mut self, operand_b: &T::InputType) { + self.set_up_multiplication(operand_b); + self.rsa.write_multi_start(); + } + /// Reads the result to the given buffer. /// This is a non blocking function that returns without an error if /// operation is completed successfully. `start_multiplication` must be @@ -279,59 +370,67 @@ pub(crate) mod asynch { use core::task::Poll; use embassy_sync::waitqueue::AtomicWaker; + use portable_atomic::{AtomicBool, Ordering}; use procmacros::handler; - use crate::rsa::{ - Multi, - RsaMode, - RsaModularExponentiation, - RsaModularMultiplication, - RsaMultiplication, + use crate::{ + rsa::{ + Multi, + Rsa, + RsaMode, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, + }, + Async, }; static WAKER: AtomicWaker = AtomicWaker::new(); + static SIGNALED: AtomicBool = AtomicBool::new(false); + /// `Future` that waits for the RSA operation to complete. #[must_use = "futures do nothing unless you `.await` or poll them"] - pub(crate) struct RsaFuture<'d> { - instance: &'d crate::peripherals::RSA, + struct RsaFuture<'a, 'd> { + #[cfg_attr(esp32, allow(dead_code))] + instance: &'a Rsa<'d, Async>, } - impl<'d> RsaFuture<'d> { - /// Asynchronously initializes the RSA peripheral. - pub fn new(instance: &'d crate::peripherals::RSA) -> Self { + impl<'a, 'd> RsaFuture<'a, 'd> { + fn new(instance: &'a Rsa<'d, Async>) -> Self { + SIGNALED.store(false, Ordering::Relaxed); + cfg_if::cfg_if! { if #[cfg(esp32)] { - instance.interrupt().modify(|_, w| w.interrupt().set_bit()); } else if #[cfg(any(esp32s2, esp32s3))] { - instance.interrupt_ena().modify(|_, w| w.interrupt_ena().set_bit()); + instance.rsa.interrupt_ena().write(|w| w.interrupt_ena().set_bit()); } else { - instance.int_ena().modify(|_, w| w.int_ena().set_bit()); + instance.rsa.int_ena().write(|w| w.int_ena().set_bit()); } } Self { instance } } - fn event_bit_is_clear(&self) -> bool { + fn is_done(&self) -> bool { + SIGNALED.load(Ordering::Acquire) + } + } + + impl Drop for RsaFuture<'_, '_> { + fn drop(&mut self) { cfg_if::cfg_if! { if #[cfg(esp32)] { - self.instance.interrupt().read().interrupt().bit_is_clear() } else if #[cfg(any(esp32s2, esp32s3))] { - self - .instance - .interrupt_ena() - .read() - .interrupt_ena() - .bit_is_clear() + self.instance.rsa.interrupt_ena().write(|w| w.interrupt_ena().clear_bit()); } else { - self.instance.int_ena().read().int_ena().bit_is_clear() + self.instance.rsa.int_ena().write(|w| w.int_ena().clear_bit()); } } } } - impl<'d> core::future::Future for RsaFuture<'d> { + impl core::future::Future for RsaFuture<'_, '_> { type Output = (); fn poll( @@ -339,7 +438,7 @@ pub(crate) mod asynch { cx: &mut core::task::Context<'_>, ) -> core::task::Poll { WAKER.register(cx.waker()); - if self.event_bit_is_clear() { + if self.is_done() { Poll::Ready(()) } else { Poll::Pending @@ -347,7 +446,7 @@ pub(crate) mod asynch { } } - impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T, crate::Async> + impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T, Async> where T: RsaMode, { @@ -358,49 +457,47 @@ pub(crate) mod asynch { r: &T::InputType, outbuf: &mut T::InputType, ) { - self.start_exponentiation(base, r); - RsaFuture::new(&self.rsa.rsa).await; - self.read_results(outbuf); + self.set_up_exponentiation(base, r); + let fut = RsaFuture::new(self.rsa); + self.rsa.write_modexp_start(); + fut.await; + self.rsa.read_out(outbuf); } } - impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T, crate::Async> + impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T, Async> where T: RsaMode, { - #[cfg(not(esp32))] /// Asynchronously performs an RSA modular multiplication operation. pub async fn modular_multiplication( &mut self, - r: &T::InputType, - outbuf: &mut T::InputType, - ) { - self.start_modular_multiplication(r); - RsaFuture::new(&self.rsa.rsa).await; - self.read_results(outbuf); - } - - #[cfg(esp32)] - /// Asynchronously performs an RSA modular multiplication operation. - pub async fn modular_multiplication( - &mut self, - operand_a: &T::InputType, operand_b: &T::InputType, - r: &T::InputType, outbuf: &mut T::InputType, ) { - self.start_step1(operand_a, r); - self.start_step2(operand_b); - RsaFuture::new(&self.rsa.rsa).await; - self.read_results(outbuf); + cfg_if::cfg_if! { + if #[cfg(esp32)] { + let fut = RsaFuture::new(self.rsa); + self.rsa.write_multi_start(); + fut.await; + + self.rsa.write_operand_a(operand_b); + } else { + self.set_up_modular_multiplication(operand_b); + } + } + + let fut = RsaFuture::new(self.rsa); + self.rsa.write_modmulti_start(); + fut.await; + self.rsa.read_out(outbuf); } } - impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T, crate::Async> + impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T, Async> where T: RsaMode, { - #[cfg(not(esp32))] /// Asynchronously performs an RSA multiplication operation. pub async fn multiplication<'b, const O: usize>( &mut self, @@ -409,44 +506,28 @@ pub(crate) mod asynch { ) where T: Multi, { - self.start_multiplication(operand_b); - RsaFuture::new(&self.rsa.rsa).await; - self.read_results(outbuf); - } - - #[cfg(esp32)] - /// Asynchronously performs an RSA multiplication operation. - pub async fn multiplication<'b, const O: usize>( - &mut self, - operand_a: &T::InputType, - operand_b: &T::InputType, - outbuf: &mut T::OutputType, - ) where - T: Multi, - { - self.start_multiplication(operand_a, operand_b); - RsaFuture::new(&self.rsa.rsa).await; - self.read_results(outbuf); + self.set_up_multiplication(operand_b); + let fut = RsaFuture::new(self.rsa); + self.rsa.write_multi_start(); + fut.await; + self.rsa.read_out(outbuf); } } #[handler] /// Interrupt handler for RSA. pub(super) fn rsa_interrupt_handler() { - #[cfg(not(any(esp32, esp32s2, esp32s3)))] - unsafe { &*crate::peripherals::RSA::ptr() } - .int_ena() - .modify(|_, w| w.int_ena().clear_bit()); - - #[cfg(esp32)] - unsafe { &*crate::peripherals::RSA::ptr() } - .interrupt() - .modify(|_, w| w.interrupt().clear_bit()); - - #[cfg(any(esp32s2, esp32s3))] - unsafe { &*crate::peripherals::RSA::ptr() } - .interrupt_ena() - .modify(|_, w| w.interrupt_ena().clear_bit()); + let rsa = unsafe { &*crate::peripherals::RSA::ptr() }; + SIGNALED.store(true, Ordering::Release); + cfg_if::cfg_if! { + if #[cfg(esp32)] { + rsa.interrupt().write(|w| w.interrupt().set_bit()); + } else if #[cfg(any(esp32s2, esp32s3))] { + rsa.clear_interrupt().write(|w| w.clear_interrupt().set_bit()); + } else { + rsa.int_clr().write(|w| w.clear_interrupt().set_bit()); + } + } WAKER.wake(); } diff --git a/esp-hal/src/spi/master.rs b/esp-hal/src/spi/master.rs index d820f245791..f0533d9cfd1 100644 --- a/esp-hal/src/spi/master.rs +++ b/esp-hal/src/spi/master.rs @@ -63,6 +63,7 @@ use core::marker::PhantomData; +pub use dma::*; #[cfg(not(any(esp32, esp32s2)))] use enumset::EnumSet; #[cfg(not(any(esp32, esp32s2)))] @@ -72,12 +73,11 @@ use fugit::HertzU32; use procmacros::ram; use super::{ + DmaError, DuplexMode, Error, FullDuplexMode, HalfDuplexMode, - IsFullDuplex, - IsHalfDuplex, SpiBitOrder, SpiDataMode, SpiMode, @@ -93,17 +93,10 @@ use crate::{ system::PeripheralClockControl, }; -/// Prelude for the SPI (Master) driver -pub mod prelude { - pub use super::{ - Instance as _esp_hal_spi_master_Instance, - InstanceDma as _esp_hal_spi_master_InstanceDma, - }; -} - /// Enumeration of possible SPI interrupt events. #[cfg(not(any(esp32, esp32s2)))] #[derive(EnumSetType)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum SpiInterrupt { /// Indicates that the SPI transaction has completed successfully. /// @@ -128,6 +121,8 @@ const MAX_DMA_SIZE: usize = 32736; /// /// Used to define specific commands sent over the SPI bus. /// Can be [Command::None] if command phase should be suppressed. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Command { /// No command is sent. None, @@ -241,6 +236,8 @@ impl Command { /// /// This can be used to specify the address phase of SPI transactions. /// Can be [Address::None] if address phase should be suppressed. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Address { /// No address phase. None, @@ -463,10 +460,9 @@ pub struct Spi<'d, T, M> { _mode: PhantomData, } -impl<'d, T, M> Spi<'d, T, M> +impl<'d, T> Spi<'d, T, FullDuplexMode> where T: Instance, - M: IsFullDuplex, { /// Read bytes from SPI. /// @@ -859,10 +855,9 @@ where } } -impl HalfDuplexReadWrite for Spi<'_, T, M> +impl HalfDuplexReadWrite for Spi<'_, T, HalfDuplexMode> where T: Instance, - M: IsHalfDuplex, { type Error = Error; @@ -907,10 +902,9 @@ where } #[cfg(feature = "embedded-hal-02")] -impl embedded_hal_02::spi::FullDuplex for Spi<'_, T, M> +impl embedded_hal_02::spi::FullDuplex for Spi<'_, T, FullDuplexMode> where T: Instance, - M: IsFullDuplex, { type Error = Error; @@ -924,10 +918,9 @@ where } #[cfg(feature = "embedded-hal-02")] -impl embedded_hal_02::blocking::spi::Transfer for Spi<'_, T, M> +impl embedded_hal_02::blocking::spi::Transfer for Spi<'_, T, FullDuplexMode> where T: Instance, - M: IsFullDuplex, { type Error = Error; @@ -937,10 +930,9 @@ where } #[cfg(feature = "embedded-hal-02")] -impl embedded_hal_02::blocking::spi::Write for Spi<'_, T, M> +impl embedded_hal_02::blocking::spi::Write for Spi<'_, T, FullDuplexMode> where T: Instance, - M: IsFullDuplex, { type Error = Error; @@ -950,8 +942,7 @@ where } } -/// DMA (Direct Memory Access) funtionality (Master). -pub mod dma { +mod dma { use core::{ cmp::min, sync::atomic::{fence, Ordering}, @@ -973,7 +964,6 @@ pub mod dma { SpiPeripheral, TxPrivate, }, - Blocking, InterruptConfigurable, Mode, }; @@ -1036,24 +1026,30 @@ pub mod dma { } /// A DMA capable SPI instance. - pub struct SpiDma<'d, T, C, M, DmaMode> + /// + /// Using `SpiDma` is not recommended unless you wish + /// to manage buffers yourself. It's recommended to use + /// [`SpiDmaBus`] via `with_buffers` to get access + /// to a DMA capable SPI bus that implements the + /// embedded-hal traits. + pub struct SpiDma<'d, T, C, D, M> where C: DmaChannel, C::P: SpiPeripheral, - M: DuplexMode, - DmaMode: Mode, + D: DuplexMode, + M: Mode, { pub(crate) spi: PeripheralRef<'d, T>, - pub(crate) channel: Channel<'d, C, DmaMode>, - _mode: PhantomData, + pub(crate) channel: Channel<'d, C, M>, + _mode: PhantomData, } - impl<'d, T, C, M, DmaMode> core::fmt::Debug for SpiDma<'d, T, C, M, DmaMode> + impl<'d, T, C, D, M> core::fmt::Debug for SpiDma<'d, T, C, D, M> where C: DmaChannel, C::P: SpiPeripheral, - M: DuplexMode, - DmaMode: Mode, + D: DuplexMode, + M: Mode, { /// Formats the `SpiDma` instance for debugging purposes. /// @@ -1064,13 +1060,13 @@ pub mod dma { } } - impl<'d, T, C, M, DmaMode> SpiDma<'d, T, C, M, DmaMode> + impl<'d, T, C, D, M> SpiDma<'d, T, C, D, M> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, - M: DuplexMode, - DmaMode: Mode, + D: DuplexMode, + M: Mode, { /// Sets the interrupt handler /// @@ -1104,23 +1100,23 @@ pub mod dma { } } - impl<'d, T, C, M, DmaMode> crate::private::Sealed for SpiDma<'d, T, C, M, DmaMode> + impl<'d, T, C, D, M> crate::private::Sealed for SpiDma<'d, T, C, D, M> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, - M: DuplexMode, - DmaMode: Mode, + D: DuplexMode, + M: Mode, { } - impl<'d, T, C, M, DmaMode> InterruptConfigurable for SpiDma<'d, T, C, M, DmaMode> + impl<'d, T, C, D, M> InterruptConfigurable for SpiDma<'d, T, C, D, M> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, - M: DuplexMode, - DmaMode: Mode, + D: DuplexMode, + M: Mode, { /// Configures the interrupt handler for the DMA-enabled SPI instance. fn set_interrupt_handler(&mut self, handler: crate::interrupt::InterruptHandler) { @@ -1128,13 +1124,13 @@ pub mod dma { } } - impl<'d, T, C, M, DmaMode> SpiDma<'d, T, C, M, DmaMode> + impl<'d, T, C, D, M> SpiDma<'d, T, C, D, M> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, - M: DuplexMode, - DmaMode: Mode, + D: DuplexMode, + M: Mode, { /// Changes the SPI bus frequency for the DMA-enabled SPI instance. pub fn change_bus_frequency(&mut self, frequency: HertzU32, clocks: &Clocks<'d>) { @@ -1142,11 +1138,13 @@ pub mod dma { } } - impl<'d, T, C> SpiDma<'d, T, C, FullDuplexMode, Blocking> + impl<'d, T, C, D, M> SpiDma<'d, T, C, D, M> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, + D: DuplexMode, + M: Mode, { /// Configures the DMA buffers for the SPI instance. /// @@ -1157,44 +1155,23 @@ pub mod dma { self, dma_tx_buf: DmaTxBuf, dma_rx_buf: DmaRxBuf, - ) -> SpiDmaBus<'d, T, C> { + ) -> SpiDmaBus<'d, T, C, D, M> { SpiDmaBus::new(self, dma_tx_buf, dma_rx_buf) } } - #[cfg(feature = "async")] - impl<'d, T, C> SpiDma<'d, T, C, FullDuplexMode, crate::Async> - where - T: InstanceDma, - C: DmaChannel, - C::P: SpiPeripheral, - { - /// Configures the DMA buffers for asynchronous SPI communication. - /// - /// This method sets up both TX and RX buffers for DMA transfers. - /// It eturns an instance of `SpiDmaAsyncBus` to be used for - /// asynchronous SPI operations. - pub fn with_buffers( - self, - dma_tx_buf: DmaTxBuf, - dma_rx_buf: DmaRxBuf, - ) -> asynch::SpiDmaAsyncBus<'d, T, C> { - asynch::SpiDmaAsyncBus::new(self, dma_tx_buf, dma_rx_buf) - } - } - /// A structure representing a DMA transfer for SPI. /// /// This structure holds references to the SPI instance, DMA buffers, and /// transfer status. - pub struct SpiDmaTransfer<'d, T, C, M, DmaMode, Buf> + pub struct SpiDmaTransfer<'d, T, C, D, M, Buf> where C: DmaChannel, C::P: SpiPeripheral, - M: DuplexMode, - DmaMode: Mode, + D: DuplexMode, + M: Mode, { - spi_dma: SpiDma<'d, T, C, M, DmaMode>, + spi_dma: SpiDma<'d, T, C, D, M>, dma_buf: Buf, is_rx: bool, is_tx: bool, @@ -1203,20 +1180,15 @@ pub mod dma { tx_future_awaited: bool, } - impl<'d, T, C, M, DmaMode, Buf> SpiDmaTransfer<'d, T, C, M, DmaMode, Buf> + impl<'d, T, C, D, M, Buf> SpiDmaTransfer<'d, T, C, D, M, Buf> where T: Instance, C: DmaChannel, C::P: SpiPeripheral, - M: DuplexMode, - DmaMode: Mode, + D: DuplexMode, + M: Mode, { - fn new( - spi_dma: SpiDma<'d, T, C, M, DmaMode>, - dma_buf: Buf, - is_rx: bool, - is_tx: bool, - ) -> Self { + fn new(spi_dma: SpiDma<'d, T, C, D, M>, dma_buf: Buf, is_rx: bool, is_tx: bool) -> Self { Self { spi_dma, dma_buf, @@ -1259,7 +1231,7 @@ pub mod dma { /// /// This method blocks until the transfer is finished and returns the /// `SpiDma` instance and the associated buffer. - pub fn wait(mut self) -> (SpiDma<'d, T, C, M, DmaMode>, Buf) { + pub fn wait(mut self) -> (SpiDma<'d, T, C, D, M>, Buf) { self.spi_dma.spi.flush().ok(); fence(Ordering::Acquire); (self.spi_dma, self.dma_buf) @@ -1267,12 +1239,12 @@ pub mod dma { } #[cfg(feature = "async")] - impl<'d, T, C, M, Buf> SpiDmaTransfer<'d, T, C, M, crate::Async, Buf> + impl<'d, T, C, D, Buf> SpiDmaTransfer<'d, T, C, D, crate::Async, Buf> where T: Instance, C: DmaChannel, C::P: SpiPeripheral, - M: DuplexMode, + D: DuplexMode, { /// Waits for the DMA transfer to complete asynchronously. /// @@ -1292,13 +1264,12 @@ pub mod dma { } } - impl<'d, T, C, M, DmaMode> SpiDma<'d, T, C, M, DmaMode> + impl<'d, T, C, M> SpiDma<'d, T, C, FullDuplexMode, M> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, - M: IsFullDuplex, - DmaMode: Mode, + M: Mode, { /// Perform a DMA write. /// @@ -1310,7 +1281,7 @@ pub mod dma { pub fn dma_write( mut self, buffer: DmaTxBuf, - ) -> Result, (Error, Self, DmaTxBuf)> + ) -> Result, (Error, Self, DmaTxBuf)> { let bytes_to_write = buffer.len(); if bytes_to_write > MAX_DMA_SIZE { @@ -1338,7 +1309,7 @@ pub mod dma { pub fn dma_read( mut self, buffer: DmaRxBuf, - ) -> Result, (Error, Self, DmaRxBuf)> + ) -> Result, (Error, Self, DmaRxBuf)> { let bytes_to_read = buffer.len(); if bytes_to_read > MAX_DMA_SIZE { @@ -1367,7 +1338,7 @@ pub mod dma { tx_buffer: DmaTxBuf, rx_buffer: DmaRxBuf, ) -> Result< - SpiDmaTransfer<'d, T, C, M, DmaMode, (DmaTxBuf, DmaRxBuf)>, + SpiDmaTransfer<'d, T, C, FullDuplexMode, M, (DmaTxBuf, DmaRxBuf)>, (Error, Self, DmaTxBuf, DmaRxBuf), > { let bytes_to_write = tx_buffer.len(); @@ -1405,13 +1376,12 @@ pub mod dma { } } - impl<'d, T, C, M, DmaMode> SpiDma<'d, T, C, M, DmaMode> + impl<'d, T, C, M> SpiDma<'d, T, C, HalfDuplexMode, M> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, - M: IsHalfDuplex, - DmaMode: Mode, + M: Mode, { /// Perform a half-duplex read operation using DMA. #[allow(clippy::type_complexity)] @@ -1423,7 +1393,7 @@ pub mod dma { address: Address, dummy: u8, buffer: DmaRxBuf, - ) -> Result, (Error, Self, DmaRxBuf)> + ) -> Result, (Error, Self, DmaRxBuf)> { let bytes_to_read = buffer.len(); if bytes_to_read > MAX_DMA_SIZE { @@ -1501,7 +1471,7 @@ pub mod dma { address: Address, dummy: u8, buffer: DmaTxBuf, - ) -> Result, (Error, Self, DmaTxBuf)> + ) -> Result, (Error, Self, DmaTxBuf)> { let bytes_to_write = buffer.len(); if bytes_to_write > MAX_DMA_SIZE { @@ -1570,96 +1540,135 @@ pub mod dma { } } - /// A DMA-capable SPI bus that handles full-duplex transfers. + #[derive(Default)] + enum State<'d, T, C, D, M> + where + T: InstanceDma, + C: DmaChannel, + C::P: SpiPeripheral, + D: DuplexMode, + M: Mode, + { + Idle(SpiDma<'d, T, C, D, M>, DmaTxBuf, DmaRxBuf), + Reading(SpiDmaTransfer<'d, T, C, D, M, DmaRxBuf>, DmaTxBuf), + Writing(SpiDmaTransfer<'d, T, C, D, M, DmaTxBuf>, DmaRxBuf), + Transferring(SpiDmaTransfer<'d, T, C, D, M, (DmaTxBuf, DmaRxBuf)>), + #[default] + TemporarilyRemoved, + } + + /// A DMA-capable SPI bus. /// /// This structure is responsible for managing SPI transfers using DMA /// buffers. - pub struct SpiDmaBus<'d, T, C> + pub struct SpiDmaBus<'d, T, C, D, M> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, + D: DuplexMode, + M: Mode, { - spi_dma: Option>, - buffers: Option<(DmaTxBuf, DmaRxBuf)>, + state: State<'d, T, C, D, M>, } - impl<'d, T, C> SpiDmaBus<'d, T, C> + impl<'d, T, C, D, M> SpiDmaBus<'d, T, C, D, M> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, + D: DuplexMode, + M: Mode, { /// Creates a new `SpiDmaBus` with the specified SPI instance and DMA /// buffers. pub fn new( - spi_dma: SpiDma<'d, T, C, FullDuplexMode, crate::Blocking>, - tx_buffer: DmaTxBuf, - rx_buffer: DmaRxBuf, + spi: SpiDma<'d, T, C, D, M>, + dma_tx_buf: DmaTxBuf, + dma_rx_buf: DmaRxBuf, ) -> Self { Self { - spi_dma: Some(spi_dma), - buffers: Some((tx_buffer, rx_buffer)), + state: State::Idle(spi, dma_tx_buf, dma_rx_buf), + } + } + + fn wait_for_idle(&mut self) -> (SpiDma<'d, T, C, D, M>, DmaTxBuf, DmaRxBuf) { + match core::mem::take(&mut self.state) { + State::Idle(spi, tx_buf, rx_buf) => (spi, tx_buf, rx_buf), + State::Reading(transfer, tx_buf) => { + let (spi, rx_buf) = transfer.wait(); + (spi, tx_buf, rx_buf) + } + State::Writing(transfer, rx_buf) => { + let (spi, tx_buf) = transfer.wait(); + (spi, tx_buf, rx_buf) + } + State::Transferring(transfer) => { + let (spi, (tx_buf, rx_buf)) = transfer.wait(); + (spi, tx_buf, rx_buf) + } + State::TemporarilyRemoved => unreachable!(), } } + } + impl<'d, T, C, M> SpiDmaBus<'d, T, C, FullDuplexMode, M> + where + T: InstanceDma, + C: DmaChannel, + C::P: SpiPeripheral, + M: Mode, + { /// Reads data from the SPI bus using DMA. pub fn read(&mut self, words: &mut [u8]) -> Result<(), Error> { - let mut spi_dma = self.spi_dma.take().unwrap(); - let (tx_buf, mut rx_buf) = self.buffers.take().unwrap(); + let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle(); for chunk in words.chunks_mut(rx_buf.capacity()) { rx_buf.set_length(chunk.len()); - let transfer = match spi_dma.dma_read(rx_buf) { - Ok(transfer) => transfer, + match spi_dma.dma_read(rx_buf) { + Ok(transfer) => self.state = State::Reading(transfer, tx_buf), Err((e, spi, rx)) => { - self.spi_dma = Some(spi); - self.buffers = Some((tx_buf, rx)); + self.state = State::Idle(spi, tx_buf, rx); return Err(e); } - }; - (spi_dma, rx_buf) = transfer.wait(); + } + (spi_dma, tx_buf, rx_buf) = self.wait_for_idle(); let bytes_read = rx_buf.read_received_data(chunk); debug_assert_eq!(bytes_read, chunk.len()); } - self.spi_dma = Some(spi_dma); - self.buffers = Some((tx_buf, rx_buf)); + self.state = State::Idle(spi_dma, tx_buf, rx_buf); Ok(()) } /// Writes data to the SPI bus using DMA. pub fn write(&mut self, words: &[u8]) -> Result<(), Error> { - let mut spi_dma = self.spi_dma.take().unwrap(); - let (mut tx_buf, rx_buf) = self.buffers.take().unwrap(); + let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle(); for chunk in words.chunks(tx_buf.capacity()) { tx_buf.fill(chunk); - let transfer = match spi_dma.dma_write(tx_buf) { - Ok(transfer) => transfer, + match spi_dma.dma_write(tx_buf) { + Ok(transfer) => self.state = State::Writing(transfer, rx_buf), Err((e, spi, tx)) => { - self.spi_dma = Some(spi); - self.buffers = Some((tx, rx_buf)); + self.state = State::Idle(spi, tx, rx_buf); return Err(e); } - }; - (spi_dma, tx_buf) = transfer.wait(); + } + (spi_dma, tx_buf, rx_buf) = self.wait_for_idle(); } - self.spi_dma = Some(spi_dma); - self.buffers = Some((tx_buf, rx_buf)); + self.state = State::Idle(spi_dma, tx_buf, rx_buf); Ok(()) } /// Transfers data to and from the SPI bus simultaneously using DMA. pub fn transfer(&mut self, read: &mut [u8], write: &[u8]) -> Result<(), Error> { - let mut spi_dma = self.spi_dma.take().unwrap(); - let (mut tx_buf, mut rx_buf) = self.buffers.take().unwrap(); + let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle(); let chunk_size = min(tx_buf.capacity(), rx_buf.capacity()); @@ -1674,22 +1683,20 @@ pub mod dma { tx_buf.fill(write_chunk); rx_buf.set_length(read_chunk.len()); - let transfer = match spi_dma.dma_transfer(tx_buf, rx_buf) { - Ok(transfer) => transfer, + match spi_dma.dma_transfer(tx_buf, rx_buf) { + Ok(transfer) => self.state = State::Transferring(transfer), Err((e, spi, tx, rx)) => { - self.spi_dma = Some(spi); - self.buffers = Some((tx, rx)); + self.state = State::Idle(spi, tx, rx); return Err(e); } - }; - (spi_dma, (tx_buf, rx_buf)) = transfer.wait(); + } + (spi_dma, tx_buf, rx_buf) = self.wait_for_idle(); let bytes_read = rx_buf.read_received_data(read_chunk); debug_assert_eq!(bytes_read, read_chunk.len()); } - self.spi_dma = Some(spi_dma); - self.buffers = Some((tx_buf, rx_buf)); + self.state = State::Idle(spi_dma, tx_buf, rx_buf); if !read_remainder.is_empty() { self.read(read_remainder) @@ -1702,8 +1709,7 @@ pub mod dma { /// Transfers data in place on the SPI bus using DMA. pub fn transfer_in_place(&mut self, words: &mut [u8]) -> Result<(), Error> { - let mut spi_dma = self.spi_dma.take().unwrap(); - let (mut tx_buf, mut rx_buf) = self.buffers.take().unwrap(); + let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle(); let chunk_size = min(tx_buf.capacity(), rx_buf.capacity()); @@ -1711,29 +1717,100 @@ pub mod dma { tx_buf.fill(chunk); rx_buf.set_length(chunk.len()); - let transfer = match spi_dma.dma_transfer(tx_buf, rx_buf) { - Ok(transfer) => transfer, + match spi_dma.dma_transfer(tx_buf, rx_buf) { + Ok(transfer) => self.state = State::Transferring(transfer), Err((e, spi, tx, rx)) => { - self.spi_dma = Some(spi); - self.buffers = Some((tx, rx)); + self.state = State::Idle(spi, tx, rx); return Err(e); } - }; - (spi_dma, (tx_buf, rx_buf)) = transfer.wait(); + } + (spi_dma, tx_buf, rx_buf) = self.wait_for_idle(); let bytes_read = rx_buf.read_received_data(chunk); debug_assert_eq!(bytes_read, chunk.len()); } - self.spi_dma = Some(spi_dma); - self.buffers = Some((tx_buf, rx_buf)); + self.state = State::Idle(spi_dma, tx_buf, rx_buf); + Ok(()) + } + } + + impl<'d, T, C, M> HalfDuplexReadWrite for SpiDmaBus<'d, T, C, HalfDuplexMode, M> + where + T: InstanceDma, + C: DmaChannel, + C::P: SpiPeripheral, + M: Mode, + { + type Error = super::Error; + + /// Half-duplex read. + fn read( + &mut self, + data_mode: SpiDataMode, + cmd: Command, + address: Address, + dummy: u8, + buffer: &mut [u8], + ) -> Result<(), Self::Error> { + let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle(); + if buffer.len() > rx_buf.capacity() { + return Err(super::Error::DmaError(DmaError::Overflow)); + } + + rx_buf.set_length(buffer.len()); + + match spi_dma.read(data_mode, cmd, address, dummy, rx_buf) { + Ok(transfer) => self.state = State::Reading(transfer, tx_buf), + Err((e, spi, rx)) => { + self.state = State::Idle(spi, tx_buf, rx); + return Err(e); + } + } + (spi_dma, tx_buf, rx_buf) = self.wait_for_idle(); + + let bytes_read = rx_buf.read_received_data(buffer); + debug_assert_eq!(bytes_read, buffer.len()); + + self.state = State::Idle(spi_dma, tx_buf, rx_buf); + + Ok(()) + } + + /// Half-duplex write. + fn write( + &mut self, + data_mode: SpiDataMode, + cmd: Command, + address: Address, + dummy: u8, + buffer: &[u8], + ) -> Result<(), Self::Error> { + let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle(); + if buffer.len() > tx_buf.capacity() { + return Err(super::Error::DmaError(DmaError::Overflow)); + } + + tx_buf.fill(buffer); + + match spi_dma.write(data_mode, cmd, address, dummy, tx_buf) { + Ok(transfer) => self.state = State::Writing(transfer, rx_buf), + Err((e, spi, tx)) => { + self.state = State::Idle(spi, tx, rx_buf); + return Err(e); + } + } + (spi_dma, tx_buf, rx_buf) = self.wait_for_idle(); + + self.state = State::Idle(spi_dma, tx_buf, rx_buf); Ok(()) } } #[cfg(feature = "embedded-hal-02")] - impl<'d, T, C> embedded_hal_02::blocking::spi::Transfer for SpiDmaBus<'d, T, C> + impl<'d, T, C> embedded_hal_02::blocking::spi::Transfer + for SpiDmaBus<'d, T, C, FullDuplexMode, crate::Blocking> where T: InstanceDma, C: DmaChannel, @@ -1748,7 +1825,8 @@ pub mod dma { } #[cfg(feature = "embedded-hal-02")] - impl<'d, T, C> embedded_hal_02::blocking::spi::Write for SpiDmaBus<'d, T, C> + impl<'d, T, C> embedded_hal_02::blocking::spi::Write + for SpiDmaBus<'d, T, C, FullDuplexMode, crate::Blocking> where T: InstanceDma, C: DmaChannel, @@ -1764,76 +1842,18 @@ pub mod dma { /// Async functionality #[cfg(feature = "async")] - pub mod asynch { + mod asynch { use core::{cmp::min, mem::take}; - use embedded_hal::spi::ErrorType; - use super::*; - #[derive(Default)] - enum State<'d, T, C> + impl<'d, T, C> SpiDmaBus<'d, T, C, FullDuplexMode, crate::Async> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, { - Idle( - SpiDma<'d, T, C, FullDuplexMode, crate::Async>, - DmaTxBuf, - DmaRxBuf, - ), - Reading( - SpiDmaTransfer<'d, T, C, FullDuplexMode, crate::Async, DmaRxBuf>, - DmaTxBuf, - ), - Writing( - SpiDmaTransfer<'d, T, C, FullDuplexMode, crate::Async, DmaTxBuf>, - DmaRxBuf, - ), - Transferring( - SpiDmaTransfer<'d, T, C, FullDuplexMode, crate::Async, (DmaTxBuf, DmaRxBuf)>, - ), - #[default] - InUse, - } - - /// An asynchronous DMA-capable SPI bus for full-duplex operations. - /// - /// This struct provides an interface for SPI operations using DMA in an - /// asynchronous way. - pub struct SpiDmaAsyncBus<'d, T, C> - where - T: InstanceDma, - C: DmaChannel, - C::P: SpiPeripheral, - { - state: State<'d, T, C>, - } - - impl<'d, T, C> SpiDmaAsyncBus<'d, T, C> - where - T: InstanceDma, - C: DmaChannel, - C::P: SpiPeripheral, - { - /// Creates a new asynchronous DMA SPI bus instance. - /// - /// Initializes the bus with the provided SPI instance and DMA - /// buffers for transmit and receive operations. - pub fn new( - spi: SpiDma<'d, T, C, FullDuplexMode, crate::Async>, - dma_tx_buf: DmaTxBuf, - dma_rx_buf: DmaRxBuf, - ) -> Self { - Self { - state: State::Idle(spi, dma_tx_buf, dma_rx_buf), - } - } - - /// Waits for the current SPI DMA transfer to complete, ensuring the - /// bus is idle. - async fn wait_for_idle( + async fn wait_for_idle_async( &mut self, ) -> ( SpiDma<'d, T, C, FullDuplexMode, crate::Async>, @@ -1845,7 +1865,7 @@ pub mod dma { State::Reading(transfer, _) => transfer.wait_for_done().await, State::Writing(transfer, _) => transfer.wait_for_done().await, State::Transferring(transfer) => transfer.wait_for_done().await, - State::InUse => unreachable!(), + State::TemporarilyRemoved => unreachable!(), } match take(&mut self.state) { State::Idle(spi, tx_buf, rx_buf) => (spi, tx_buf, rx_buf), @@ -1861,31 +1881,14 @@ pub mod dma { let (spi, (tx_buf, rx_buf)) = transfer.wait(); (spi, tx_buf, rx_buf) } - State::InUse => unreachable!(), + State::TemporarilyRemoved => unreachable!(), } } - } - impl<'d, T, C> ErrorType for SpiDmaAsyncBus<'d, T, C> - where - T: InstanceDma, - C: DmaChannel, - C::P: SpiPeripheral, - { - type Error = Error; - } - - impl<'d, T, C> embedded_hal_async::spi::SpiBus for SpiDmaAsyncBus<'d, T, C> - where - T: InstanceDma, - C: DmaChannel, - C::P: SpiPeripheral, - { - /// Asynchronously reads data from the SPI bus into the provided - /// buffer. - async fn read(&mut self, words: &mut [u8]) -> Result<(), Self::Error> { + /// Fill the given buffer with data from the bus. + pub async fn read_async(&mut self, words: &mut [u8]) -> Result<(), super::Error> { // Get previous transfer. - let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle().await; + let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle_async().await; for chunk in words.chunks_mut(rx_buf.capacity()) { rx_buf.set_length(chunk.len()); @@ -1900,17 +1903,7 @@ pub mod dma { } }; - match &mut self.state { - State::Reading(transfer, _) => transfer.wait_for_done().await, - _ => unreachable!(), - }; - (spi_dma, tx_buf, rx_buf) = match take(&mut self.state) { - State::Reading(transfer, tx_buf) => { - let (spi, rx_buf) = transfer.wait(); - (spi, tx_buf, rx_buf) - } - _ => unreachable!(), - }; + (spi_dma, tx_buf, rx_buf) = self.wait_for_idle_async().await; let bytes_read = rx_buf.read_received_data(chunk); debug_assert_eq!(bytes_read, chunk.len()); @@ -1921,11 +1914,10 @@ pub mod dma { Ok(()) } - /// Asynchronously writes data to the SPI bus from the provided - /// buffer. - async fn write(&mut self, words: &[u8]) -> Result<(), Self::Error> { + /// Transmit the given buffer to the bus. + pub async fn write_async(&mut self, words: &[u8]) -> Result<(), super::Error> { // Get previous transfer. - let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle().await; + let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle_async().await; for chunk in words.chunks(tx_buf.capacity()) { tx_buf.fill(chunk); @@ -1940,18 +1932,7 @@ pub mod dma { } }; - match &mut self.state { - State::Writing(transfer, _) => transfer.wait_for_done().await, - _ => unreachable!(), - }; - - (spi_dma, tx_buf, rx_buf) = match take(&mut self.state) { - State::Writing(transfer, rx_buf) => { - let (spi, tx_buf) = transfer.wait(); - (spi, tx_buf, rx_buf) - } - _ => unreachable!(), - }; + (spi_dma, tx_buf, rx_buf) = self.wait_for_idle_async().await; } self.state = State::Idle(spi_dma, tx_buf, rx_buf); @@ -1959,15 +1940,15 @@ pub mod dma { Ok(()) } - /// Asynchronously performs a full-duplex transfer over the SPI bus. - /// - /// This method splits the transfer operation into chunks and - /// processes it asynchronously. It simultaneously - /// writes data from the `write` buffer and reads data into the - /// `read` buffer. - async fn transfer(&mut self, read: &mut [u8], write: &[u8]) -> Result<(), Self::Error> { + /// Transfer by writing out a buffer and reading the response from + /// the bus into another buffer. + pub async fn transfer_async( + &mut self, + read: &mut [u8], + write: &[u8], + ) -> Result<(), super::Error> { // Get previous transfer. - let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle().await; + let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle_async().await; let chunk_size = min(tx_buf.capacity(), rx_buf.capacity()); @@ -1992,18 +1973,7 @@ pub mod dma { } }; - match &mut self.state { - State::Transferring(transfer) => transfer.wait_for_done().await, - _ => unreachable!(), - }; - - (spi_dma, tx_buf, rx_buf) = match take(&mut self.state) { - State::Transferring(transfer) => { - let (spi, (tx_buf, rx_buf)) = transfer.wait(); - (spi, tx_buf, rx_buf) - } - _ => unreachable!(), - }; + (spi_dma, tx_buf, rx_buf) = self.wait_for_idle_async().await; let bytes_read = rx_buf.read_received_data(read_chunk); assert_eq!(bytes_read, read_chunk.len()); @@ -2012,19 +1982,22 @@ pub mod dma { self.state = State::Idle(spi_dma, tx_buf, rx_buf); if !read_remainder.is_empty() { - self.read(read_remainder).await + self.read_async(read_remainder).await } else if !write_remainder.is_empty() { - self.write(write_remainder).await + self.write_async(write_remainder).await } else { Ok(()) } } - /// Asynchronously performs an in-place full-duplex transfer over - /// the SPI bus. - async fn transfer_in_place(&mut self, words: &mut [u8]) -> Result<(), Self::Error> { + /// Transfer by writing out a buffer and reading the response from + /// the bus into the same buffer. + pub async fn transfer_in_place_async( + &mut self, + words: &mut [u8], + ) -> Result<(), super::Error> { // Get previous transfer. - let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle().await; + let (mut spi_dma, mut tx_buf, mut rx_buf) = self.wait_for_idle_async().await; for chunk in words.chunks_mut(tx_buf.capacity()) { tx_buf.fill(chunk); @@ -2062,13 +2035,41 @@ pub mod dma { Ok(()) } - async fn flush(&mut self) -> Result<(), Self::Error> { + /// Flush any pending data in the SPI peripheral. + pub async fn flush_async(&mut self) -> Result<(), super::Error> { // Get previous transfer. - let (spi_dma, tx_buf, rx_buf) = self.wait_for_idle().await; + let (spi_dma, tx_buf, rx_buf) = self.wait_for_idle_async().await; self.state = State::Idle(spi_dma, tx_buf, rx_buf); Ok(()) } } + + impl<'d, T, C> embedded_hal_async::spi::SpiBus for SpiDmaBus<'d, T, C, FullDuplexMode, crate::Async> + where + T: InstanceDma, + C: DmaChannel, + C::P: SpiPeripheral, + { + async fn read(&mut self, words: &mut [u8]) -> Result<(), Self::Error> { + self.read_async(words).await + } + + async fn write(&mut self, words: &[u8]) -> Result<(), Self::Error> { + self.write_async(words).await + } + + async fn transfer(&mut self, read: &mut [u8], write: &[u8]) -> Result<(), Self::Error> { + self.transfer_async(read, write).await + } + + async fn transfer_in_place(&mut self, words: &mut [u8]) -> Result<(), Self::Error> { + self.transfer_in_place_async(words).await + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + self.flush_async().await + } + } } #[cfg(feature = "embedded-hal")] @@ -2077,20 +2078,22 @@ pub mod dma { use super::*; - impl<'d, T, C> ErrorType for SpiDmaBus<'d, T, C> + impl<'d, T, C, M> ErrorType for SpiDmaBus<'d, T, C, FullDuplexMode, M> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, + M: Mode, { type Error = Error; } - impl<'d, T, C> SpiBus for SpiDmaBus<'d, T, C> + impl<'d, T, C, M> SpiBus for SpiDmaBus<'d, T, C, FullDuplexMode, M> where T: InstanceDma, C: DmaChannel, C::P: SpiPeripheral, + M: Mode, { fn read(&mut self, words: &mut [u8]) -> Result<(), Self::Error> { self.read(words) @@ -2127,10 +2130,9 @@ mod ehal1 { type Error = super::Error; } - impl FullDuplex for Spi<'_, T, M> + impl FullDuplex for Spi<'_, T, FullDuplexMode> where T: Instance, - M: IsFullDuplex, { fn read(&mut self) -> nb::Result { self.spi.read_byte() @@ -2141,10 +2143,9 @@ mod ehal1 { } } - impl SpiBus for Spi<'_, T, M> + impl SpiBus for Spi<'_, T, FullDuplexMode> where T: Instance, - M: IsFullDuplex, { fn read(&mut self, words: &mut [u8]) -> Result<(), Self::Error> { self.spi.read_bytes(words) diff --git a/esp-hal/src/spi/mod.rs b/esp-hal/src/spi/mod.rs index 9f30272438a..3c73581c159 100644 --- a/esp-hal/src/spi/mod.rs +++ b/esp-hal/src/spi/mod.rs @@ -79,14 +79,10 @@ pub enum SpiBitOrder { } /// Trait marker for defining SPI duplex modes. -pub trait DuplexMode {} -/// Trait marker for SPI full-duplex mode. -pub trait IsFullDuplex: DuplexMode {} -/// Trait marker for SPI half-duplex mode. -pub trait IsHalfDuplex: DuplexMode {} +pub trait DuplexMode: crate::private::Sealed {} /// SPI data mode -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum SpiDataMode { /// `Single` Data Mode - 1 bit, 2 wires. @@ -100,9 +96,9 @@ pub enum SpiDataMode { /// Full-duplex operation pub struct FullDuplexMode {} impl DuplexMode for FullDuplexMode {} -impl IsFullDuplex for FullDuplexMode {} +impl crate::private::Sealed for FullDuplexMode {} /// Half-duplex operation pub struct HalfDuplexMode {} impl DuplexMode for HalfDuplexMode {} -impl IsHalfDuplex for HalfDuplexMode {} +impl crate::private::Sealed for HalfDuplexMode {} diff --git a/esp-hal/src/system.rs b/esp-hal/src/system.rs index cca5b43d92a..1c31a0e611b 100755 --- a/esp-hal/src/system.rs +++ b/esp-hal/src/system.rs @@ -292,9 +292,10 @@ impl InterruptConfigurable for SoftwareInterrupt { #[cfg_attr( multi_core, doc = r#" + Please note: Software interrupt 3 is reserved -for inter-processor communication when the `embassy` -feature is enabled."# +for inter-processor communication when using +`esp-hal-embassy`."# )] #[non_exhaustive] pub struct SoftwareInterruptControl { @@ -304,7 +305,9 @@ pub struct SoftwareInterruptControl { pub software_interrupt1: SoftwareInterrupt<1>, /// Software interrupt 2. pub software_interrupt2: SoftwareInterrupt<2>, - /// Software interrupt 3. + #[cfg(not(all(feature = "__esp_hal_embassy", multi_core)))] + /// Software interrupt 3. Only available when not using `esp-hal-embassy`, + /// or on single-core systems. pub software_interrupt3: SoftwareInterrupt<3>, } @@ -314,6 +317,7 @@ impl SoftwareInterruptControl { software_interrupt0: SoftwareInterrupt {}, software_interrupt1: SoftwareInterrupt {}, software_interrupt2: SoftwareInterrupt {}, + #[cfg(not(all(feature = "__esp_hal_embassy", multi_core)))] software_interrupt3: SoftwareInterrupt {}, } } diff --git a/hil-test/Cargo.toml b/hil-test/Cargo.toml index ba898bfadc4..f6e0287abd6 100644 --- a/hil-test/Cargo.toml +++ b/hil-test/Cargo.toml @@ -107,6 +107,10 @@ harness = false name = "rsa" harness = false +[[test]] +name = "rsa_async" +harness = false + [[test]] name = "sha" harness = false diff --git a/hil-test/tests/rsa.rs b/hil-test/tests/rsa.rs index d3e728c7659..c102dd025a5 100644 --- a/hil-test/tests/rsa.rs +++ b/hil-test/tests/rsa.rs @@ -10,7 +10,7 @@ use esp_hal::{ peripherals::Peripherals, prelude::*, rsa::{ - operand_sizes, + operand_sizes::*, Rsa, RsaModularExponentiation, RsaModularMultiplication, @@ -37,16 +37,6 @@ struct Context<'a> { rsa: Rsa<'a, Blocking>, } -impl Context<'_> { - pub fn init() -> Self { - let peripherals = Peripherals::take(); - let mut rsa = Rsa::new(peripherals.RSA); - nb::block!(rsa.ready()).unwrap(); - - Context { rsa } - } -} - const fn compute_r(modulus: &U512) -> U512 { let mut d = [0_u32; U512::LIMBS * 2 + 1]; d[d.len() - 1] = 1; @@ -68,10 +58,15 @@ mod tests { #[init] fn init() -> Context<'static> { - Context::init() + let peripherals = Peripherals::take(); + let mut rsa = Rsa::new(peripherals.RSA); + nb::block!(rsa.ready()).unwrap(); + + Context { rsa } } #[test] + #[timeout(5)] fn test_modular_exponentiation(mut ctx: Context<'static>) { const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [ 1601059419, 3994655875, 2600857657, 1530060852, 64828275, 4221878473, 2751381085, @@ -85,20 +80,20 @@ mod tests { ctx.rsa.enable_disable_search_acceleration(true); } let mut outbuf = [0_u32; U512::LIMBS]; - let mut mod_exp = RsaModularExponentiation::::new( + let mut mod_exp = RsaModularExponentiation::::new( &mut ctx.rsa, BIGNUM_2.as_words(), BIGNUM_3.as_words(), compute_mprime(&BIGNUM_3), ); let r = compute_r(&BIGNUM_3); - let base = &BIGNUM_1.as_words(); - mod_exp.start_exponentiation(&base, r.as_words()); + mod_exp.start_exponentiation(BIGNUM_1.as_words(), r.as_words()); mod_exp.read_results(&mut outbuf); assert_eq!(EXPECTED_OUTPUT, outbuf); } #[test] + #[timeout(5)] fn test_modular_multiplication(mut ctx: Context<'static>) { const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [ 1868256644, 833470784, 4187374062, 2684021027, 191862388, 1279046003, 1929899870, @@ -107,31 +102,21 @@ mod tests { ]; let mut outbuf = [0_u32; U512::LIMBS]; - let mut mod_multi = - RsaModularMultiplication::::new( - &mut ctx.rsa, - #[cfg(not(feature = "esp32"))] - BIGNUM_1.as_words(), - #[cfg(not(feature = "esp32"))] - BIGNUM_2.as_words(), - BIGNUM_3.as_words(), - compute_mprime(&BIGNUM_3), - ); let r = compute_r(&BIGNUM_3); - #[cfg(feature = "esp32")] - { - mod_multi.start_step1(BIGNUM_1.as_words(), r.as_words()); - mod_multi.start_step2(BIGNUM_2.as_words()); - } - #[cfg(not(feature = "esp32"))] - { - mod_multi.start_modular_multiplication(r.as_words()); - } + let mut mod_multi = RsaModularMultiplication::::new( + &mut ctx.rsa, + BIGNUM_1.as_words(), + BIGNUM_3.as_words(), + r.as_words(), + compute_mprime(&BIGNUM_3), + ); + mod_multi.start_modular_multiplication(BIGNUM_2.as_words()); mod_multi.read_results(&mut outbuf); assert_eq!(EXPECTED_OUTPUT, outbuf); } #[test] + #[timeout(5)] fn test_multiplication(mut ctx: Context<'static>) { const EXPECTED_OUTPUT: [u32; U1024::LIMBS] = [ 1264702968, 3552243420, 2602501218, 498422249, 2431753435, 2307424767, 349202767, @@ -145,21 +130,10 @@ mod tests { let operand_a = BIGNUM_1.as_words(); let operand_b = BIGNUM_2.as_words(); - cfg_if::cfg_if! { - if #[cfg(feature = "esp32")] { - let mut rsamulti = - RsaMultiplication::::new(&mut ctx.rsa); - rsamulti.start_multiplication(operand_a, operand_b); - rsamulti.read_results(&mut outbuf); - } else { - let mut rsamulti = RsaMultiplication::::new( - &mut ctx.rsa, - operand_a, - ); - rsamulti.start_multiplication(operand_b); - rsamulti.read_results(&mut outbuf); - } - } + let mut rsamulti = RsaMultiplication::::new(&mut ctx.rsa, operand_a); + rsamulti.start_multiplication(operand_b); + rsamulti.read_results(&mut outbuf); + assert_eq!(EXPECTED_OUTPUT, outbuf) } } diff --git a/hil-test/tests/rsa_async.rs b/hil-test/tests/rsa_async.rs new file mode 100644 index 00000000000..5730ea13c57 --- /dev/null +++ b/hil-test/tests/rsa_async.rs @@ -0,0 +1,140 @@ +//! Async RSA Test + +//% CHIPS: esp32 esp32c3 esp32c6 esp32h2 esp32s2 esp32s3 + +#![no_std] +#![no_main] + +use crypto_bigint::{Uint, U1024, U512}; +use esp_hal::{ + peripherals::Peripherals, + prelude::*, + rsa::{ + operand_sizes::*, + Rsa, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, + }, + Async, +}; +use hil_test as _; + +const BIGNUM_1: U512 = Uint::from_be_hex( + "c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\ + 7dc353c5e0ec11f5fc8ce7f6073796cc8f73fa878", +); +const BIGNUM_2: U512 = Uint::from_be_hex( + "1763db3344e97be15d04de4868badb12a38046bb793f7630d87cf100aa1c759afac15a01f3c4c83ec2d2f66\ + 6bd22f71c3c1f075ec0e2cb0cb29994d091b73f51", +); +const BIGNUM_3: U512 = Uint::from_be_hex( + "6b6bb3d2b6cbeb45a769eaa0384e611e1b89b0c9b45a045aca1c5fd6e8785b38df7118cf5dd45b9b63d293b\ + 67aeafa9ba25feb8712f188cb139b7d9b9af1c361", +); + +struct Context<'a> { + rsa: Rsa<'a, Async>, +} + +const fn compute_r(modulus: &U512) -> U512 { + let mut d = [0_u32; U512::LIMBS * 2 + 1]; + d[d.len() - 1] = 1; + let d = Uint::from_words(d); + d.const_rem(&modulus.resize()).0.resize() +} + +const fn compute_mprime(modulus: &U512) -> u32 { + let m_inv = modulus.inv_mod2k(32).to_words()[0]; + (-1 * m_inv as i64 % 4294967296) as u32 +} + +#[cfg(test)] +#[embedded_test::tests(executor = esp_hal_embassy::Executor::new())] +mod tests { + use defmt::assert_eq; + + use super::*; + + #[init] + fn init() -> Context<'static> { + let peripherals = Peripherals::take(); + let mut rsa = Rsa::new_async(peripherals.RSA); + nb::block!(rsa.ready()).unwrap(); + + Context { rsa } + } + + #[test] + #[timeout(5)] + async fn modular_exponentiation(mut ctx: Context<'static>) { + const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [ + 1601059419, 3994655875, 2600857657, 1530060852, 64828275, 4221878473, 2751381085, + 1938128086, 625895085, 2087010412, 2133352910, 101578249, 3798099415, 3357588690, + 2065243474, 330914193, + ]; + + #[cfg(not(feature = "esp32"))] + { + ctx.rsa.enable_disable_constant_time_acceleration(true); + ctx.rsa.enable_disable_search_acceleration(true); + } + let mut outbuf = [0_u32; U512::LIMBS]; + let mut mod_exp = RsaModularExponentiation::::new( + &mut ctx.rsa, + BIGNUM_2.as_words(), + BIGNUM_3.as_words(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3); + mod_exp + .exponentiation(BIGNUM_1.as_words(), r.as_words(), &mut outbuf) + .await; + assert_eq!(EXPECTED_OUTPUT, outbuf); + } + + #[test] + #[timeout(5)] + async fn test_modular_multiplication(mut ctx: Context<'static>) { + const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [ + 1868256644, 833470784, 4187374062, 2684021027, 191862388, 1279046003, 1929899870, + 4209598061, 3830489207, 1317083344, 2666864448, 3701382766, 3232598924, 2904609522, + 747558855, 479377985, + ]; + + let mut outbuf = [0_u32; U512::LIMBS]; + let r = compute_r(&BIGNUM_3); + let mut mod_multi = RsaModularMultiplication::::new( + &mut ctx.rsa, + BIGNUM_1.as_words(), + BIGNUM_3.as_words(), + r.as_words(), + compute_mprime(&BIGNUM_3), + ); + mod_multi + .modular_multiplication(BIGNUM_2.as_words(), &mut outbuf) + .await; + assert_eq!(EXPECTED_OUTPUT, outbuf); + } + + #[test] + #[timeout(5)] + async fn test_multiplication(mut ctx: Context<'static>) { + const EXPECTED_OUTPUT: [u32; U1024::LIMBS] = [ + 1264702968, 3552243420, 2602501218, 498422249, 2431753435, 2307424767, 349202767, + 2269697177, 1525551459, 3623276361, 3146383138, 191420847, 4252021895, 9176459, + 301757643, 4220806186, 434407318, 3722444851, 1850128766, 928651940, 107896699, + 563405838, 1834067613, 1289630401, 3145128058, 3300293535, 3077505758, 1926648662, + 1264151247, 3626086486, 3701894076, 306518743, + ]; + let mut outbuf = [0_u32; U1024::LIMBS]; + + let operand_a = BIGNUM_1.as_words(); + let operand_b = BIGNUM_2.as_words(); + + let mut rsamulti = RsaMultiplication::::new(&mut ctx.rsa, operand_a); + rsamulti.multiplication(operand_b, &mut outbuf).await; + + assert_eq!(EXPECTED_OUTPUT, outbuf) + } +} diff --git a/hil-test/tests/spi_full_duplex_dma.rs b/hil-test/tests/spi_full_duplex_dma.rs index 3988d59f249..7810c4de7b2 100644 --- a/hil-test/tests/spi_full_duplex_dma.rs +++ b/hil-test/tests/spi_full_duplex_dma.rs @@ -21,7 +21,7 @@ use esp_hal::{ peripherals::{Peripherals, SPI2}, prelude::*, spi::{ - master::{dma::SpiDma, Spi}, + master::{Spi, SpiDma}, FullDuplexMode, SpiMode, }, @@ -49,6 +49,7 @@ struct Context { #[embedded_test::tests] mod tests { use defmt::assert_eq; + use esp_hal::dma::{DmaRxBuf, DmaTxBuf}; use super::*; diff --git a/hil-test/tests/spi_full_duplex_dma_async.rs b/hil-test/tests/spi_full_duplex_dma_async.rs index 75fee9b2d00..3dd9bc45998 100644 --- a/hil-test/tests/spi_full_duplex_dma_async.rs +++ b/hil-test/tests/spi_full_duplex_dma_async.rs @@ -34,10 +34,12 @@ use esp_hal::{ peripherals::{Peripherals, SPI2}, prelude::*, spi::{ - master::{dma::asynch::SpiDmaAsyncBus, Spi}, + master::{Spi, SpiDmaBus}, + FullDuplexMode, SpiMode, }, system::SystemControl, + Async, }; use hil_test as _; @@ -55,7 +57,7 @@ cfg_if::cfg_if! { const DMA_BUFFER_SIZE: usize = 5; struct Context { - spi: SpiDmaAsyncBus<'static, SPI2, DmaChannel0>, + spi: SpiDmaBus<'static, SPI2, DmaChannel0, FullDuplexMode, Async>, pcnt_unit: Unit<'static, 0>, out_pin: Output<'static, GpioPin<5>>, mosi_mirror: GpioPin<2>, diff --git a/hil-test/tests/spi_full_duplex_dma_pcnt.rs b/hil-test/tests/spi_full_duplex_dma_pcnt.rs index e304784240d..50e78fea918 100644 --- a/hil-test/tests/spi_full_duplex_dma_pcnt.rs +++ b/hil-test/tests/spi_full_duplex_dma_pcnt.rs @@ -30,7 +30,7 @@ use esp_hal::{ peripherals::{Peripherals, SPI2}, prelude::*, spi::{ - master::{dma::SpiDma, Spi}, + master::{Spi, SpiDma}, FullDuplexMode, SpiMode, }, diff --git a/hil-test/tests/spi_half_duplex_read.rs b/hil-test/tests/spi_half_duplex_read.rs index ab392a520b3..eda40ba5e55 100644 --- a/hil-test/tests/spi_half_duplex_read.rs +++ b/hil-test/tests/spi_half_duplex_read.rs @@ -15,13 +15,13 @@ use esp_hal::{ clock::ClockControl, - dma::{Dma, DmaPriority, DmaRxBuf}, + dma::{Dma, DmaPriority, DmaRxBuf, DmaTxBuf}, dma_buffers, gpio::{GpioPin, Io, Level, Output}, peripherals::{Peripherals, SPI2}, prelude::*, spi::{ - master::{dma::SpiDma, Address, Command, Spi}, + master::{Address, Command, HalfDuplexReadWrite, Spi, SpiDma}, HalfDuplexMode, SpiDataMode, SpiMode, @@ -129,4 +129,45 @@ mod tests { assert_eq!(dma_rx_buf.as_slice(), &[0xFF; DMA_BUFFER_SIZE]); } + + #[test] + #[timeout(3)] + fn test_spidmabus_reads_correctly_from_gpio_pin(mut ctx: Context) { + const DMA_BUFFER_SIZE: usize = 4; + + let (buffer, descriptors, tx, txd) = dma_buffers!(DMA_BUFFER_SIZE, 1); + let dma_rx_buf = DmaRxBuf::new(descriptors, buffer).unwrap(); + let dma_tx_buf = DmaTxBuf::new(txd, tx).unwrap(); + + let mut spi = ctx.spi.with_buffers(dma_tx_buf, dma_rx_buf); + + // SPI should read '0's from the MISO pin + ctx.miso_mirror.set_low(); + + let mut buffer = [0xAA; DMA_BUFFER_SIZE]; + spi.read( + SpiDataMode::Single, + Command::None, + Address::None, + 0, + &mut buffer, + ) + .unwrap(); + + assert_eq!(buffer.as_slice(), &[0x00; DMA_BUFFER_SIZE]); + + // SPI should read '1's from the MISO pin + ctx.miso_mirror.set_high(); + + spi.read( + SpiDataMode::Single, + Command::None, + Address::None, + 0, + &mut buffer, + ) + .unwrap(); + + assert_eq!(buffer.as_slice(), &[0xFF; DMA_BUFFER_SIZE]); + } } diff --git a/hil-test/tests/spi_half_duplex_write.rs b/hil-test/tests/spi_half_duplex_write.rs index 1acd34b17ce..e62ecd94d3d 100644 --- a/hil-test/tests/spi_half_duplex_write.rs +++ b/hil-test/tests/spi_half_duplex_write.rs @@ -15,7 +15,7 @@ use esp_hal::{ clock::ClockControl, - dma::{Dma, DmaPriority, DmaTxBuf}, + dma::{Dma, DmaPriority, DmaRxBuf, DmaTxBuf}, dma_buffers, gpio::{GpioPin, Io, Pull}, pcnt::{ @@ -26,7 +26,7 @@ use esp_hal::{ peripherals::{Peripherals, SPI2}, prelude::*, spi::{ - master::{dma::SpiDma, Address, Command, Spi}, + master::{Address, Command, HalfDuplexReadWrite, Spi, SpiDma}, HalfDuplexMode, SpiDataMode, SpiMode, @@ -143,4 +143,48 @@ mod tests { assert_eq!(unit.get_value(), (6 * DMA_BUFFER_SIZE) as _); } + + #[test] + #[timeout(3)] + fn test_spidmabus_writes_are_correctly_by_pcnt(ctx: Context) { + const DMA_BUFFER_SIZE: usize = 4; + + let (buffer, descriptors, rx, rxd) = dma_buffers!(DMA_BUFFER_SIZE, 1); + let dma_tx_buf = DmaTxBuf::new(descriptors, buffer).unwrap(); + let dma_rx_buf = DmaRxBuf::new(rxd, rx).unwrap(); + + let unit = ctx.pcnt_unit; + let mut spi = ctx.spi.with_buffers(dma_tx_buf, dma_rx_buf); + + unit.channel0.set_edge_signal(PcntSource::from_pin( + ctx.mosi_mirror, + PcntInputConfig { pull: Pull::Down }, + )); + unit.channel0 + .set_input_mode(EdgeMode::Hold, EdgeMode::Increment); + + let buffer = [0b0110_1010; DMA_BUFFER_SIZE]; + // Write the buffer where each byte has 3 pos edges. + spi.write( + SpiDataMode::Single, + Command::None, + Address::None, + 0, + &buffer, + ) + .unwrap(); + + assert_eq!(unit.get_value(), (3 * DMA_BUFFER_SIZE) as _); + + spi.write( + SpiDataMode::Single, + Command::None, + Address::None, + 0, + &buffer, + ) + .unwrap(); + + assert_eq!(unit.get_value(), (6 * DMA_BUFFER_SIZE) as _); + } }