diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c882ab7..8bd51cd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ All notable changes to `mcp/sdk` will be documented in this file. * Add client component for building MCP clients * Add `Builder::setReferenceHandler()` to allow custom `ReferenceHandlerInterface` implementations (e.g. authorization decorators) * Add elicitation enum schema types per SEP-1330: `TitledEnumSchemaDefinition`, `MultiSelectEnumSchemaDefinition`, `TitledMultiSelectEnumSchemaDefinition` +* Add `DnsRebindingProtectionMiddleware` enabled by default on `StreamableHttpTransport` to validate Origin headers against allowed hostnames 0.4.0 ----- diff --git a/docs/transports.md b/docs/transports.md index a68875d9..d0331f37 100644 --- a/docs/transports.md +++ b/docs/transports.md @@ -219,6 +219,28 @@ $transport = new StreamableHttpTransport( If middleware returns a response, the transport will still ensure CORS headers are present unless you set them yourself. +#### DNS Rebinding Protection + +`StreamableHttpTransport` automatically includes `DnsRebindingProtectionMiddleware`, which validates `Origin` and `Host` +headers to prevent [DNS rebinding attacks](https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#security-warning). +By default it only allows localhost variants (`localhost`, `127.0.0.1`, `[::1]`, `::1`). + +To allow additional hosts, pass your own instance — the transport will use it instead of the default: + +```php +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; + +$transport = new StreamableHttpTransport( + $request, + middleware: [ + new DnsRebindingProtectionMiddleware(allowedHosts: ['localhost', '127.0.0.1', '[::1]', '::1', 'myapp.local']), + ], +); +``` + +Requests with a non-allowed `Origin` or `Host` header receive a `403 Forbidden` response. +When the `Origin` header is present it takes precedence; otherwise the `Host` header is validated. + ### Architecture The HTTP transport doesn't run its own web server. Instead, it processes PSR-7 requests and returns PSR-7 responses that diff --git a/src/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddleware.php b/src/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddleware.php new file mode 100644 index 00000000..b3d94783 --- /dev/null +++ b/src/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddleware.php @@ -0,0 +1,114 @@ + */ + private readonly array $allowedHosts; + + /** + * @param string[] $allowedHosts Allowed hostnames (without port). Defaults to localhost variants. + * @param ResponseFactoryInterface|null $responseFactory PSR-17 response factory + * @param StreamFactoryInterface|null $streamFactory PSR-17 stream factory + */ + public function __construct( + array $allowedHosts = ['localhost', '127.0.0.1', '[::1]', '::1'], + ?ResponseFactoryInterface $responseFactory = null, + ?StreamFactoryInterface $streamFactory = null, + ) { + $this->allowedHosts = array_values(array_map('strtolower', $allowedHosts)); + $this->responseFactory = $responseFactory ?? Psr17FactoryDiscovery::findResponseFactory(); + $this->streamFactory = $streamFactory ?? Psr17FactoryDiscovery::findStreamFactory(); + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + $origin = $request->getHeaderLine('Origin'); + if ('' !== $origin) { + if (!$this->isAllowedOrigin($origin)) { + return $this->createForbiddenResponse('Forbidden: Invalid Origin header.'); + } + + return $handler->handle($request); + } + + $host = $request->getHeaderLine('Host'); + if ('' !== $host && !$this->isAllowedHost($host)) { + return $this->createForbiddenResponse('Forbidden: Invalid Host header.'); + } + + return $handler->handle($request); + } + + private function isAllowedOrigin(string $origin): bool + { + $parsed = parse_url($origin); + if (false === $parsed || !isset($parsed['host'])) { + return false; + } + + return \in_array(strtolower($parsed['host']), $this->allowedHosts, true); + } + + /** + * Validates the Host header value (host or host:port) against the allowed list. + */ + private function isAllowedHost(string $host): bool + { + // IPv6 host with port: [::1]:8080 + if (str_starts_with($host, '[')) { + $closingBracket = strpos($host, ']'); + if (false === $closingBracket) { + return false; + } + $hostname = substr($host, 0, $closingBracket + 1); + } else { + // Strip port if present (host:port) + $hostname = explode(':', $host, 2)[0]; + } + + return \in_array(strtolower($hostname), $this->allowedHosts, true); + } + + private function createForbiddenResponse(string $message): ResponseInterface + { + $body = json_encode(Error::forInvalidRequest($message), \JSON_THROW_ON_ERROR); + + return $this->responseFactory + ->createResponse(403) + ->withHeader('Content-Type', 'application/json') + ->withBody($this->streamFactory->createStream($body)); + } +} diff --git a/src/Server/Transport/StreamableHttpTransport.php b/src/Server/Transport/StreamableHttpTransport.php index 3c9b2f67..f4d36372 100644 --- a/src/Server/Transport/StreamableHttpTransport.php +++ b/src/Server/Transport/StreamableHttpTransport.php @@ -14,6 +14,7 @@ use Http\Discovery\Psr17FactoryDiscovery; use Mcp\Exception\InvalidArgumentException; use Mcp\Schema\JsonRpc\Error; +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; use Mcp\Server\Transport\Http\MiddlewareRequestHandler; use Psr\Http\Message\ResponseFactoryInterface; use Psr\Http\Message\ResponseInterface; @@ -77,12 +78,23 @@ public function __construct( 'Access-Control-Expose-Headers' => self::SESSION_HEADER, ], $corsHeaders); + $hasDnsRebindingProtection = false; foreach ($middleware as $m) { if (!$m instanceof MiddlewareInterface) { throw new InvalidArgumentException('Streamable HTTP middleware must implement Psr\\Http\\Server\\MiddlewareInterface.'); } + if ($m instanceof DnsRebindingProtectionMiddleware) { + $hasDnsRebindingProtection = true; + } $this->middleware[] = $m; } + + if (!$hasDnsRebindingProtection) { + array_unshift($this->middleware, new DnsRebindingProtectionMiddleware( + responseFactory: $this->responseFactory, + streamFactory: $this->streamFactory, + )); + } } public function send(string $data, array $context): void diff --git a/tests/Conformance/conformance-baseline.yml b/tests/Conformance/conformance-baseline.yml index de676e85..03939d5f 100644 --- a/tests/Conformance/conformance-baseline.yml +++ b/tests/Conformance/conformance-baseline.yml @@ -1,3 +1 @@ -server: - - dns-rebinding-protection - +server: [] diff --git a/tests/Unit/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddlewareTest.php b/tests/Unit/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddlewareTest.php new file mode 100644 index 00000000..e2f6052a --- /dev/null +++ b/tests/Unit/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddlewareTest.php @@ -0,0 +1,244 @@ +factory = new Psr17Factory(); + $this->handler = new class($this->factory) implements RequestHandlerInterface { + public function __construct(private ResponseFactoryInterface $factory) + { + } + + public function handle(ServerRequestInterface $request): ResponseInterface + { + return $this->factory->createResponse(200); + } + }; + } + + #[TestDox('allows request with valid localhost Origin header')] + public function testAllowsLocalhostOrigin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://localhost:8000'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('allows request with 127.0.0.1 Origin header')] + public function testAllows127001Origin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://127.0.0.1:3000'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('allows request with [::1] Origin header')] + public function testAllowsIpv6LocalhostOrigin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://[::1]:8000'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('allows request with no Origin header')] + public function testAllowsEmptyOrigin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('rejects request with evil Origin header')] + public function testRejectsEvilOrigin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://evil.example.com'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(403, $response->getStatusCode()); + $this->assertStringContainsString('Origin', (string) $response->getBody()); + } + + #[TestDox('rejects request with evil Origin header even with port')] + public function testRejectsEvilOriginWithPort(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://evil.example.com:8000'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(403, $response->getStatusCode()); + } + + #[TestDox('rejects malformed Origin header')] + public function testRejectsMalformedOrigin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'not-a-url'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(403, $response->getStatusCode()); + } + + #[TestDox('Origin matching is case-insensitive')] + public function testOriginMatchingIsCaseInsensitive(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://LOCALHOST:8000'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('supports custom allowed hosts')] + public function testCustomAllowedHosts(): void + { + $middleware = new DnsRebindingProtectionMiddleware( + allowedHosts: ['myapp.local'], + responseFactory: $this->factory, + ); + + $allowed = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://myapp.local:9000'); + $this->assertSame(200, $middleware->process($allowed, $this->handler)->getStatusCode()); + + $rejected = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://localhost'); + $this->assertSame(403, $middleware->process($rejected, $this->handler)->getStatusCode()); + } + + #[TestDox('rejects request with evil Host header when no Origin is present')] + public function testRejectsEvilHostWithoutOrigin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://evil.example.com/') + ->withHeader('Host', 'evil.example.com'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(403, $response->getStatusCode()); + $this->assertStringContainsString('Host', (string) $response->getBody()); + } + + #[TestDox('rejects request with evil Host header including port when no Origin is present')] + public function testRejectsEvilHostWithPortWithoutOrigin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://evil.example.com:8000/') + ->withHeader('Host', 'evil.example.com:8000'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(403, $response->getStatusCode()); + } + + #[TestDox('allows request with localhost Host header when no Origin is present')] + public function testAllowsLocalhostHostWithoutOrigin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost:8000/') + ->withHeader('Host', 'localhost:8000'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('allows request with IPv6 Host header when no Origin is present')] + public function testAllowsIpv6HostWithoutOrigin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://[::1]:8000/') + ->withHeader('Host', '[::1]:8000'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('Origin takes precedence over Host header')] + public function testOriginTakesPrecedenceOverHost(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + // Valid Origin but evil Host — should pass because Origin is checked first + $request = $this->factory->createServerRequest('POST', 'http://evil.example.com/') + ->withHeader('Origin', 'http://localhost:8000') + ->withHeader('Host', 'evil.example.com'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('allowed hosts are normalized to lowercase')] + public function testAllowedHostsAreCaseInsensitive(): void + { + $middleware = new DnsRebindingProtectionMiddleware( + allowedHosts: ['MyApp.Local'], + responseFactory: $this->factory, + ); + + $request = $this->factory->createServerRequest('POST', 'http://myapp.local/') + ->withHeader('Origin', 'http://myapp.local:9000'); + + $this->assertSame(200, $middleware->process($request, $this->handler)->getStatusCode()); + } + + #[TestDox('Host matching is case-insensitive')] + public function testHostMatchingIsCaseInsensitive(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://LOCALHOST:8000/') + ->withHeader('Host', 'LOCALHOST:8000'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(200, $response->getStatusCode()); + } +} diff --git a/tests/Unit/Server/Transport/StreamableHttpTransportTest.php b/tests/Unit/Server/Transport/StreamableHttpTransportTest.php index 7d9cd484..bda7f28c 100644 --- a/tests/Unit/Server/Transport/StreamableHttpTransportTest.php +++ b/tests/Unit/Server/Transport/StreamableHttpTransportTest.php @@ -40,7 +40,7 @@ public static function corsHeaderProvider(): iterable public function testCorsHeader(string $method, bool $middlewareDelegatesToTransport, int $expectedStatusCode): void { $factory = new Psr17Factory(); - $request = $factory->createServerRequest($method, 'https://example.com'); + $request = $factory->createServerRequest($method, 'http://localhost:8000'); $middleware = new class($factory, $expectedStatusCode, $middlewareDelegatesToTransport) implements MiddlewareInterface { public function __construct( @@ -90,7 +90,7 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface public function testCorsHeadersAreReplacedWhenAlreadyPresent(): void { $factory = new Psr17Factory(); - $request = $factory->createServerRequest('GET', 'https://example.com'); + $request = $factory->createServerRequest('GET', 'http://localhost:8000'); $middleware = new class($factory) implements MiddlewareInterface { public function __construct(private ResponseFactoryInterface $responses) @@ -130,7 +130,7 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface public function testMiddlewareRunsBeforeTransportHandlesRequest(): void { $factory = new Psr17Factory(); - $request = $factory->createServerRequest('OPTIONS', 'https://example.com'); + $request = $factory->createServerRequest('OPTIONS', 'http://localhost:8000'); $state = new \stdClass(); $state->called = false;