diff --git a/integrations/oauth_test.go b/integrations/oauth_test.go index 53b83bb01..9674146f8 100644 --- a/integrations/oauth_test.go +++ b/integrations/oauth_test.go @@ -136,3 +136,44 @@ func TestAccessTokenExchangeWithInvalidCredentials(t *testing.T) { }) MakeRequest(t, req, 400) } + +func TestAccessTokenExchangeWithBasicAuth(t *testing.T) { + prepareTestEnv(t) + req := NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ + "grant_type": "authorization_code", + "redirect_uri": "a", + "code": "authcode", + "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", // test PKCE additionally + }) + req.Header.Add("Authorization", "Basic ZGE3ZGEzYmEtOWExMy00MTY3LTg1NmYtMzg5OWRlMGIwMTM4OjRNSzhOYTZSNTVzbWRDWTBXdUNDdW1aNmhqUlBuR1k1c2FXVlJISGpKaUE9") + resp := MakeRequest(t, req, 200) + type response struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + } + parsed := new(response) + assert.NoError(t, json.Unmarshal(resp.Body.Bytes(), parsed)) + assert.True(t, len(parsed.AccessToken) > 10) + assert.True(t, len(parsed.RefreshToken) > 10) + + // use wrong client_secret + req = NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ + "grant_type": "authorization_code", + "redirect_uri": "a", + "code": "authcode", + "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", // test PKCE additionally + }) + req.Header.Add("Authorization", "Basic ZGE3ZGEzYmEtOWExMy00MTY3LTg1NmYtMzg5OWRlMGIwMTM4OmJsYWJsYQ==") + resp = MakeRequest(t, req, 400) + + // missing header + req = NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ + "grant_type": "authorization_code", + "redirect_uri": "a", + "code": "authcode", + "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", // test PKCE additionally + }) + resp = MakeRequest(t, req, 400) +} diff --git a/routers/user/oauth.go b/routers/user/oauth.go index dbb3c4a39..110fa93b3 100644 --- a/routers/user/oauth.go +++ b/routers/user/oauth.go @@ -5,8 +5,10 @@ package user import ( + "encoding/base64" "fmt" "net/url" + "strings" "github.com/dgrijalva/jwt-go" "github.com/go-macaron/binding" @@ -305,6 +307,30 @@ func GrantApplicationOAuth(ctx *context.Context, form auth.GrantApplicationForm) // AccessTokenOAuth manages all access token requests by the client func AccessTokenOAuth(ctx *context.Context, form auth.AccessTokenForm) { + if form.ClientID == "" { + authHeader := ctx.Req.Header.Get("Authorization") + authContent := strings.SplitN(authHeader, " ", 2) + if len(authContent) == 2 && authContent[0] == "Basic" { + payload, err := base64.StdEncoding.DecodeString(authContent[1]) + if err != nil { + handleAccessTokenError(ctx, AccessTokenError{ + ErrorCode: AccessTokenErrorCodeInvalidRequest, + ErrorDescription: "cannot parse basic auth header", + }) + return + } + pair := strings.SplitN(string(payload), ":", 2) + if len(pair) != 2 { + handleAccessTokenError(ctx, AccessTokenError{ + ErrorCode: AccessTokenErrorCodeInvalidRequest, + ErrorDescription: "cannot parse basic auth header", + }) + return + } + form.ClientID = pair[0] + form.ClientSecret = pair[1] + } + } switch form.GrantType { case "refresh_token": handleRefreshToken(ctx, form) @@ -361,7 +387,7 @@ func handleAuthorizationCode(ctx *context.Context, form auth.AccessTokenForm) { if err != nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidClient, - ErrorDescription: "cannot load client", + ErrorDescription: fmt.Sprintf("cannot load client with client id: '%s'", form.ClientID), }) return }