From 16d6e6c34e3676330118d89e3e108348aea76d64 Mon Sep 17 00:00:00 2001 From: Roman Loginov Date: Fri, 3 May 2024 15:29:51 +0300 Subject: [PATCH] [#112] tokens: Extend test coverage Signed-off-by: Roman Loginov --- tokens/bearer-token_test.go | 193 +++++++++++++++++++++++++++++++----- 1 file changed, 168 insertions(+), 25 deletions(-) diff --git a/tokens/bearer-token_test.go b/tokens/bearer-token_test.go index cc54e74..6fb3bf4 100644 --- a/tokens/bearer-token_test.go +++ b/tokens/bearer-token_test.go @@ -23,19 +23,29 @@ func makeTestCookie(value []byte) *fasthttp.RequestHeader { func makeTestHeader(value []byte) *fasthttp.RequestHeader { header := new(fasthttp.RequestHeader) if value != nil { - header.Set(fasthttp.HeaderAuthorization, bearerTokenHdr+" "+string(value)) + header.Set(fasthttp.HeaderAuthorization, string(value)) } return header } -func Test_fromCookie(t *testing.T) { +func makeBearer(value string) string { + return bearerTokenHdr + " " + value +} + +func TestBearerTokenFromCookie(t *testing.T) { cases := []struct { name string actual []byte expect []byte }{ - {name: "empty"}, - {name: "normal", actual: []byte("TOKEN"), expect: []byte("TOKEN")}, + { + name: "empty", + }, + { + name: "normal", + actual: []byte("TOKEN"), + expect: []byte("TOKEN"), + }, } for _, tt := range cases { @@ -45,14 +55,31 @@ func Test_fromCookie(t *testing.T) { } } -func Test_fromHeader(t *testing.T) { +func TestBearerTokenFromHeader(t *testing.T) { + validToken := "token" + tokenWithoutPrefix := "invalid-token" + cases := []struct { name string actual []byte expect []byte }{ - {name: "empty"}, - {name: "normal", actual: []byte("TOKEN"), expect: []byte("TOKEN")}, + { + name: "empty", + }, + { + name: "token without the bearer prefix", + actual: []byte(tokenWithoutPrefix), + }, + { + name: "token without payload", + actual: []byte(makeBearer("")), + }, + { + name: "normal", + actual: []byte(makeBearer(validToken)), + expect: []byte(validToken), + }, } for _, tt := range cases { @@ -62,7 +89,7 @@ func Test_fromHeader(t *testing.T) { } } -func Test_fetchBearerToken(t *testing.T) { +func TestFetchBearerToken(t *testing.T) { key, err := keys.NewPrivateKey() require.NoError(t, err) var uid user.ID @@ -75,43 +102,77 @@ func Test_fetchBearerToken(t *testing.T) { require.NotEmpty(t, t64) cases := []struct { - name string - + name string cookie string header string - error string + nilCtx bool expect *bearer.Token }{ - {name: "empty"}, - - {name: "bad base64 header", header: "WRONG BASE64", error: "can't base64-decode bearer token"}, - {name: "bad base64 cookie", cookie: "WRONG BASE64", error: "can't base64-decode bearer token"}, - - {name: "header token unmarshal error", header: "dGVzdAo=", error: "can't unmarshal bearer token"}, - {name: "cookie token unmarshal error", cookie: "dGVzdAo=", error: "can't unmarshal bearer token"}, - + { + name: "empty", + }, + { + name: "nil context", + nilCtx: true, + }, + { + name: "bad base64 header", + header: "WRONG BASE64", + error: "can't base64-decode bearer token", + }, + { + name: "bad base64 cookie", + cookie: "WRONG BASE64", + error: "can't base64-decode bearer token", + }, + { + name: "header token unmarshal error", + header: "dGVzdAo=", + error: "can't unmarshal bearer token", + }, + { + name: "cookie token unmarshal error", + cookie: "dGVzdAo=", + error: "can't unmarshal bearer token", + }, { name: "bad header and cookie", header: "WRONG BASE64", cookie: "dGVzdAo=", error: "can't unmarshal bearer token", }, - { name: "bad header, but good cookie", header: "dGVzdAo=", cookie: t64, expect: tkn, }, - - {name: "ok for header", header: t64, expect: tkn}, - {name: "ok for cookie", cookie: t64, expect: tkn}, + { + name: "bad cookie, but good header", + header: t64, + cookie: "dGVzdAo=", + expect: tkn, + }, + { + name: "ok for header", + header: t64, + expect: tkn, + }, + { + name: "ok for cookie", + cookie: t64, + expect: tkn, + }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - ctx := makeTestRequest(tt.cookie, tt.header) + var ctx *fasthttp.RequestCtx + if !tt.nilCtx { + ctx = makeTestRequest(tt.cookie, tt.header) + } + actual, err := fetchBearerToken(ctx) if tt.error == "" { @@ -139,7 +200,7 @@ func makeTestRequest(cookie, header string) *fasthttp.RequestCtx { return ctx } -func Test_checkAndPropagateBearerToken(t *testing.T) { +func TestCheckAndPropagateBearerToken(t *testing.T) { key, err := keys.NewPrivateKey() require.NoError(t, err) var uid user.ID @@ -162,3 +223,85 @@ func Test_checkAndPropagateBearerToken(t *testing.T) { require.NoError(t, err) require.Equal(t, tkn, actual) } + +func TestLoadBearerToken(t *testing.T) { + ctx := context.Background() + token := new(bearer.Token) + + cases := []struct { + name string + appCtx context.Context + error string + }{ + { + name: "token is missing in the context", + appCtx: ctx, + error: "found empty bearer token", + }, + { + name: "normal", + appCtx: context.WithValue(ctx, bearerTokenKey, token), + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tkn, err := LoadBearerToken(tt.appCtx) + + if tt.error == "" { + require.NoError(t, err) + require.Equal(t, token, tkn) + + return + } + + require.Contains(t, err.Error(), tt.error) + }) + } +} + +func TestStoreBearerTokenAppCtx(t *testing.T) { + key, err := keys.NewPrivateKey() + require.NoError(t, err) + var uid user.ID + user.IDFromKey(&uid, key.PrivateKey.PublicKey) + + tkn := new(bearer.Token) + tkn.ForUser(uid) + + t64 := base64.StdEncoding.EncodeToString(tkn.Marshal()) + require.NotEmpty(t, t64) + + cases := []struct { + name string + req *fasthttp.RequestCtx + error string + }{ + { + name: "invalid token", + req: makeTestRequest("dGVzdAo=", ""), + error: "can't unmarshal bearer token", + }, + { + name: "normal", + req: makeTestRequest(t64, ""), + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + ctx, err := StoreBearerTokenAppCtx(context.Background(), tt.req) + + if tt.error == "" { + require.NoError(t, err) + actualToken, ok := ctx.Value(bearerTokenKey).(*bearer.Token) + require.True(t, ok) + require.Equal(t, tkn, actualToken) + + return + } + + require.Contains(t, err.Error(), tt.error) + }) + } +}