Wave 3: Cache read path with integrity validation, async HTTP client with SSRF protection (bd-3ea, bd-3b6)
This commit is contained in:
@@ -229,6 +229,176 @@ impl CacheManager {
|
||||
|
||||
Ok(meta)
|
||||
}
|
||||
|
||||
/// Load a cached spec index with integrity validation.
|
||||
///
|
||||
/// Reads `meta.json` first (as commit marker), then `index.json`.
|
||||
/// Validates that index_version, generation, and index_hash all match
|
||||
/// between meta and the on-disk index. Returns `AliasNotFound` if
|
||||
/// meta.json is missing, `CacheIntegrity` on any mismatch.
|
||||
pub fn load_index(
|
||||
&self,
|
||||
alias: &str,
|
||||
) -> Result<(SpecIndex, CacheMetadata), SwaggerCliError> {
|
||||
validate_alias(alias)?;
|
||||
let dir = self.alias_dir(alias);
|
||||
|
||||
let meta_path = dir.join("meta.json");
|
||||
let meta_bytes = fs::read(&meta_path).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
SwaggerCliError::AliasNotFound(alias.to_string())
|
||||
} else {
|
||||
SwaggerCliError::Cache(format!(
|
||||
"Failed to read {}: {e}",
|
||||
meta_path.display()
|
||||
))
|
||||
}
|
||||
})?;
|
||||
let meta: CacheMetadata = serde_json::from_slice(&meta_bytes).map_err(|e| {
|
||||
SwaggerCliError::CacheIntegrity(format!(
|
||||
"Corrupt meta.json for alias '{alias}': {e}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let index_path = dir.join("index.json");
|
||||
let index_bytes = fs::read(&index_path).map_err(|e| {
|
||||
SwaggerCliError::Cache(format!(
|
||||
"Failed to read {}: {e}",
|
||||
index_path.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
let actual_hash = compute_hash(&index_bytes);
|
||||
if meta.index_hash != actual_hash {
|
||||
return Err(SwaggerCliError::CacheIntegrity(format!(
|
||||
"Index hash mismatch for '{alias}': expected {}, got {actual_hash}",
|
||||
meta.index_hash
|
||||
)));
|
||||
}
|
||||
|
||||
let index: SpecIndex = serde_json::from_slice(&index_bytes).map_err(|e| {
|
||||
SwaggerCliError::CacheIntegrity(format!(
|
||||
"Corrupt index.json for alias '{alias}': {e}"
|
||||
))
|
||||
})?;
|
||||
|
||||
if meta.index_version != index.index_version {
|
||||
return Err(SwaggerCliError::CacheIntegrity(format!(
|
||||
"index_version mismatch for '{alias}': meta={}, index={}",
|
||||
meta.index_version, index.index_version
|
||||
)));
|
||||
}
|
||||
|
||||
if meta.generation != index.generation {
|
||||
return Err(SwaggerCliError::CacheIntegrity(format!(
|
||||
"generation mismatch for '{alias}': meta={}, index={}",
|
||||
meta.generation, index.generation
|
||||
)));
|
||||
}
|
||||
|
||||
// Best-effort coalesced last_accessed update (no lock, ignore errors)
|
||||
let age = Utc::now() - meta.last_accessed;
|
||||
if age.num_minutes() > 10 {
|
||||
let mut updated_meta = meta.clone();
|
||||
updated_meta.last_accessed = Utc::now();
|
||||
if let Ok(bytes) = serde_json::to_vec_pretty(&updated_meta) {
|
||||
let _ = write_atomic(&meta_path, &bytes);
|
||||
}
|
||||
}
|
||||
|
||||
Ok((index, meta))
|
||||
}
|
||||
|
||||
/// Load the raw spec JSON with hash validation against metadata.
|
||||
pub fn load_raw(
|
||||
&self,
|
||||
alias: &str,
|
||||
meta: &CacheMetadata,
|
||||
) -> Result<serde_json::Value, SwaggerCliError> {
|
||||
let raw_path = self.alias_dir(alias).join("raw.json");
|
||||
let raw_bytes = fs::read(&raw_path).map_err(|e| {
|
||||
SwaggerCliError::Cache(format!(
|
||||
"Failed to read {}: {e}",
|
||||
raw_path.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
let actual_hash = compute_hash(&raw_bytes);
|
||||
if meta.raw_hash != actual_hash {
|
||||
return Err(SwaggerCliError::CacheIntegrity(format!(
|
||||
"Raw hash mismatch for '{}': expected {}, got {actual_hash}",
|
||||
alias, meta.raw_hash
|
||||
)));
|
||||
}
|
||||
|
||||
let value: serde_json::Value =
|
||||
serde_json::from_slice(&raw_bytes).map_err(|e| {
|
||||
SwaggerCliError::Cache(format!(
|
||||
"Failed to parse raw.json for '{}': {e}",
|
||||
alias
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
/// List all cached aliases by reading meta.json from each subdirectory.
|
||||
///
|
||||
/// Skips directories with missing or unreadable metadata (no panic).
|
||||
pub fn list_aliases(&self) -> Result<Vec<CacheMetadata>, SwaggerCliError> {
|
||||
let entries = fs::read_dir(&self.cache_dir).map_err(|e| {
|
||||
SwaggerCliError::Cache(format!(
|
||||
"Failed to read cache directory {}: {e}",
|
||||
self.cache_dir.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
for entry in entries {
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let path = entry.path();
|
||||
if !path.is_dir() {
|
||||
continue;
|
||||
}
|
||||
let meta_path = path.join("meta.json");
|
||||
let bytes = match fs::read(&meta_path) {
|
||||
Ok(b) => b,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if let Ok(meta) = serde_json::from_slice::<CacheMetadata>(&bytes) {
|
||||
results.push(meta);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Check whether a cached alias exists (meta.json present).
|
||||
pub fn alias_exists(&self, alias: &str) -> bool {
|
||||
self.alias_dir(alias).join("meta.json").exists()
|
||||
}
|
||||
|
||||
/// Delete a cached alias directory (requires lock).
|
||||
pub fn delete_alias(&self, alias: &str) -> Result<(), SwaggerCliError> {
|
||||
validate_alias(alias)?;
|
||||
let dir = self.alias_dir(alias);
|
||||
if !dir.exists() {
|
||||
return Err(SwaggerCliError::AliasNotFound(alias.to_string()));
|
||||
}
|
||||
|
||||
let _lock = self.acquire_lock(alias)?;
|
||||
fs::remove_dir_all(&dir).map_err(|e| {
|
||||
SwaggerCliError::Cache(format!(
|
||||
"Failed to delete alias directory {}: {e}",
|
||||
dir.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Write `data` to `path.tmp`, fsync, then rename to `path`.
|
||||
@@ -433,4 +603,124 @@ mod tests {
|
||||
assert_eq!(meta.source_format, "yaml");
|
||||
assert!(meta.content_hash.starts_with("sha256:"));
|
||||
}
|
||||
|
||||
/// Helper: write a cache entry and return the manager + tempdir for further testing.
|
||||
fn write_test_cache(alias: &str) -> (CacheManager, tempfile::TempDir, CacheMetadata) {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let manager = CacheManager::new(tmp.path().to_path_buf());
|
||||
let index = make_test_index();
|
||||
|
||||
let meta = manager
|
||||
.write_cache(
|
||||
alias,
|
||||
b"openapi: 3.0.3",
|
||||
b"{\"openapi\":\"3.0.3\"}",
|
||||
&index,
|
||||
Some("https://example.com/api.json".into()),
|
||||
"1.0.0",
|
||||
"Test API",
|
||||
"yaml",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
(manager, tmp, meta)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_index_success() {
|
||||
let (manager, _tmp, written_meta) = write_test_cache("loadtest");
|
||||
|
||||
let (index, loaded_meta) = manager.load_index("loadtest").unwrap();
|
||||
assert_eq!(loaded_meta.alias, "loadtest");
|
||||
assert_eq!(loaded_meta.generation, written_meta.generation);
|
||||
assert_eq!(loaded_meta.index_hash, written_meta.index_hash);
|
||||
assert_eq!(index.index_version, 1);
|
||||
assert_eq!(index.openapi, "3.0.3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_index_integrity_check() {
|
||||
let (manager, tmp, _meta) = write_test_cache("tampered");
|
||||
|
||||
// Tamper with index.json
|
||||
let index_path = tmp.path().join("tampered").join("index.json");
|
||||
fs::write(&index_path, b"{\"corrupted\": true}").unwrap();
|
||||
|
||||
let result = manager.load_index("tampered");
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, SwaggerCliError::CacheIntegrity(_)),
|
||||
"Expected CacheIntegrity, got: {err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_index_missing_meta() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let manager = CacheManager::new(tmp.path().to_path_buf());
|
||||
|
||||
let result = manager.load_index("nonexistent");
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, SwaggerCliError::AliasNotFound(_)),
|
||||
"Expected AliasNotFound, got: {err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_raw_validates_hash() {
|
||||
let (manager, tmp, meta) = write_test_cache("rawtest");
|
||||
|
||||
// Tamper with raw.json
|
||||
let raw_path = tmp.path().join("rawtest").join("raw.json");
|
||||
fs::write(&raw_path, b"{\"tampered\": true}").unwrap();
|
||||
|
||||
let result = manager.load_raw("rawtest", &meta);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, SwaggerCliError::CacheIntegrity(_)),
|
||||
"Expected CacheIntegrity, got: {err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_aliases() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let manager = CacheManager::new(tmp.path().to_path_buf());
|
||||
let index = make_test_index();
|
||||
|
||||
manager
|
||||
.write_cache(
|
||||
"api1", b"src1", b"{}", &index, None, "1.0", "API 1", "json",
|
||||
None, None, None,
|
||||
)
|
||||
.unwrap();
|
||||
manager
|
||||
.write_cache(
|
||||
"api2", b"src2", b"{}", &index, None, "2.0", "API 2", "yaml",
|
||||
None, None, None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let aliases = manager.list_aliases().unwrap();
|
||||
assert_eq!(aliases.len(), 2);
|
||||
|
||||
let names: Vec<&str> = aliases.iter().map(|m| m.alias.as_str()).collect();
|
||||
assert!(names.contains(&"api1"));
|
||||
assert!(names.contains(&"api2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_alias_exists() {
|
||||
let (manager, _tmp, _meta) = write_test_cache("exists");
|
||||
|
||||
assert!(manager.alias_exists("exists"));
|
||||
assert!(!manager.alias_exists("nope"));
|
||||
}
|
||||
}
|
||||
|
||||
512
src/core/http.rs
Normal file
512
src/core/http.rs
Normal file
@@ -0,0 +1,512 @@
|
||||
use std::net::IpAddr;
|
||||
use std::time::Duration;
|
||||
|
||||
use reqwest::{StatusCode, Url};
|
||||
use tokio::net::lookup_host;
|
||||
|
||||
use crate::errors::SwaggerCliError;
|
||||
|
||||
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const DEFAULT_OVERALL_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const DEFAULT_MAX_BYTES: u64 = 25 * 1024 * 1024; // 25 MB
|
||||
const DEFAULT_MAX_RETRIES: u32 = 2;
|
||||
const RETRY_BASE_DELAY: Duration = Duration::from_millis(500);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SSRF protection
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn is_ip_blocked(ip: &IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(v4) => {
|
||||
v4.is_loopback() // 127.0.0.0/8
|
||||
|| v4.is_link_local() // 169.254.0.0/16
|
||||
|| v4.is_broadcast() // 255.255.255.255
|
||||
|| v4.is_unspecified() // 0.0.0.0
|
||||
|| v4.is_multicast() // 224.0.0.0/4
|
||||
|| is_private_v4(v4)
|
||||
}
|
||||
IpAddr::V6(v6) => {
|
||||
v6.is_loopback() // ::1
|
||||
|| v6.is_unspecified() // ::
|
||||
|| v6.is_multicast() // ff00::/8
|
||||
|| is_link_local_v6(v6) // fe80::/10
|
||||
|| is_blocked_mapped_v4(v6)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool {
|
||||
let octets = ip.octets();
|
||||
// 10.0.0.0/8
|
||||
octets[0] == 10
|
||||
// 172.16.0.0/12
|
||||
|| (octets[0] == 172 && (16..=31).contains(&octets[1]))
|
||||
// 192.168.0.0/16
|
||||
|| (octets[0] == 192 && octets[1] == 168)
|
||||
}
|
||||
|
||||
fn is_link_local_v6(ip: &std::net::Ipv6Addr) -> bool {
|
||||
let segments = ip.segments();
|
||||
// fe80::/10 — first 10 bits are 1111_1110_10
|
||||
(segments[0] & 0xffc0) == 0xfe80
|
||||
}
|
||||
|
||||
fn is_blocked_mapped_v4(v6: &std::net::Ipv6Addr) -> bool {
|
||||
// ::ffff:x.x.x.x — IPv4-mapped IPv6
|
||||
let segments = v6.segments();
|
||||
if segments[0..5] == [0, 0, 0, 0, 0] && segments[5] == 0xffff {
|
||||
let v4 = std::net::Ipv4Addr::new(
|
||||
(segments[6] >> 8) as u8,
|
||||
segments[6] as u8,
|
||||
(segments[7] >> 8) as u8,
|
||||
segments[7] as u8,
|
||||
);
|
||||
return is_ip_blocked(&IpAddr::V4(v4));
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// URL validation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn validate_url(url: &str, allow_insecure_http: bool) -> Result<Url, SwaggerCliError> {
|
||||
let parsed = Url::parse(url).map_err(|e| {
|
||||
SwaggerCliError::InvalidSpec(format!("invalid URL '{url}': {e}"))
|
||||
})?;
|
||||
|
||||
match parsed.scheme() {
|
||||
"https" => Ok(parsed),
|
||||
"http" if allow_insecure_http => Ok(parsed),
|
||||
"http" => Err(SwaggerCliError::PolicyBlocked(
|
||||
format!("HTTP is not allowed for '{url}'. Use --allow-insecure-http to override."),
|
||||
)),
|
||||
other => Err(SwaggerCliError::InvalidSpec(
|
||||
format!("unsupported scheme '{other}' in URL '{url}'"),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DNS resolution + SSRF check
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn resolve_and_check(
|
||||
host: &str,
|
||||
port: u16,
|
||||
allowed_private_hosts: &[String],
|
||||
) -> Result<(), SwaggerCliError> {
|
||||
if allowed_private_hosts.iter().any(|h| h == host) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let addr = format!("{host}:{port}");
|
||||
let addrs: Vec<_> = match lookup_host(&addr).await {
|
||||
Ok(iter) => iter.collect(),
|
||||
Err(e) => {
|
||||
return Err(SwaggerCliError::InvalidSpec(
|
||||
format!("DNS resolution failed for '{host}': {e}"),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if addrs.is_empty() {
|
||||
return Err(SwaggerCliError::InvalidSpec(
|
||||
format!("DNS resolution returned no addresses for '{host}'"),
|
||||
));
|
||||
}
|
||||
|
||||
for socket_addr in &addrs {
|
||||
if is_ip_blocked(&socket_addr.ip()) {
|
||||
return Err(SwaggerCliError::PolicyBlocked(format!(
|
||||
"resolved IP {} for host '{host}' is in a blocked range. \
|
||||
Use --allow-private-host {host} to override.",
|
||||
socket_addr.ip()
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// FetchResult
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FetchResult {
|
||||
pub bytes: Vec<u8>,
|
||||
pub content_type: Option<String>,
|
||||
pub etag: Option<String>,
|
||||
pub last_modified: Option<String>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AsyncHttpClient builder
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct AsyncHttpClient {
|
||||
connect_timeout: Duration,
|
||||
overall_timeout: Duration,
|
||||
max_bytes: u64,
|
||||
max_retries: u32,
|
||||
allow_insecure_http: bool,
|
||||
allowed_private_hosts: Vec<String>,
|
||||
auth_headers: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl Default for AsyncHttpClient {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
|
||||
overall_timeout: DEFAULT_OVERALL_TIMEOUT,
|
||||
max_bytes: DEFAULT_MAX_BYTES,
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
allow_insecure_http: false,
|
||||
allowed_private_hosts: Vec::new(),
|
||||
auth_headers: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncHttpClient {
|
||||
pub fn builder() -> AsyncHttpClientBuilder {
|
||||
AsyncHttpClientBuilder::default()
|
||||
}
|
||||
|
||||
pub async fn fetch_spec(&self, url: &str) -> Result<FetchResult, SwaggerCliError> {
|
||||
let parsed = validate_url(url, self.allow_insecure_http)?;
|
||||
|
||||
let host = parsed.host_str().ok_or_else(|| {
|
||||
SwaggerCliError::InvalidSpec(format!("URL '{url}' has no host"))
|
||||
})?;
|
||||
let port = parsed.port_or_known_default().unwrap_or(443);
|
||||
|
||||
resolve_and_check(host, port, &self.allowed_private_hosts).await?;
|
||||
|
||||
let client = self.build_reqwest_client()?;
|
||||
|
||||
let mut attempts = 0u32;
|
||||
loop {
|
||||
let mut request = client.get(parsed.clone());
|
||||
for (name, value) in &self.auth_headers {
|
||||
request = request.header(name.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
let response = request.send().await.map_err(SwaggerCliError::Network)?;
|
||||
let status = response.status();
|
||||
|
||||
match status {
|
||||
s if s.is_success() => {
|
||||
return self.read_response(response).await;
|
||||
}
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
return Err(SwaggerCliError::Auth(format!(
|
||||
"server returned {status} for '{url}'"
|
||||
)));
|
||||
}
|
||||
StatusCode::NOT_FOUND => {
|
||||
return Err(SwaggerCliError::InvalidSpec(format!(
|
||||
"spec not found at '{url}' (404)"
|
||||
)));
|
||||
}
|
||||
s if s == StatusCode::TOO_MANY_REQUESTS || s.is_server_error() => {
|
||||
attempts += 1;
|
||||
if attempts > self.max_retries {
|
||||
return Err(SwaggerCliError::Network(
|
||||
client
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.unwrap_err(),
|
||||
));
|
||||
}
|
||||
let delay = self.retry_delay(&response, attempts);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
_ => {
|
||||
return Err(SwaggerCliError::InvalidSpec(format!(
|
||||
"unexpected status {status} fetching '{url}'"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_reqwest_client(&self) -> Result<reqwest::Client, SwaggerCliError> {
|
||||
reqwest::Client::builder()
|
||||
.connect_timeout(self.connect_timeout)
|
||||
.timeout(self.overall_timeout)
|
||||
.https_only(!self.allow_insecure_http)
|
||||
.build()
|
||||
.map_err(SwaggerCliError::Network)
|
||||
}
|
||||
|
||||
async fn read_response(
|
||||
&self,
|
||||
response: reqwest::Response,
|
||||
) -> Result<FetchResult, SwaggerCliError> {
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let etag = response
|
||||
.headers()
|
||||
.get(reqwest::header::ETAG)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let last_modified = response
|
||||
.headers()
|
||||
.get(reqwest::header::LAST_MODIFIED)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
// Stream the body with a size limit
|
||||
let mut bytes = Vec::new();
|
||||
let mut stream = response;
|
||||
while let Some(chunk) = stream.chunk().await.map_err(SwaggerCliError::Network)? {
|
||||
bytes.extend_from_slice(&chunk);
|
||||
if bytes.len() as u64 > self.max_bytes {
|
||||
return Err(SwaggerCliError::PolicyBlocked(format!(
|
||||
"response exceeds maximum size of {} bytes",
|
||||
self.max_bytes
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(FetchResult {
|
||||
bytes,
|
||||
content_type,
|
||||
etag,
|
||||
last_modified,
|
||||
})
|
||||
}
|
||||
|
||||
fn retry_delay(&self, _response: &reqwest::Response, attempt: u32) -> Duration {
|
||||
// TODO: parse Retry-After header when present
|
||||
RETRY_BASE_DELAY * 2u32.saturating_pow(attempt.saturating_sub(1))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Builder
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AsyncHttpClientBuilder {
|
||||
connect_timeout: Option<Duration>,
|
||||
overall_timeout: Option<Duration>,
|
||||
max_bytes: Option<u64>,
|
||||
max_retries: Option<u32>,
|
||||
allow_insecure_http: bool,
|
||||
allowed_private_hosts: Vec<String>,
|
||||
auth_headers: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl AsyncHttpClientBuilder {
|
||||
pub fn connect_timeout(mut self, d: Duration) -> Self {
|
||||
self.connect_timeout = Some(d);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn overall_timeout(mut self, d: Duration) -> Self {
|
||||
self.overall_timeout = Some(d);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_bytes(mut self, n: u64) -> Self {
|
||||
self.max_bytes = Some(n);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_retries(mut self, n: u32) -> Self {
|
||||
self.max_retries = Some(n);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn allow_insecure_http(mut self, allow: bool) -> Self {
|
||||
self.allow_insecure_http = allow;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn allowed_private_hosts(mut self, hosts: Vec<String>) -> Self {
|
||||
self.allowed_private_hosts = hosts;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn auth_header(mut self, name: String, value: String) -> Self {
|
||||
self.auth_headers.push((name, value));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> AsyncHttpClient {
|
||||
AsyncHttpClient {
|
||||
connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT),
|
||||
overall_timeout: self.overall_timeout.unwrap_or(DEFAULT_OVERALL_TIMEOUT),
|
||||
max_bytes: self.max_bytes.unwrap_or(DEFAULT_MAX_BYTES),
|
||||
max_retries: self.max_retries.unwrap_or(DEFAULT_MAX_RETRIES),
|
||||
allow_insecure_http: self.allow_insecure_http,
|
||||
allowed_private_hosts: self.allowed_private_hosts,
|
||||
auth_headers: self.auth_headers,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
|
||||
// -- SSRF IP blocking ---------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_blocks_loopback() {
|
||||
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))));
|
||||
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(127, 255, 255, 254))));
|
||||
assert!(is_ip_blocked(&IpAddr::V6(Ipv6Addr::LOCALHOST)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_blocks_private() {
|
||||
// 10.0.0.0/8
|
||||
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
|
||||
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255))));
|
||||
|
||||
// 172.16.0.0/12
|
||||
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1))));
|
||||
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(172, 31, 255, 255))));
|
||||
|
||||
// 192.168.0.0/16
|
||||
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
|
||||
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_blocks_link_local() {
|
||||
// IPv4 link-local (169.254.x.x) -- includes the AWS metadata endpoint
|
||||
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))));
|
||||
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(169, 254, 0, 1))));
|
||||
|
||||
// IPv6 link-local (fe80::/10)
|
||||
assert!(is_ip_blocked(&IpAddr::V6(Ipv6Addr::new(
|
||||
0xfe80, 0, 0, 0, 0, 0, 0, 1
|
||||
))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_blocks_multicast() {
|
||||
assert!(is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1))));
|
||||
assert!(is_ip_blocked(&IpAddr::V6(Ipv6Addr::new(
|
||||
0xff02, 0, 0, 0, 0, 0, 0, 1
|
||||
))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_blocks_mapped_v4() {
|
||||
// ::ffff:127.0.0.1
|
||||
let mapped = Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x7f00, 0x0001);
|
||||
assert!(is_ip_blocked(&IpAddr::V6(mapped)));
|
||||
|
||||
// ::ffff:10.0.0.1
|
||||
let mapped_private = Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x0a00, 0x0001);
|
||||
assert!(is_ip_blocked(&IpAddr::V6(mapped_private)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_allows_public() {
|
||||
assert!(!is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
|
||||
assert!(!is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))));
|
||||
assert!(!is_ip_blocked(&IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34))));
|
||||
}
|
||||
|
||||
// -- URL validation -----------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_url_rejects_http() {
|
||||
let result = validate_url("http://example.com/spec.json", false);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(err, SwaggerCliError::PolicyBlocked(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_allows_https() {
|
||||
let result = validate_url("https://example.com/spec.json", false);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(
|
||||
result.unwrap().as_str(),
|
||||
"https://example.com/spec.json"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_allows_http_when_opted_in() {
|
||||
let result = validate_url("http://example.com/spec.json", true);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_rejects_unsupported_scheme() {
|
||||
let result = validate_url("ftp://example.com/spec.json", false);
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), SwaggerCliError::InvalidSpec(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_rejects_garbage() {
|
||||
let result = validate_url("not a url at all", false);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// -- Builder defaults ---------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_builder_defaults() {
|
||||
let client = AsyncHttpClient::builder().build();
|
||||
assert_eq!(client.connect_timeout, DEFAULT_CONNECT_TIMEOUT);
|
||||
assert_eq!(client.overall_timeout, DEFAULT_OVERALL_TIMEOUT);
|
||||
assert_eq!(client.max_bytes, DEFAULT_MAX_BYTES);
|
||||
assert_eq!(client.max_retries, DEFAULT_MAX_RETRIES);
|
||||
assert!(!client.allow_insecure_http);
|
||||
assert!(client.allowed_private_hosts.is_empty());
|
||||
assert!(client.auth_headers.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_custom() {
|
||||
let client = AsyncHttpClient::builder()
|
||||
.connect_timeout(Duration::from_secs(3))
|
||||
.overall_timeout(Duration::from_secs(30))
|
||||
.max_bytes(1024)
|
||||
.max_retries(5)
|
||||
.allow_insecure_http(true)
|
||||
.allowed_private_hosts(vec!["internal.corp".into()])
|
||||
.auth_header("Authorization".into(), "Bearer tok".into())
|
||||
.build();
|
||||
|
||||
assert_eq!(client.connect_timeout, Duration::from_secs(3));
|
||||
assert_eq!(client.overall_timeout, Duration::from_secs(30));
|
||||
assert_eq!(client.max_bytes, 1024);
|
||||
assert_eq!(client.max_retries, 5);
|
||||
assert!(client.allow_insecure_http);
|
||||
assert_eq!(client.allowed_private_hosts, vec!["internal.corp"]);
|
||||
assert_eq!(client.auth_headers.len(), 1);
|
||||
}
|
||||
|
||||
// -- DNS + SSRF integration (async) -------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_resolve_and_check_skips_allowed_host() {
|
||||
let result =
|
||||
resolve_and_check("localhost", 80, &["localhost".into()]).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod cache;
|
||||
pub mod config;
|
||||
pub mod http;
|
||||
pub mod indexer;
|
||||
pub mod spec;
|
||||
|
||||
Reference in New Issue
Block a user