From f282e877e2dd82b6970a8d3282e646a3cd0d9320 Mon Sep 17 00:00:00 2001
From: Denis Kirillov <denis@nspcc.ru>
Date: Wed, 1 Jun 2022 17:09:28 +0300
Subject: [PATCH] [#484] Handle conditional headers

Signed-off-by: Denis Kirillov <denis@nspcc.ru>
---
 api/handler/attributes.go |  9 ++++++++-
 api/handler/get.go        | 12 ++++--------
 api/handler/head.go       | 12 ++++++++++++
 3 files changed, 24 insertions(+), 9 deletions(-)

diff --git a/api/handler/attributes.go b/api/handler/attributes.go
index 0289280e..ace6d12d 100644
--- a/api/handler/attributes.go
+++ b/api/handler/attributes.go
@@ -38,6 +38,7 @@ type (
 		PartNumberMarker int
 		Attributes       []string
 		VersionID        string
+		Conditional      *conditionalArgs
 	}
 )
 
@@ -90,6 +91,11 @@ func (h *handler) GetObjectAttributesHandler(w http.ResponseWriter, r *http.Requ
 		return
 	}
 
+	if err = checkPreconditions(info, params.Conditional); err != nil {
+		h.logAndSendError(w, "precondition failed", reqInfo, err)
+		return
+	}
+
 	response, err := encodeToObjectAttributesResponse(info, params)
 	if err != nil {
 		h.logAndSendError(w, "couldn't encode object info to response", reqInfo, err)
@@ -152,7 +158,8 @@ func parseGetObjectAttributeArgs(r *http.Request) (*GetObjectAttributesArgs, err
 
 	res.VersionID = queryValues.Get(api.QueryVersionID)
 
-	return res, nil
+	res.Conditional, err = parseConditionalHeaders(r.Header)
+	return res, err
 }
 
 func encodeToObjectAttributesResponse(info *data.ObjectInfo, p *GetObjectAttributesArgs) (*GetObjectAttributesResponse, error) {
diff --git a/api/handler/get.go b/api/handler/get.go
index 0c3e16c8..546ba374 100644
--- a/api/handler/get.go
+++ b/api/handler/get.go
@@ -21,10 +21,6 @@ type conditionalArgs struct {
 	IfNoneMatch       string
 }
 
-type getObjectArgs struct {
-	Conditional *conditionalArgs
-}
-
 func fetchRangeHeader(headers http.Header, fullSize uint64) (*layer.RangeParams, error) {
 	const prefix = "bytes="
 	rangeHeader := headers.Get("Range")
@@ -109,7 +105,7 @@ func (h *handler) GetObjectHandler(w http.ResponseWriter, r *http.Request) {
 		reqInfo = api.GetReqInfo(r.Context())
 	)
 
-	args, err := parseGetObjectArgs(r.Header)
+	conditional, err := parseConditionalHeaders(r.Header)
 	if err != nil {
 		h.logAndSendError(w, "could not parse request params", reqInfo, err)
 		return
@@ -132,7 +128,7 @@ func (h *handler) GetObjectHandler(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	if err = checkPreconditions(info, args.Conditional); err != nil {
+	if err = checkPreconditions(info, conditional); err != nil {
 		h.logAndSendError(w, "precondition failed", reqInfo, err)
 		return
 	}
@@ -194,7 +190,7 @@ func checkPreconditions(info *data.ObjectInfo, args *conditionalArgs) error {
 	return nil
 }
 
-func parseGetObjectArgs(headers http.Header) (*getObjectArgs, error) {
+func parseConditionalHeaders(headers http.Header) (*conditionalArgs, error) {
 	var err error
 	args := &conditionalArgs{
 		IfMatch:     headers.Get(api.IfMatch),
@@ -208,7 +204,7 @@ func parseGetObjectArgs(headers http.Header) (*getObjectArgs, error) {
 		return nil, err
 	}
 
-	return &getObjectArgs{Conditional: args}, nil
+	return args, nil
 }
 
 func parseHTTPTime(data string) (*time.Time, error) {
diff --git a/api/handler/head.go b/api/handler/head.go
index 6070287b..51812b09 100644
--- a/api/handler/head.go
+++ b/api/handler/head.go
@@ -40,6 +40,12 @@ func (h *handler) HeadObjectHandler(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	conditional, err := parseConditionalHeaders(r.Header)
+	if err != nil {
+		h.logAndSendError(w, "could not parse request params", reqInfo, err)
+		return
+	}
+
 	p := &layer.HeadObjectParams{
 		BktInfo:   bktInfo,
 		Object:    reqInfo.ObjectName,
@@ -50,6 +56,12 @@ func (h *handler) HeadObjectHandler(w http.ResponseWriter, r *http.Request) {
 		h.logAndSendError(w, "could not fetch object info", reqInfo, err)
 		return
 	}
+
+	if err = checkPreconditions(info, conditional); err != nil {
+		h.logAndSendError(w, "precondition failed", reqInfo, err)
+		return
+	}
+
 	tagSet, err := h.obj.GetObjectTagging(r.Context(), info)
 	if err != nil && !errors.IsS3Error(err, errors.ErrNoSuchKey) {
 		h.logAndSendError(w, "could not get object tag set", reqInfo, err)