diff --git a/api/handler/cors.go b/api/handler/cors.go index c309b45f..ddbc8294 100644 --- a/api/handler/cors.go +++ b/api/handler/cors.go @@ -105,7 +105,7 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) { if reqInfo.BucketName == "" { return } - bktInfo, err := h.obj.GetBucketInfo(ctx, reqInfo.BucketName) + bktInfo, err := h.getBucketInfo(ctx, reqInfo.BucketName) if err != nil { h.reqLogger(ctx).Warn(logs.GetBucketInfo, zap.Error(err)) return @@ -154,7 +154,7 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) { func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) { ctx := r.Context() reqInfo := middleware.GetReqInfo(ctx) - bktInfo, err := h.obj.GetBucketInfo(ctx, reqInfo.BucketName) + bktInfo, err := h.getBucketInfo(ctx, reqInfo.BucketName) if err != nil { h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err) return diff --git a/api/handler/locking_test.go b/api/handler/locking_test.go index d2a85685..b3eb1f3b 100644 --- a/api/handler/locking_test.go +++ b/api/handler/locking_test.go @@ -276,6 +276,7 @@ func TestPutBucketLockConfigurationHandler(t *testing.T) { }{ { name: "bkt not found", + bucket: "not-found-bucket", expectedError: apierr.GetAPIError(apierr.ErrNoSuchBucket), }, { @@ -365,6 +366,7 @@ func TestGetBucketLockConfigurationHandler(t *testing.T) { }{ { name: "bkt not found", + bucket: "not-found-bucket", expectedError: apierr.GetAPIError(apierr.ErrNoSuchBucket), }, { diff --git a/api/handler/patch_test.go b/api/handler/patch_test.go index e4c7e4a8..aa294230 100644 --- a/api/handler/patch_test.go +++ b/api/handler/patch_test.go @@ -102,7 +102,7 @@ func TestPatch(t *testing.T) { res := patchObject(t, tc, bktName, objName, tt.rng, patchPayload, tt.headers) require.Equal(t, data.Quote(hash), res.Object.ETag) } else { - patchObjectErr(t, tc, bktName, objName, tt.rng, patchPayload, tt.headers, tt.code) + patchObjectErr(tc, bktName, objName, tt.rng, patchPayload, tt.headers, tt.code) } }) } @@ -377,7 +377,7 @@ func TestPatchEncryptedObject(t *testing.T) { tc.Handler().PutObjectHandler(w, r) assertStatus(t, w, http.StatusOK) - patchObjectErr(t, tc, bktName, objName, "bytes 2-4/*", []byte("new"), nil, apierr.ErrInternalError) + patchObjectErr(tc, bktName, objName, "bytes 2-4/*", []byte("new"), nil, apierr.ErrInternalError) } func TestPatchMissingHeaders(t *testing.T) { @@ -402,6 +402,14 @@ func TestPatchMissingHeaders(t *testing.T) { assertS3Error(t, w, apierr.GetAPIError(apierr.ErrMissingContentLength)) } +func TestPatchInvalidBucketName(t *testing.T) { + tc := prepareHandlerContext(t) + bktName, objName := "bucket", "object" + createTestBucket(tc, bktName) + + patchObjectErr(tc, "bkt_name", objName, "bytes 2-4/*", []byte("new"), nil, apierr.ErrInvalidBucketName) +} + func TestParsePatchByteRange(t *testing.T) { for _, tt := range []struct { rng string @@ -501,9 +509,9 @@ func patchObjectVersion(t *testing.T, tc *handlerContext, bktName, objName, vers return result } -func patchObjectErr(t *testing.T, tc *handlerContext, bktName, objName, rng string, payload []byte, headers map[string]string, code apierr.ErrorCode) { +func patchObjectErr(tc *handlerContext, bktName, objName, rng string, payload []byte, headers map[string]string, code apierr.ErrorCode) { w := patchObjectBase(tc, bktName, objName, "", rng, payload, headers) - assertS3Error(t, w, apierr.GetAPIError(code)) + assertS3Error(tc.t, w, apierr.GetAPIError(code)) } func patchObjectBase(tc *handlerContext, bktName, objName, version, rng string, payload []byte, headers map[string]string) *httptest.ResponseRecorder { diff --git a/api/handler/util.go b/api/handler/util.go index 625b1b08..bc2d279c 100644 --- a/api/handler/util.go +++ b/api/handler/util.go @@ -55,15 +55,22 @@ func handleDeleteMarker(w http.ResponseWriter, err error) error { } func (h *handler) ResolveBucket(ctx context.Context, bucket string) (*data.BucketInfo, error) { - return h.obj.GetBucketInfo(ctx, bucket) + return h.getBucketInfo(ctx, bucket) } func (h *handler) ResolveCID(ctx context.Context, bucket string) (cid.ID, error) { return h.obj.ResolveCID(ctx, bucket) } +func (h *handler) getBucketInfo(ctx context.Context, bucket string) (*data.BucketInfo, error) { + if err := checkBucketName(bucket); err != nil { + return nil, err + } + return h.obj.GetBucketInfo(ctx, bucket) +} + func (h *handler) getBucketAndCheckOwner(r *http.Request, bucket string, header ...string) (*data.BucketInfo, error) { - bktInfo, err := h.obj.GetBucketInfo(r.Context(), bucket) + bktInfo, err := h.getBucketInfo(r.Context(), bucket) if err != nil { return nil, err }