Files
swagger-cli/src/core/http.rs
teernisse 346fef9135 Wave 5: Schemas command, sync command, network policy, test fixtures (bd-x15, bd-3f4, bd-1cv, bd-lx6)
- Implement schemas command with list/show modes, regex filtering, ref expansion
- Implement sync command with conditional fetch, content hash diffing, dry-run
- Add NetworkPolicy enum (Auto/Offline/OnlineOnly) with env var + CLI flag resolution
- Integrate network policy into AsyncHttpClient and fetch command
- Create test fixtures (petstore.json/yaml, minimal.json) and integration test helpers
- Fix clippy lints: derivable_impls, len_zero, borrow-after-move, deprecated API
- 192 tests passing (179 unit + 13 integration), all quality gates green
2026-02-12 14:37:14 -05:00

608 lines
20 KiB
Rust

use std::net::IpAddr;
use std::time::Duration;
use reqwest::{StatusCode, Url};
use tokio::net::lookup_host;
use crate::core::network::{NetworkPolicy, check_remote_fetch};
use crate::errors::SwaggerCliError;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_OVERALL_TIMEOUT: Duration = Duration::from_secs(10);
const DEFAULT_MAX_BYTES: u64 = 25 * 1024 * 1024; // 25 MB
const DEFAULT_MAX_RETRIES: u32 = 2;
const RETRY_BASE_DELAY: Duration = Duration::from_millis(500);
// ---------------------------------------------------------------------------
// SSRF protection
// ---------------------------------------------------------------------------
fn is_ip_blocked(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
v4.is_loopback() // 127.0.0.0/8
|| v4.is_link_local() // 169.254.0.0/16
|| v4.is_broadcast() // 255.255.255.255
|| v4.is_unspecified() // 0.0.0.0
|| v4.is_multicast() // 224.0.0.0/4
|| is_private_v4(v4)
}
IpAddr::V6(v6) => {
v6.is_loopback() // ::1
|| v6.is_unspecified() // ::
|| v6.is_multicast() // ff00::/8
|| is_link_local_v6(v6) // fe80::/10
|| is_blocked_mapped_v4(v6)
}
}
}
fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool {
let octets = ip.octets();
// 10.0.0.0/8
octets[0] == 10
// 172.16.0.0/12
|| (octets[0] == 172 && (16..=31).contains(&octets[1]))
// 192.168.0.0/16
|| (octets[0] == 192 && octets[1] == 168)
}
fn is_link_local_v6(ip: &std::net::Ipv6Addr) -> bool {
let segments = ip.segments();
// fe80::/10 — first 10 bits are 1111_1110_10
(segments[0] & 0xffc0) == 0xfe80
}
fn is_blocked_mapped_v4(v6: &std::net::Ipv6Addr) -> bool {
// ::ffff:x.x.x.x — IPv4-mapped IPv6
let segments = v6.segments();
if segments[0..5] == [0, 0, 0, 0, 0] && segments[5] == 0xffff {
let v4 = std::net::Ipv4Addr::new(
(segments[6] >> 8) as u8,
segments[6] as u8,
(segments[7] >> 8) as u8,
segments[7] as u8,
);
return is_ip_blocked(&IpAddr::V4(v4));
}
false
}
// ---------------------------------------------------------------------------
// URL validation
// ---------------------------------------------------------------------------
fn validate_url(url: &str, allow_insecure_http: bool) -> Result<Url, SwaggerCliError> {
let parsed = Url::parse(url)
.map_err(|e| SwaggerCliError::InvalidSpec(format!("invalid URL '{url}': {e}")))?;
match parsed.scheme() {
"https" => Ok(parsed),
"http" if allow_insecure_http => Ok(parsed),
"http" => Err(SwaggerCliError::PolicyBlocked(format!(
"HTTP is not allowed for '{url}'. Use --allow-insecure-http to override."
))),
other => Err(SwaggerCliError::InvalidSpec(format!(
"unsupported scheme '{other}' in URL '{url}'"
))),
}
}
// ---------------------------------------------------------------------------
// DNS resolution + SSRF check
// ---------------------------------------------------------------------------
async fn resolve_and_check(
host: &str,
port: u16,
allowed_private_hosts: &[String],
) -> Result<(), SwaggerCliError> {
if allowed_private_hosts.iter().any(|h| h == host) {
return Ok(());
}
let addr = format!("{host}:{port}");
let addrs: Vec<_> = match lookup_host(&addr).await {
Ok(iter) => iter.collect(),
Err(e) => {
return Err(SwaggerCliError::InvalidSpec(format!(
"DNS resolution failed for '{host}': {e}"
)));
}
};
if addrs.is_empty() {
return Err(SwaggerCliError::InvalidSpec(format!(
"DNS resolution returned no addresses for '{host}'"
)));
}
for socket_addr in &addrs {
if is_ip_blocked(&socket_addr.ip()) {
return Err(SwaggerCliError::PolicyBlocked(format!(
"resolved IP {} for host '{host}' is in a blocked range. \
Use --allow-private-host {host} to override.",
socket_addr.ip()
)));
}
}
Ok(())
}
// ---------------------------------------------------------------------------
// FetchResult
// ---------------------------------------------------------------------------
#[derive(Debug, Clone)]
pub struct FetchResult {
pub bytes: Vec<u8>,
pub content_type: Option<String>,
pub etag: Option<String>,
pub last_modified: Option<String>,
}
/// Result of a conditional fetch (If-None-Match / If-Modified-Since).
#[derive(Debug, Clone)]
pub enum ConditionalFetchResult {
/// Server returned 304 Not Modified -- cached content is still current.
NotModified,
/// Server returned new content.
Modified(FetchResult),
}
// ---------------------------------------------------------------------------
// AsyncHttpClient builder
// ---------------------------------------------------------------------------
pub struct AsyncHttpClient {
connect_timeout: Duration,
overall_timeout: Duration,
max_bytes: u64,
max_retries: u32,
allow_insecure_http: bool,
allowed_private_hosts: Vec<String>,
auth_headers: Vec<(String, String)>,
network_policy: NetworkPolicy,
}
impl Default for AsyncHttpClient {
fn default() -> Self {
Self {
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
overall_timeout: DEFAULT_OVERALL_TIMEOUT,
max_bytes: DEFAULT_MAX_BYTES,
max_retries: DEFAULT_MAX_RETRIES,
allow_insecure_http: false,
allowed_private_hosts: Vec::new(),
auth_headers: Vec::new(),
network_policy: NetworkPolicy::Auto,
}
}
}
impl AsyncHttpClient {
pub fn builder() -> AsyncHttpClientBuilder {
AsyncHttpClientBuilder::default()
}
/// Fetch a spec with conditional request headers.
///
/// Sends If-None-Match (for ETag) and If-Modified-Since (for Last-Modified)
/// when provided. Returns `NotModified` on 304, `Modified` on 200.
pub async fn fetch_conditional(
&self,
url: &str,
etag: Option<&str>,
last_modified: Option<&str>,
) -> Result<ConditionalFetchResult, SwaggerCliError> {
let parsed = validate_url(url, self.allow_insecure_http)?;
let host = parsed
.host_str()
.ok_or_else(|| SwaggerCliError::InvalidSpec(format!("URL '{url}' has no host")))?;
let port = parsed.port_or_known_default().unwrap_or(443);
resolve_and_check(host, port, &self.allowed_private_hosts).await?;
let client = self.build_reqwest_client()?;
let mut attempts = 0u32;
loop {
let mut request = client.get(parsed.clone());
for (name, value) in &self.auth_headers {
request = request.header(name.as_str(), value.as_str());
}
if let Some(etag_val) = etag {
request = request.header(reqwest::header::IF_NONE_MATCH, etag_val);
}
if let Some(lm_val) = last_modified {
request = request.header(reqwest::header::IF_MODIFIED_SINCE, lm_val);
}
let response = request.send().await.map_err(SwaggerCliError::Network)?;
let status = response.status();
match status {
StatusCode::NOT_MODIFIED => {
return Ok(ConditionalFetchResult::NotModified);
}
s if s.is_success() => {
let result = self.read_response(response).await?;
return Ok(ConditionalFetchResult::Modified(result));
}
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
return Err(SwaggerCliError::Auth(format!(
"server returned {status} for '{url}'"
)));
}
StatusCode::NOT_FOUND => {
return Err(SwaggerCliError::InvalidSpec(format!(
"spec not found at '{url}' (404)"
)));
}
s if s == StatusCode::TOO_MANY_REQUESTS || s.is_server_error() => {
attempts += 1;
if attempts > self.max_retries {
return Err(SwaggerCliError::Network(
client.get(url).send().await.unwrap_err(),
));
}
let delay = self.retry_delay(&response, attempts);
tokio::time::sleep(delay).await;
}
_ => {
return Err(SwaggerCliError::InvalidSpec(format!(
"unexpected status {status} fetching '{url}'"
)));
}
}
}
}
pub async fn fetch_spec(&self, url: &str) -> Result<FetchResult, SwaggerCliError> {
// Check network policy before any HTTP request
check_remote_fetch(self.network_policy)?;
let parsed = validate_url(url, self.allow_insecure_http)?;
let host = parsed
.host_str()
.ok_or_else(|| SwaggerCliError::InvalidSpec(format!("URL '{url}' has no host")))?;
let port = parsed.port_or_known_default().unwrap_or(443);
resolve_and_check(host, port, &self.allowed_private_hosts).await?;
let client = self.build_reqwest_client()?;
let mut attempts = 0u32;
loop {
let mut request = client.get(parsed.clone());
for (name, value) in &self.auth_headers {
request = request.header(name.as_str(), value.as_str());
}
let response = request.send().await.map_err(SwaggerCliError::Network)?;
let status = response.status();
match status {
s if s.is_success() => {
return self.read_response(response).await;
}
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
return Err(SwaggerCliError::Auth(format!(
"server returned {status} for '{url}'"
)));
}
StatusCode::NOT_FOUND => {
return Err(SwaggerCliError::InvalidSpec(format!(
"spec not found at '{url}' (404)"
)));
}
s if s == StatusCode::TOO_MANY_REQUESTS || s.is_server_error() => {
attempts += 1;
if attempts > self.max_retries {
return Err(SwaggerCliError::Network(
client.get(url).send().await.unwrap_err(),
));
}
let delay = self.retry_delay(&response, attempts);
tokio::time::sleep(delay).await;
}
_ => {
return Err(SwaggerCliError::InvalidSpec(format!(
"unexpected status {status} fetching '{url}'"
)));
}
}
}
}
fn build_reqwest_client(&self) -> Result<reqwest::Client, SwaggerCliError> {
reqwest::Client::builder()
.connect_timeout(self.connect_timeout)
.timeout(self.overall_timeout)
.https_only(!self.allow_insecure_http)
.build()
.map_err(SwaggerCliError::Network)
}
async fn read_response(
&self,
response: reqwest::Response,
) -> Result<FetchResult, SwaggerCliError> {
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(String::from);
let etag = response
.headers()
.get(reqwest::header::ETAG)
.and_then(|v| v.to_str().ok())
.map(String::from);
let last_modified = response
.headers()
.get(reqwest::header::LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(String::from);
// Stream the body with a size limit
let mut bytes = Vec::new();
let mut stream = response;
while let Some(chunk) = stream.chunk().await.map_err(SwaggerCliError::Network)? {
bytes.extend_from_slice(&chunk);
if bytes.len() as u64 > self.max_bytes {
return Err(SwaggerCliError::PolicyBlocked(format!(
"response exceeds maximum size of {} bytes",
self.max_bytes
)));
}
}
Ok(FetchResult {
bytes,
content_type,
etag,
last_modified,
})
}
fn retry_delay(&self, _response: &reqwest::Response, attempt: u32) -> Duration {
// TODO: parse Retry-After header when present
RETRY_BASE_DELAY * 2u32.saturating_pow(attempt.saturating_sub(1))
}
}
// ---------------------------------------------------------------------------
// Builder
// ---------------------------------------------------------------------------
#[derive(Default)]
pub struct AsyncHttpClientBuilder {
connect_timeout: Option<Duration>,
overall_timeout: Option<Duration>,
max_bytes: Option<u64>,
max_retries: Option<u32>,
allow_insecure_http: bool,
allowed_private_hosts: Vec<String>,
auth_headers: Vec<(String, String)>,
network_policy: NetworkPolicy,
}
impl AsyncHttpClientBuilder {
pub fn connect_timeout(mut self, d: Duration) -> Self {
self.connect_timeout = Some(d);
self
}
pub fn overall_timeout(mut self, d: Duration) -> Self {
self.overall_timeout = Some(d);
self
}
pub fn max_bytes(mut self, n: u64) -> Self {
self.max_bytes = Some(n);
self
}
pub fn max_retries(mut self, n: u32) -> Self {
self.max_retries = Some(n);
self
}
pub fn allow_insecure_http(mut self, allow: bool) -> Self {
self.allow_insecure_http = allow;
self
}
pub fn allowed_private_hosts(mut self, hosts: Vec<String>) -> Self {
self.allowed_private_hosts = hosts;
self
}
pub fn auth_header(mut self, name: String, value: String) -> Self {
self.auth_headers.push((name, value));
self
}
pub fn network_policy(mut self, policy: NetworkPolicy) -> Self {
self.network_policy = policy;
self
}
pub fn build(self) -> AsyncHttpClient {
AsyncHttpClient {
connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT),
overall_timeout: self.overall_timeout.unwrap_or(DEFAULT_OVERALL_TIMEOUT),
max_bytes: self.max_bytes.unwrap_or(DEFAULT_MAX_BYTES),
max_retries: self.max_retries.unwrap_or(DEFAULT_MAX_RETRIES),
allow_insecure_http: self.allow_insecure_http,
allowed_private_hosts: self.allowed_private_hosts,
auth_headers: self.auth_headers,
network_policy: self.network_policy,
}
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
// -- SSRF IP blocking ---------------------------------------------------
#[test]
fn test_ssrf_blocks_loopback() {
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))));
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(
127, 255, 255, 254
))));
assert!(is_ip_blocked(&IpAddr::V6(Ipv6Addr::LOCALHOST)));
}
#[test]
fn test_ssrf_blocks_private() {
// 10.0.0.0/8
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255))));
// 172.16.0.0/12
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1))));
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(172, 31, 255, 255))));
// 192.168.0.0/16
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))));
}
#[test]
fn test_ssrf_blocks_link_local() {
// IPv4 link-local (169.254.x.x) -- includes the AWS metadata endpoint
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(
169, 254, 169, 254
))));
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(169, 254, 0, 1))));
// IPv6 link-local (fe80::/10)
assert!(is_ip_blocked(&IpAddr::V6(Ipv6Addr::new(
0xfe80, 0, 0, 0, 0, 0, 0, 1
))));
}
#[test]
fn test_ssrf_blocks_multicast() {
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1))));
assert!(is_ip_blocked(&IpAddr::V6(Ipv6Addr::new(
0xff02, 0, 0, 0, 0, 0, 0, 1
))));
}
#[test]
fn test_ssrf_blocks_mapped_v4() {
// ::ffff:127.0.0.1
let mapped = Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x7f00, 0x0001);
assert!(is_ip_blocked(&IpAddr::V6(mapped)));
// ::ffff:10.0.0.1
let mapped_private = Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x0a00, 0x0001);
assert!(is_ip_blocked(&IpAddr::V6(mapped_private)));
}
#[test]
fn test_ssrf_allows_public() {
assert!(!is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
assert!(!is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))));
assert!(!is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34))));
}
// -- URL validation -----------------------------------------------------
#[test]
fn test_url_rejects_http() {
let result = validate_url("http://example.com/spec.json", false);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, SwaggerCliError::PolicyBlocked(_)));
}
#[test]
fn test_url_allows_https() {
let result = validate_url("https://example.com/spec.json", false);
assert!(result.is_ok());
assert_eq!(result.unwrap().as_str(), "https://example.com/spec.json");
}
#[test]
fn test_url_allows_http_when_opted_in() {
let result = validate_url("http://example.com/spec.json", true);
assert!(result.is_ok());
}
#[test]
fn test_url_rejects_unsupported_scheme() {
let result = validate_url("ftp://example.com/spec.json", false);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SwaggerCliError::InvalidSpec(_)
));
}
#[test]
fn test_url_rejects_garbage() {
let result = validate_url("not a url at all", false);
assert!(result.is_err());
}
// -- Builder defaults ---------------------------------------------------
#[test]
fn test_builder_defaults() {
let client = AsyncHttpClient::builder().build();
assert_eq!(client.connect_timeout, DEFAULT_CONNECT_TIMEOUT);
assert_eq!(client.overall_timeout, DEFAULT_OVERALL_TIMEOUT);
assert_eq!(client.max_bytes, DEFAULT_MAX_BYTES);
assert_eq!(client.max_retries, DEFAULT_MAX_RETRIES);
assert!(!client.allow_insecure_http);
assert!(client.allowed_private_hosts.is_empty());
assert!(client.auth_headers.is_empty());
}
#[test]
fn test_builder_custom() {
let client = AsyncHttpClient::builder()
.connect_timeout(Duration::from_secs(3))
.overall_timeout(Duration::from_secs(30))
.max_bytes(1024)
.max_retries(5)
.allow_insecure_http(true)
.allowed_private_hosts(vec!["internal.corp".into()])
.auth_header("Authorization".into(), "Bearer tok".into())
.build();
assert_eq!(client.connect_timeout, Duration::from_secs(3));
assert_eq!(client.overall_timeout, Duration::from_secs(30));
assert_eq!(client.max_bytes, 1024);
assert_eq!(client.max_retries, 5);
assert!(client.allow_insecure_http);
assert_eq!(client.allowed_private_hosts, vec!["internal.corp"]);
assert_eq!(client.auth_headers.len(), 1);
}
// -- DNS + SSRF integration (async) -------------------------------------
#[tokio::test]
async fn test_resolve_and_check_skips_allowed_host() {
let result = resolve_and_check("localhost", 80, &["localhost".into()]).await;
assert!(result.is_ok());
}
}