diff --git a/cmd/src/main.go b/cmd/src/main.go index 93be07c4bf..84f044c5a0 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -201,7 +201,8 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { return api.NewClient(opts) } -// readConfig reads the config file from the given path. +// readConfig reads the config from the standard config file, the (deprecated) user-specified config file, +// the environment variables, and the (deprecated) command-line flags. func readConfig() (*config, error) { cfgFile := *configPath userSpecified := *configPath != "" diff --git a/internal/api/proxy.go b/internal/api/proxy.go index 9589b9beb5..38b489367f 100644 --- a/internal/api/proxy.go +++ b/internal/api/proxy.go @@ -1,11 +1,8 @@ package api import ( - "bufio" "context" "crypto/tls" - "encoding/base64" - "fmt" "net" "net/http" "net/url" @@ -15,8 +12,8 @@ import ( // // Note: baseTransport is considered to be a clone created with transport.Clone() // -// - If a the proxyPath is not empty, a unix socket proxy is created. -// - Otherwise, the proxyURL is used to determine if we should proxy socks5 / http connections +// - If proxyPath is not empty, a unix socket proxy is created. +// - Otherwise, proxyURL is used to determine if we should proxy socks5 / http connections func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyPath string) *http.Transport { handshakeTLS := func(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { // Extract the hostname (without the port) for TLS SNI @@ -24,13 +21,13 @@ func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyP if err != nil { return nil, err } - tlsConn := tls.Client(conn, &tls.Config{ - ServerName: host, - // Pull InsecureSkipVerify from the target host transport - // so that insecure-skip-verify flag settings are honored for the proxy server - InsecureSkipVerify: baseTransport.TLSClientConfig.InsecureSkipVerify, - }) + cfg := baseTransport.TLSClientConfig.Clone() + if cfg.ServerName == "" { + cfg.ServerName = host + } + tlsConn := tls.Client(conn, cfg) if err := tlsConn.HandshakeContext(ctx); err != nil { + tlsConn.Close() return nil, err } return tlsConn, nil @@ -53,82 +50,7 @@ func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyP // clear out any system proxy settings baseTransport.Proxy = nil } else if proxyURL != nil { - switch proxyURL.Scheme { - case "socks5", "socks5h": - // SOCKS proxies work out of the box - no need to manually dial - baseTransport.Proxy = http.ProxyURL(proxyURL) - case "http", "https": - dial := func(ctx context.Context, network, addr string) (net.Conn, error) { - // Dial the proxy - d := net.Dialer{} - conn, err := d.DialContext(ctx, "tcp", proxyURL.Host) - if err != nil { - return nil, err - } - - // this is the whole point of manually dialing the HTTP(S) proxy: - // being able to force HTTP/1. - // When relying on Transport.Proxy, the protocol is always HTTP/2, - // but many proxy servers don't support HTTP/2. - // We don't want to disable HTTP/2 in general because we want to use it when - // connecting to the Sourcegraph API, using HTTP/1 for the proxy connection only. - protocol := "HTTP/1.1" - - // CONNECT is the HTTP method used to set up a tunneling connection with a proxy - method := "CONNECT" - - // Manually writing out the HTTP commands because it's not complicated, - // and http.Request has some janky behavior: - // - ignores the Proto field and hard-codes the protocol to HTTP/1.1 - // - ignores the Host Header (Header.Set("Host", host)) and uses URL.Host instead. - // - When the Host field is set, overrides the URL field - connectReq := fmt.Sprintf("%s %s %s\r\n", method, addr, protocol) - - // A Host header is required per RFC 2616, section 14.23 - connectReq += fmt.Sprintf("Host: %s\r\n", addr) - - // use authentication if proxy credentials are present - if proxyURL.User != nil { - password, _ := proxyURL.User.Password() - auth := base64.StdEncoding.EncodeToString([]byte(proxyURL.User.Username() + ":" + password)) - connectReq += fmt.Sprintf("Proxy-Authorization: Basic %s\r\n", auth) - } - - // finish up with an extra carriage return + newline, as per RFC 7230, section 3 - connectReq += "\r\n" - - // Send the CONNECT request to the proxy to establish the tunnel - if _, err := conn.Write([]byte(connectReq)); err != nil { - conn.Close() - return nil, err - } - - // Read and check the response from the proxy - resp, err := http.ReadResponse(bufio.NewReader(conn), nil) - if err != nil { - conn.Close() - return nil, err - } - if resp.StatusCode != http.StatusOK { - conn.Close() - return nil, fmt.Errorf("failed to connect to proxy %v: %v", proxyURL, resp.Status) - } - resp.Body.Close() - return conn, nil - } - dialTLS := func(ctx context.Context, network, addr string) (net.Conn, error) { - // Dial the underlying connection through the proxy - conn, err := dial(ctx, network, addr) - if err != nil { - return nil, err - } - return handshakeTLS(ctx, conn, addr) - } - baseTransport.DialContext = dial - baseTransport.DialTLSContext = dialTLS - // clear out any system proxy settings - baseTransport.Proxy = nil - } + baseTransport.Proxy = http.ProxyURL(proxyURL) } return baseTransport diff --git a/internal/api/proxy_test.go b/internal/api/proxy_test.go new file mode 100644 index 0000000000..a89d056328 --- /dev/null +++ b/internal/api/proxy_test.go @@ -0,0 +1,399 @@ +package api + +import ( + "crypto/tls" + "encoding/base64" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" +) + +type proxyOpts struct { + useTLS bool + username string + password string + observe func(*http.Request) +} + +// startProxy starts an HTTP or HTTPS CONNECT proxy on a random port. +// If opts.observe is set, it is called for each CONNECT request. +// If opts.username is set, Proxy-Authorization is required. +func startProxy(t *testing.T, opts proxyOpts) *url.URL { + t.Helper() + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if opts.observe != nil { + opts.observe(r) + } + + if r.Method != http.MethodConnect { + http.Error(w, "expected CONNECT", http.StatusMethodNotAllowed) + return + } + + if opts.username != "" { + wantAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(opts.username+":"+opts.password)) + if r.Header.Get("Proxy-Authorization") != wantAuth { + http.Error(w, "proxy auth required", http.StatusProxyAuthRequired) + return + } + } + + serveTunnel(w, r) + })) + + if opts.useTLS { + srv.StartTLS() + } else { + srv.Start() + } + t.Cleanup(srv.Close) + + pURL, _ := url.Parse(srv.URL) + if opts.username != "" { + pURL.User = url.UserPassword(opts.username, opts.password) + } + return pURL +} + +// serveTunnel implements the CONNECT tunnel: dials the target, hijacks the +// client connection, and copies bytes bidirectionally. +func serveTunnel(w http.ResponseWriter, r *http.Request) { + destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer destConn.Close() + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacking not supported", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + clientConn, bufrw, err := hijacker.Hijack() + if err != nil { + return + } + + var wg sync.WaitGroup + var once sync.Once + closeBoth := func() { + clientConn.Close() + destConn.Close() + } + defer once.Do(closeBoth) + + wg.Add(2) + // Read from bufrw (not clientConn) so any bytes already buffered + // by the server's bufio.Reader are forwarded to the destination. + go func() { + defer wg.Done() + io.Copy(destConn, bufrw) + once.Do(closeBoth) + }() + go func() { + defer wg.Done() + io.Copy(clientConn, destConn) + once.Do(closeBoth) + }() + wg.Wait() +} + +// newTestTransport creates a base transport suitable for proxy tests. +func newTestTransport() *http.Transport { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + return transport +} + +// startTargetServer starts an HTTPS server (with HTTP/2 enabled) that +// responds with "ok" to GET /. +func startTargetServer(t *testing.T) *httptest.Server { + t.Helper() + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "ok") + })) + srv.EnableHTTP2 = true + srv.StartTLS() + t.Cleanup(srv.Close) + return srv +} + +func TestWithProxyTransport_HTTPProxy(t *testing.T) { + target := startTargetServer(t) + + var mu sync.Mutex + var used bool + var proto string + + proxyURL := startProxy(t, proxyOpts{ + observe: func(r *http.Request) { + mu.Lock() + defer mu.Unlock() + used = true + proto = r.Proto + }, + }) + + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through http proxy: %v", err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if got := strings.TrimSpace(string(body)); got != "ok" { + t.Errorf("expected body 'ok', got %q", got) + } + + mu.Lock() + defer mu.Unlock() + if !used { + t.Fatal("proxy handler was never invoked") + } + if proto != "HTTP/1.1" { + t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto) + } +} + +func TestWithProxyTransport_HTTPSProxy(t *testing.T) { + target := startTargetServer(t) + + var mu sync.Mutex + var used bool + var proto string + + proxyURL := startProxy(t, proxyOpts{ + useTLS: true, + observe: func(r *http.Request) { + mu.Lock() + defer mu.Unlock() + used = true + proto = r.Proto + }, + }) + + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through https proxy: %v", err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if got := strings.TrimSpace(string(body)); got != "ok" { + t.Errorf("expected body 'ok', got %q", got) + } + + mu.Lock() + defer mu.Unlock() + if !used { + t.Fatal("proxy handler was never invoked") + } + if proto != "HTTP/1.1" { + t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto) + } +} + +func TestWithProxyTransport_ProxyAuth(t *testing.T) { + target := startTargetServer(t) + + t.Run("http proxy with auth", func(t *testing.T) { + proxyURL := startProxy(t, proxyOpts{username: "user", password: "pass"}) + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through authenticated http proxy: %v", err) + } + defer resp.Body.Close() + if _, err := io.ReadAll(resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("https proxy with auth", func(t *testing.T) { + proxyURL := startProxy(t, proxyOpts{useTLS: true, username: "user", password: "s3cret"}) + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through authenticated https proxy: %v", err) + } + defer resp.Body.Close() + if _, err := io.ReadAll(resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) +} + +func TestWithProxyTransport_HTTPSProxy_HTTP2ToOrigin(t *testing.T) { + // Verify that when tunneling through an HTTPS proxy, the connection to + // the origin target still negotiates HTTP/2 (not downgraded to HTTP/1.1). + target := startTargetServer(t) + proxyURL := startProxy(t, proxyOpts{useTLS: true}) + + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through https proxy: %v", err) + } + defer resp.Body.Close() + if _, err := io.ReadAll(resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.ProtoMajor != 2 { + t.Errorf("expected HTTP/2 to origin, got %s", resp.Proto) + } +} + +func TestWithProxyTransport_HandshakeFailureClosesConn(t *testing.T) { + // Verify that when the TLS handshake to the origin fails, the underlying + // tunnel connection is closed (regression test for tlsConn.Close on error). + // + // A plain TCP listener acts as the target. The proxy CONNECT succeeds + // (TCP-level), but the subsequent TLS handshake fails because the target + // is not a TLS server. If handshakeTLS properly closes tlsConn on failure, + // the tunnel tears down and the target sees the connection close. + connClosed := make(chan struct{}) + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + // Send non-TLS bytes so the client handshake fails immediately + // rather than waiting for a timeout. + conn.Write([]byte("not-tls\n")) + // Drain until the remote side closes the tunnel. + io.Copy(io.Discard, conn) + close(connClosed) + }() + + proxyURL := startProxy(t, proxyOpts{useTLS: true}) + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 5 * time.Second} + + _, err = client.Get("https://" + ln.Addr().String()) + if err == nil { + t.Fatal("expected TLS handshake error, got nil") + } + + select { + case <-connClosed: + // Connection was properly cleaned up. + case <-time.After(5 * time.Second): + t.Fatal("connection was not closed after TLS handshake failure") + } +} + +func TestWithProxyTransport_ProxyRejectsConnect(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantErr string + }{ + {"407 proxy auth required", http.StatusProxyAuthRequired, "proxy auth required", "Proxy Authentication Required"}, + {"403 forbidden", http.StatusForbidden, "access denied by policy", "Forbidden"}, + {"502 bad gateway", http.StatusBadGateway, "upstream unreachable", "Bad Gateway"}, + } + + // Use a local target so we never depend on external DNS. + target := startTargetServer(t) + + for _, tt := range tests { + t.Run("http proxy/"+tt.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, tt.body, tt.statusCode) + })) + t.Cleanup(srv.Close) + + proxyURL, _ := url.Parse(srv.URL) + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + _, err := client.Get(target.URL) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error should contain %q, got: %v", tt.wantErr, err) + } + }) + + t.Run("https proxy/"+tt.name, func(t *testing.T) { + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, tt.body, tt.statusCode) + })) + srv.StartTLS() + t.Cleanup(srv.Close) + + proxyURL, _ := url.Parse(srv.URL) + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + _, err := client.Get(target.URL) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error should contain %q, got: %v", tt.wantErr, err) + } + }) + } +}