frostfs-s3-gw/cmd/s3-authmate/modules/utils.go

187 lines
4.6 KiB
Go
Raw Normal View History

package modules
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/authmate"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs/frostfsid/contract"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/pool"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/spf13/viper"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type PoolConfig struct {
Key *keys.PrivateKey
Address string
DialTimeout time.Duration
HealthcheckTimeout time.Duration
StreamTimeout time.Duration
RebalanceInterval time.Duration
}
func createFrostFS(ctx context.Context, log *zap.Logger, cfg PoolConfig) (*frostfs.AuthmateFrostFS, error) {
log.Debug(logs.PrepareConnectionPool)
var prm pool.InitParameters
prm.SetKey(&cfg.Key.PrivateKey)
prm.SetNodeDialTimeout(cfg.DialTimeout)
prm.SetHealthcheckTimeout(cfg.HealthcheckTimeout)
prm.SetNodeStreamTimeout(cfg.StreamTimeout)
prm.SetClientRebalanceInterval(cfg.RebalanceInterval)
prm.SetLogger(log)
prm.AddNode(pool.NewNodeParam(1, cfg.Address, 1))
p, err := pool.NewPool(prm)
if err != nil {
return nil, fmt.Errorf("create pool: %w", err)
}
if err = p.Dial(ctx); err != nil {
return nil, fmt.Errorf("dial pool: %w", err)
}
return frostfs.NewAuthmateFrostFS(frostfs.NewFrostFS(p, cfg.Key)), nil
}
func parsePolicies(val string) (authmate.ContainerPolicies, error) {
if val == "" {
return nil, nil
}
var (
data = []byte(val)
err error
)
if !json.Valid(data) {
if data, err = os.ReadFile(val); err != nil {
return nil, fmt.Errorf("coudln't read json file or provided json is invalid")
}
}
var policies authmate.ContainerPolicies
if err = json.Unmarshal(data, &policies); err != nil {
return nil, fmt.Errorf("unmarshal policies: %w", err)
}
if _, ok := policies[api.DefaultLocationConstraint]; ok {
return nil, fmt.Errorf("config overrides %s location constraint", api.DefaultLocationConstraint)
}
return policies, nil
}
func getJSONRules(val string) ([]byte, error) {
if val == "" {
return nil, nil
}
data := []byte(val)
if json.Valid(data) {
return data, nil
}
if data, err := os.ReadFile(val); err == nil {
if json.Valid(data) {
return data, nil
}
}
return nil, fmt.Errorf("coudln't read json file or provided json is invalid")
}
// getSessionRules reads json session rules.
// It returns true if rules must be skipped.
func getSessionRules(r string) ([]byte, bool, error) {
if r == "none" {
return nil, true, nil
}
data, err := getJSONRules(r)
return data, false, err
}
// getLogger returns new logger depending on appropriate values in viper.Viper
// if logger cannot be built it panics.
func getLogger() *zap.Logger {
if !viper.GetBool(withLogFlag) {
return zap.NewNop()
}
var zapConfig = zap.Config{
Development: true,
Encoding: "console",
Level: zap.NewAtomicLevelAt(zapcore.FatalLevel),
OutputPaths: []string{"stdout"},
EncoderConfig: zapcore.EncoderConfig{
MessageKey: "message",
LevelKey: "level",
EncodeLevel: zapcore.CapitalLevelEncoder,
TimeKey: "time",
EncodeTime: zapcore.ISO8601TimeEncoder,
CallerKey: "caller",
EncodeCaller: zapcore.ShortCallerEncoder,
},
}
if viper.GetBool(debugFlag) {
zapConfig.Level = zap.NewAtomicLevelAt(zapcore.DebugLevel)
}
log, err := zapConfig.Build()
if err != nil {
panic(fmt.Errorf("create logger: %w", err))
}
return log
}
func createFrostFSID(ctx context.Context, log *zap.Logger, cfg contract.Config) (*contract.FrostFSID, error) {
log.Debug(logs.PrepareFrostfsIDClient)
cli, err := contract.New(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("create frostfsid client: %w", err)
}
return cli, nil
}
func registerPublicKey(cli *contract.FrostFSID, namespace string, key *keys.PublicKey) error {
err := cli.Wait(cli.CreateSubject(namespace, key))
if err != nil && !strings.Contains(err.Error(), "subject already exists") {
return err
}
return nil
}
func parseObjectAttrs(attributes string) ([]object.Attribute, error) {
if len(attributes) == 0 {
return nil, nil
}
rawAttrs := strings.Split(attributes, ",")
attrs := make([]object.Attribute, len(rawAttrs))
for i := range rawAttrs {
k, v, found := strings.Cut(rawAttrs[i], "=")
if !found {
return nil, fmt.Errorf("invalid attribute format: %s", rawAttrs[i])
}
attrs[i].SetKey(k)
attrs[i].SetValue(v)
}
return attrs, nil
}