use std::net::IpAddr; use std::time::Duration; use reqwest::{StatusCode, Url}; use tokio::net::lookup_host; use crate::core::network::{NetworkPolicy, check_remote_fetch}; use crate::errors::SwaggerCliError; const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); const DEFAULT_OVERALL_TIMEOUT: Duration = Duration::from_secs(10); const DEFAULT_MAX_BYTES: u64 = 25 * 1024 * 1024; // 25 MB const DEFAULT_MAX_RETRIES: u32 = 2; const RETRY_BASE_DELAY: Duration = Duration::from_millis(500); // --------------------------------------------------------------------------- // SSRF protection // --------------------------------------------------------------------------- fn is_ip_blocked(ip: &IpAddr) -> bool { match ip { IpAddr::V4(v4) => { v4.is_loopback() // 127.0.0.0/8 || v4.is_link_local() // 169.254.0.0/16 || v4.is_broadcast() // 255.255.255.255 || v4.is_unspecified() // 0.0.0.0 || v4.is_multicast() // 224.0.0.0/4 || is_private_v4(v4) } IpAddr::V6(v6) => { v6.is_loopback() // ::1 || v6.is_unspecified() // :: || v6.is_multicast() // ff00::/8 || is_link_local_v6(v6) // fe80::/10 || is_blocked_mapped_v4(v6) } } } fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool { let octets = ip.octets(); // 10.0.0.0/8 octets[0] == 10 // 172.16.0.0/12 || (octets[0] == 172 && (16..=31).contains(&octets[1])) // 192.168.0.0/16 || (octets[0] == 192 && octets[1] == 168) } fn is_link_local_v6(ip: &std::net::Ipv6Addr) -> bool { let segments = ip.segments(); // fe80::/10 — first 10 bits are 1111_1110_10 (segments[0] & 0xffc0) == 0xfe80 } fn is_blocked_mapped_v4(v6: &std::net::Ipv6Addr) -> bool { // ::ffff:x.x.x.x — IPv4-mapped IPv6 let segments = v6.segments(); if segments[0..5] == [0, 0, 0, 0, 0] && segments[5] == 0xffff { let v4 = std::net::Ipv4Addr::new( (segments[6] >> 8) as u8, segments[6] as u8, (segments[7] >> 8) as u8, segments[7] as u8, ); return is_ip_blocked(&IpAddr::V4(v4)); } false } // --------------------------------------------------------------------------- // URL validation // --------------------------------------------------------------------------- fn validate_url(url: &str, allow_insecure_http: bool) -> Result { let parsed = Url::parse(url) .map_err(|e| SwaggerCliError::InvalidSpec(format!("invalid URL '{url}': {e}")))?; match parsed.scheme() { "https" => Ok(parsed), "http" if allow_insecure_http => Ok(parsed), "http" => Err(SwaggerCliError::PolicyBlocked(format!( "HTTP is not allowed for '{url}'. Use --allow-insecure-http to override." ))), other => Err(SwaggerCliError::InvalidSpec(format!( "unsupported scheme '{other}' in URL '{url}'" ))), } } // --------------------------------------------------------------------------- // DNS resolution + SSRF check // --------------------------------------------------------------------------- async fn resolve_and_check( host: &str, port: u16, allowed_private_hosts: &[String], ) -> Result<(), SwaggerCliError> { if allowed_private_hosts.iter().any(|h| h == host) { return Ok(()); } let addr = format!("{host}:{port}"); let addrs: Vec<_> = match lookup_host(&addr).await { Ok(iter) => iter.collect(), Err(e) => { return Err(SwaggerCliError::InvalidSpec(format!( "DNS resolution failed for '{host}': {e}" ))); } }; if addrs.is_empty() { return Err(SwaggerCliError::InvalidSpec(format!( "DNS resolution returned no addresses for '{host}'" ))); } for socket_addr in &addrs { if is_ip_blocked(&socket_addr.ip()) { return Err(SwaggerCliError::PolicyBlocked(format!( "resolved IP {} for host '{host}' is in a blocked range. \ Use --allow-private-host {host} to override.", socket_addr.ip() ))); } } Ok(()) } // --------------------------------------------------------------------------- // FetchResult // --------------------------------------------------------------------------- #[derive(Debug, Clone)] pub struct FetchResult { pub bytes: Vec, pub content_type: Option, pub etag: Option, pub last_modified: Option, } /// Result of a conditional fetch (If-None-Match / If-Modified-Since). #[derive(Debug, Clone)] pub enum ConditionalFetchResult { /// Server returned 304 Not Modified -- cached content is still current. NotModified, /// Server returned new content. Modified(FetchResult), } // --------------------------------------------------------------------------- // AsyncHttpClient builder // --------------------------------------------------------------------------- pub struct AsyncHttpClient { connect_timeout: Duration, overall_timeout: Duration, max_bytes: u64, max_retries: u32, allow_insecure_http: bool, allowed_private_hosts: Vec, auth_headers: Vec<(String, String)>, network_policy: NetworkPolicy, } impl Default for AsyncHttpClient { fn default() -> Self { Self { connect_timeout: DEFAULT_CONNECT_TIMEOUT, overall_timeout: DEFAULT_OVERALL_TIMEOUT, max_bytes: DEFAULT_MAX_BYTES, max_retries: DEFAULT_MAX_RETRIES, allow_insecure_http: false, allowed_private_hosts: Vec::new(), auth_headers: Vec::new(), network_policy: NetworkPolicy::Auto, } } } impl AsyncHttpClient { pub fn builder() -> AsyncHttpClientBuilder { AsyncHttpClientBuilder::default() } /// Fetch a spec with conditional request headers. /// /// Sends If-None-Match (for ETag) and If-Modified-Since (for Last-Modified) /// when provided. Returns `NotModified` on 304, `Modified` on 200. pub async fn fetch_conditional( &self, url: &str, etag: Option<&str>, last_modified: Option<&str>, ) -> Result { let parsed = validate_url(url, self.allow_insecure_http)?; let host = parsed .host_str() .ok_or_else(|| SwaggerCliError::InvalidSpec(format!("URL '{url}' has no host")))?; let port = parsed.port_or_known_default().unwrap_or(443); resolve_and_check(host, port, &self.allowed_private_hosts).await?; let client = self.build_reqwest_client()?; let mut attempts = 0u32; loop { let mut request = client.get(parsed.clone()); for (name, value) in &self.auth_headers { request = request.header(name.as_str(), value.as_str()); } if let Some(etag_val) = etag { request = request.header(reqwest::header::IF_NONE_MATCH, etag_val); } if let Some(lm_val) = last_modified { request = request.header(reqwest::header::IF_MODIFIED_SINCE, lm_val); } let response = request.send().await.map_err(SwaggerCliError::Network)?; let status = response.status(); match status { StatusCode::NOT_MODIFIED => { return Ok(ConditionalFetchResult::NotModified); } s if s.is_success() => { let result = self.read_response(response).await?; return Ok(ConditionalFetchResult::Modified(result)); } StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { return Err(SwaggerCliError::Auth(format!( "server returned {status} for '{url}'" ))); } StatusCode::NOT_FOUND => { return Err(SwaggerCliError::InvalidSpec(format!( "spec not found at '{url}' (404)" ))); } s if s == StatusCode::TOO_MANY_REQUESTS || s.is_server_error() => { attempts += 1; if attempts > self.max_retries { return Err(SwaggerCliError::Network( client.get(url).send().await.unwrap_err(), )); } let delay = self.retry_delay(&response, attempts); tokio::time::sleep(delay).await; } _ => { return Err(SwaggerCliError::InvalidSpec(format!( "unexpected status {status} fetching '{url}'" ))); } } } } pub async fn fetch_spec(&self, url: &str) -> Result { // Check network policy before any HTTP request check_remote_fetch(self.network_policy)?; let parsed = validate_url(url, self.allow_insecure_http)?; let host = parsed .host_str() .ok_or_else(|| SwaggerCliError::InvalidSpec(format!("URL '{url}' has no host")))?; let port = parsed.port_or_known_default().unwrap_or(443); resolve_and_check(host, port, &self.allowed_private_hosts).await?; let client = self.build_reqwest_client()?; let mut attempts = 0u32; loop { let mut request = client.get(parsed.clone()); for (name, value) in &self.auth_headers { request = request.header(name.as_str(), value.as_str()); } let response = request.send().await.map_err(SwaggerCliError::Network)?; let status = response.status(); match status { s if s.is_success() => { return self.read_response(response).await; } StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { return Err(SwaggerCliError::Auth(format!( "server returned {status} for '{url}'" ))); } StatusCode::NOT_FOUND => { return Err(SwaggerCliError::InvalidSpec(format!( "spec not found at '{url}' (404)" ))); } s if s == StatusCode::TOO_MANY_REQUESTS || s.is_server_error() => { attempts += 1; if attempts > self.max_retries { return Err(SwaggerCliError::Network( client.get(url).send().await.unwrap_err(), )); } let delay = self.retry_delay(&response, attempts); tokio::time::sleep(delay).await; } _ => { return Err(SwaggerCliError::InvalidSpec(format!( "unexpected status {status} fetching '{url}'" ))); } } } } fn build_reqwest_client(&self) -> Result { reqwest::Client::builder() .connect_timeout(self.connect_timeout) .timeout(self.overall_timeout) .https_only(!self.allow_insecure_http) .build() .map_err(SwaggerCliError::Network) } async fn read_response( &self, response: reqwest::Response, ) -> Result { let content_type = response .headers() .get(reqwest::header::CONTENT_TYPE) .and_then(|v| v.to_str().ok()) .map(String::from); let etag = response .headers() .get(reqwest::header::ETAG) .and_then(|v| v.to_str().ok()) .map(String::from); let last_modified = response .headers() .get(reqwest::header::LAST_MODIFIED) .and_then(|v| v.to_str().ok()) .map(String::from); // Stream the body with a size limit let mut bytes = Vec::new(); let mut stream = response; while let Some(chunk) = stream.chunk().await.map_err(SwaggerCliError::Network)? { bytes.extend_from_slice(&chunk); if bytes.len() as u64 > self.max_bytes { return Err(SwaggerCliError::PolicyBlocked(format!( "response exceeds maximum size of {} bytes", self.max_bytes ))); } } Ok(FetchResult { bytes, content_type, etag, last_modified, }) } fn retry_delay(&self, _response: &reqwest::Response, attempt: u32) -> Duration { // TODO: parse Retry-After header when present RETRY_BASE_DELAY * 2u32.saturating_pow(attempt.saturating_sub(1)) } } // --------------------------------------------------------------------------- // Builder // --------------------------------------------------------------------------- #[derive(Default)] pub struct AsyncHttpClientBuilder { connect_timeout: Option, overall_timeout: Option, max_bytes: Option, max_retries: Option, allow_insecure_http: bool, allowed_private_hosts: Vec, auth_headers: Vec<(String, String)>, network_policy: NetworkPolicy, } impl AsyncHttpClientBuilder { pub fn connect_timeout(mut self, d: Duration) -> Self { self.connect_timeout = Some(d); self } pub fn overall_timeout(mut self, d: Duration) -> Self { self.overall_timeout = Some(d); self } pub fn max_bytes(mut self, n: u64) -> Self { self.max_bytes = Some(n); self } pub fn max_retries(mut self, n: u32) -> Self { self.max_retries = Some(n); self } pub fn allow_insecure_http(mut self, allow: bool) -> Self { self.allow_insecure_http = allow; self } pub fn allowed_private_hosts(mut self, hosts: Vec) -> Self { self.allowed_private_hosts = hosts; self } pub fn auth_header(mut self, name: String, value: String) -> Self { self.auth_headers.push((name, value)); self } pub fn network_policy(mut self, policy: NetworkPolicy) -> Self { self.network_policy = policy; self } pub fn build(self) -> AsyncHttpClient { AsyncHttpClient { connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT), overall_timeout: self.overall_timeout.unwrap_or(DEFAULT_OVERALL_TIMEOUT), max_bytes: self.max_bytes.unwrap_or(DEFAULT_MAX_BYTES), max_retries: self.max_retries.unwrap_or(DEFAULT_MAX_RETRIES), allow_insecure_http: self.allow_insecure_http, allowed_private_hosts: self.allowed_private_hosts, auth_headers: self.auth_headers, network_policy: self.network_policy, } } } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; use std::net::{Ipv4Addr, Ipv6Addr}; // -- SSRF IP blocking --------------------------------------------------- #[test] fn test_ssrf_blocks_loopback() { assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))); assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new( 127, 255, 255, 254 )))); assert!(is_ip_blocked(&IpAddr::V6(Ipv6Addr::LOCALHOST))); } #[test] fn test_ssrf_blocks_private() { // 10.0.0.0/8 assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))); assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255)))); // 172.16.0.0/12 assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)))); assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(172, 31, 255, 255)))); // 192.168.0.0/16 assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)))); assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)))); } #[test] fn test_ssrf_blocks_link_local() { // IPv4 link-local (169.254.x.x) -- includes the AWS metadata endpoint assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new( 169, 254, 169, 254 )))); assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(169, 254, 0, 1)))); // IPv6 link-local (fe80::/10) assert!(is_ip_blocked(&IpAddr::V6(Ipv6Addr::new( 0xfe80, 0, 0, 0, 0, 0, 0, 1 )))); } #[test] fn test_ssrf_blocks_multicast() { assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1)))); assert!(is_ip_blocked(&IpAddr::V6(Ipv6Addr::new( 0xff02, 0, 0, 0, 0, 0, 0, 1 )))); } #[test] fn test_ssrf_blocks_mapped_v4() { // ::ffff:127.0.0.1 let mapped = Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x7f00, 0x0001); assert!(is_ip_blocked(&IpAddr::V6(mapped))); // ::ffff:10.0.0.1 let mapped_private = Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x0a00, 0x0001); assert!(is_ip_blocked(&IpAddr::V6(mapped_private))); } #[test] fn test_ssrf_allows_public() { assert!(!is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))); assert!(!is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)))); assert!(!is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)))); } // -- URL validation ----------------------------------------------------- #[test] fn test_url_rejects_http() { let result = validate_url("http://example.com/spec.json", false); assert!(result.is_err()); let err = result.unwrap_err(); assert!(matches!(err, SwaggerCliError::PolicyBlocked(_))); } #[test] fn test_url_allows_https() { let result = validate_url("https://example.com/spec.json", false); assert!(result.is_ok()); assert_eq!(result.unwrap().as_str(), "https://example.com/spec.json"); } #[test] fn test_url_allows_http_when_opted_in() { let result = validate_url("http://example.com/spec.json", true); assert!(result.is_ok()); } #[test] fn test_url_rejects_unsupported_scheme() { let result = validate_url("ftp://example.com/spec.json", false); assert!(result.is_err()); assert!(matches!( result.unwrap_err(), SwaggerCliError::InvalidSpec(_) )); } #[test] fn test_url_rejects_garbage() { let result = validate_url("not a url at all", false); assert!(result.is_err()); } // -- Builder defaults --------------------------------------------------- #[test] fn test_builder_defaults() { let client = AsyncHttpClient::builder().build(); assert_eq!(client.connect_timeout, DEFAULT_CONNECT_TIMEOUT); assert_eq!(client.overall_timeout, DEFAULT_OVERALL_TIMEOUT); assert_eq!(client.max_bytes, DEFAULT_MAX_BYTES); assert_eq!(client.max_retries, DEFAULT_MAX_RETRIES); assert!(!client.allow_insecure_http); assert!(client.allowed_private_hosts.is_empty()); assert!(client.auth_headers.is_empty()); } #[test] fn test_builder_custom() { let client = AsyncHttpClient::builder() .connect_timeout(Duration::from_secs(3)) .overall_timeout(Duration::from_secs(30)) .max_bytes(1024) .max_retries(5) .allow_insecure_http(true) .allowed_private_hosts(vec!["internal.corp".into()]) .auth_header("Authorization".into(), "Bearer tok".into()) .build(); assert_eq!(client.connect_timeout, Duration::from_secs(3)); assert_eq!(client.overall_timeout, Duration::from_secs(30)); assert_eq!(client.max_bytes, 1024); assert_eq!(client.max_retries, 5); assert!(client.allow_insecure_http); assert_eq!(client.allowed_private_hosts, vec!["internal.corp"]); assert_eq!(client.auth_headers.len(), 1); } // -- DNS + SSRF integration (async) ------------------------------------- #[tokio::test] async fn test_resolve_and_check_skips_allowed_host() { let result = resolve_and_check("localhost", 80, &["localhost".into()]).await; assert!(result.is_ok()); } }