objectsvc: Make requestContextKey private
1. Disallow storing wrongly typed value by this key. 2. This is the way described in context.Value() documentation. Signed-off-by: Evgenii Stratonikov <e.stratonikov@yadro.com>
This commit is contained in:
parent
08c6578156
commit
2d3bed89d2
3 changed files with 34 additions and 28 deletions
|
@ -112,7 +112,7 @@ type wrappedGetObjectStream struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wrappedGetObjectStream) Context() context.Context {
|
func (w *wrappedGetObjectStream) Context() context.Context {
|
||||||
return context.WithValue(w.GetObjectStream.Context(), object.RequestContextKey, &object.RequestContext{
|
return object.NewRequestContext(w.GetObjectStream.Context(), &object.RequestContext{
|
||||||
Namespace: w.requestInfo.ContainerNamespace(),
|
Namespace: w.requestInfo.ContainerNamespace(),
|
||||||
SenderKey: w.requestInfo.SenderKey(),
|
SenderKey: w.requestInfo.SenderKey(),
|
||||||
Role: w.requestInfo.RequestRole(),
|
Role: w.requestInfo.RequestRole(),
|
||||||
|
@ -135,7 +135,7 @@ type wrappedRangeStream struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wrappedRangeStream) Context() context.Context {
|
func (w *wrappedRangeStream) Context() context.Context {
|
||||||
return context.WithValue(w.GetObjectRangeStream.Context(), object.RequestContextKey, &object.RequestContext{
|
return object.NewRequestContext(w.GetObjectRangeStream.Context(), &object.RequestContext{
|
||||||
Namespace: w.requestInfo.ContainerNamespace(),
|
Namespace: w.requestInfo.ContainerNamespace(),
|
||||||
SenderKey: w.requestInfo.SenderKey(),
|
SenderKey: w.requestInfo.SenderKey(),
|
||||||
Role: w.requestInfo.RequestRole(),
|
Role: w.requestInfo.RequestRole(),
|
||||||
|
@ -158,7 +158,7 @@ type wrappedSearchStream struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wrappedSearchStream) Context() context.Context {
|
func (w *wrappedSearchStream) Context() context.Context {
|
||||||
return context.WithValue(w.SearchStream.Context(), object.RequestContextKey, &object.RequestContext{
|
return object.NewRequestContext(w.SearchStream.Context(), &object.RequestContext{
|
||||||
Namespace: w.requestInfo.ContainerNamespace(),
|
Namespace: w.requestInfo.ContainerNamespace(),
|
||||||
SenderKey: w.requestInfo.SenderKey(),
|
SenderKey: w.requestInfo.SenderKey(),
|
||||||
Role: w.requestInfo.RequestRole(),
|
Role: w.requestInfo.RequestRole(),
|
||||||
|
@ -457,7 +457,7 @@ func (b Service) GetRange(request *objectV2.GetRangeRequest, stream object.GetOb
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestContext(ctx context.Context, reqInfo RequestInfo) context.Context {
|
func requestContext(ctx context.Context, reqInfo RequestInfo) context.Context {
|
||||||
return context.WithValue(ctx, object.RequestContextKey, &object.RequestContext{
|
return object.NewRequestContext(ctx, &object.RequestContext{
|
||||||
Namespace: reqInfo.ContainerNamespace(),
|
Namespace: reqInfo.ContainerNamespace(),
|
||||||
SenderKey: reqInfo.SenderKey(),
|
SenderKey: reqInfo.SenderKey(),
|
||||||
Role: reqInfo.RequestRole(),
|
Role: reqInfo.RequestRole(),
|
||||||
|
|
|
@ -3,7 +3,6 @@ package ape
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object"
|
objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object"
|
||||||
"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/refs"
|
"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/refs"
|
||||||
|
@ -88,25 +87,13 @@ func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error {
|
||||||
return g.GetObjectStream.Send(resp)
|
return g.GetObjectStream.Send(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestContext(ctx context.Context) (*objectSvc.RequestContext, error) {
|
|
||||||
untyped := ctx.Value(objectSvc.RequestContextKey)
|
|
||||||
if untyped == nil {
|
|
||||||
return nil, fmt.Errorf("no key %s in context", objectSvc.RequestContextKey)
|
|
||||||
}
|
|
||||||
rc, ok := untyped.(*objectSvc.RequestContext)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("failed cast to RequestContext")
|
|
||||||
}
|
|
||||||
return rc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Service) Get(request *objectV2.GetRequest, stream objectSvc.GetObjectStream) error {
|
func (c *Service) Get(request *objectV2.GetRequest, stream objectSvc.GetObjectStream) error {
|
||||||
cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID())
|
cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return toStatusErr(err)
|
return toStatusErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, err := requestContext(stream.Context())
|
reqCtx, err := objectSvc.FromRequestContext(stream.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return toStatusErr(err)
|
return toStatusErr(err)
|
||||||
}
|
}
|
||||||
|
@ -137,7 +124,7 @@ type putStreamBasicChecker struct {
|
||||||
|
|
||||||
func (p *putStreamBasicChecker) Send(ctx context.Context, request *objectV2.PutRequest) error {
|
func (p *putStreamBasicChecker) Send(ctx context.Context, request *objectV2.PutRequest) error {
|
||||||
if partInit, ok := request.GetBody().GetObjectPart().(*objectV2.PutObjectPartInit); ok {
|
if partInit, ok := request.GetBody().GetObjectPart().(*objectV2.PutObjectPartInit); ok {
|
||||||
reqCtx, err := requestContext(ctx)
|
reqCtx, err := objectSvc.FromRequestContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return toStatusErr(err)
|
return toStatusErr(err)
|
||||||
}
|
}
|
||||||
|
@ -184,7 +171,7 @@ func (c *Service) Head(ctx context.Context, request *objectV2.HeadRequest) (*obj
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, err := requestContext(ctx)
|
reqCtx, err := objectSvc.FromRequestContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -248,7 +235,7 @@ func (c *Service) Search(request *objectV2.SearchRequest, stream objectSvc.Searc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, err := requestContext(stream.Context())
|
reqCtx, err := objectSvc.FromRequestContext(stream.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return toStatusErr(err)
|
return toStatusErr(err)
|
||||||
}
|
}
|
||||||
|
@ -273,7 +260,7 @@ func (c *Service) Delete(ctx context.Context, request *objectV2.DeleteRequest) (
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, err := requestContext(ctx)
|
reqCtx, err := objectSvc.FromRequestContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -304,7 +291,7 @@ func (c *Service) GetRange(request *objectV2.GetRangeRequest, stream objectSvc.G
|
||||||
return toStatusErr(err)
|
return toStatusErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, err := requestContext(stream.Context())
|
reqCtx, err := objectSvc.FromRequestContext(stream.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return toStatusErr(err)
|
return toStatusErr(err)
|
||||||
}
|
}
|
||||||
|
@ -330,7 +317,7 @@ func (c *Service) GetRangeHash(ctx context.Context, request *objectV2.GetRangeHa
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, err := requestContext(ctx)
|
reqCtx, err := objectSvc.FromRequestContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -365,7 +352,7 @@ func (c *Service) PutSingle(ctx context.Context, request *objectV2.PutSingleRequ
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, err := requestContext(ctx)
|
reqCtx, err := objectSvc.FromRequestContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
package object
|
package object
|
||||||
|
|
||||||
import "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/acl"
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
type RequestContextKeyT struct{}
|
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/acl"
|
||||||
|
)
|
||||||
|
|
||||||
var RequestContextKey = RequestContextKeyT{}
|
type requestContextKeyT struct{}
|
||||||
|
|
||||||
|
var requestContextKey = requestContextKeyT{}
|
||||||
|
|
||||||
// RequestContext is a context passed between middleware handlers.
|
// RequestContext is a context passed between middleware handlers.
|
||||||
type RequestContext struct {
|
type RequestContext struct {
|
||||||
|
@ -14,3 +19,17 @@ type RequestContext struct {
|
||||||
|
|
||||||
Role acl.Role
|
Role acl.Role
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewRequestContext returns a copy of ctx which carries value.
|
||||||
|
func NewRequestContext(ctx context.Context, value *RequestContext) context.Context {
|
||||||
|
return context.WithValue(ctx, requestContextKey, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromRequestContext returns RequestContext value stored in ctx if any.
|
||||||
|
func FromRequestContext(ctx context.Context) (*RequestContext, error) {
|
||||||
|
reqCtx, ok := ctx.Value(requestContextKey).(*RequestContext)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no key %s in context", requestContextKey)
|
||||||
|
}
|
||||||
|
return reqCtx, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue