From 22ceb8c883644790c2aa36c9140027f21550d276 Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 30 Mar 2026 13:25:47 +0200 Subject: [PATCH] Only allow measurements to be fetch with plain HTTP from localhost / loopback device --- crates/attestation/src/measurements.rs | 70 ++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/crates/attestation/src/measurements.rs b/crates/attestation/src/measurements.rs index 3484be6..74e9ea7 100644 --- a/crates/attestation/src/measurements.rs +++ b/crates/attestation/src/measurements.rs @@ -1,9 +1,9 @@ //! Measurements and policy for enforcing them when validating a remote //! attestation -use std::{collections::HashMap, fmt, fmt::Formatter, path::PathBuf}; +use std::{collections::HashMap, fmt, fmt::Formatter, net::IpAddr, path::PathBuf}; use dcap_qvl::quote::Report; -use http::{HeaderValue, header::InvalidHeaderValue}; +use http::{HeaderValue, header::InvalidHeaderValue, uri::InvalidUri}; use serde::Deserialize; use thiserror::Error; @@ -268,6 +268,10 @@ pub enum MeasurementFormatError { ParseInt(#[from] std::num::ParseIntError), #[error("Failed to read measurements from URL: {0}")] Reqwest(#[from] reqwest::Error), + #[error("Invalid URL: {0}")] + InvalidUri(#[from] InvalidUri), + #[error("Refusing to load measurement policy over plain HTTP from non-loopback host: {0}")] + InsecureHttpNotLoopback(String), #[error("Measurement entry for register '{0}' has both 'expected' and 'expected_any'")] BothExpectedAndExpectedAny(String), #[error("Measurement entry for register '{0}' has neither 'expected' nor 'expected_any'")] @@ -428,7 +432,14 @@ impl MeasurementPolicy { /// Given either a URL or the path to a file, parse the measurement /// policy from JSON pub async fn from_file_or_url(file_or_url: String) -> Result { - if file_or_url.starts_with("https://") || file_or_url.starts_with("http://") { + if file_or_url.starts_with("https://") { + let measurements_json = reqwest::get(file_or_url).await?.bytes().await?; + Self::from_json_bytes(measurements_json.to_vec()) + } else if file_or_url.starts_with("http://") { + if !Self::is_loopback_http_url(&file_or_url)? { + return Err(MeasurementFormatError::InsecureHttpNotLoopback(file_or_url)); + } + let measurements_json = reqwest::get(file_or_url).await?.bytes().await?; Self::from_json_bytes(measurements_json.to_vec()) } else { @@ -552,6 +563,20 @@ impl MeasurementPolicy { Ok(MeasurementPolicy { accepted_measurements: measurement_policy }) } + + /// Determine whether a url is local / loopback device + /// + /// This is used to decide whether to allow fetching in plaintext http + fn is_loopback_http_url(url: &str) -> Result { + let uri: http::Uri = url.parse()?; + let Some(host) = uri.host() else { + return Ok(false); + }; + let normalized_host = host.trim_start_matches('[').trim_end_matches(']'); + + Ok(normalized_host.eq_ignore_ascii_case("localhost") || + normalized_host.parse::().is_ok_and(|address| address.is_loopback())) + } } #[cfg(test)] @@ -1021,4 +1046,43 @@ mod tests { assert!(azure_debug.contains(&hex::encode(azure_register_value))); assert!(!azure_debug.contains(&format!("{azure_register_value:?}"))); } + + #[tokio::test] + async fn test_from_file_or_url_rejects_non_loopback_http() { + let result = + MeasurementPolicy::from_file_or_url("http://example.com/measurements.json".into()) + .await; + + assert!(matches!( + result, + Err(MeasurementFormatError::InsecureHttpNotLoopback(url)) + if url == "http://example.com/measurements.json" + )); + } + + #[tokio::test] + async fn test_from_file_or_url_allows_http_localhost() { + let result = + MeasurementPolicy::from_file_or_url("http://localhost:1/measurements.json".into()) + .await; + + assert!(matches!(result, Err(MeasurementFormatError::Reqwest(_)))); + } + + #[tokio::test] + async fn test_from_file_or_url_allows_http_ipv4_loopback() { + let result = + MeasurementPolicy::from_file_or_url("http://127.0.0.1:1/measurements.json".into()) + .await; + + assert!(matches!(result, Err(MeasurementFormatError::Reqwest(_)))); + } + + #[tokio::test] + async fn test_from_file_or_url_allows_http_ipv6_loopback() { + let result = + MeasurementPolicy::from_file_or_url("http://[::1]:1/measurements.json".into()).await; + + assert!(matches!(result, Err(MeasurementFormatError::Reqwest(_)))); + } }