[#XX] Check bucket name not only during creation

Signed-off-by: Denis Kirillov <d.kirillov@yadro.com>
This commit is contained in:
Denis Kirillov 2024-11-19 14:48:59 +03:00
parent 368c7d2acd
commit 543d47db4b
4 changed files with 25 additions and 8 deletions

View file

@ -105,7 +105,7 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) {
if reqInfo.BucketName == "" { if reqInfo.BucketName == "" {
return return
} }
bktInfo, err := h.obj.GetBucketInfo(ctx, reqInfo.BucketName) bktInfo, err := h.getBucketInfo(ctx, reqInfo.BucketName)
if err != nil { if err != nil {
h.reqLogger(ctx).Warn(logs.GetBucketInfo, zap.Error(err)) h.reqLogger(ctx).Warn(logs.GetBucketInfo, zap.Error(err))
return return
@ -154,7 +154,7 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) {
func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) { func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx) reqInfo := middleware.GetReqInfo(ctx)
bktInfo, err := h.obj.GetBucketInfo(ctx, reqInfo.BucketName) bktInfo, err := h.getBucketInfo(ctx, reqInfo.BucketName)
if err != nil { if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err) h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
return return

View file

@ -276,6 +276,7 @@ func TestPutBucketLockConfigurationHandler(t *testing.T) {
}{ }{
{ {
name: "bkt not found", name: "bkt not found",
bucket: "not-found-bucket",
expectedError: apierr.GetAPIError(apierr.ErrNoSuchBucket), expectedError: apierr.GetAPIError(apierr.ErrNoSuchBucket),
}, },
{ {
@ -365,6 +366,7 @@ func TestGetBucketLockConfigurationHandler(t *testing.T) {
}{ }{
{ {
name: "bkt not found", name: "bkt not found",
bucket: "not-found-bucket",
expectedError: apierr.GetAPIError(apierr.ErrNoSuchBucket), expectedError: apierr.GetAPIError(apierr.ErrNoSuchBucket),
}, },
{ {

View file

@ -102,7 +102,7 @@ func TestPatch(t *testing.T) {
res := patchObject(t, tc, bktName, objName, tt.rng, patchPayload, tt.headers) res := patchObject(t, tc, bktName, objName, tt.rng, patchPayload, tt.headers)
require.Equal(t, data.Quote(hash), res.Object.ETag) require.Equal(t, data.Quote(hash), res.Object.ETag)
} else { } else {
patchObjectErr(t, tc, bktName, objName, tt.rng, patchPayload, tt.headers, tt.code) patchObjectErr(tc, bktName, objName, tt.rng, patchPayload, tt.headers, tt.code)
} }
}) })
} }
@ -377,7 +377,7 @@ func TestPatchEncryptedObject(t *testing.T) {
tc.Handler().PutObjectHandler(w, r) tc.Handler().PutObjectHandler(w, r)
assertStatus(t, w, http.StatusOK) assertStatus(t, w, http.StatusOK)
patchObjectErr(t, tc, bktName, objName, "bytes 2-4/*", []byte("new"), nil, apierr.ErrInternalError) patchObjectErr(tc, bktName, objName, "bytes 2-4/*", []byte("new"), nil, apierr.ErrInternalError)
} }
func TestPatchMissingHeaders(t *testing.T) { func TestPatchMissingHeaders(t *testing.T) {
@ -402,6 +402,14 @@ func TestPatchMissingHeaders(t *testing.T) {
assertS3Error(t, w, apierr.GetAPIError(apierr.ErrMissingContentLength)) assertS3Error(t, w, apierr.GetAPIError(apierr.ErrMissingContentLength))
} }
func TestPatchInvalidBucketName(t *testing.T) {
tc := prepareHandlerContext(t)
bktName, objName := "bucket", "object"
createTestBucket(tc, bktName)
patchObjectErr(tc, "bkt_name", objName, "bytes 2-4/*", []byte("new"), nil, apierr.ErrInvalidBucketName)
}
func TestParsePatchByteRange(t *testing.T) { func TestParsePatchByteRange(t *testing.T) {
for _, tt := range []struct { for _, tt := range []struct {
rng string rng string
@ -501,9 +509,9 @@ func patchObjectVersion(t *testing.T, tc *handlerContext, bktName, objName, vers
return result return result
} }
func patchObjectErr(t *testing.T, tc *handlerContext, bktName, objName, rng string, payload []byte, headers map[string]string, code apierr.ErrorCode) { func patchObjectErr(tc *handlerContext, bktName, objName, rng string, payload []byte, headers map[string]string, code apierr.ErrorCode) {
w := patchObjectBase(tc, bktName, objName, "", rng, payload, headers) w := patchObjectBase(tc, bktName, objName, "", rng, payload, headers)
assertS3Error(t, w, apierr.GetAPIError(code)) assertS3Error(tc.t, w, apierr.GetAPIError(code))
} }
func patchObjectBase(tc *handlerContext, bktName, objName, version, rng string, payload []byte, headers map[string]string) *httptest.ResponseRecorder { func patchObjectBase(tc *handlerContext, bktName, objName, version, rng string, payload []byte, headers map[string]string) *httptest.ResponseRecorder {

View file

@ -55,15 +55,22 @@ func handleDeleteMarker(w http.ResponseWriter, err error) error {
} }
func (h *handler) ResolveBucket(ctx context.Context, bucket string) (*data.BucketInfo, error) { func (h *handler) ResolveBucket(ctx context.Context, bucket string) (*data.BucketInfo, error) {
return h.obj.GetBucketInfo(ctx, bucket) return h.getBucketInfo(ctx, bucket)
} }
func (h *handler) ResolveCID(ctx context.Context, bucket string) (cid.ID, error) { func (h *handler) ResolveCID(ctx context.Context, bucket string) (cid.ID, error) {
return h.obj.ResolveCID(ctx, bucket) return h.obj.ResolveCID(ctx, bucket)
} }
func (h *handler) getBucketInfo(ctx context.Context, bucket string) (*data.BucketInfo, error) {
if err := checkBucketName(bucket); err != nil {
return nil, err
}
return h.obj.GetBucketInfo(ctx, bucket)
}
func (h *handler) getBucketAndCheckOwner(r *http.Request, bucket string, header ...string) (*data.BucketInfo, error) { func (h *handler) getBucketAndCheckOwner(r *http.Request, bucket string, header ...string) (*data.BucketInfo, error) {
bktInfo, err := h.obj.GetBucketInfo(r.Context(), bucket) bktInfo, err := h.getBucketInfo(r.Context(), bucket)
if err != nil { if err != nil {
return nil, err return nil, err
} }