diff --git a/api/middleware/policy.go b/api/middleware/policy.go index 5a7142a2..2501c7a7 100644 --- a/api/middleware/policy.go +++ b/api/middleware/policy.go @@ -190,15 +190,16 @@ func getPolicyRequest(r *http.Request, cfg PolicyConfig, reqType ReqType, bktNam res = fmt.Sprintf(s3.ResourceFormatS3Bucket, bktName) } - properties, err := determineProperties(r, cfg.Decoder, cfg.BucketResolver, cfg.Tagging, reqType, op, bktName, objName, owner, groups, tags) + requestProps, resourceProps, err := determineProperties(r, cfg.Decoder, cfg.BucketResolver, cfg.Tagging, reqType, op, bktName, objName, owner, groups, tags) if err != nil { return nil, nil, nil, fmt.Errorf("determine properties: %w", err) } reqLogOrDefault(r.Context(), cfg.Log).Debug(logs.PolicyRequest, zap.String("action", op), - zap.String("resource", res), zap.Any("properties", properties)) + zap.String("resource", res), zap.Any("request properties", requestProps), + zap.Any("resource properties", resourceProps)) - return testutil.NewRequest(op, testutil.NewResource(res, nil), properties), pk, groups, nil + return testutil.NewRequest(op, testutil.NewResource(res, resourceProps), requestProps), pk, groups, nil } type ReqType int @@ -427,72 +428,59 @@ func determineGeneralOperation(r *http.Request) string { } func determineProperties(r *http.Request, decoder XMLDecoder, resolver BucketResolveFunc, tagging ResourceTagging, reqType ReqType, - op, bktName, objName, owner string, groups []string, tags map[string]string) (map[string]string, error) { - res := map[string]string{ + op, bktName, objName, owner string, groups []string, userClaims map[string]string) (requestProperties map[string]string, resourceProperties map[string]string, err error) { + requestProperties = map[string]string{ s3.PropertyKeyOwner: owner, common.PropertyKeyFrostFSIDGroupID: chain.FormCondSliceContainsValue(groups), common.PropertyKeyFrostFSSourceIP: GetReqInfo(r.Context()).RemoteHost, } queries := GetReqInfo(r.Context()).URL.Query() - for k, v := range tags { - res[fmt.Sprintf(common.PropertyKeyFormatFrostFSIDUserClaim, k)] = v + for k, v := range userClaims { + requestProperties[fmt.Sprintf(common.PropertyKeyFormatFrostFSIDUserClaim, k)] = v } if reqType == objectType { if versionID := queries.Get(QueryVersionID); len(versionID) > 0 { - res[s3.PropertyKeyVersionID] = versionID + requestProperties[s3.PropertyKeyVersionID] = versionID } } if reqType == bucketType && (strings.HasSuffix(op, ListObjectsV1Operation) || strings.HasSuffix(op, ListObjectsV2Operation) || strings.HasSuffix(op, ListBucketObjectVersionsOperation) || strings.HasSuffix(op, ListMultipartUploadsOperation)) { if prefix := queries.Get(QueryPrefix); len(prefix) > 0 { - res[s3.PropertyKeyPrefix] = prefix + requestProperties[s3.PropertyKeyPrefix] = prefix } if delimiter := queries.Get(QueryDelimiter); len(delimiter) > 0 { - res[s3.PropertyKeyDelimiter] = delimiter + requestProperties[s3.PropertyKeyDelimiter] = delimiter } if maxKeys := queries.Get(QueryMaxKeys); len(maxKeys) > 0 { - res[s3.PropertyKeyMaxKeys] = maxKeys + requestProperties[s3.PropertyKeyMaxKeys] = maxKeys } } - tags, err := determineTags(r, decoder, resolver, tagging, reqType, op, bktName, objName, queries.Get(QueryVersionID)) - if err != nil { - return nil, fmt.Errorf("determine tags: %w", err) - } - for k, v := range tags { - res[k] = v - } - - res[s3.PropertyKeyAccessBoxAttrMFA] = "false" + requestProperties[s3.PropertyKeyAccessBoxAttrMFA] = "false" attrs, err := GetAccessBoxAttrs(r.Context()) if err == nil { for _, attr := range attrs { - res[fmt.Sprintf(s3.PropertyKeyFormatAccessBoxAttr, attr.Key())] = attr.Value() + requestProperties[fmt.Sprintf(s3.PropertyKeyFormatAccessBoxAttr, attr.Key())] = attr.Value() } } - return res, nil -} - -func determineTags(r *http.Request, decoder XMLDecoder, resolver BucketResolveFunc, tagging ResourceTagging, reqType ReqType, - op, bktName, objName, versionID string) (map[string]string, error) { - res, err := determineRequestTags(r, decoder, op) + reqTags, err := determineRequestTags(r, decoder, op) if err != nil { - return nil, fmt.Errorf("determine request tags: %w", err) + return nil, nil, fmt.Errorf("determine request tags: %w", err) + } + for k, v := range reqTags { + requestProperties[k] = v } - tags, err := determineResourceTags(r.Context(), reqType, op, bktName, objName, versionID, resolver, tagging) + resourceProperties, err = determineResourceTags(r.Context(), reqType, op, bktName, objName, queries.Get(QueryVersionID), resolver, tagging) if err != nil { - return nil, fmt.Errorf("determine resource tags: %w", err) - } - for k, v := range tags { - res[k] = v + return nil, nil, fmt.Errorf("determine resource tags: %w", err) } - return res, nil + return requestProperties, resourceProperties, nil } func determineRequestTags(r *http.Request, decoder XMLDecoder, op string) (map[string]string, error) {