objectsvc: Make requestContextKey private

1. Disallow storing wrongly typed value by this key.
2. This is the way described in context.Value() documentation.

Signed-off-by: Evgenii Stratonikov <e.stratonikov@yadro.com>
This commit is contained in:
Evgenii Stratonikov 2024-02-09 13:33:44 +03:00
parent 08c6578156
commit 2d3bed89d2
3 changed files with 34 additions and 28 deletions

View file

@ -112,7 +112,7 @@ type wrappedGetObjectStream struct {
} }
func (w *wrappedGetObjectStream) Context() context.Context { func (w *wrappedGetObjectStream) Context() context.Context {
return context.WithValue(w.GetObjectStream.Context(), object.RequestContextKey, &object.RequestContext{ return object.NewRequestContext(w.GetObjectStream.Context(), &object.RequestContext{
Namespace: w.requestInfo.ContainerNamespace(), Namespace: w.requestInfo.ContainerNamespace(),
SenderKey: w.requestInfo.SenderKey(), SenderKey: w.requestInfo.SenderKey(),
Role: w.requestInfo.RequestRole(), Role: w.requestInfo.RequestRole(),
@ -135,7 +135,7 @@ type wrappedRangeStream struct {
} }
func (w *wrappedRangeStream) Context() context.Context { func (w *wrappedRangeStream) Context() context.Context {
return context.WithValue(w.GetObjectRangeStream.Context(), object.RequestContextKey, &object.RequestContext{ return object.NewRequestContext(w.GetObjectRangeStream.Context(), &object.RequestContext{
Namespace: w.requestInfo.ContainerNamespace(), Namespace: w.requestInfo.ContainerNamespace(),
SenderKey: w.requestInfo.SenderKey(), SenderKey: w.requestInfo.SenderKey(),
Role: w.requestInfo.RequestRole(), Role: w.requestInfo.RequestRole(),
@ -158,7 +158,7 @@ type wrappedSearchStream struct {
} }
func (w *wrappedSearchStream) Context() context.Context { func (w *wrappedSearchStream) Context() context.Context {
return context.WithValue(w.SearchStream.Context(), object.RequestContextKey, &object.RequestContext{ return object.NewRequestContext(w.SearchStream.Context(), &object.RequestContext{
Namespace: w.requestInfo.ContainerNamespace(), Namespace: w.requestInfo.ContainerNamespace(),
SenderKey: w.requestInfo.SenderKey(), SenderKey: w.requestInfo.SenderKey(),
Role: w.requestInfo.RequestRole(), Role: w.requestInfo.RequestRole(),
@ -457,7 +457,7 @@ func (b Service) GetRange(request *objectV2.GetRangeRequest, stream object.GetOb
} }
func requestContext(ctx context.Context, reqInfo RequestInfo) context.Context { func requestContext(ctx context.Context, reqInfo RequestInfo) context.Context {
return context.WithValue(ctx, object.RequestContextKey, &object.RequestContext{ return object.NewRequestContext(ctx, &object.RequestContext{
Namespace: reqInfo.ContainerNamespace(), Namespace: reqInfo.ContainerNamespace(),
SenderKey: reqInfo.SenderKey(), SenderKey: reqInfo.SenderKey(),
Role: reqInfo.RequestRole(), Role: reqInfo.RequestRole(),

View file

@ -3,7 +3,6 @@ package ape
import ( import (
"context" "context"
"encoding/hex" "encoding/hex"
"fmt"
objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object" objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object"
"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/refs" "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/refs"
@ -88,25 +87,13 @@ func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error {
return g.GetObjectStream.Send(resp) return g.GetObjectStream.Send(resp)
} }
func requestContext(ctx context.Context) (*objectSvc.RequestContext, error) {
untyped := ctx.Value(objectSvc.RequestContextKey)
if untyped == nil {
return nil, fmt.Errorf("no key %s in context", objectSvc.RequestContextKey)
}
rc, ok := untyped.(*objectSvc.RequestContext)
if !ok {
return nil, fmt.Errorf("failed cast to RequestContext")
}
return rc, nil
}
func (c *Service) Get(request *objectV2.GetRequest, stream objectSvc.GetObjectStream) error { func (c *Service) Get(request *objectV2.GetRequest, stream objectSvc.GetObjectStream) error {
cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID()) cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID())
if err != nil { if err != nil {
return toStatusErr(err) return toStatusErr(err)
} }
reqCtx, err := requestContext(stream.Context()) reqCtx, err := objectSvc.FromRequestContext(stream.Context())
if err != nil { if err != nil {
return toStatusErr(err) return toStatusErr(err)
} }
@ -137,7 +124,7 @@ type putStreamBasicChecker struct {
func (p *putStreamBasicChecker) Send(ctx context.Context, request *objectV2.PutRequest) error { func (p *putStreamBasicChecker) Send(ctx context.Context, request *objectV2.PutRequest) error {
if partInit, ok := request.GetBody().GetObjectPart().(*objectV2.PutObjectPartInit); ok { if partInit, ok := request.GetBody().GetObjectPart().(*objectV2.PutObjectPartInit); ok {
reqCtx, err := requestContext(ctx) reqCtx, err := objectSvc.FromRequestContext(ctx)
if err != nil { if err != nil {
return toStatusErr(err) return toStatusErr(err)
} }
@ -184,7 +171,7 @@ func (c *Service) Head(ctx context.Context, request *objectV2.HeadRequest) (*obj
return nil, err return nil, err
} }
reqCtx, err := requestContext(ctx) reqCtx, err := objectSvc.FromRequestContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -248,7 +235,7 @@ func (c *Service) Search(request *objectV2.SearchRequest, stream objectSvc.Searc
} }
} }
reqCtx, err := requestContext(stream.Context()) reqCtx, err := objectSvc.FromRequestContext(stream.Context())
if err != nil { if err != nil {
return toStatusErr(err) return toStatusErr(err)
} }
@ -273,7 +260,7 @@ func (c *Service) Delete(ctx context.Context, request *objectV2.DeleteRequest) (
return nil, err return nil, err
} }
reqCtx, err := requestContext(ctx) reqCtx, err := objectSvc.FromRequestContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -304,7 +291,7 @@ func (c *Service) GetRange(request *objectV2.GetRangeRequest, stream objectSvc.G
return toStatusErr(err) return toStatusErr(err)
} }
reqCtx, err := requestContext(stream.Context()) reqCtx, err := objectSvc.FromRequestContext(stream.Context())
if err != nil { if err != nil {
return toStatusErr(err) return toStatusErr(err)
} }
@ -330,7 +317,7 @@ func (c *Service) GetRangeHash(ctx context.Context, request *objectV2.GetRangeHa
return nil, err return nil, err
} }
reqCtx, err := requestContext(ctx) reqCtx, err := objectSvc.FromRequestContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -365,7 +352,7 @@ func (c *Service) PutSingle(ctx context.Context, request *objectV2.PutSingleRequ
return nil, err return nil, err
} }
reqCtx, err := requestContext(ctx) reqCtx, err := objectSvc.FromRequestContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,10 +1,15 @@
package object package object
import "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/acl" import (
"context"
"fmt"
type RequestContextKeyT struct{} "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/acl"
)
var RequestContextKey = RequestContextKeyT{} type requestContextKeyT struct{}
var requestContextKey = requestContextKeyT{}
// RequestContext is a context passed between middleware handlers. // RequestContext is a context passed between middleware handlers.
type RequestContext struct { type RequestContext struct {
@ -14,3 +19,17 @@ type RequestContext struct {
Role acl.Role Role acl.Role
} }
// NewRequestContext returns a copy of ctx which carries value.
func NewRequestContext(ctx context.Context, value *RequestContext) context.Context {
return context.WithValue(ctx, requestContextKey, value)
}
// FromRequestContext returns RequestContext value stored in ctx if any.
func FromRequestContext(ctx context.Context) (*RequestContext, error) {
reqCtx, ok := ctx.Value(requestContextKey).(*RequestContext)
if !ok {
return nil, fmt.Errorf("no key %s in context", requestContextKey)
}
return reqCtx, nil
}