Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions crates/attestation/src/measurements.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! Measurements and policy for enforcing them when validating a remote
//! attestation
use std::{collections::HashMap, fmt, fmt::Formatter, path::PathBuf};
use std::{collections::HashMap, fmt, fmt::Formatter, net::IpAddr, path::PathBuf};

use dcap_qvl::quote::Report;
use http::{HeaderValue, header::InvalidHeaderValue};
use http::{HeaderValue, header::InvalidHeaderValue, uri::InvalidUri};
use serde::Deserialize;
use thiserror::Error;

Expand Down Expand Up @@ -268,6 +268,10 @@ pub enum MeasurementFormatError {
ParseInt(#[from] std::num::ParseIntError),
#[error("Failed to read measurements from URL: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("Invalid URL: {0}")]
InvalidUri(#[from] InvalidUri),
#[error("Refusing to load measurement policy over plain HTTP from non-loopback host: {0}")]
InsecureHttpNotLoopback(String),
#[error("Measurement entry for register '{0}' has both 'expected' and 'expected_any'")]
BothExpectedAndExpectedAny(String),
#[error("Measurement entry for register '{0}' has neither 'expected' nor 'expected_any'")]
Expand Down Expand Up @@ -428,7 +432,14 @@ impl MeasurementPolicy {
/// Given either a URL or the path to a file, parse the measurement
/// policy from JSON
pub async fn from_file_or_url(file_or_url: String) -> Result<Self, MeasurementFormatError> {
if file_or_url.starts_with("https://") || file_or_url.starts_with("http://") {
if file_or_url.starts_with("https://") {
let measurements_json = reqwest::get(file_or_url).await?.bytes().await?;
Self::from_json_bytes(measurements_json.to_vec())
} else if file_or_url.starts_with("http://") {
if !Self::is_loopback_http_url(&file_or_url)? {
return Err(MeasurementFormatError::InsecureHttpNotLoopback(file_or_url));
}

let measurements_json = reqwest::get(file_or_url).await?.bytes().await?;
Self::from_json_bytes(measurements_json.to_vec())
} else {
Expand Down Expand Up @@ -552,6 +563,20 @@ impl MeasurementPolicy {

Ok(MeasurementPolicy { accepted_measurements: measurement_policy })
}

/// Determine whether a url is local / loopback device
///
/// This is used to decide whether to allow fetching in plaintext http
fn is_loopback_http_url(url: &str) -> Result<bool, MeasurementFormatError> {
let uri: http::Uri = url.parse()?;
let Some(host) = uri.host() else {
return Ok(false);
};
let normalized_host = host.trim_start_matches('[').trim_end_matches(']');

Ok(normalized_host.eq_ignore_ascii_case("localhost") ||
normalized_host.parse::<IpAddr>().is_ok_and(|address| address.is_loopback()))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -1021,4 +1046,43 @@ mod tests {
assert!(azure_debug.contains(&hex::encode(azure_register_value)));
assert!(!azure_debug.contains(&format!("{azure_register_value:?}")));
}

#[tokio::test]
async fn test_from_file_or_url_rejects_non_loopback_http() {
let result =
MeasurementPolicy::from_file_or_url("http://example.com/measurements.json".into())
.await;

assert!(matches!(
result,
Err(MeasurementFormatError::InsecureHttpNotLoopback(url))
if url == "http://example.com/measurements.json"
));
}

#[tokio::test]
async fn test_from_file_or_url_allows_http_localhost() {
let result =
MeasurementPolicy::from_file_or_url("http://localhost:1/measurements.json".into())
.await;

assert!(matches!(result, Err(MeasurementFormatError::Reqwest(_))));
}

#[tokio::test]
async fn test_from_file_or_url_allows_http_ipv4_loopback() {
let result =
MeasurementPolicy::from_file_or_url("http://127.0.0.1:1/measurements.json".into())
.await;

assert!(matches!(result, Err(MeasurementFormatError::Reqwest(_))));
}

#[tokio::test]
async fn test_from_file_or_url_allows_http_ipv6_loopback() {
let result =
MeasurementPolicy::from_file_or_url("http://[::1]:1/measurements.json".into()).await;

assert!(matches!(result, Err(MeasurementFormatError::Reqwest(_))));
}
}
Loading