forked from TrueCloudLab/distribution
Simplify configuration and transport
Repository creation now just takes in an http.RoundTripper. Authenticated requests or requests which require additional headers should use the NewTransport function along with a request modifier (such an an authentication handler). Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)
This commit is contained in:
parent
8b0ea19d39
commit
89c396e0f5
6 changed files with 95 additions and 154 deletions
|
@ -124,13 +124,8 @@ func TestUploadReadFrom(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
repoConfig := &RepositoryConfig{}
|
|
||||||
client, err := repoConfig.HTTPClient()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error creating client: %s", err)
|
|
||||||
}
|
|
||||||
layerUpload := &httpLayerUpload{
|
layerUpload := &httpLayerUpload{
|
||||||
client: client,
|
client: &http.Client{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Valid case
|
// Valid case
|
||||||
|
|
|
@ -20,7 +20,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewRepository creates a new Repository for the given repository name and endpoint
|
// NewRepository creates a new Repository for the given repository name and endpoint
|
||||||
func NewRepository(ctx context.Context, name, endpoint string, repoConfig *RepositoryConfig) (distribution.Repository, error) {
|
func NewRepository(ctx context.Context, name, endpoint string, transport http.RoundTripper) (distribution.Repository, error) {
|
||||||
if err := v2.ValidateRespositoryName(name); err != nil {
|
if err := v2.ValidateRespositoryName(name); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -30,9 +30,10 @@ func NewRepository(ctx context.Context, name, endpoint string, repoConfig *Repos
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := repoConfig.HTTPClient()
|
client := &http.Client{
|
||||||
if err != nil {
|
Transport: transport,
|
||||||
return nil, err
|
Timeout: 1 * time.Minute,
|
||||||
|
// TODO(dmcgowan): create cookie jar
|
||||||
}
|
}
|
||||||
|
|
||||||
return &repository{
|
return &repository{
|
||||||
|
|
|
@ -97,7 +97,7 @@ func TestLayerFetch(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepository(context.Background(), "test.example.com/repo1", e, &RepositoryConfig{})
|
r, err := NewRepository(context.Background(), "test.example.com/repo1", e, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ func TestLayerExists(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepository(context.Background(), "test.example.com/repo1", e, &RepositoryConfig{})
|
r, err := NewRepository(context.Background(), "test.example.com/repo1", e, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -227,7 +227,7 @@ func TestLayerUploadChunked(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
|
r, err := NewRepository(context.Background(), repo, e, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -334,7 +334,7 @@ func TestLayerUploadMonolithic(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
|
r, err := NewRepository(context.Background(), repo, e, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -475,7 +475,7 @@ func TestManifestFetch(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
|
r, err := NewRepository(context.Background(), repo, e, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -508,7 +508,7 @@ func TestManifestFetchByTag(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
|
r, err := NewRepository(context.Background(), repo, e, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -553,7 +553,7 @@ func TestManifestDelete(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
|
r, err := NewRepository(context.Background(), repo, e, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -591,7 +591,7 @@ func TestManifestPut(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
|
r, err := NewRepository(context.Background(), repo, e, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -636,7 +636,7 @@ func TestManifestTags(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
|
r, err := NewRepository(context.Background(), repo, e, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,12 +11,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authorizer is used to apply Authorization to an HTTP request
|
|
||||||
type Authorizer interface {
|
|
||||||
// Authorizer updates an HTTP request with the needed authorization
|
|
||||||
Authorize(req *http.Request) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthenticationHandler is an interface for authorizing a request from
|
// AuthenticationHandler is an interface for authorizing a request from
|
||||||
// params from a "WWW-Authenicate" header for a single scheme.
|
// params from a "WWW-Authenicate" header for a single scheme.
|
||||||
type AuthenticationHandler interface {
|
type AuthenticationHandler interface {
|
||||||
|
@ -31,54 +25,11 @@ type CredentialStore interface {
|
||||||
Basic(*url.URL) (string, string)
|
Basic(*url.URL) (string, string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RepositoryConfig holds the base configuration needed to communicate
|
|
||||||
// with a registry including a method of authorization and HTTP headers.
|
|
||||||
type RepositoryConfig struct {
|
|
||||||
Header http.Header
|
|
||||||
AuthSource Authorizer
|
|
||||||
|
|
||||||
BaseTransport http.RoundTripper
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTPClient returns a new HTTP client configured for this configuration
|
|
||||||
func (rc *RepositoryConfig) HTTPClient() (*http.Client, error) {
|
|
||||||
transport := &Transport{
|
|
||||||
ExtraHeader: rc.Header,
|
|
||||||
AuthSource: rc.AuthSource,
|
|
||||||
Base: rc.BaseTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{
|
|
||||||
Transport: transport,
|
|
||||||
}
|
|
||||||
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTokenAuthorizer returns an authorizer which is capable of getting a token
|
|
||||||
// from a token server. The expected authorization method will be discovered
|
|
||||||
// by the authorizer, getting the token server endpoint from the URL being
|
|
||||||
// requested. Basic authentication may either be done to the token source or
|
|
||||||
// directly with the requested endpoint depending on the endpoint's
|
|
||||||
// WWW-Authenticate header.
|
|
||||||
func NewTokenAuthorizer(creds CredentialStore, transport http.RoundTripper, header http.Header, scope TokenScope) Authorizer {
|
|
||||||
return &tokenAuthorizer{
|
|
||||||
header: header,
|
|
||||||
challenges: map[string]map[string]authorizationChallenge{},
|
|
||||||
handlers: []AuthenticationHandler{
|
|
||||||
NewTokenHandler(transport, creds, scope, header),
|
|
||||||
NewBasicHandler(creds),
|
|
||||||
},
|
|
||||||
transport: transport,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAuthorizer creates an authorizer which can handle multiple authentication
|
// NewAuthorizer creates an authorizer which can handle multiple authentication
|
||||||
// schemes. The handlers are tried in order, the higher priority authentication
|
// schemes. The handlers are tried in order, the higher priority authentication
|
||||||
// methods should be first.
|
// methods should be first.
|
||||||
func NewAuthorizer(transport http.RoundTripper, header http.Header, handlers ...AuthenticationHandler) Authorizer {
|
func NewAuthorizer(transport http.RoundTripper, handlers ...AuthenticationHandler) RequestModifier {
|
||||||
return &tokenAuthorizer{
|
return &tokenAuthorizer{
|
||||||
header: header,
|
|
||||||
challenges: map[string]map[string]authorizationChallenge{},
|
challenges: map[string]map[string]authorizationChallenge{},
|
||||||
handlers: handlers,
|
handlers: handlers,
|
||||||
transport: transport,
|
transport: transport,
|
||||||
|
@ -86,7 +37,6 @@ func NewAuthorizer(transport http.RoundTripper, header http.Header, handlers ...
|
||||||
}
|
}
|
||||||
|
|
||||||
type tokenAuthorizer struct {
|
type tokenAuthorizer struct {
|
||||||
header http.Header
|
|
||||||
challenges map[string]map[string]authorizationChallenge
|
challenges map[string]map[string]authorizationChallenge
|
||||||
handlers []AuthenticationHandler
|
handlers []AuthenticationHandler
|
||||||
transport http.RoundTripper
|
transport http.RoundTripper
|
||||||
|
@ -99,10 +49,7 @@ func (ta *tokenAuthorizer) ping(endpoint string) (map[string]authorizationChalle
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Transport: &Transport{
|
Transport: ta.transport,
|
||||||
ExtraHeader: ta.header,
|
|
||||||
Base: ta.transport,
|
|
||||||
},
|
|
||||||
// Ping should fail fast
|
// Ping should fail fast
|
||||||
Timeout: 5 * time.Second,
|
Timeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
|
@ -140,7 +87,7 @@ HeaderLoop:
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ta *tokenAuthorizer) Authorize(req *http.Request) error {
|
func (ta *tokenAuthorizer) ModifyRequest(req *http.Request) error {
|
||||||
v2Root := strings.Index(req.URL.Path, "/v2/")
|
v2Root := strings.Index(req.URL.Path, "/v2/")
|
||||||
if v2Root == -1 {
|
if v2Root == -1 {
|
||||||
return nil
|
return nil
|
||||||
|
@ -195,54 +142,52 @@ type TokenScope struct {
|
||||||
Actions []string
|
Actions []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTokenHandler creates a new AuthenicationHandler which supports
|
|
||||||
// fetching tokens from a remote token server.
|
|
||||||
func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope TokenScope, header http.Header) AuthenticationHandler {
|
|
||||||
return &tokenHandler{
|
|
||||||
header: header,
|
|
||||||
creds: creds,
|
|
||||||
scope: scope,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts TokenScope) String() string {
|
func (ts TokenScope) String() string {
|
||||||
return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ","))
|
return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ","))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *tokenHandler) client() *http.Client {
|
// NewTokenHandler creates a new AuthenicationHandler which supports
|
||||||
return &http.Client{
|
// fetching tokens from a remote token server.
|
||||||
Transport: &Transport{
|
func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope TokenScope) AuthenticationHandler {
|
||||||
ExtraHeader: ts.header,
|
return &tokenHandler{
|
||||||
Base: ts.transport,
|
transport: transport,
|
||||||
},
|
creds: creds,
|
||||||
|
scope: scope,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *tokenHandler) Scheme() string {
|
func (th *tokenHandler) client() *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: th.transport,
|
||||||
|
Timeout: 15 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (th *tokenHandler) Scheme() string {
|
||||||
return "bearer"
|
return "bearer"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error {
|
func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error {
|
||||||
if err := ts.refreshToken(params); err != nil {
|
if err := th.refreshToken(params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.tokenCache))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.tokenCache))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *tokenHandler) refreshToken(params map[string]string) error {
|
func (th *tokenHandler) refreshToken(params map[string]string) error {
|
||||||
ts.tokenLock.Lock()
|
th.tokenLock.Lock()
|
||||||
defer ts.tokenLock.Unlock()
|
defer th.tokenLock.Unlock()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if now.After(ts.tokenExpiration) {
|
if now.After(th.tokenExpiration) {
|
||||||
token, err := ts.fetchToken(params)
|
token, err := th.fetchToken(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ts.tokenCache = token
|
th.tokenCache = token
|
||||||
ts.tokenExpiration = now.Add(time.Minute)
|
th.tokenExpiration = now.Add(time.Minute)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -252,7 +197,7 @@ type tokenResponse struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err error) {
|
func (th *tokenHandler) fetchToken(params map[string]string) (token string, err error) {
|
||||||
//log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, ta.auth.Username)
|
//log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, ta.auth.Username)
|
||||||
realm, ok := params["realm"]
|
realm, ok := params["realm"]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -273,7 +218,7 @@ func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err
|
||||||
|
|
||||||
reqParams := req.URL.Query()
|
reqParams := req.URL.Query()
|
||||||
service := params["service"]
|
service := params["service"]
|
||||||
scope := ts.scope.String()
|
scope := th.scope.String()
|
||||||
|
|
||||||
if service != "" {
|
if service != "" {
|
||||||
reqParams.Add("service", service)
|
reqParams.Add("service", service)
|
||||||
|
@ -283,8 +228,8 @@ func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err
|
||||||
reqParams.Add("scope", scopeField)
|
reqParams.Add("scope", scopeField)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ts.creds != nil {
|
if th.creds != nil {
|
||||||
username, password := ts.creds.Basic(realmURL)
|
username, password := th.creds.Basic(realmURL)
|
||||||
if username != "" && password != "" {
|
if username != "" && password != "" {
|
||||||
reqParams.Add("account", username)
|
reqParams.Add("account", username)
|
||||||
req.SetBasicAuth(username, password)
|
req.SetBasicAuth(username, password)
|
||||||
|
@ -293,7 +238,7 @@ func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err
|
||||||
|
|
||||||
req.URL.RawQuery = reqParams.Encode()
|
req.URL.RawQuery = reqParams.Encode()
|
||||||
|
|
||||||
resp, err := ts.client().Do(req)
|
resp, err := th.client().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,14 +116,8 @@ func TestEndpointAuthorizeToken(t *testing.T) {
|
||||||
e, c := testServerWithAuth(m, authenicate, validCheck)
|
e, c := testServerWithAuth(m, authenicate, validCheck)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
repo1Config := &RepositoryConfig{
|
transport1 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, nil, tokenScope1)))
|
||||||
AuthSource: NewTokenAuthorizer(nil, nil, nil, tokenScope1),
|
client := &http.Client{Transport: transport1}
|
||||||
}
|
|
||||||
|
|
||||||
client, err := repo1Config.HTTPClient()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error creating http client: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
|
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
|
@ -141,13 +135,8 @@ func TestEndpointAuthorizeToken(t *testing.T) {
|
||||||
e2, c2 := testServerWithAuth(m, authenicate, badCheck)
|
e2, c2 := testServerWithAuth(m, authenicate, badCheck)
|
||||||
defer c2()
|
defer c2()
|
||||||
|
|
||||||
repo2Config := &RepositoryConfig{
|
transport2 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, nil, tokenScope2)))
|
||||||
AuthSource: NewTokenAuthorizer(nil, nil, nil, tokenScope2),
|
client2 := &http.Client{Transport: transport2}
|
||||||
}
|
|
||||||
client2, err := repo2Config.HTTPClient()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error creating http client: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, _ = http.NewRequest("GET", e2+"/v2/hello", nil)
|
req, _ = http.NewRequest("GET", e2+"/v2/hello", nil)
|
||||||
resp, err = client2.Do(req)
|
resp, err = client2.Do(req)
|
||||||
|
@ -220,14 +209,9 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) {
|
||||||
username: username,
|
username: username,
|
||||||
password: password,
|
password: password,
|
||||||
}
|
}
|
||||||
repoConfig := &RepositoryConfig{
|
|
||||||
AuthSource: NewTokenAuthorizer(creds, nil, nil, tokenScope),
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := repoConfig.HTTPClient()
|
transport1 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, creds, tokenScope), NewBasicHandler(creds)))
|
||||||
if err != nil {
|
client := &http.Client{Transport: transport1}
|
||||||
t.Fatalf("Error creating http client: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
|
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
|
@ -265,14 +249,9 @@ func TestEndpointAuthorizeBasic(t *testing.T) {
|
||||||
username: username,
|
username: username,
|
||||||
password: password,
|
password: password,
|
||||||
}
|
}
|
||||||
repoConfig := &RepositoryConfig{
|
|
||||||
AuthSource: NewTokenAuthorizer(creds, nil, nil, TokenScope{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := repoConfig.HTTPClient()
|
transport1 := NewTransport(nil, NewAuthorizer(nil, NewBasicHandler(creds)))
|
||||||
if err != nil {
|
client := &http.Client{Transport: transport1}
|
||||||
t.Fatalf("Error creating http client: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
|
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
|
|
|
@ -6,14 +6,36 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Transport is an http.RoundTripper that makes registry HTTP requests,
|
type RequestModifier interface {
|
||||||
// wrapping a base RoundTripper and adding an Authorization header
|
ModifyRequest(*http.Request) error
|
||||||
// from an Auth source
|
}
|
||||||
type Transport struct {
|
|
||||||
AuthSource Authorizer
|
|
||||||
ExtraHeader http.Header
|
|
||||||
|
|
||||||
Base http.RoundTripper
|
type headerModifier http.Header
|
||||||
|
|
||||||
|
func NewHeaderRequestModifier(header http.Header) RequestModifier {
|
||||||
|
return headerModifier(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h headerModifier) ModifyRequest(req *http.Request) error {
|
||||||
|
for k, s := range http.Header(h) {
|
||||||
|
req.Header[k] = append(req.Header[k], s...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper {
|
||||||
|
return &transport{
|
||||||
|
Modifiers: modifiers,
|
||||||
|
Base: base,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// transport is an http.RoundTripper that makes HTTP requests after
|
||||||
|
// copying and modifying the request
|
||||||
|
type transport struct {
|
||||||
|
Modifiers []RequestModifier
|
||||||
|
Base http.RoundTripper
|
||||||
|
|
||||||
mu sync.Mutex // guards modReq
|
mu sync.Mutex // guards modReq
|
||||||
modReq map[*http.Request]*http.Request // original -> modified
|
modReq map[*http.Request]*http.Request // original -> modified
|
||||||
|
@ -22,13 +44,14 @@ type Transport struct {
|
||||||
// RoundTrip authorizes and authenticates the request with an
|
// RoundTrip authorizes and authenticates the request with an
|
||||||
// access token. If no token exists or token is expired,
|
// access token. If no token exists or token is expired,
|
||||||
// tries to refresh/fetch a new token.
|
// tries to refresh/fetch a new token.
|
||||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
req2 := t.cloneRequest(req)
|
req2 := cloneRequest(req)
|
||||||
if t.AuthSource != nil {
|
for _, modifier := range t.Modifiers {
|
||||||
if err := t.AuthSource.Authorize(req2); err != nil {
|
if err := modifier.ModifyRequest(req2); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.setModReq(req, req2)
|
t.setModReq(req, req2)
|
||||||
res, err := t.base().RoundTrip(req2)
|
res, err := t.base().RoundTrip(req2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -43,7 +66,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CancelRequest cancels an in-flight request by closing its connection.
|
// CancelRequest cancels an in-flight request by closing its connection.
|
||||||
func (t *Transport) CancelRequest(req *http.Request) {
|
func (t *transport) CancelRequest(req *http.Request) {
|
||||||
type canceler interface {
|
type canceler interface {
|
||||||
CancelRequest(*http.Request)
|
CancelRequest(*http.Request)
|
||||||
}
|
}
|
||||||
|
@ -56,14 +79,14 @@ func (t *Transport) CancelRequest(req *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) base() http.RoundTripper {
|
func (t *transport) base() http.RoundTripper {
|
||||||
if t.Base != nil {
|
if t.Base != nil {
|
||||||
return t.Base
|
return t.Base
|
||||||
}
|
}
|
||||||
return http.DefaultTransport
|
return http.DefaultTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) setModReq(orig, mod *http.Request) {
|
func (t *transport) setModReq(orig, mod *http.Request) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
if t.modReq == nil {
|
if t.modReq == nil {
|
||||||
|
@ -78,7 +101,7 @@ func (t *Transport) setModReq(orig, mod *http.Request) {
|
||||||
|
|
||||||
// cloneRequest returns a clone of the provided *http.Request.
|
// cloneRequest returns a clone of the provided *http.Request.
|
||||||
// The clone is a shallow copy of the struct and its Header map.
|
// The clone is a shallow copy of the struct and its Header map.
|
||||||
func (t *Transport) cloneRequest(r *http.Request) *http.Request {
|
func cloneRequest(r *http.Request) *http.Request {
|
||||||
// shallow copy of the struct
|
// shallow copy of the struct
|
||||||
r2 := new(http.Request)
|
r2 := new(http.Request)
|
||||||
*r2 = *r
|
*r2 = *r
|
||||||
|
@ -87,9 +110,7 @@ func (t *Transport) cloneRequest(r *http.Request) *http.Request {
|
||||||
for k, s := range r.Header {
|
for k, s := range r.Header {
|
||||||
r2.Header[k] = append([]string(nil), s...)
|
r2.Header[k] = append([]string(nil), s...)
|
||||||
}
|
}
|
||||||
for k, s := range t.ExtraHeader {
|
|
||||||
r2.Header[k] = append(r2.Header[k], s...)
|
|
||||||
}
|
|
||||||
return r2
|
return r2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue