diff --git a/src/config.rs b/src/config.rs index 849d9160..88105e3c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,5 @@ use super::banner::BANNER; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Error, Result}; use serde::{Deserialize, Serialize}; use std::{ fs, @@ -127,14 +127,8 @@ impl ClientConfig { number += 1; } - // TODO: Handle empty input? - let mut client_id = String::new(); - println!("\nEnter your Client ID: "); - stdin().read_line(&mut client_id)?; - - let mut client_secret = String::new(); - println!("\nEnter your Client Secret: "); - stdin().read_line(&mut client_secret)?; + let client_id = ClientConfig::get_client_key_from_input("Client ID")?; + let client_secret = ClientConfig::get_client_key_from_input("Client Secret")?; let mut port = String::new(); println!("\nEnter port of redirect uri (default {}): ", DEFAULT_PORT); @@ -142,8 +136,8 @@ impl ClientConfig { let port = port.trim().parse::().unwrap_or(DEFAULT_PORT); let config_yml = ClientConfig { - client_id: client_id.trim().to_string(), - client_secret: client_secret.trim().to_string(), + client_id, + client_secret, device_id: None, port: Some(port), }; @@ -161,4 +155,46 @@ impl ClientConfig { Ok(()) } } + + fn get_client_key_from_input(type_label: &'static str) -> Result { + let mut client_key = String::new(); + const MAX_RETRIES: u8 = 5; + let mut num_retries = 0; + loop { + println!("\nEnter your {}: ", type_label); + stdin().read_line(&mut client_key)?; + client_key = client_key.trim().to_string(); + match ClientConfig::validate_client_key(&client_key) { + Ok(_) => return Ok(client_key), + Err(error_string) => { + println!("{}", error_string); + client_key.clear(); + num_retries += 1; + if num_retries == MAX_RETRIES { + return Err(Error::from(std::io::Error::new( + std::io::ErrorKind::Other, + format!("Maximum retries ({}) exceeded.", MAX_RETRIES), + ))); + } + } + }; + } + } + + fn validate_client_key(key: &str) -> Result<()> { + const EXPECTED_LEN: usize = 32; + if key.len() != EXPECTED_LEN { + Err(Error::from(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("invalid length: {} (must be {})", key.len(), EXPECTED_LEN,), + ))) + } else if !key.chars().all(|c| c.is_digit(16)) { + Err(Error::from(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "invalid character found (must be hex digits)", + ))) + } else { + Ok(()) + } + } }