[#137] Refactor context data retrievers

Signed-off-by: Roman Loginov <r.loginov@yadro.com>
This commit is contained in:
Roman Loginov 2023-08-14 18:34:41 +03:00 committed by Denis Kirillov
parent 52b89d3497
commit 40d7f844e3
19 changed files with 106 additions and 76 deletions

View file

@ -266,7 +266,7 @@ func (h *handler) GetBucketACLHandler(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) bearerTokenIssuerKey(ctx context.Context) (*keys.PublicKey, error) {
box, err := layer.GetBoxData(ctx)
box, err := middleware.GetBoxData(ctx)
if err != nil {
return nil, err
}

View file

@ -2,7 +2,6 @@ package handler
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/rand"
"crypto/sha256"
@ -1428,7 +1427,7 @@ func TestPutBucketPolicy(t *testing.T) {
createBucket(t, hc, bktName, box)
w, r := prepareTestPayloadRequest(hc, bktName, "", bytes.NewReader([]byte(bktPolicy)))
ctx := context.WithValue(r.Context(), middleware.BoxData, box)
ctx := middleware.SetBoxData(r.Context(), box)
r = r.WithContext(ctx)
hc.Handler().PutBucketPolicyHandler(w, r)
assertStatus(hc.t, w, http.StatusOK)
@ -1450,7 +1449,7 @@ func putBucketPolicy(hc *handlerContext, bktName string, bktPolicy *bucketPolicy
require.NoError(hc.t, err)
w, r := prepareTestPayloadRequest(hc, bktName, "", bytes.NewReader(body))
ctx := context.WithValue(r.Context(), middleware.BoxData, box)
ctx := middleware.SetBoxData(r.Context(), box)
r = r.WithContext(ctx)
hc.Handler().PutBucketPolicyHandler(w, r)
assertStatus(hc.t, w, status)
@ -1517,7 +1516,7 @@ func createBucketAssertS3Error(hc *handlerContext, bktName string, box *accessbo
func createBucketBase(hc *handlerContext, bktName string, box *accessbox.Box) *httptest.ResponseRecorder {
w, r := prepareTestRequest(hc, bktName, "", nil)
ctx := context.WithValue(r.Context(), middleware.BoxData, box)
ctx := middleware.SetBoxData(r.Context(), box)
r = r.WithContext(ctx)
hc.Handler().CreateBucketHandler(w, r)
return w
@ -1528,7 +1527,7 @@ func putBucketACL(t *testing.T, tc *handlerContext, bktName string, box *accessb
for key, val := range header {
r.Header.Set(key, val)
}
ctx := context.WithValue(r.Context(), middleware.BoxData, box)
ctx := middleware.SetBoxData(r.Context(), box)
r = r.WithContext(ctx)
tc.Handler().PutBucketACLHandler(w, r)
assertStatus(t, w, http.StatusOK)

View file

@ -1,7 +1,6 @@
package handler
import (
"context"
"net/http"
"strings"
"testing"
@ -24,14 +23,14 @@ func TestCORSOriginWildcard(t *testing.T) {
bktName := "bucket-for-cors"
box, _ := createAccessBox(t)
w, r := prepareTestRequest(hc, bktName, "", nil)
ctx := context.WithValue(r.Context(), middleware.BoxData, box)
ctx := middleware.SetBoxData(r.Context(), box)
r = r.WithContext(ctx)
r.Header.Add(api.AmzACL, "public-read")
hc.Handler().CreateBucketHandler(w, r)
assertStatus(t, w, http.StatusOK)
w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body))
ctx = context.WithValue(r.Context(), middleware.BoxData, box)
ctx = middleware.SetBoxData(r.Context(), box)
r = r.WithContext(ctx)
hc.Handler().PutBucketCorsHandler(w, r)
assertStatus(t, w, http.StatusOK)

View file

@ -281,7 +281,7 @@ func (h *handler) DeleteBucketHandler(w http.ResponseWriter, r *http.Request) {
var sessionToken *session.Container
boxData, err := layer.GetBoxData(r.Context())
boxData, err := middleware.GetBoxData(r.Context())
if err == nil {
sessionToken = boxData.Gate.SessionTokenForDelete()
}

View file

@ -152,7 +152,7 @@ func prepareHandlerContextBase(t *testing.T, minCache bool) *handlerContext {
h: h,
tp: tp,
tree: treeMock,
context: context.WithValue(context.Background(), middleware.BoxData, newTestAccessBox(t, key)),
context: middleware.SetBoxData(context.Background(), newTestAccessBox(t, key)),
kludge: kludge,
}
}

View file

@ -1,7 +1,6 @@
package handler
import (
"context"
"net/http"
"testing"
"time"
@ -95,7 +94,7 @@ func TestInvalidAccessThroughCache(t *testing.T) {
headObject(t, hc, bktName, objName, nil, http.StatusOK)
w, r := prepareTestRequest(hc, bktName, objName, nil)
hc.Handler().HeadObjectHandler(w, r.WithContext(context.WithValue(r.Context(), middleware.BoxData, newTestAccessBox(t, nil))))
hc.Handler().HeadObjectHandler(w, r.WithContext(middleware.SetBoxData(r.Context(), newTestAccessBox(t, nil))))
assertStatus(t, w, http.StatusForbidden)
}

View file

@ -166,7 +166,7 @@ func (h *handler) sendNotifications(ctx context.Context, p *SendNotificationPara
return nil
}
box, err := layer.GetBoxData(ctx)
box, err := middleware.GetBoxData(ctx)
if err == nil && box.Gate.BearerToken != nil {
p.User = bearer.ResolveIssuer(*box.Gate.BearerToken).EncodeToString()
}

View file

@ -748,7 +748,7 @@ func (h *handler) CreateBucketHandler(w http.ResponseWriter, r *http.Request) {
}
var policies []*accessbox.ContainerPolicy
boxData, err := layer.GetBoxData(ctx)
boxData, err := middleware.GetBoxData(ctx)
if err == nil {
policies = boxData.Policies
p.SessionContainerCreation = boxData.Gate.SessionTokenForPut()

View file

@ -282,14 +282,14 @@ func getChunkedRequest(ctx context.Context, t *testing.T, bktName, objName strin
w := httptest.NewRecorder()
reqInfo := middleware.NewReqInfo(w, req, middleware.ObjectRequest{Bucket: bktName, Object: objName})
req = req.WithContext(middleware.SetReqInfo(ctx, reqInfo))
req = req.WithContext(context.WithValue(req.Context(), middleware.ClientTime, signTime))
req = req.WithContext(context.WithValue(req.Context(), middleware.AuthHeaders, &auth.AuthHeader{
req = req.WithContext(middleware.SetClientTime(req.Context(), signTime))
req = req.WithContext(middleware.SetAuthHeaders(req.Context(), &auth.AuthHeader{
AccessKeyID: AWSAccessKeyID,
SignatureV4: "4f232c4386841ef735655705268965c44a0e4690baa4adea153f7db9fa80a0a9",
Service: "s3",
Region: "us-east-1",
}))
req = req.WithContext(context.WithValue(req.Context(), middleware.BoxData, &accessbox.Box{
req = req.WithContext(middleware.SetBoxData(req.Context(), &accessbox.Box{
Gate: &accessbox.GateData{
AccessKey: AWSSecretAccessKey,
},

View file

@ -9,11 +9,9 @@ import (
"net/http"
"time"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth"
v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4"
errs "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"github.com/aws/aws-sdk-go/aws/credentials"
)
@ -191,15 +189,13 @@ func (c *s3ChunkReader) Read(buf []byte) (num int, err error) {
}
func newSignV4ChunkedReader(req *http.Request) (io.ReadCloser, error) {
// Expecting to refactor this in future:
// https://git.frostfs.info/TrueCloudLab/frostfs-s3-gw/issues/137
box, ok := req.Context().Value(middleware.BoxData).(*accessbox.Box)
if !ok {
box, err := middleware.GetBoxData(req.Context())
if err != nil {
return nil, errs.GetAPIError(errs.ErrAuthorizationHeaderMalformed)
}
authHeaders, ok := req.Context().Value(middleware.AuthHeaders).(*auth.AuthHeader)
if !ok {
authHeaders, err := middleware.GetAuthHeaders(req.Context())
if err != nil {
return nil, errs.GetAPIError(errs.ErrAuthorizationHeaderMalformed)
}
@ -209,8 +205,8 @@ func newSignV4ChunkedReader(req *http.Request) (io.ReadCloser, error) {
return nil, errs.GetAPIError(errs.ErrSignatureDoesNotMatch)
}
reqTime, ok := req.Context().Value(middleware.ClientTime).(time.Time)
if !ok {
reqTime, err := middleware.GetClientTime(req.Context())
if err != nil {
return nil, errs.GetAPIError(errs.ErrMalformedDate)
}
newStreamSigner := v4.NewStreamSigner(authHeaders.Region, "s3", seed, currentCredentials)

View file

@ -131,7 +131,7 @@ func parseRange(s string) (*layer.RangeParams, error) {
}
func getSessionTokenSetEACL(ctx context.Context) (*session.Container, error) {
boxData, err := layer.GetBoxData(ctx)
boxData, err := middleware.GetBoxData(ctx)
if err != nil {
return nil, err
}

View file

@ -12,7 +12,6 @@ import (
objectv2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/checksum"
apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
@ -367,7 +366,7 @@ func (t *TestFrostFS) checkAccess(cnrID cid.ID, owner user.ID, op eacl.Operation
}
func getBearerOwner(ctx context.Context) user.ID {
if bd, ok := ctx.Value(middleware.BoxData).(*accessbox.Box); ok && bd != nil && bd.Gate != nil && bd.Gate.BearerToken != nil {
if bd, err := middleware.GetBoxData(ctx); err == nil && bd.Gate.BearerToken != nil {
return bearer.ResolveIssuer(*bd.Gate.BearerToken)
}

View file

@ -16,7 +16,6 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/encryption"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client"
cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
@ -328,13 +327,13 @@ func (n *layer) IsNotificationEnabled() bool {
// IsAuthenticatedRequest checks if access box exists in the current request.
func IsAuthenticatedRequest(ctx context.Context) bool {
_, ok := ctx.Value(middleware.BoxData).(*accessbox.Box)
return ok
_, err := middleware.GetBoxData(ctx)
return err == nil
}
// TimeNow returns client time from request or time.Now().
func TimeNow(ctx context.Context) time.Time {
if now, ok := ctx.Value(middleware.ClientTime).(time.Time); ok {
if now, err := middleware.GetClientTime(ctx); err == nil {
return now
}
@ -343,7 +342,7 @@ func TimeNow(ctx context.Context) time.Time {
// BearerOwner returns owner id from BearerToken (context) or from client owner.
func (n *layer) BearerOwner(ctx context.Context) user.ID {
if bd, ok := ctx.Value(middleware.BoxData).(*accessbox.Box); ok && bd != nil && bd.Gate != nil && bd.Gate.BearerToken != nil {
if bd, err := middleware.GetBoxData(ctx); err == nil && bd.Gate.BearerToken != nil {
return bearer.ResolveIssuer(*bd.Gate.BearerToken)
}
@ -362,7 +361,7 @@ func (n *layer) reqLogger(ctx context.Context) *zap.Logger {
}
func (n *layer) prepareAuthParameters(ctx context.Context, prm *PrmAuth, bktOwner user.ID) {
if bd, ok := ctx.Value(middleware.BoxData).(*accessbox.Box); ok && bd != nil && bd.Gate != nil && bd.Gate.BearerToken != nil {
if bd, err := middleware.GetBoxData(ctx); err == nil && bd.Gate.BearerToken != nil {
if bd.Gate.BearerToken.Impersonate() || bktOwner.Equals(bearer.ResolveIssuer(*bd.Gate.BearerToken)) {
prm.BearerToken = bd.Gate.BearerToken
return

View file

@ -1,7 +1,6 @@
package layer
import (
"context"
"encoding/hex"
"fmt"
"os"
@ -11,8 +10,6 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/encryption"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
)
@ -137,18 +134,3 @@ func NameFromString(name string) (string, string) {
ind := strings.LastIndex(name, PathSeparator)
return name[ind+1:], name[:ind+1]
}
// GetBoxData extracts accessbox.Box from context.
func GetBoxData(ctx context.Context) (*accessbox.Box, error) {
var boxData *accessbox.Box
data, ok := ctx.Value(middleware.BoxData).(*accessbox.Box)
if !ok || data == nil {
return nil, fmt.Errorf("couldn't get box data from context")
}
boxData = data
if boxData.Gate == nil {
boxData.Gate = &accessbox.GateData{}
}
return boxData, nil
}

View file

@ -144,7 +144,7 @@ func prepareContext(t *testing.T, cachesConfig ...*CachesConfig) *testContext {
bearerToken := bearertest.Token()
require.NoError(t, bearerToken.Sign(key.PrivateKey))
ctx := context.WithValue(context.Background(), middleware.BoxData, &accessbox.Box{
ctx := middleware.SetBoxData(context.Background(), &accessbox.Box{
Gate: &accessbox.GateData{
BearerToken: &bearerToken,
GateKey: key.PublicKey(),

View file

@ -1,7 +1,6 @@
package middleware
import (
"context"
"net/http"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth"
@ -9,18 +8,6 @@ import (
"go.uber.org/zap"
)
// KeyWrapper is wrapper for context keys.
type KeyWrapper string
// AuthHeaders is a wrapper for authentication headers of a request.
var AuthHeaders = KeyWrapper("__context_auth_headers_key")
// BoxData is an ID used to store accessbox.Box in a context.
var BoxData = KeyWrapper("__context_box_key")
// ClientTime is an ID used to store client time.Time in a context.
var ClientTime = KeyWrapper("__context_client_time")
func Auth(center auth.Center, log *zap.Logger) Func {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -38,11 +25,11 @@ func Auth(center auth.Center, log *zap.Logger) Func {
return
}
} else {
ctx = context.WithValue(ctx, BoxData, box.AccessBox)
ctx = SetBoxData(ctx, box.AccessBox)
if !box.ClientTime.IsZero() {
ctx = context.WithValue(ctx, ClientTime, box.ClientTime)
ctx = SetClientTime(ctx, box.ClientTime)
}
ctx = context.WithValue(ctx, AuthHeaders, box.AuthHeaders)
ctx = SetAuthHeaders(ctx, box.AuthHeaders)
}
h.ServeHTTP(w, r.WithContext(ctx))

View file

@ -10,7 +10,6 @@ import (
"time"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/metrics"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer"
"go.uber.org/zap"
@ -152,7 +151,7 @@ func resolveCID(log *zap.Logger, resolveBucket BucketResolveFunc) cidResolveFunc
func resolveUser(ctx context.Context) string {
user := "anon"
if bd, ok := ctx.Value(BoxData).(*accessbox.Box); ok && bd != nil && bd.Gate != nil && bd.Gate.BearerToken != nil {
if bd, err := GetBoxData(ctx); err == nil && bd.Gate.BearerToken != nil {
user = bearer.ResolveIssuer(*bd.Gate.BearerToken).String()
}
return user

72
api/middleware/util.go Normal file
View file

@ -0,0 +1,72 @@
package middleware
import (
"context"
"fmt"
"time"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
)
// keyWrapper is wrapper for context keys.
type keyWrapper string
// authHeaders is a wrapper for authentication headers of a request.
var authHeadersKey = keyWrapper("__context_auth_headers_key")
// boxData is an ID used to store accessbox.Box in a context.
var boxDataKey = keyWrapper("__context_box_key")
// clientTime is an ID used to store client time.Time in a context.
var clientTimeKey = keyWrapper("__context_client_time")
// GetBoxData extracts accessbox.Box from context.
func GetBoxData(ctx context.Context) (*accessbox.Box, error) {
var box *accessbox.Box
data, ok := ctx.Value(boxDataKey).(*accessbox.Box)
if !ok || data == nil {
return nil, fmt.Errorf("couldn't get box data from context")
}
box = data
if box.Gate == nil {
box.Gate = &accessbox.GateData{}
}
return box, nil
}
// GetAuthHeaders extracts auth.AuthHeader from context.
func GetAuthHeaders(ctx context.Context) (*auth.AuthHeader, error) {
authHeaders, ok := ctx.Value(authHeadersKey).(*auth.AuthHeader)
if !ok {
return nil, fmt.Errorf("couldn't get auth headers from context")
}
return authHeaders, nil
}
// GetClientTime extracts time.Time from context.
func GetClientTime(ctx context.Context) (time.Time, error) {
clientTime, ok := ctx.Value(clientTimeKey).(time.Time)
if !ok {
return time.Time{}, fmt.Errorf("couldn't get client time from context")
}
return clientTime, nil
}
// SetBoxData sets accessbox.Box in the context.
func SetBoxData(ctx context.Context, box *accessbox.Box) context.Context {
return context.WithValue(ctx, boxDataKey, box)
}
// SetAuthHeaders sets auth.AuthHeader in the context.
func SetAuthHeaders(ctx context.Context, header *auth.AuthHeader) context.Context {
return context.WithValue(ctx, authHeadersKey, header)
}
// SetClientTime sets time.Time in the context.
func SetClientTime(ctx context.Context, newTime time.Time) context.Context {
return context.WithValue(ctx, clientTimeKey, newTime)
}

View file

@ -8,7 +8,6 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
errorsFrost "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/pkg/service/tree"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer"
@ -169,7 +168,7 @@ func (w *PoolWrapper) RemoveNode(ctx context.Context, bktInfo *data.BucketInfo,
}
func getBearer(ctx context.Context, bktInfo *data.BucketInfo) []byte {
if bd, ok := ctx.Value(middleware.BoxData).(*accessbox.Box); ok && bd != nil && bd.Gate != nil {
if bd, err := middleware.GetBoxData(ctx); err == nil {
if bd.Gate.BearerToken != nil {
if bd.Gate.BearerToken.Impersonate() || bktInfo.Owner.Equals(bearer.ResolveIssuer(*bd.Gate.BearerToken)) {
return bd.Gate.BearerToken.Marshal()