Skip to content

Commit 8f016c5

Browse files
authored
Merge branch 'main' into peterguy/support-http_proxy
2 parents 6520ac4 + 0e5e006 commit 8f016c5

17 files changed

Lines changed: 531 additions & 128 deletions

.tool-versions

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
golang 1.25.4
1+
golang 1.26.1
22
shfmt 3.8.0
33
shellcheck 0.10.0

cmd/src/auth.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
)
7+
8+
var authCommands commander
9+
10+
func init() {
11+
usage := `'src auth' provides authentication-related helper commands.
12+
13+
Usage:
14+
15+
src auth command [command options]
16+
17+
The commands are:
18+
19+
token prints the current authentication token or Authorization header
20+
21+
Use "src auth [command] -h" for more information about a command.
22+
`
23+
24+
flagSet := flag.NewFlagSet("auth", flag.ExitOnError)
25+
handler := func(args []string) error {
26+
authCommands.run(flagSet, "src auth", usage, args)
27+
return nil
28+
}
29+
30+
commands = append(commands, &command{
31+
flagSet: flagSet,
32+
handler: handler,
33+
usageFunc: func() {
34+
fmt.Println(usage)
35+
},
36+
})
37+
}

cmd/src/auth_token.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"flag"
6+
"fmt"
7+
8+
"github.com/sourcegraph/sourcegraph/lib/errors"
9+
10+
"github.com/sourcegraph/src-cli/internal/oauth"
11+
)
12+
13+
var (
14+
loadOAuthToken = oauth.LoadToken
15+
newOAuthTokenRefresher = func(token *oauth.Token) oauthTokenRefresher {
16+
return oauth.NewTokenRefresher(token)
17+
}
18+
)
19+
20+
type oauthTokenRefresher interface {
21+
GetToken(ctx context.Context) (oauth.Token, error)
22+
}
23+
24+
func init() {
25+
flagSet := flag.NewFlagSet("token", flag.ExitOnError)
26+
header := flagSet.Bool("header", false, "print the token as an Authorization header")
27+
usageFunc := func() {
28+
fmt.Fprintf(flag.CommandLine.Output(), "Usage of 'src auth token':\n\n")
29+
fmt.Fprintf(flag.CommandLine.Output(), "Print the current authentication token.\n")
30+
fmt.Fprintf(flag.CommandLine.Output(), "Use --header to print a complete Authorization header instead.\n\n")
31+
flagSet.PrintDefaults()
32+
}
33+
34+
handler := func(args []string) error {
35+
if err := flagSet.Parse(args); err != nil {
36+
return err
37+
}
38+
39+
token, err := resolveAuthToken(context.Background(), cfg)
40+
if err != nil {
41+
return err
42+
}
43+
44+
token = formatAuthTokenOutput(token, cfg.AuthMode(), *header)
45+
fmt.Println(token)
46+
return nil
47+
}
48+
49+
authCommands = append(authCommands, &command{
50+
flagSet: flagSet,
51+
handler: handler,
52+
usageFunc: usageFunc,
53+
})
54+
}
55+
56+
func resolveAuthToken(ctx context.Context, cfg *config) (string, error) {
57+
if err := cfg.requireCIAccessToken(); err != nil {
58+
return "", err
59+
}
60+
61+
if cfg.accessToken != "" {
62+
return cfg.accessToken, nil
63+
}
64+
65+
oauthToken, err := loadOAuthToken(ctx, cfg.endpointURL)
66+
if err != nil {
67+
return "", errors.Wrap(err, "error loading OAuth token; set SRC_ACCESS_TOKEN or run `src login`")
68+
}
69+
70+
token, err := newOAuthTokenRefresher(oauthToken).GetToken(ctx)
71+
if err != nil {
72+
return "", errors.Wrap(err, "refreshing OAuth token")
73+
}
74+
75+
return token.AccessToken, nil
76+
}
77+
78+
func formatAuthTokenOutput(token string, mode AuthMode, header bool) string {
79+
if !header {
80+
return token
81+
}
82+
83+
if mode == AuthModeAccessToken {
84+
return fmt.Sprintf("Authorization: token %s", token)
85+
}
86+
87+
return fmt.Sprintf("Authorization: Bearer %s", token)
88+
}

cmd/src/auth_token_test.go

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/url"
7+
"testing"
8+
9+
"github.com/sourcegraph/src-cli/internal/oauth"
10+
)
11+
12+
func TestResolveAuthToken(t *testing.T) {
13+
t.Run("uses configured access token before keyring", func(t *testing.T) {
14+
reset := stubAuthTokenDependencies(t)
15+
defer reset()
16+
17+
newRefresherCalled := false
18+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
19+
newRefresherCalled = true
20+
return fakeOAuthTokenRefresher{}
21+
}
22+
23+
token, err := resolveAuthToken(context.Background(), &config{
24+
accessToken: "access-token",
25+
endpointURL: mustParseURL(t, "https://example.com"),
26+
})
27+
if err != nil {
28+
t.Fatal(err)
29+
}
30+
if token != "access-token" {
31+
t.Fatalf("token = %q, want %q", token, "access-token")
32+
}
33+
if newRefresherCalled {
34+
t.Fatal("expected OAuth token refresher not to be created")
35+
}
36+
})
37+
38+
t.Run("requires access token in CI", func(t *testing.T) {
39+
reset := stubAuthTokenDependencies(t)
40+
defer reset()
41+
42+
loadCalled := false
43+
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
44+
loadCalled = true
45+
return nil, nil
46+
}
47+
48+
_, err := resolveAuthToken(context.Background(), &config{
49+
inCI: true,
50+
endpointURL: mustParseURL(t, "https://example.com"),
51+
})
52+
if err != errCIAccessTokenRequired {
53+
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
54+
}
55+
if loadCalled {
56+
t.Fatal("expected OAuth token loader not to be called")
57+
}
58+
})
59+
60+
t.Run("uses stored oauth token", func(t *testing.T) {
61+
reset := stubAuthTokenDependencies(t)
62+
defer reset()
63+
64+
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
65+
return &oauth.Token{
66+
AccessToken: "oauth-token",
67+
}, nil
68+
}
69+
70+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
71+
return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "oauth-token"}}
72+
}
73+
74+
token, err := resolveAuthToken(context.Background(), &config{
75+
endpointURL: mustParseURL(t, "https://example.com"),
76+
})
77+
if err != nil {
78+
t.Fatal(err)
79+
}
80+
if token != "oauth-token" {
81+
t.Fatalf("token = %q, want %q", token, "oauth-token")
82+
}
83+
})
84+
85+
t.Run("refreshes expiring oauth token", func(t *testing.T) {
86+
reset := stubAuthTokenDependencies(t)
87+
defer reset()
88+
89+
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
90+
return &oauth.Token{AccessToken: "old-token"}, nil
91+
}
92+
93+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
94+
return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "new-token"}}
95+
}
96+
97+
token, err := resolveAuthToken(context.Background(), &config{
98+
endpointURL: mustParseURL(t, "https://example.com"),
99+
})
100+
if err != nil {
101+
t.Fatal(err)
102+
}
103+
if token != "new-token" {
104+
t.Fatalf("token = %q, want %q", token, "new-token")
105+
}
106+
})
107+
108+
t.Run("returns refresh error when shared refresh logic fails", func(t *testing.T) {
109+
reset := stubAuthTokenDependencies(t)
110+
defer reset()
111+
112+
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
113+
return &oauth.Token{AccessToken: "old-token"}, nil
114+
}
115+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
116+
return fakeOAuthTokenRefresher{err: fmt.Errorf("refresh failed")}
117+
}
118+
119+
_, err := resolveAuthToken(context.Background(), &config{
120+
endpointURL: mustParseURL(t, "https://example.com"),
121+
})
122+
if err == nil {
123+
t.Fatal("expected error")
124+
}
125+
})
126+
}
127+
128+
func TestFormatAuthTokenOutput(t *testing.T) {
129+
tests := []struct {
130+
name string
131+
token string
132+
mode AuthMode
133+
header bool
134+
want string
135+
}{
136+
{
137+
name: "raw access token",
138+
token: "access-token",
139+
mode: AuthModeAccessToken,
140+
header: false,
141+
want: "access-token",
142+
},
143+
{
144+
name: "raw oauth token",
145+
token: "oauth-token",
146+
mode: AuthModeOAuth,
147+
header: false,
148+
want: "oauth-token",
149+
},
150+
{
151+
name: "authorization header for access token",
152+
token: "access-token",
153+
mode: AuthModeAccessToken,
154+
header: true,
155+
want: "Authorization: token access-token",
156+
},
157+
{
158+
name: "authorization header for oauth token",
159+
token: "oauth-token",
160+
mode: AuthModeOAuth,
161+
header: true,
162+
want: "Authorization: Bearer oauth-token",
163+
},
164+
}
165+
166+
for _, test := range tests {
167+
t.Run(test.name, func(t *testing.T) {
168+
if got := formatAuthTokenOutput(test.token, test.mode, test.header); got != test.want {
169+
t.Fatalf("formatAuthTokenOutput(%q, %v, %v) = %q, want %q", test.token, test.mode, test.header, got, test.want)
170+
}
171+
})
172+
}
173+
}
174+
175+
func stubAuthTokenDependencies(t *testing.T) func() {
176+
t.Helper()
177+
178+
prevLoad := loadOAuthToken
179+
prevNewRefresher := newOAuthTokenRefresher
180+
181+
return func() {
182+
loadOAuthToken = prevLoad
183+
newOAuthTokenRefresher = prevNewRefresher
184+
}
185+
}
186+
187+
type fakeOAuthTokenRefresher struct {
188+
token oauth.Token
189+
err error
190+
}
191+
192+
func (r fakeOAuthTokenRefresher) GetToken(context.Context) (oauth.Token, error) {
193+
if r.err != nil {
194+
return oauth.Token{}, r.err
195+
}
196+
return r.token, nil
197+
}

cmd/src/login.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ const (
100100
)
101101

102102
func loginCmd(ctx context.Context, p loginParams) error {
103+
if err := p.cfg.requireCIAccessToken(); err != nil {
104+
return err
105+
}
106+
103107
if p.cfg.configFilePath != "" {
104108
fmt.Fprintln(p.out)
105109
fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.configFilePath)

cmd/src/login_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ func TestLogin(t *testing.T) {
6161
}
6262
})
6363

64+
t.Run("CI requires access token", func(t *testing.T) {
65+
u := &url.URL{Scheme: "https", Host: "example.com"}
66+
out, err := check(t, &config{endpointURL: u, inCI: true}, u)
67+
if err != errCIAccessTokenRequired {
68+
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
69+
}
70+
if out != "" {
71+
t.Fatalf("output = %q, want empty output", out)
72+
}
73+
})
74+
6475
t.Run("warning when using config file", func(t *testing.T) {
6576
endpoint := &url.URL{Scheme: "https", Host: "example.com"}
6677
out, err := check(t, &config{endpointURL: endpoint, configFilePath: "f"}, endpoint)

0 commit comments

Comments
 (0)