From f4d174e740f741d87e2bd9eabad742a1af1b8dba Mon Sep 17 00:00:00 2001 From: Roman Loginov Date: Thu, 16 May 2024 08:15:15 +0300 Subject: [PATCH] [#387] middleware: Extend test coverage Signed-off-by: Roman Loginov --- api/middleware/policy_test.go | 459 ++++++++++++++++++++++++++++++++ api/middleware/reqinfo_test.go | 70 +++++ api/middleware/response_test.go | 43 +++ api/middleware/tracing_test.go | 115 ++++++++ api/middleware/util_test.go | 179 +++++++++++++ 5 files changed, 866 insertions(+) create mode 100644 api/middleware/reqinfo_test.go create mode 100644 api/middleware/response_test.go create mode 100644 api/middleware/tracing_test.go create mode 100644 api/middleware/util_test.go diff --git a/api/middleware/policy_test.go b/api/middleware/policy_test.go index 3f636c7f..34d3a9c5 100644 --- a/api/middleware/policy_test.go +++ b/api/middleware/policy_test.go @@ -80,3 +80,462 @@ func TestReqTypeDetermination(t *testing.T) { }) } } + +func TestDetermineBucketOperation(t *testing.T) { + const defaultValue = "value" + + for _, tc := range []struct { + name string + method string + queryParam map[string]string + expected string + }{ + { + name: "OptionsOperation", + method: http.MethodOptions, + expected: OptionsOperation, + }, + { + name: "HeadBucketOperation", + method: http.MethodHead, + expected: HeadBucketOperation, + }, + { + name: "ListMultipartUploadsOperation", + method: http.MethodGet, + queryParam: map[string]string{UploadsQuery: defaultValue}, + expected: ListMultipartUploadsOperation, + }, + { + name: "GetBucketLocationOperation", + method: http.MethodGet, + queryParam: map[string]string{LocationQuery: defaultValue}, + expected: GetBucketLocationOperation, + }, + { + name: "GetBucketPolicyOperation", + method: http.MethodGet, + queryParam: map[string]string{PolicyQuery: defaultValue}, + expected: GetBucketPolicyOperation, + }, + { + name: "GetBucketLifecycleOperation", + method: http.MethodGet, + queryParam: map[string]string{LifecycleQuery: defaultValue}, + expected: GetBucketLifecycleOperation, + }, + { + name: "GetBucketEncryptionOperation", + method: http.MethodGet, + queryParam: map[string]string{EncryptionQuery: defaultValue}, + expected: GetBucketEncryptionOperation, + }, + { + name: "GetBucketCorsOperation", + method: http.MethodGet, + queryParam: map[string]string{CorsQuery: defaultValue}, + expected: GetBucketCorsOperation, + }, + { + name: "GetBucketACLOperation", + method: http.MethodGet, + queryParam: map[string]string{ACLQuery: defaultValue}, + expected: GetBucketACLOperation, + }, + { + name: "GetBucketWebsiteOperation", + method: http.MethodGet, + queryParam: map[string]string{WebsiteQuery: defaultValue}, + expected: GetBucketWebsiteOperation, + }, + { + name: "GetBucketAccelerateOperation", + method: http.MethodGet, + queryParam: map[string]string{AccelerateQuery: defaultValue}, + expected: GetBucketAccelerateOperation, + }, + { + name: "GetBucketRequestPaymentOperation", + method: http.MethodGet, + queryParam: map[string]string{RequestPaymentQuery: defaultValue}, + expected: GetBucketRequestPaymentOperation, + }, + { + name: "GetBucketLoggingOperation", + method: http.MethodGet, + queryParam: map[string]string{LoggingQuery: defaultValue}, + expected: GetBucketLoggingOperation, + }, + { + name: "GetBucketReplicationOperation", + method: http.MethodGet, + queryParam: map[string]string{ReplicationQuery: defaultValue}, + expected: GetBucketReplicationOperation, + }, + { + name: "GetBucketTaggingOperation", + method: http.MethodGet, + queryParam: map[string]string{TaggingQuery: defaultValue}, + expected: GetBucketTaggingOperation, + }, + { + name: "GetBucketObjectLockConfigOperation", + method: http.MethodGet, + queryParam: map[string]string{ObjectLockQuery: defaultValue}, + expected: GetBucketObjectLockConfigOperation, + }, + { + name: "GetBucketVersioningOperation", + method: http.MethodGet, + queryParam: map[string]string{VersioningQuery: defaultValue}, + expected: GetBucketVersioningOperation, + }, + { + name: "GetBucketNotificationOperation", + method: http.MethodGet, + queryParam: map[string]string{NotificationQuery: defaultValue}, + expected: GetBucketNotificationOperation, + }, + { + name: "ListenBucketNotificationOperation", + method: http.MethodGet, + queryParam: map[string]string{EventsQuery: defaultValue}, + expected: ListenBucketNotificationOperation, + }, + { + name: "ListBucketObjectVersionsOperation", + method: http.MethodGet, + queryParam: map[string]string{VersionsQuery: defaultValue}, + expected: ListBucketObjectVersionsOperation, + }, + { + name: "ListObjectsV2MOperation", + method: http.MethodGet, + queryParam: map[string]string{ListTypeQuery: "2", MetadataQuery: "true"}, + expected: ListObjectsV2MOperation, + }, + { + name: "ListObjectsV2Operation", + method: http.MethodGet, + queryParam: map[string]string{ListTypeQuery: "2"}, + expected: ListObjectsV2Operation, + }, + { + name: "ListObjectsV1Operation", + method: http.MethodGet, + expected: ListObjectsV1Operation, + }, + { + name: "PutBucketCorsOperation", + method: http.MethodPut, + queryParam: map[string]string{CorsQuery: defaultValue}, + expected: PutBucketCorsOperation, + }, + { + name: "PutBucketACLOperation", + method: http.MethodPut, + queryParam: map[string]string{ACLQuery: defaultValue}, + expected: PutBucketACLOperation, + }, + { + name: "PutBucketLifecycleOperation", + method: http.MethodPut, + queryParam: map[string]string{LifecycleQuery: defaultValue}, + expected: PutBucketLifecycleOperation, + }, + { + name: "PutBucketEncryptionOperation", + method: http.MethodPut, + queryParam: map[string]string{EncryptionQuery: defaultValue}, + expected: PutBucketEncryptionOperation, + }, + { + name: "PutBucketPolicyOperation", + method: http.MethodPut, + queryParam: map[string]string{PolicyQuery: defaultValue}, + expected: PutBucketPolicyOperation, + }, + { + name: "PutBucketObjectLockConfigOperation", + method: http.MethodPut, + queryParam: map[string]string{ObjectLockQuery: defaultValue}, + expected: PutBucketObjectLockConfigOperation, + }, + { + name: "PutBucketTaggingOperation", + method: http.MethodPut, + queryParam: map[string]string{TaggingQuery: defaultValue}, + expected: PutBucketTaggingOperation, + }, + { + name: "PutBucketVersioningOperation", + method: http.MethodPut, + queryParam: map[string]string{VersioningQuery: defaultValue}, + expected: PutBucketVersioningOperation, + }, + { + name: "PutBucketNotificationOperation", + method: http.MethodPut, + queryParam: map[string]string{NotificationQuery: defaultValue}, + expected: PutBucketNotificationOperation, + }, + { + name: "CreateBucketOperation", + method: http.MethodPut, + expected: CreateBucketOperation, + }, + { + name: "DeleteMultipleObjectsOperation", + method: http.MethodPost, + queryParam: map[string]string{DeleteQuery: defaultValue}, + expected: DeleteMultipleObjectsOperation, + }, + { + name: "PostObjectOperation", + method: http.MethodPost, + expected: PostObjectOperation, + }, + { + name: "DeleteBucketCorsOperation", + method: http.MethodDelete, + queryParam: map[string]string{CorsQuery: defaultValue}, + expected: DeleteBucketCorsOperation, + }, + { + name: "DeleteBucketWebsiteOperation", + method: http.MethodDelete, + queryParam: map[string]string{WebsiteQuery: defaultValue}, + expected: DeleteBucketWebsiteOperation, + }, + { + name: "DeleteBucketTaggingOperation", + method: http.MethodDelete, + queryParam: map[string]string{TaggingQuery: defaultValue}, + expected: DeleteBucketTaggingOperation, + }, + { + name: "DeleteBucketPolicyOperation", + method: http.MethodDelete, + queryParam: map[string]string{PolicyQuery: defaultValue}, + expected: DeleteBucketPolicyOperation, + }, + { + name: "DeleteBucketLifecycleOperation", + method: http.MethodDelete, + queryParam: map[string]string{LifecycleQuery: defaultValue}, + expected: DeleteBucketLifecycleOperation, + }, + { + name: "DeleteBucketEncryptionOperation", + method: http.MethodDelete, + queryParam: map[string]string{EncryptionQuery: defaultValue}, + expected: DeleteBucketEncryptionOperation, + }, + { + name: "DeleteBucketOperation", + method: http.MethodDelete, + expected: DeleteBucketOperation, + }, + { + name: "UnmatchedBucketOperation", + method: "invalid-method", + expected: "UnmatchedBucketOperation", + }, + } { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.method, "/test", nil) + if tc.queryParam != nil { + addQueryParams(req, tc.queryParam) + } + + actual := determineBucketOperation(req) + require.Equal(t, tc.expected, actual) + }) + } +} + +func TestDetermineObjectOperation(t *testing.T) { + const ( + amzCopySource = "X-Amz-Copy-Source" + defaultValue = "value" + ) + + for _, tc := range []struct { + name string + method string + queryParam map[string]string + headerKeys []string + expected string + }{ + { + name: "HeadObjectOperation", + method: http.MethodHead, + expected: HeadObjectOperation, + }, + { + name: "ListPartsOperation", + method: http.MethodGet, + queryParam: map[string]string{UploadIDQuery: defaultValue}, + expected: ListPartsOperation, + }, + { + name: "GetObjectACLOperation", + method: http.MethodGet, + queryParam: map[string]string{ACLQuery: defaultValue}, + expected: GetObjectACLOperation, + }, + { + name: "GetObjectTaggingOperation", + method: http.MethodGet, + queryParam: map[string]string{TaggingQuery: defaultValue}, + expected: GetObjectTaggingOperation, + }, + { + name: "GetObjectRetentionOperation", + method: http.MethodGet, + queryParam: map[string]string{RetentionQuery: defaultValue}, + expected: GetObjectRetentionOperation, + }, + { + name: "GetObjectLegalHoldOperation", + method: http.MethodGet, + queryParam: map[string]string{LegalQuery: defaultValue}, + expected: GetObjectLegalHoldOperation, + }, + { + name: "GetObjectAttributesOperation", + method: http.MethodGet, + queryParam: map[string]string{AttributesQuery: defaultValue}, + expected: GetObjectAttributesOperation, + }, + { + name: "GetObjectOperation", + method: http.MethodGet, + expected: GetObjectOperation, + }, + { + name: "UploadPartCopyOperation", + method: http.MethodPut, + queryParam: map[string]string{PartNumberQuery: defaultValue, UploadIDQuery: defaultValue}, + headerKeys: []string{amzCopySource}, + expected: UploadPartCopyOperation, + }, + { + name: "UploadPartOperation", + method: http.MethodPut, + queryParam: map[string]string{PartNumberQuery: defaultValue, UploadIDQuery: defaultValue}, + expected: UploadPartOperation, + }, + { + name: "PutObjectACLOperation", + method: http.MethodPut, + queryParam: map[string]string{ACLQuery: defaultValue}, + expected: PutObjectACLOperation, + }, + { + name: "PutObjectTaggingOperation", + method: http.MethodPut, + queryParam: map[string]string{TaggingQuery: defaultValue}, + expected: PutObjectTaggingOperation, + }, + { + name: "CopyObjectOperation", + method: http.MethodPut, + headerKeys: []string{amzCopySource}, + expected: CopyObjectOperation, + }, + { + name: "PutObjectRetentionOperation", + method: http.MethodPut, + queryParam: map[string]string{RetentionQuery: defaultValue}, + expected: PutObjectRetentionOperation, + }, + { + name: "PutObjectLegalHoldOperation", + method: http.MethodPut, + queryParam: map[string]string{LegalHoldQuery: defaultValue}, + expected: PutObjectLegalHoldOperation, + }, + { + name: "PutObjectOperation", + method: http.MethodPut, + expected: PutObjectOperation, + }, + { + name: "CompleteMultipartUploadOperation", + method: http.MethodPost, + queryParam: map[string]string{UploadIDQuery: defaultValue}, + expected: CompleteMultipartUploadOperation, + }, + { + name: "CreateMultipartUploadOperation", + method: http.MethodPost, + queryParam: map[string]string{UploadsQuery: defaultValue}, + expected: CreateMultipartUploadOperation, + }, + { + name: "SelectObjectContentOperation", + method: http.MethodPost, + expected: SelectObjectContentOperation, + }, + { + name: "AbortMultipartUploadOperation", + method: http.MethodDelete, + queryParam: map[string]string{UploadIDQuery: defaultValue}, + expected: AbortMultipartUploadOperation, + }, + { + name: "DeleteObjectTaggingOperation", + method: http.MethodDelete, + queryParam: map[string]string{TaggingQuery: defaultValue}, + expected: DeleteObjectTaggingOperation, + }, + { + name: "DeleteObjectOperation", + method: http.MethodDelete, + expected: DeleteObjectOperation, + }, + { + name: "UnmatchedObjectOperation", + method: "invalid-method", + expected: "UnmatchedObjectOperation", + }, + } { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.method, "/test", nil) + if tc.queryParam != nil { + addQueryParams(req, tc.queryParam) + } + if tc.headerKeys != nil { + addHeaderParams(req, tc.headerKeys) + } + + actual := determineObjectOperation(req) + require.Equal(t, tc.expected, actual) + }) + } +} + +func addQueryParams(req *http.Request, pairs map[string]string) { + values := req.URL.Query() + for key, val := range pairs { + values.Add(key, val) + } + req.URL.RawQuery = values.Encode() +} + +func addHeaderParams(req *http.Request, keys []string) { + for _, key := range keys { + req.Header.Set(key, "val") + } +} + +func TestDetermineGeneralOperation(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + actual := determineGeneralOperation(req) + require.Equal(t, ListBucketsOperation, actual) + + req = httptest.NewRequest(http.MethodPost, "/test", nil) + actual = determineGeneralOperation(req) + require.Equal(t, "UnmatchedOperation", actual) +} diff --git a/api/middleware/reqinfo_test.go b/api/middleware/reqinfo_test.go new file mode 100644 index 00000000..ca7c1692 --- /dev/null +++ b/api/middleware/reqinfo_test.go @@ -0,0 +1,70 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetSourceIP(t *testing.T) { + for _, tc := range []struct { + name string + req *http.Request + }{ + { + name: "headers not set", + req: func() *http.Request { + request := httptest.NewRequest(http.MethodGet, "/test", nil) + request.RemoteAddr = "192.0.2.1:1234" + return request + }(), + }, + { + name: "headers not set, and the port is not set", + req: func() *http.Request { + request := httptest.NewRequest(http.MethodGet, "/test", nil) + request.RemoteAddr = "192.0.2.1" + return request + }(), + }, + { + name: "x-forwarded-for single-host header", + req: func() *http.Request { + request := httptest.NewRequest(http.MethodGet, "/test", nil) + request.Header.Set(xForwardedFor, "192.0.2.1") + return request + }(), + }, + { + name: "x-forwarded-for header by multiple hosts", + req: func() *http.Request { + request := httptest.NewRequest(http.MethodGet, "/test", nil) + request.Header.Set(xForwardedFor, "192.0.2.1, 10.1.1.1") + return request + }(), + }, + { + name: "x-real-ip header", + req: func() *http.Request { + request := httptest.NewRequest(http.MethodGet, "/test", nil) + request.Header.Set(xRealIP, "192.0.2.1") + return request + }(), + }, + { + name: "forwarded header", + req: func() *http.Request { + request := httptest.NewRequest(http.MethodGet, "/test", nil) + request.Header.Set(forwarded, "for=192.0.2.1, 10.1.1.1; proto=https; by=192.0.2.4") + return request + }(), + }, + } { + t.Run(tc.name, func(t *testing.T) { + actual := getSourceIP(tc.req) + require.Equal(t, actual, "192.0.2.1") + }) + } +} diff --git a/api/middleware/response_test.go b/api/middleware/response_test.go new file mode 100644 index 00000000..6653e5f8 --- /dev/null +++ b/api/middleware/response_test.go @@ -0,0 +1,43 @@ +package middleware + +import ( + "encoding/xml" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +type testXMLData struct { + XMLName xml.Name `xml:"data"` + Text string `xml:"text"` +} + +func TestEncodeResponse(t *testing.T) { + w := httptest.NewRecorder() + + err := EncodeToResponse(w, []byte{}) + require.Error(t, err) + require.Contains(t, err.Error(), "encode xml response") + + err = EncodeToResponse(w, testXMLData{Text: "test"}) + require.NoError(t, err) + + expectedXML := "\n\ntest" + require.Equal(t, expectedXML, w.Body.String()) +} + +func TestErrorResponse(t *testing.T) { + errResp := ErrorResponse{Code: "invalid-code"} + + actual := errResp.Error() + require.Contains(t, actual, "Error response code") + + errResp.Code = "AccessDenied" + actual = errResp.Error() + require.Equal(t, "Access Denied.", actual) + + errResp.Message = "Request body is empty." + actual = errResp.Error() + require.Equal(t, "Request body is empty.", actual) +} diff --git a/api/middleware/tracing_test.go b/api/middleware/tracing_test.go new file mode 100644 index 00000000..830473f9 --- /dev/null +++ b/api/middleware/tracing_test.go @@ -0,0 +1,115 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHTTPResponseCarrierSetGet(t *testing.T) { + const ( + testKey1 = "Key" + testValue1 = "Value" + ) + + respCarrier := httpResponseCarrier{} + respCarrier.resp = httptest.NewRecorder() + + actual := respCarrier.Get(testKey1) + require.Equal(t, "", actual) + + respCarrier.Set(testKey1, testValue1) + actual = respCarrier.Get(testKey1) + require.Equal(t, testValue1, actual) +} + +func TestHTTPResponseCarrierKeys(t *testing.T) { + const ( + testKey1 = "Key1" + testKey2 = "Key2" + testKey3 = "Key3" + testValue1 = "Value1" + testValue2 = "Value2" + testValue3 = "Value3" + ) + + respCarrier := httpResponseCarrier{} + respCarrier.resp = httptest.NewRecorder() + + actual := respCarrier.Keys() + require.Equal(t, 0, len(actual)) + + respCarrier.Set(testKey1, testValue1) + respCarrier.Set(testKey2, testValue2) + respCarrier.Set(testKey3, testValue3) + + actual = respCarrier.Keys() + require.Equal(t, 3, len(actual)) + require.Contains(t, actual, testKey1) + require.Contains(t, actual, testKey2) + require.Contains(t, actual, testKey3) +} + +func TestHTTPRequestCarrierSet(t *testing.T) { + const ( + testKey = "Key" + testValue = "Value" + ) + + reqCarrier := httpRequestCarrier{} + reqCarrier.req = httptest.NewRequest(http.MethodGet, "/test", nil) + reqCarrier.req.Response = httptest.NewRecorder().Result() + + actual := reqCarrier.req.Response.Header.Get(testKey) + require.Equal(t, "", actual) + + reqCarrier.Set(testKey, testValue) + actual = reqCarrier.req.Response.Header.Get(testKey) + require.Contains(t, testValue, actual) +} + +func TestHTTPRequestCarrierGet(t *testing.T) { + const ( + testKey = "Key" + testValue = "Value" + ) + + reqCarrier := httpRequestCarrier{} + reqCarrier.req = httptest.NewRequest(http.MethodGet, "/test", nil) + + actual := reqCarrier.Get(testKey) + require.Equal(t, "", actual) + + reqCarrier.req.Header.Set(testKey, testValue) + actual = reqCarrier.Get(testKey) + require.Equal(t, testValue, actual) +} + +func TestHTTPRequestCarrierKeys(t *testing.T) { + const ( + testKey1 = "Key1" + testKey2 = "Key2" + testKey3 = "Key3" + testValue1 = "Value1" + testValue2 = "Value2" + testValue3 = "Value3" + ) + + reqCarrier := httpRequestCarrier{} + reqCarrier.req = httptest.NewRequest(http.MethodGet, "/test", nil) + + actual := reqCarrier.Keys() + require.Equal(t, 0, len(actual)) + + reqCarrier.req.Header.Set(testKey1, testValue1) + reqCarrier.req.Header.Set(testKey2, testValue2) + reqCarrier.req.Header.Set(testKey3, testValue3) + + actual = reqCarrier.Keys() + require.Equal(t, 3, len(actual)) + require.Contains(t, actual, testKey1) + require.Contains(t, actual, testKey2) + require.Contains(t, actual, testKey3) +} diff --git a/api/middleware/util_test.go b/api/middleware/util_test.go new file mode 100644 index 00000000..d4a591d5 --- /dev/null +++ b/api/middleware/util_test.go @@ -0,0 +1,179 @@ +package middleware + +import ( + "context" + "testing" + "time" + + "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object" + "github.com/stretchr/testify/require" +) + +func TestGetBoxData(t *testing.T) { + for _, tc := range []struct { + name string + value any + error string + }{ + { + name: "valid", + value: &Box{ + AccessBox: &accessbox.Box{}, + }, + }, + { + name: "invalid data", + value: "invalid-data", + error: "couldn't get box from context", + }, + { + name: "box does not exist", + error: "couldn't get box from context", + }, + { + name: "access box is nil", + value: &Box{}, + error: "couldn't get box data from context", + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx := context.WithValue(context.Background(), boxKey, tc.value) + actual, err := GetBoxData(ctx) + if tc.error != "" { + require.Contains(t, err.Error(), tc.error) + return + } + + require.NoError(t, err) + require.NotNil(t, actual) + require.NotNil(t, actual.Gate) + }) + } +} + +func TestGetAuthHeaders(t *testing.T) { + for _, tc := range []struct { + name string + value any + error bool + }{ + { + name: "valid", + value: &Box{ + AuthHeaders: &AuthHeader{ + AccessKeyID: "valid-key", + Region: "valid-region", + SignatureV4: "valid-sign", + }, + }, + }, + { + name: "invalid data", + value: "invalid-data", + error: true, + }, + { + name: "box does not exist", + error: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx := context.WithValue(context.Background(), boxKey, tc.value) + actual, err := GetAuthHeaders(ctx) + if tc.error { + require.Contains(t, err.Error(), "couldn't get box from context") + return + } + + require.NoError(t, err) + require.Equal(t, tc.value.(*Box).AuthHeaders.AccessKeyID, actual.AccessKeyID) + require.Equal(t, tc.value.(*Box).AuthHeaders.Region, actual.Region) + require.Equal(t, tc.value.(*Box).AuthHeaders.SignatureV4, actual.SignatureV4) + }) + } +} + +func TestGetClientTime(t *testing.T) { + for _, tc := range []struct { + name string + value any + error string + }{ + { + name: "valid", + value: &Box{ + ClientTime: time.Now(), + }, + }, + { + name: "invalid data", + value: "invalid-data", + error: "couldn't get box from context", + }, + { + name: "box does not exist", + error: "couldn't get box from context", + }, + { + name: "zero time", + value: &Box{ + ClientTime: time.Time{}, + }, + error: "couldn't get client time from context", + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx := context.WithValue(context.Background(), boxKey, tc.value) + actual, err := GetClientTime(ctx) + if tc.error != "" { + require.Contains(t, err.Error(), tc.error) + return + } + + require.NoError(t, err) + require.Equal(t, tc.value.(*Box).ClientTime, actual) + }) + } +} + +func TestGetAccessBoxAttrs(t *testing.T) { + for _, tc := range []struct { + name string + value any + error bool + }{ + { + name: "valid", + value: func() *Box { + var attr object.Attribute + attr.SetKey("key") + attr.SetValue("value") + return &Box{Attributes: []object.Attribute{attr}} + }(), + }, + { + name: "invalid data", + value: "invalid-data", + error: true, + }, + { + name: "box does not exist", + error: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx := context.WithValue(context.Background(), boxKey, tc.value) + actual, err := GetAccessBoxAttrs(ctx) + if tc.error { + require.Contains(t, err.Error(), "couldn't get box from context") + return + } + + require.NoError(t, err) + require.Equal(t, len(tc.value.(*Box).Attributes), len(actual)) + require.Equal(t, tc.value.(*Box).Attributes[0].Key(), actual[0].Key()) + require.Equal(t, tc.value.(*Box).Attributes[0].Value(), actual[0].Value()) + }) + } +}