distribution/vendor/github.com/google/s2a-go/s2a.go

413 lines
14 KiB
Go
Raw Normal View History

/*
*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package s2a provides the S2A transport credentials used by a gRPC
// application.
package s2a
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/golang/protobuf/proto"
"github.com/google/s2a-go/fallback"
"github.com/google/s2a-go/internal/handshaker"
"github.com/google/s2a-go/internal/handshaker/service"
"github.com/google/s2a-go/internal/tokenmanager"
"github.com/google/s2a-go/internal/v2"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
)
const (
s2aSecurityProtocol = "tls"
// defaultTimeout specifies the default server handshake timeout.
defaultTimeout = 30.0 * time.Second
)
// s2aTransportCreds are the transport credentials required for establishing
// a secure connection using the S2A. They implement the
// credentials.TransportCredentials interface.
type s2aTransportCreds struct {
info *credentials.ProtocolInfo
minTLSVersion commonpb.TLSVersion
maxTLSVersion commonpb.TLSVersion
// tlsCiphersuites contains the ciphersuites used in the S2A connection.
// Note that these are currently unconfigurable.
tlsCiphersuites []commonpb.Ciphersuite
// localIdentity should only be used by the client.
localIdentity *commonpb.Identity
// localIdentities should only be used by the server.
localIdentities []*commonpb.Identity
// targetIdentities should only be used by the client.
targetIdentities []*commonpb.Identity
isClient bool
s2aAddr string
ensureProcessSessionTickets *sync.WaitGroup
}
// NewClientCreds returns a client-side transport credentials object that uses
// the S2A to establish a secure connection with a server.
func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
if opts == nil {
return nil, errors.New("nil client options")
}
var targetIdentities []*commonpb.Identity
for _, targetIdentity := range opts.TargetIdentities {
protoTargetIdentity, err := toProtoIdentity(targetIdentity)
if err != nil {
return nil, err
}
targetIdentities = append(targetIdentities, protoTargetIdentity)
}
localIdentity, err := toProtoIdentity(opts.LocalIdentity)
if err != nil {
return nil, err
}
if opts.EnableLegacyMode {
return &s2aTransportCreds{
info: &credentials.ProtocolInfo{
SecurityProtocol: s2aSecurityProtocol,
},
minTLSVersion: commonpb.TLSVersion_TLS1_3,
maxTLSVersion: commonpb.TLSVersion_TLS1_3,
tlsCiphersuites: []commonpb.Ciphersuite{
commonpb.Ciphersuite_AES_128_GCM_SHA256,
commonpb.Ciphersuite_AES_256_GCM_SHA384,
commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
},
localIdentity: localIdentity,
targetIdentities: targetIdentities,
isClient: true,
s2aAddr: opts.S2AAddress,
ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
}, nil
}
verificationMode := getVerificationMode(opts.VerificationMode)
var fallbackFunc fallback.ClientHandshake
if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
}
return v2.NewClientCreds(opts.S2AAddress, localIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
}
// NewServerCreds returns a server-side transport credentials object that uses
// the S2A to establish a secure connection with a client.
func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
if opts == nil {
return nil, errors.New("nil server options")
}
var localIdentities []*commonpb.Identity
for _, localIdentity := range opts.LocalIdentities {
protoLocalIdentity, err := toProtoIdentity(localIdentity)
if err != nil {
return nil, err
}
localIdentities = append(localIdentities, protoLocalIdentity)
}
if opts.EnableLegacyMode {
return &s2aTransportCreds{
info: &credentials.ProtocolInfo{
SecurityProtocol: s2aSecurityProtocol,
},
minTLSVersion: commonpb.TLSVersion_TLS1_3,
maxTLSVersion: commonpb.TLSVersion_TLS1_3,
tlsCiphersuites: []commonpb.Ciphersuite{
commonpb.Ciphersuite_AES_128_GCM_SHA256,
commonpb.Ciphersuite_AES_256_GCM_SHA384,
commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
},
localIdentities: localIdentities,
isClient: false,
s2aAddr: opts.S2AAddress,
}, nil
}
verificationMode := getVerificationMode(opts.VerificationMode)
return v2.NewServerCreds(opts.S2AAddress, localIdentities, verificationMode, opts.getS2AStream)
}
// ClientHandshake initiates a client-side TLS handshake using the S2A.
func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if !c.isClient {
return nil, nil, errors.New("client handshake called using server transport credentials")
}
// Connect to the S2A.
hsConn, err := service.Dial(c.s2aAddr)
if err != nil {
grpclog.Infof("Failed to connect to S2A: %v", err)
return nil, nil, err
}
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()
opts := &handshaker.ClientHandshakerOptions{
MinTLSVersion: c.minTLSVersion,
MaxTLSVersion: c.maxTLSVersion,
TLSCiphersuites: c.tlsCiphersuites,
TargetIdentities: c.targetIdentities,
LocalIdentity: c.localIdentity,
TargetName: serverAuthority,
EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
}
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
if err != nil {
grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
return nil, nil, err
}
defer func() {
if err != nil {
if closeErr := chs.Close(); closeErr != nil {
grpclog.Infof("Close failed unexpectedly: %v", err)
err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
}
}
}()
secConn, authInfo, err := chs.ClientHandshake(context.Background())
if err != nil {
grpclog.Infof("Handshake failed: %v", err)
return nil, nil, err
}
return secConn, authInfo, nil
}
// ServerHandshake initiates a server-side TLS handshake using the S2A.
func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if c.isClient {
return nil, nil, errors.New("server handshake called using client transport credentials")
}
// Connect to the S2A.
hsConn, err := service.Dial(c.s2aAddr)
if err != nil {
grpclog.Infof("Failed to connect to S2A: %v", err)
return nil, nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
opts := &handshaker.ServerHandshakerOptions{
MinTLSVersion: c.minTLSVersion,
MaxTLSVersion: c.maxTLSVersion,
TLSCiphersuites: c.tlsCiphersuites,
LocalIdentities: c.localIdentities,
}
shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
if err != nil {
grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
return nil, nil, err
}
defer func() {
if err != nil {
if closeErr := shs.Close(); closeErr != nil {
grpclog.Infof("Close failed unexpectedly: %v", err)
err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
}
}
}()
secConn, authInfo, err := shs.ServerHandshake(context.Background())
if err != nil {
grpclog.Infof("Handshake failed: %v", err)
return nil, nil, err
}
return secConn, authInfo, nil
}
func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
return *c.info
}
func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
info := *c.info
var localIdentity *commonpb.Identity
if c.localIdentity != nil {
localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
}
var localIdentities []*commonpb.Identity
if c.localIdentities != nil {
localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
for i, localIdentity := range c.localIdentities {
localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
}
}
var targetIdentities []*commonpb.Identity
if c.targetIdentities != nil {
targetIdentities = make([]*commonpb.Identity, len(c.targetIdentities))
for i, targetIdentity := range c.targetIdentities {
targetIdentities[i] = proto.Clone(targetIdentity).(*commonpb.Identity)
}
}
return &s2aTransportCreds{
info: &info,
minTLSVersion: c.minTLSVersion,
maxTLSVersion: c.maxTLSVersion,
tlsCiphersuites: c.tlsCiphersuites,
localIdentity: localIdentity,
localIdentities: localIdentities,
targetIdentities: targetIdentities,
isClient: c.isClient,
s2aAddr: c.s2aAddr,
}
}
func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
c.info.ServerName = serverNameOverride
return nil
}
// TLSClientConfigOptions specifies parameters for creating client TLS config.
type TLSClientConfigOptions struct {
// ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
// tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
// ServerName: "example.com",
// })
ServerName string
}
// TLSClientConfigFactory defines the interface for a client TLS config factory.
type TLSClientConfigFactory interface {
Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
}
// NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
if opts == nil {
return nil, fmt.Errorf("opts must be non-nil")
}
if opts.EnableLegacyMode {
return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
}
tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
if err != nil {
// The only possible error is: access token not set in the environment,
// which is okay in environments other than serverless.
grpclog.Infof("Access token manager not initialized: %v", err)
return &s2aTLSClientConfigFactory{
s2av2Address: opts.S2AAddress,
tokenManager: nil,
verificationMode: getVerificationMode(opts.VerificationMode),
serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
}, nil
}
return &s2aTLSClientConfigFactory{
s2av2Address: opts.S2AAddress,
tokenManager: tokenManager,
verificationMode: getVerificationMode(opts.VerificationMode),
serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
}, nil
}
type s2aTLSClientConfigFactory struct {
s2av2Address string
tokenManager tokenmanager.AccessTokenManager
verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
serverAuthorizationPolicy []byte
}
func (f *s2aTLSClientConfigFactory) Build(
ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
serverName := ""
if opts != nil && opts.ServerName != "" {
serverName = opts.ServerName
}
return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
}
func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
switch verificationMode {
case ConnectToGoogle:
return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
case Spiffe:
return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
default:
return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
}
}
// NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
// Example use with http.RoundTripper:
//
// dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
// S2AAddress: s2aAddress, // required
// })
// transport := http.DefaultTransport
// transport.DialTLSContext = dialTLSContext
func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
fallback := func(err error) (net.Conn, error) {
if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
fbDialer := opts.FallbackOpts.FallbackDialer
grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
if fbErr != nil {
return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
}
return fbConn, nil
}
return nil, err
}
factory, err := NewTLSClientConfigFactory(opts)
if err != nil {
grpclog.Infof("error creating S2A client config factory: %v", err)
return fallback(err)
}
serverName, _, err := net.SplitHostPort(addr)
if err != nil {
serverName = addr
}
timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
defer cancel()
s2aTLSConfig, err := factory.Build(timeoutCtx, &TLSClientConfigOptions{
ServerName: serverName,
})
if err != nil {
grpclog.Infof("error building S2A TLS config: %v", err)
return fallback(err)
}
s2aDialer := &tls.Dialer{
Config: s2aTLSConfig,
}
c, err := s2aDialer.DialContext(ctx, network, addr)
if err != nil {
grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
return fallback(err)
}
grpclog.Infof("success dialing MTLS to %s with S2A", addr)
return c, nil
}
}