[##] middleware: Extend test coverage

Signed-off-by: Roman Loginov <r.loginov@yadro.com>
Roman Loginov 2024-05-16 08:15:15 +03:00
parent c718902e2c
commit 4357734719
5 changed files with 866 additions and 0 deletions

View File

@ -80,3 +80,462 @@ func TestReqTypeDetermination(t *testing.T) {
})
}
}
func TestDetermineBucketOperation(t *testing.T) {
const defaultValue = "value"
for _, tc := range []struct {
name string
method string
queryParam map[string]string
excepted string
}{
{
name: "OptionsOperation",
method: http.MethodOptions,
excepted: OptionsOperation,
},
{
name: "HeadBucketOperation",
method: http.MethodHead,
excepted: HeadBucketOperation,
},
{
name: "ListMultipartUploadsOperation",
method: http.MethodGet,
queryParam: map[string]string{UploadsQuery: defaultValue},
excepted: ListMultipartUploadsOperation,
},
{
name: "GetBucketLocationOperation",
method: http.MethodGet,
queryParam: map[string]string{LocationQuery: defaultValue},
excepted: GetBucketLocationOperation,
},
{
name: "GetBucketPolicyOperation",
method: http.MethodGet,
queryParam: map[string]string{PolicyQuery: defaultValue},
excepted: GetBucketPolicyOperation,
},
{
name: "GetBucketLifecycleOperation",
method: http.MethodGet,
queryParam: map[string]string{LifecycleQuery: defaultValue},
excepted: GetBucketLifecycleOperation,
},
{
name: "GetBucketEncryptionOperation",
method: http.MethodGet,
queryParam: map[string]string{EncryptionQuery: defaultValue},
excepted: GetBucketEncryptionOperation,
},
{
name: "GetBucketCorsOperation",
method: http.MethodGet,
queryParam: map[string]string{CorsQuery: defaultValue},
excepted: GetBucketCorsOperation,
},
{
name: "GetBucketACLOperation",
method: http.MethodGet,
queryParam: map[string]string{ACLQuery: defaultValue},
excepted: GetBucketACLOperation,
},
{
name: "GetBucketWebsiteOperation",
method: http.MethodGet,
queryParam: map[string]string{WebsiteQuery: defaultValue},
excepted: GetBucketWebsiteOperation,
},
{
name: "GetBucketAccelerateOperation",
method: http.MethodGet,
queryParam: map[string]string{AccelerateQuery: defaultValue},
excepted: GetBucketAccelerateOperation,
},
{
name: "GetBucketRequestPaymentOperation",
method: http.MethodGet,
queryParam: map[string]string{RequestPaymentQuery: defaultValue},
excepted: GetBucketRequestPaymentOperation,
},
{
name: "GetBucketLoggingOperation",
method: http.MethodGet,
queryParam: map[string]string{LoggingQuery: defaultValue},
excepted: GetBucketLoggingOperation,
},
{
name: "GetBucketReplicationOperation",
method: http.MethodGet,
queryParam: map[string]string{ReplicationQuery: defaultValue},
excepted: GetBucketReplicationOperation,
},
{
name: "GetBucketTaggingOperation",
method: http.MethodGet,
queryParam: map[string]string{TaggingQuery: defaultValue},
excepted: GetBucketTaggingOperation,
},
{
name: "GetBucketObjectLockConfigOperation",
method: http.MethodGet,
queryParam: map[string]string{ObjectLockQuery: defaultValue},
excepted: GetBucketObjectLockConfigOperation,
},
{
name: "GetBucketVersioningOperation",
method: http.MethodGet,
queryParam: map[string]string{VersioningQuery: defaultValue},
excepted: GetBucketVersioningOperation,
},
{
name: "GetBucketNotificationOperation",
method: http.MethodGet,
queryParam: map[string]string{NotificationQuery: defaultValue},
excepted: GetBucketNotificationOperation,
},
{
name: "ListenBucketNotificationOperation",
method: http.MethodGet,
queryParam: map[string]string{EventsQuery: defaultValue},
excepted: ListenBucketNotificationOperation,
},
{
name: "ListBucketObjectVersionsOperation",
method: http.MethodGet,
queryParam: map[string]string{VersionsQuery: defaultValue},
excepted: ListBucketObjectVersionsOperation,
},
{
name: "ListObjectsV2MOperation",
method: http.MethodGet,
queryParam: map[string]string{ListTypeQuery: "2", MetadataQuery: "true"},
excepted: ListObjectsV2MOperation,
},
{
name: "ListObjectsV2Operation",
method: http.MethodGet,
queryParam: map[string]string{ListTypeQuery: "2"},
excepted: ListObjectsV2Operation,
},
{
name: "ListObjectsV1Operation",
method: http.MethodGet,
excepted: ListObjectsV1Operation,
},
{
name: "PutBucketCorsOperation",
method: http.MethodPut,
queryParam: map[string]string{CorsQuery: defaultValue},
excepted: PutBucketCorsOperation,
},
{
name: "PutBucketACLOperation",
method: http.MethodPut,
queryParam: map[string]string{ACLQuery: defaultValue},
excepted: PutBucketACLOperation,
},
{
name: "PutBucketLifecycleOperation",
method: http.MethodPut,
queryParam: map[string]string{LifecycleQuery: defaultValue},
excepted: PutBucketLifecycleOperation,
},
{
name: "PutBucketEncryptionOperation",
method: http.MethodPut,
queryParam: map[string]string{EncryptionQuery: defaultValue},
excepted: PutBucketEncryptionOperation,
},
{
name: "PutBucketPolicyOperation",
method: http.MethodPut,
queryParam: map[string]string{PolicyQuery: defaultValue},
excepted: PutBucketPolicyOperation,
},
{
name: "PutBucketObjectLockConfigOperation",
method: http.MethodPut,
queryParam: map[string]string{ObjectLockQuery: defaultValue},
excepted: PutBucketObjectLockConfigOperation,
},
{
name: "PutBucketTaggingOperation",
method: http.MethodPut,
queryParam: map[string]string{TaggingQuery: defaultValue},
excepted: PutBucketTaggingOperation,
},
{
name: "PutBucketVersioningOperation",
method: http.MethodPut,
queryParam: map[string]string{VersioningQuery: defaultValue},
excepted: PutBucketVersioningOperation,
},
{
name: "PutBucketNotificationOperation",
method: http.MethodPut,
queryParam: map[string]string{NotificationQuery: defaultValue},
excepted: PutBucketNotificationOperation,
},
{
name: "CreateBucketOperation",
method: http.MethodPut,
excepted: CreateBucketOperation,
},
{
name: "DeleteMultipleObjectsOperation",
method: http.MethodPost,
queryParam: map[string]string{DeleteQuery: defaultValue},
excepted: DeleteMultipleObjectsOperation,
},
{
name: "PostObjectOperation",
method: http.MethodPost,
excepted: PostObjectOperation,
},
{
name: "DeleteBucketCorsOperation",
method: http.MethodDelete,
queryParam: map[string]string{CorsQuery: defaultValue},
excepted: DeleteBucketCorsOperation,
},
{
name: "DeleteBucketWebsiteOperation",
method: http.MethodDelete,
queryParam: map[string]string{WebsiteQuery: defaultValue},
excepted: DeleteBucketWebsiteOperation,
},
{
name: "DeleteBucketTaggingOperation",
method: http.MethodDelete,
queryParam: map[string]string{TaggingQuery: defaultValue},
excepted: DeleteBucketTaggingOperation,
},
{
name: "DeleteBucketPolicyOperation",
method: http.MethodDelete,
queryParam: map[string]string{PolicyQuery: defaultValue},
excepted: DeleteBucketPolicyOperation,
},
{
name: "DeleteBucketLifecycleOperation",
method: http.MethodDelete,
queryParam: map[string]string{LifecycleQuery: defaultValue},
excepted: DeleteBucketLifecycleOperation,
},
{
name: "DeleteBucketEncryptionOperation",
method: http.MethodDelete,
queryParam: map[string]string{EncryptionQuery: defaultValue},
excepted: DeleteBucketEncryptionOperation,
},
{
name: "DeleteBucketOperation",
method: http.MethodDelete,
excepted: DeleteBucketOperation,
},
{
name: "UnmatchedBucketOperation",
method: "invalid-method",
excepted: "UnmatchedBucketOperation",
},
} {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.method, "/test", nil)
if tc.queryParam != nil {
addQueryParams(req, tc.queryParam)
}
actual := determineBucketOperation(req)
require.Equal(t, tc.excepted, actual)
})
}
}
func TestDetermineObjectOperation(t *testing.T) {
const (
amzCopySource = "X-Amz-Copy-Source"
defaultValue = "value"
)
for _, tc := range []struct {
name string
method string
queryParam map[string]string
headerKeys []string
excepted string
}{
{
name: "HeadObjectOperation",
method: http.MethodHead,
excepted: HeadObjectOperation,
},
{
name: "ListPartsOperation",
method: http.MethodGet,
queryParam: map[string]string{UploadIDQuery: defaultValue},
excepted: ListPartsOperation,
},
{
name: "GetObjectACLOperation",
method: http.MethodGet,
queryParam: map[string]string{ACLQuery: defaultValue},
excepted: GetObjectACLOperation,
},
{
name: "GetObjectTaggingOperation",
method: http.MethodGet,
queryParam: map[string]string{TaggingQuery: defaultValue},
excepted: GetObjectTaggingOperation,
},
{
name: "GetObjectRetentionOperation",
method: http.MethodGet,
queryParam: map[string]string{RetentionQuery: defaultValue},
excepted: GetObjectRetentionOperation,
},
{
name: "GetObjectLegalHoldOperation",
method: http.MethodGet,
queryParam: map[string]string{LegalQuery: defaultValue},
excepted: GetObjectLegalHoldOperation,
},
{
name: "GetObjectAttributesOperation",
method: http.MethodGet,
queryParam: map[string]string{AttributesQuery: defaultValue},
excepted: GetObjectAttributesOperation,
},
{
name: "GetObjectOperation",
method: http.MethodGet,
excepted: GetObjectOperation,
},
{
name: "UploadPartCopyOperation",
method: http.MethodPut,
queryParam: map[string]string{PartNumberQuery: defaultValue, UploadIDQuery: defaultValue},
headerKeys: []string{amzCopySource},
excepted: UploadPartCopyOperation,
},
{
name: "UploadPartOperation",
method: http.MethodPut,
queryParam: map[string]string{PartNumberQuery: defaultValue, UploadIDQuery: defaultValue},
excepted: UploadPartOperation,
},
{
name: "PutObjectACLOperation",
method: http.MethodPut,
queryParam: map[string]string{ACLQuery: defaultValue},
excepted: PutObjectACLOperation,
},
{
name: "PutObjectTaggingOperation",
method: http.MethodPut,
queryParam: map[string]string{TaggingQuery: defaultValue},
excepted: PutObjectTaggingOperation,
},
{
name: "CopyObjectOperation",
method: http.MethodPut,
headerKeys: []string{amzCopySource},
excepted: CopyObjectOperation,
},
{
name: "PutObjectRetentionOperation",
method: http.MethodPut,
queryParam: map[string]string{RetentionQuery: defaultValue},
excepted: PutObjectRetentionOperation,
},
{
name: "PutObjectLegalHoldOperation",
method: http.MethodPut,
queryParam: map[string]string{LegalHoldQuery: defaultValue},
excepted: PutObjectLegalHoldOperation,
},
{
name: "PutObjectOperation",
method: http.MethodPut,
excepted: PutObjectOperation,
},
{
name: "CompleteMultipartUploadOperation",
method: http.MethodPost,
queryParam: map[string]string{UploadIDQuery: defaultValue},
excepted: CompleteMultipartUploadOperation,
},
{
name: "CreateMultipartUploadOperation",
method: http.MethodPost,
queryParam: map[string]string{UploadsQuery: defaultValue},
excepted: CreateMultipartUploadOperation,
},
{
name: "SelectObjectContentOperation",
method: http.MethodPost,
excepted: SelectObjectContentOperation,
},
{
name: "AbortMultipartUploadOperation",
method: http.MethodDelete,
queryParam: map[string]string{UploadIDQuery: defaultValue},
excepted: AbortMultipartUploadOperation,
},
{
name: "DeleteObjectTaggingOperation",
method: http.MethodDelete,
queryParam: map[string]string{TaggingQuery: defaultValue},
excepted: DeleteObjectTaggingOperation,
},
{
name: "DeleteObjectOperation",
method: http.MethodDelete,
excepted: DeleteObjectOperation,
},
{
name: "UnmatchedObjectOperation",
method: "invalid-method",
excepted: "UnmatchedObjectOperation",
},
} {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.method, "/test", nil)
if tc.queryParam != nil {
addQueryParams(req, tc.queryParam)
}
if tc.headerKeys != nil {
addHeaderParams(req, tc.headerKeys)
}
actual := determineObjectOperation(req)
require.Equal(t, tc.excepted, actual)
})
}
}
func addQueryParams(req *http.Request, pairs map[string]string) {
values := req.URL.Query()
for key, val := range pairs {
values.Add(key, val)
}
req.URL.RawQuery = values.Encode()
}
func addHeaderParams(req *http.Request, keys []string) {
for _, key := range keys {
req.Header.Set(key, "val")
}
}
func TestDetermineGeneralOperation(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
actual := determineGeneralOperation(req)
require.Equal(t, ListBucketsOperation, actual)
req = httptest.NewRequest(http.MethodPost, "/test", nil)
actual = determineGeneralOperation(req)
require.Equal(t, "UnmatchedOperation", actual)
}

View File

@ -0,0 +1,70 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetSourceIP(t *testing.T) {
for _, tc := range []struct {
name string
req *http.Request
}{
{
name: "headers not set",
req: func() *http.Request {
request := httptest.NewRequest(http.MethodGet, "/test", nil)
request.RemoteAddr = "192.0.2.1:1234"
return request
}(),
},
{
name: "headers not set, and the port is not set",
req: func() *http.Request {
request := httptest.NewRequest(http.MethodGet, "/test", nil)
request.RemoteAddr = "192.0.2.1"
return request
}(),
},
{
name: "x-forwarded-for single-host header",
req: func() *http.Request {
request := httptest.NewRequest(http.MethodGet, "/test", nil)
request.Header.Set(xForwardedFor, "192.0.2.1")
return request
}(),
},
{
name: "x-forwarded-for header by multiple hosts",
req: func() *http.Request {
request := httptest.NewRequest(http.MethodGet, "/test", nil)
request.Header.Set(xForwardedFor, "192.0.2.1, 10.1.1.1")
return request
}(),
},
{
name: "x-real-ip header",
req: func() *http.Request {
request := httptest.NewRequest(http.MethodGet, "/test", nil)
request.Header.Set(xRealIP, "192.0.2.1")
return request
}(),
},
{
name: "forwarded header",
req: func() *http.Request {
request := httptest.NewRequest(http.MethodGet, "/test", nil)
request.Header.Set(forwarded, "for=192.0.2.1, 10.1.1.1; proto=https; by=192.0.2.4")
return request
}(),
},
} {
t.Run(tc.name, func(t *testing.T) {
actual := getSourceIP(tc.req)
require.Equal(t, actual, "192.0.2.1")
})
}
}

View File

@ -0,0 +1,43 @@
package middleware
import (
"encoding/xml"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
type testXMLData struct {
XMLName xml.Name `xml:"data"`
Text string `xml:"text"`
}
func TestEncodeResponse(t *testing.T) {
w := httptest.NewRecorder()
err := EncodeToResponse(w, []byte{})
require.Error(t, err)
require.Contains(t, err.Error(), "encode xml response")
err = EncodeToResponse(w, testXMLData{Text: "test"})
require.NoError(t, err)
expectedXML := "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<data><text>test</text></data>"
require.Equal(t, expectedXML, w.Body.String())
}
func TestErrorResponse(t *testing.T) {
errResp := ErrorResponse{Code: "invalid-code"}
actual := errResp.Error()
require.Contains(t, actual, "Error response code")
errResp.Code = "AccessDenied"
actual = errResp.Error()
require.Equal(t, "Access Denied.", actual)
errResp.Message = "Request body is empty."
actual = errResp.Error()
require.Equal(t, "Request body is empty.", actual)
}

View File

@ -0,0 +1,115 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestHTTPResponseCarrierSetGet(t *testing.T) {
const (
testKey1 = "Key"
testValue1 = "Value"
)
respCarrier := httpResponseCarrier{}
respCarrier.resp = httptest.NewRecorder()
actual := respCarrier.Get(testKey1)
require.Equal(t, "", actual)
respCarrier.Set(testKey1, testValue1)
actual = respCarrier.Get(testKey1)
require.Equal(t, testValue1, actual)
}
func TestHTTPResponseCarrierKeys(t *testing.T) {
const (
testKey1 = "Key1"
testKey2 = "Key2"
testKey3 = "Key3"
testValue1 = "Value1"
testValue2 = "Value2"
testValue3 = "Value3"
)
respCarrier := httpResponseCarrier{}
respCarrier.resp = httptest.NewRecorder()
actual := respCarrier.Keys()
require.Equal(t, 0, len(actual))
respCarrier.Set(testKey1, testValue1)
respCarrier.Set(testKey2, testValue2)
respCarrier.Set(testKey3, testValue3)
actual = respCarrier.Keys()
require.Equal(t, 3, len(actual))
require.Contains(t, actual, testKey1)
require.Contains(t, actual, testKey2)
require.Contains(t, actual, testKey3)
}
func TestHTTPRequestCarrierSet(t *testing.T) {
const (
testKey = "Key"
testValue = "Value"
)
reqCarrier := httpRequestCarrier{}
reqCarrier.req = httptest.NewRequest(http.MethodGet, "/test", nil)
reqCarrier.req.Response = httptest.NewRecorder().Result()
actual := reqCarrier.req.Response.Header.Get(testKey)
require.Equal(t, "", actual)
reqCarrier.Set(testKey, testValue)
actual = reqCarrier.req.Response.Header.Get(testKey)
require.Contains(t, testValue, actual)
}
func TestHTTPRequestCarrierGet(t *testing.T) {
const (
testKey = "Key"
testValue = "Value"
)
reqCarrier := httpRequestCarrier{}
reqCarrier.req = httptest.NewRequest(http.MethodGet, "/test", nil)
actual := reqCarrier.Get(testKey)
require.Equal(t, "", actual)
reqCarrier.req.Header.Set(testKey, testValue)
actual = reqCarrier.Get(testKey)
require.Equal(t, testValue, actual)
}
func TestHTTPRequestCarrierKeys(t *testing.T) {
const (
testKey1 = "Key1"
testKey2 = "Key2"
testKey3 = "Key3"
testValue1 = "Value1"
testValue2 = "Value2"
testValue3 = "Value3"
)
reqCarrier := httpRequestCarrier{}
reqCarrier.req = httptest.NewRequest(http.MethodGet, "/test", nil)
actual := reqCarrier.Keys()
require.Equal(t, 0, len(actual))
reqCarrier.req.Header.Set(testKey1, testValue1)
reqCarrier.req.Header.Set(testKey2, testValue2)
reqCarrier.req.Header.Set(testKey3, testValue3)
actual = reqCarrier.Keys()
require.Equal(t, 3, len(actual))
require.Contains(t, actual, testKey1)
require.Contains(t, actual, testKey2)
require.Contains(t, actual, testKey3)
}

View File

@ -0,0 +1,179 @@
package middleware
import (
"context"
"testing"
"time"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
"github.com/stretchr/testify/require"
)
func TestGetBoxData(t *testing.T) {
for _, tc := range []struct {
name string
value any
error string
}{
{
name: "valid",
value: &Box{
AccessBox: &accessbox.Box{},
},
},
{
name: "invalid data",
value: "invalid-data",
error: "couldn't get box from context",
},
{
name: "box does not exist",
error: "couldn't get box from context",
},
{
name: "access box is nil",
value: &Box{},
error: "couldn't get box data from context",
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx := context.WithValue(context.Background(), boxKey, tc.value)
actual, err := GetBoxData(ctx)
if tc.error != "" {
require.Contains(t, err.Error(), tc.error)
return
}
require.NoError(t, err)
require.NotNil(t, actual)
require.NotNil(t, actual.Gate)
})
}
}
func TestGetAuthHeaders(t *testing.T) {
for _, tc := range []struct {
name string
value any
error bool
}{
{
name: "valid",
value: &Box{
AuthHeaders: &AuthHeader{
AccessKeyID: "valid-key",
Region: "valid-region",
SignatureV4: "valid-sign",
},
},
},
{
name: "invalid data",
value: "invalid-data",
error: true,
},
{
name: "box does not exist",
error: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx := context.WithValue(context.Background(), boxKey, tc.value)
actual, err := GetAuthHeaders(ctx)
if tc.error {
require.Contains(t, err.Error(), "couldn't get box from context")
return
}
require.NoError(t, err)
require.Equal(t, tc.value.(*Box).AuthHeaders.AccessKeyID, actual.AccessKeyID)
require.Equal(t, tc.value.(*Box).AuthHeaders.Region, actual.Region)
require.Equal(t, tc.value.(*Box).AuthHeaders.SignatureV4, actual.SignatureV4)
})
}
}
func TestGetClientTime(t *testing.T) {
for _, tc := range []struct {
name string
value any
error string
}{
{
name: "valid",
value: &Box{
ClientTime: time.Now(),
},
},
{
name: "invalid data",
value: "invalid-data",
error: "couldn't get box from context",
},
{
name: "box does not exist",
error: "couldn't get box from context",
},
{
name: "zero time",
value: &Box{
ClientTime: time.Time{},
},
error: "couldn't get client time from context",
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx := context.WithValue(context.Background(), boxKey, tc.value)
actual, err := GetClientTime(ctx)
if tc.error != "" {
require.Contains(t, err.Error(), tc.error)
return
}
require.NoError(t, err)
require.Equal(t, tc.value.(*Box).ClientTime, actual)
})
}
}
func TestGetAccessBoxAttrs(t *testing.T) {
for _, tc := range []struct {
name string
value any
error bool
}{
{
name: "valid",
value: func() *Box {
var attr object.Attribute
attr.SetKey("key")
attr.SetValue("value")
return &Box{Attributes: []object.Attribute{attr}}
}(),
},
{
name: "invalid data",
value: "invalid-data",
error: true,
},
{
name: "box does not exist",
error: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx := context.WithValue(context.Background(), boxKey, tc.value)
actual, err := GetAccessBoxAttrs(ctx)
if tc.error {
require.Contains(t, err.Error(), "couldn't get box from context")
return
}
require.NoError(t, err)
require.Equal(t, len(tc.value.(*Box).Attributes), len(actual))
require.Equal(t, tc.value.(*Box).Attributes[0].Key(), actual[0].Key())
require.Equal(t, tc.value.(*Box).Attributes[0].Value(), actual[0].Value())
})
}
}