Dep helper (#2151)

* Add dep task to update go dependencies

* Update go dependencies
This commit is contained in:
Manuel Alejandro de Brito Fontes 2018-09-29 19:47:07 -03:00 committed by Miek Gieben
parent 8f8b81f56b
commit 0e8977761d
764 changed files with 172 additions and 267451 deletions

View file

@ -1,353 +0,0 @@
package awsutil_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awsutil"
)
func ExampleCopy() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Print the result
fmt.Println(awsutil.Prettify(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}
func TestCopy1(t *testing.T) {
type Bar struct {
a *int
B *int
c int
D int
}
type Foo struct {
A int
B []*string
C map[string]*int
D *time.Time
E *Bar
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
int1 := 1
int2 := 2
intPtr1 := 1
intPtr2 := 2
now := time.Now()
f1 := &Foo{
A: 1,
B: []*string{&str1, &str2},
C: map[string]*int{
"A": &int1,
"B": &int2,
},
D: &now,
E: &Bar{
&intPtr1,
&intPtr2,
2,
3,
},
}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Values are equal
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.C, f1.C; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.D, f1.D; !v1.Equal(*v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.E.B, f1.E.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.E.D, f1.E.D; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
// But pointers are not!
str3 := "nothello"
int3 := 57
f2.A = 100
*f2.B[0] = str3
*f2.C["B"] = int3
*f2.D = time.Now()
f2.E.a = &int3
*f2.E.B = int3
f2.E.c = 5
f2.E.D = 5
if v1, v2 := f2.A, f1.A; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B, f1.B; reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.C, f1.C; reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.D, f1.D; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.a, f1.E.a; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.B, f1.E.B; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.c, f1.E.c; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.D, f1.E.D; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
}
func TestCopyNestedWithUnexported(t *testing.T) {
type Bar struct {
a int
B int
}
type Foo struct {
A string
B Bar
}
f1 := &Foo{A: "string", B: Bar{a: 1, B: 2}}
var f2 Foo
awsutil.Copy(&f2, f1)
// Values match
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B.a, f1.B.a; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B.B, f2.B.B; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyIgnoreNilMembers(t *testing.T) {
type Foo struct {
A *string
B []string
C map[string]string
}
f := &Foo{}
if v1 := f.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
var f2 Foo
awsutil.Copy(&f2, f)
if v1 := f2.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f2.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f2.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
fcopy := awsutil.CopyOf(f)
f3 := fcopy.(*Foo)
if v1 := f3.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f3.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f3.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
}
func TestCopyPrimitive(t *testing.T) {
str := "hello"
var s string
awsutil.Copy(&s, &str)
if v1, v2 := "hello", s; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyNil(t *testing.T) {
var s string
awsutil.Copy(&s, nil)
if v1, v2 := "", s; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyReader(t *testing.T) {
var buf io.Reader = bytes.NewReader([]byte("hello world"))
var r io.Reader
awsutil.Copy(&r, buf)
b, err := ioutil.ReadAll(r)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if v1, v2 := []byte("hello world"), b; !bytes.Equal(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
// empty bytes because this is not a deep copy
b, err = ioutil.ReadAll(buf)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if v1, v2 := []byte(""), b; !bytes.Equal(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyDifferentStructs(t *testing.T) {
type SrcFoo struct {
A int
B []*string
C map[string]*int
SrcUnique string
SameNameDiffType int
unexportedPtr *int
ExportedPtr *int
}
type DstFoo struct {
A int
B []*string
C map[string]*int
DstUnique int
SameNameDiffType string
unexportedPtr *int
ExportedPtr *int
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
int1 := 1
int2 := 2
f1 := &SrcFoo{
A: 1,
B: []*string{&str1, &str2},
C: map[string]*int{
"A": &int1,
"B": &int2,
},
SrcUnique: "unique",
SameNameDiffType: 1,
unexportedPtr: &int1,
ExportedPtr: &int2,
}
// Do the copy
var f2 DstFoo
awsutil.Copy(&f2, f1)
// Values are equal
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.C, f1.C; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := "unique", f1.SrcUnique; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := 1, f1.SameNameDiffType; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := 0, f2.DstUnique; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := "", f2.SameNameDiffType; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := int1, *f1.unexportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1 := f2.unexportedPtr; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1, v2 := int2, *f1.ExportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := int2, *f2.ExportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func ExampleCopyOf() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
v := awsutil.CopyOf(f1)
var f2 *Foo = v.(*Foo)
// Print the result
fmt.Println(awsutil.Prettify(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}

View file

@ -1,30 +0,0 @@
package awsutil_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
)
func TestDeepEqual(t *testing.T) {
cases := []struct {
a, b interface{}
equal bool
}{
{"a", "a", true},
{"a", "b", false},
{"a", aws.String(""), false},
{"a", nil, false},
{"a", aws.String("a"), true},
{(*bool)(nil), (*bool)(nil), true},
{(*bool)(nil), (*string)(nil), false},
{nil, nil, true},
}
for i, c := range cases {
if awsutil.DeepEqual(c.a, c.b) != c.equal {
t.Errorf("%d, a:%v b:%v, %t", i, c.a, c.b, c.equal)
}
}
}

View file

@ -1,182 +0,0 @@
package awsutil_test
import (
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws/awsutil"
)
type Struct struct {
A []Struct
z []Struct
B *Struct
D *Struct
C string
E map[string]string
}
var data = Struct{
A: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
z: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
B: &Struct{B: &Struct{C: "terminal"}, D: &Struct{C: "terminal2"}},
C: "initial",
}
var data2 = Struct{A: []Struct{
{A: []Struct{{C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}}},
{A: []Struct{{C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}}},
}}
func TestValueAtPathSuccess(t *testing.T) {
var testCases = []struct {
expect []interface{}
data interface{}
path string
}{
{[]interface{}{"initial"}, data, "C"},
{[]interface{}{"value1"}, data, "A[0].C"},
{[]interface{}{"value2"}, data, "A[1].C"},
{[]interface{}{"value3"}, data, "A[2].C"},
{[]interface{}{"value3"}, data, "a[2].c"},
{[]interface{}{"value3"}, data, "A[-1].C"},
{[]interface{}{"value1", "value2", "value3"}, data, "A[].C"},
{[]interface{}{"terminal"}, data, "B . B . C"},
{[]interface{}{"initial"}, data, "A.D.X || C"},
{[]interface{}{"initial"}, data, "A[0].B || C"},
{[]interface{}{
Struct{A: []Struct{{C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}}},
Struct{A: []Struct{{C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}}},
}, data2, "A"},
}
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
if err != nil {
t.Errorf("case %v, expected no error, %v", i, c.path)
}
if e, a := c.expect, v; !awsutil.DeepEqual(e, a) {
t.Errorf("case %v, %v", i, c.path)
}
}
}
func TestValueAtPathFailure(t *testing.T) {
var testCases = []struct {
expect []interface{}
errContains string
data interface{}
path string
}{
{nil, "", data, "C.x"},
{nil, "SyntaxError: Invalid token: tDot", data, ".x"},
{nil, "", data, "X.Y.Z"},
{nil, "", data, "A[100].C"},
{nil, "", data, "A[3].C"},
{nil, "", data, "B.B.C.Z"},
{nil, "", data, "z[-1].C"},
{nil, "", nil, "A.B.C"},
{[]interface{}{}, "", Struct{}, "A"},
{nil, "", data, "A[0].B.C"},
{nil, "", data, "D"},
}
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
if c.errContains != "" {
if !strings.Contains(err.Error(), c.errContains) {
t.Errorf("case %v, expected error, %v", i, c.path)
}
continue
} else {
if err != nil {
t.Errorf("case %v, expected no error, %v", i, c.path)
}
}
if e, a := c.expect, v; !awsutil.DeepEqual(e, a) {
t.Errorf("case %v, %v", i, c.path)
}
}
}
func TestSetValueAtPathSuccess(t *testing.T) {
var s Struct
awsutil.SetValueAtPath(&s, "C", "test1")
awsutil.SetValueAtPath(&s, "B.B.C", "test2")
awsutil.SetValueAtPath(&s, "B.D.C", "test3")
if e, a := "test1", s.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test2", s.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test3", s.B.D.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
awsutil.SetValueAtPath(&s, "B.*.C", "test0")
if e, a := "test0", s.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test0", s.B.D.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
var s2 Struct
awsutil.SetValueAtPath(&s2, "b.b.c", "test0")
if e, a := "test0", s2.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
awsutil.SetValueAtPath(&s2, "A", []Struct{{}})
if e, a := []Struct{{}}, s2.A; !awsutil.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
str := "foo"
s3 := Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", str)
if e, a := "foo", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{B: &Struct{B: &Struct{C: str}}}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
if e, a := "", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
if e, a := "", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", &str)
if e, a := "foo", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
var s4 struct{ Name *string }
awsutil.SetValueAtPath(&s4, "Name", str)
if e, a := str, *s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", nil)
if e, a := (*string)(nil), s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{Name: &str}
awsutil.SetValueAtPath(&s4, "Name", nil)
if e, a := (*string)(nil), s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", &str)
if e, a := str, *s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
}

View file

@ -1,78 +0,0 @@
package client
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
)
func pushBackTestHandler(name string, list *request.HandlerList) *bool {
called := false
(*list).PushBackNamed(request.NamedHandler{
Name: name,
Fn: func(r *request.Request) {
called = true
},
})
return &called
}
func pushFrontTestHandler(name string, list *request.HandlerList) *bool {
called := false
(*list).PushFrontNamed(request.NamedHandler{
Name: name,
Fn: func(r *request.Request) {
called = true
},
})
return &called
}
func TestNewClient_CopyHandlers(t *testing.T) {
handlers := request.Handlers{}
firstCalled := pushBackTestHandler("first", &handlers.Send)
secondCalled := pushBackTestHandler("second", &handlers.Send)
var clientHandlerCalled *bool
c := New(aws.Config{}, metadata.ClientInfo{}, handlers,
func(c *Client) {
clientHandlerCalled = pushFrontTestHandler("client handler", &c.Handlers.Send)
},
)
if e, a := 2, handlers.Send.Len(); e != a {
t.Errorf("expect %d original handlers, got %d", e, a)
}
if e, a := 3, c.Handlers.Send.Len(); e != a {
t.Errorf("expect %d client handlers, got %d", e, a)
}
handlers.Send.Run(nil)
if !*firstCalled {
t.Errorf("expect first handler to of been called")
}
*firstCalled = false
if !*secondCalled {
t.Errorf("expect second handler to of been called")
}
*secondCalled = false
if *clientHandlerCalled {
t.Errorf("expect client handler to not of been called, but was")
}
c.Handlers.Send.Run(nil)
if !*firstCalled {
t.Errorf("expect client's first handler to of been called")
}
if !*secondCalled {
t.Errorf("expect client's second handler to of been called")
}
if !*clientHandlerCalled {
t.Errorf("expect client's client handler to of been called")
}
}

View file

@ -1,189 +0,0 @@
package client
import (
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestRetryThrottleStatusCodes(t *testing.T) {
cases := []struct {
expectThrottle bool
expectRetry bool
r request.Request
}{
{
false,
false,
request.Request{
HTTPResponse: &http.Response{StatusCode: 200},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 429},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 502},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 503},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 504},
},
},
{
false,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 500},
},
},
}
d := DefaultRetryer{NumMaxRetries: 10}
for i, c := range cases {
throttle := d.shouldThrottle(&c.r)
retry := d.ShouldRetry(&c.r)
if e, a := c.expectThrottle, throttle; e != a {
t.Errorf("%d: expected %v, but received %v", i, e, a)
}
if e, a := c.expectRetry, retry; e != a {
t.Errorf("%d: expected %v, but received %v", i, e, a)
}
}
}
func TestCanUseRetryAfter(t *testing.T) {
cases := []struct {
r request.Request
e bool
}{
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 200},
},
false,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 500},
},
false,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 429},
},
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503},
},
true,
},
}
for i, c := range cases {
a := canUseRetryAfterHeader(&c.r)
if c.e != a {
t.Errorf("%d: expected %v, but received %v", i, c.e, a)
}
}
}
func TestGetRetryDelay(t *testing.T) {
cases := []struct {
r request.Request
e time.Duration
equal bool
ok bool
}{
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 429, Header: http.Header{"Retry-After": []string{"3600"}}},
},
3600 * time.Second,
true,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}},
},
120 * time.Second,
true,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}},
},
1 * time.Second,
false,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{""}}},
},
0 * time.Second,
true,
false,
},
}
for i, c := range cases {
a, ok := getRetryDelay(&c.r)
if c.ok != ok {
t.Errorf("%d: expected %v, but received %v", i, c.ok, ok)
}
if (c.e != a) == c.equal {
t.Errorf("%d: expected %v, but received %v", i, c.e, a)
}
}
}
func TestRetryDelay(t *testing.T) {
r := request.Request{}
for i := 0; i < 100; i++ {
rTemp := r
rTemp.HTTPResponse = &http.Response{StatusCode: 500, Header: http.Header{"Retry-After": []string{""}}}
rTemp.RetryCount = i
a, _ := getRetryDelay(&rTemp)
if a > 5*time.Minute {
t.Errorf("retry delay should never be greater than five minutes, received %d", a)
}
}
for i := 0; i < 100; i++ {
rTemp := r
rTemp.RetryCount = i
rTemp.HTTPResponse = &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{""}}}
a, _ := getRetryDelay(&rTemp)
if a > 5*time.Minute {
t.Errorf("retry delay should never be greater than five minutes, received %d", a)
}
}
}

View file

@ -1,222 +0,0 @@
package client
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)
type mockCloser struct {
closed bool
}
func (closer *mockCloser) Read(b []byte) (int, error) {
return 0, io.EOF
}
func (closer *mockCloser) Close() error {
closer.closed = true
return nil
}
func TestTeeReaderCloser(t *testing.T) {
expected := "FOO"
buf := bytes.NewBuffer([]byte(expected))
lw := bytes.NewBuffer(nil)
c := &mockCloser{}
closer := teeReaderCloser{
io.TeeReader(buf, lw),
c,
}
b := make([]byte, len(expected))
_, err := closer.Read(b)
closer.Close()
if expected != lw.String() {
t.Errorf("Expected %q, but received %q", expected, lw.String())
}
if err != nil {
t.Errorf("Expected 'nil', but received %v", err)
}
if !c.closed {
t.Error("Expected 'true', but received 'false'")
}
}
func TestLogWriter(t *testing.T) {
expected := "FOO"
lw := &logWriter{nil, bytes.NewBuffer(nil)}
lw.Write([]byte(expected))
if expected != lw.buf.String() {
t.Errorf("Expected %q, but received %q", expected, lw.buf.String())
}
}
func TestLogRequest(t *testing.T) {
cases := []struct {
Body io.ReadSeeker
ExpectBody []byte
LogLevel aws.LogLevelType
}{
{
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
ExpectBody: []byte("body content"),
},
{
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
LogLevel: aws.LogDebugWithHTTPBody,
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewReader([]byte("body content")),
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewReader([]byte("body content")),
LogLevel: aws.LogDebugWithHTTPBody,
ExpectBody: []byte("body content"),
},
}
for i, c := range cases {
logW := bytes.NewBuffer(nil)
req := request.New(
aws.Config{
Credentials: credentials.AnonymousCredentials,
Logger: &bufLogger{w: logW},
LogLevel: aws.LogLevel(c.LogLevel),
},
metadata.ClientInfo{
Endpoint: "https://mock-service.mock-region.amazonaws.com",
},
testHandlers(),
nil,
&request.Operation{
Name: "APIName",
HTTPMethod: "POST",
HTTPPath: "/",
},
struct{}{}, nil,
)
req.SetReaderBody(c.Body)
req.Build()
logRequest(req)
b, err := ioutil.ReadAll(req.HTTPRequest.Body)
if err != nil {
t.Fatalf("%d, expect to read SDK request Body", i)
}
if e, a := c.ExpectBody, b; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %v body, got %v", i, e, a)
}
}
}
func TestLogResponse(t *testing.T) {
cases := []struct {
Body *bytes.Buffer
ExpectBody []byte
ReadBody bool
LogLevel aws.LogLevelType
}{
{
Body: bytes.NewBuffer([]byte("body content")),
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewBuffer([]byte("body content")),
LogLevel: aws.LogDebug,
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewBuffer([]byte("body content")),
LogLevel: aws.LogDebugWithHTTPBody,
ReadBody: true,
ExpectBody: []byte("body content"),
},
}
for i, c := range cases {
var logW bytes.Buffer
req := request.New(
aws.Config{
Credentials: credentials.AnonymousCredentials,
Logger: &bufLogger{w: &logW},
LogLevel: aws.LogLevel(c.LogLevel),
},
metadata.ClientInfo{
Endpoint: "https://mock-service.mock-region.amazonaws.com",
},
testHandlers(),
nil,
&request.Operation{
Name: "APIName",
HTTPMethod: "POST",
HTTPPath: "/",
},
struct{}{}, nil,
)
req.HTTPResponse = &http.Response{
StatusCode: 200,
Status: "OK",
Header: http.Header{
"ABC": []string{"123"},
},
Body: ioutil.NopCloser(c.Body),
}
logResponse(req)
req.Handlers.Unmarshal.Run(req)
if c.ReadBody {
if e, a := len(c.ExpectBody), c.Body.Len(); e != a {
t.Errorf("%d, expect orginal body not to of been read", i)
}
}
if logW.Len() == 0 {
t.Errorf("%d, expect HTTP Response headers to be logged", i)
}
b, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
t.Fatalf("%d, expect to read SDK request Body", i)
}
if e, a := c.ExpectBody, b; !bytes.Equal(e, a) {
t.Errorf("%d, expect %v body, got %v", i, e, a)
}
}
}
type bufLogger struct {
w *bytes.Buffer
}
func (l *bufLogger) Log(args ...interface{}) {
fmt.Fprintln(l.w, args...)
}
func testHandlers() request.Handlers {
var handlers request.Handlers
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
return handlers
}

View file

@ -1,86 +0,0 @@
package aws
import (
"net/http"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws/credentials"
)
var testCredentials = credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
var copyTestConfig = Config{
Credentials: testCredentials,
Endpoint: String("CopyTestEndpoint"),
Region: String("COPY_TEST_AWS_REGION"),
DisableSSL: Bool(true),
HTTPClient: http.DefaultClient,
LogLevel: LogLevel(LogDebug),
Logger: NewDefaultLogger(),
MaxRetries: Int(3),
DisableParamValidation: Bool(true),
DisableComputeChecksums: Bool(true),
S3ForcePathStyle: Bool(true),
}
func TestCopy(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if !reflect.DeepEqual(*got, want) {
t.Errorf("Copy() = %+v", got)
t.Errorf(" want %+v", want)
}
got.Region = String("other")
if got.Region == want.Region {
t.Errorf("Expect setting copy values not not reflect in source")
}
}
func TestCopyReturnsNewInstance(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if got == &want {
t.Errorf("Copy() = %p; want different instance as source %p", got, &want)
}
}
var mergeTestZeroValueConfig = Config{}
var mergeTestConfig = Config{
Credentials: testCredentials,
Endpoint: String("MergeTestEndpoint"),
Region: String("MERGE_TEST_AWS_REGION"),
DisableSSL: Bool(true),
HTTPClient: http.DefaultClient,
LogLevel: LogLevel(LogDebug),
Logger: NewDefaultLogger(),
MaxRetries: Int(10),
DisableParamValidation: Bool(true),
DisableComputeChecksums: Bool(true),
S3ForcePathStyle: Bool(true),
}
var mergeTests = []struct {
cfg *Config
in *Config
want *Config
}{
{&Config{}, nil, &Config{}},
{&Config{}, &mergeTestZeroValueConfig, &Config{}},
{&Config{}, &mergeTestConfig, &mergeTestConfig},
}
func TestMerge(t *testing.T) {
for i, tt := range mergeTests {
got := tt.cfg.Copy()
got.MergeIn(tt.in)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Config %d %+v", i, tt.cfg)
t.Errorf(" Merge(%+v)", tt.in)
t.Errorf(" got %+v", got)
t.Errorf(" want %+v", tt.want)
}
}
}

View file

@ -1,37 +0,0 @@
package aws_test
import (
"fmt"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestSleepWithContext(t *testing.T) {
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
err := aws.SleepWithContext(ctx, 1*time.Millisecond)
if err != nil {
t.Errorf("expect context to not be canceled, got %v", err)
}
}
func TestSleepWithContext_Canceled(t *testing.T) {
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
expectErr := fmt.Errorf("context canceled")
ctx.Error = expectErr
close(ctx.DoneCh)
err := aws.SleepWithContext(ctx, 1*time.Millisecond)
if err == nil {
t.Fatalf("expect error, did not get one")
}
if e, a := expectErr, err; e != a {
t.Errorf("expect %v error, got %v", e, a)
}
}

View file

@ -1,641 +0,0 @@
package aws
import (
"reflect"
"testing"
"time"
)
var testCasesStringSlice = [][]string{
{"a", "b", "c", "d", "e"},
{"a", "b", "", "", "e"},
}
func TestStringSlice(t *testing.T) {
for idx, in := range testCasesStringSlice {
if in == nil {
continue
}
out := StringSlice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := StringValueSlice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
var testCasesStringValueSlice = [][]*string{
{String("a"), String("b"), nil, String("c")},
}
func TestStringValueSlice(t *testing.T) {
for idx, in := range testCasesStringValueSlice {
if in == nil {
continue
}
out := StringValueSlice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
if out[i] != "" {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := StringSlice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
if *(out2[i]) != "" {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := *in[i], *out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
}
var testCasesStringMap = []map[string]string{
{"a": "1", "b": "2", "c": "3"},
}
func TestStringMap(t *testing.T) {
for idx, in := range testCasesStringMap {
if in == nil {
continue
}
out := StringMap(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := StringValueMap(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
var testCasesBoolSlice = [][]bool{
{true, true, false, false},
}
func TestBoolSlice(t *testing.T) {
for idx, in := range testCasesBoolSlice {
if in == nil {
continue
}
out := BoolSlice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := BoolValueSlice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
var testCasesBoolValueSlice = [][]*bool{}
func TestBoolValueSlice(t *testing.T) {
for idx, in := range testCasesBoolValueSlice {
if in == nil {
continue
}
out := BoolValueSlice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
if out[i] {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := BoolSlice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
if *(out2[i]) {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
}
var testCasesBoolMap = []map[string]bool{
{"a": true, "b": false, "c": true},
}
func TestBoolMap(t *testing.T) {
for idx, in := range testCasesBoolMap {
if in == nil {
continue
}
out := BoolMap(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := BoolValueMap(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
var testCasesIntSlice = [][]int{
{1, 2, 3, 4},
}
func TestIntSlice(t *testing.T) {
for idx, in := range testCasesIntSlice {
if in == nil {
continue
}
out := IntSlice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := IntValueSlice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
var testCasesIntValueSlice = [][]*int{}
func TestIntValueSlice(t *testing.T) {
for idx, in := range testCasesIntValueSlice {
if in == nil {
continue
}
out := IntValueSlice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
if out[i] != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := IntSlice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
if *(out2[i]) != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
}
var testCasesIntMap = []map[string]int{
{"a": 3, "b": 2, "c": 1},
}
func TestIntMap(t *testing.T) {
for idx, in := range testCasesIntMap {
if in == nil {
continue
}
out := IntMap(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := IntValueMap(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
var testCasesInt64Slice = [][]int64{
{1, 2, 3, 4},
}
func TestInt64Slice(t *testing.T) {
for idx, in := range testCasesInt64Slice {
if in == nil {
continue
}
out := Int64Slice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Int64ValueSlice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
var testCasesInt64ValueSlice = [][]*int64{}
func TestInt64ValueSlice(t *testing.T) {
for idx, in := range testCasesInt64ValueSlice {
if in == nil {
continue
}
out := Int64ValueSlice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
if out[i] != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := Int64Slice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
if *(out2[i]) != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
}
var testCasesInt64Map = []map[string]int64{
{"a": 3, "b": 2, "c": 1},
}
func TestInt64Map(t *testing.T) {
for idx, in := range testCasesInt64Map {
if in == nil {
continue
}
out := Int64Map(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Int64ValueMap(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
var testCasesFloat64Slice = [][]float64{
{1, 2, 3, 4},
}
func TestFloat64Slice(t *testing.T) {
for idx, in := range testCasesFloat64Slice {
if in == nil {
continue
}
out := Float64Slice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Float64ValueSlice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
var testCasesFloat64ValueSlice = [][]*float64{}
func TestFloat64ValueSlice(t *testing.T) {
for idx, in := range testCasesFloat64ValueSlice {
if in == nil {
continue
}
out := Float64ValueSlice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
if out[i] != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := Float64Slice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
if *(out2[i]) != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
}
var testCasesFloat64Map = []map[string]float64{
{"a": 3, "b": 2, "c": 1},
}
func TestFloat64Map(t *testing.T) {
for idx, in := range testCasesFloat64Map {
if in == nil {
continue
}
out := Float64Map(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Float64ValueMap(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
var testCasesTimeSlice = [][]time.Time{
{time.Now(), time.Now().AddDate(100, 0, 0)},
}
func TestTimeSlice(t *testing.T) {
for idx, in := range testCasesTimeSlice {
if in == nil {
continue
}
out := TimeSlice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := TimeValueSlice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
var testCasesTimeValueSlice = [][]*time.Time{}
func TestTimeValueSlice(t *testing.T) {
for idx, in := range testCasesTimeValueSlice {
if in == nil {
continue
}
out := TimeValueSlice(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
if !out[i].IsZero() {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := TimeSlice(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
if !(*(out2[i])).IsZero() {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
}
var testCasesTimeMap = []map[string]time.Time{
{"a": time.Now().AddDate(-100, 0, 0), "b": time.Now()},
}
func TestTimeMap(t *testing.T) {
for idx, in := range testCasesTimeMap {
if in == nil {
continue
}
out := TimeMap(in)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := TimeValueMap(out)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
type TimeValueTestCase struct {
in int64
outSecs time.Time
outMillis time.Time
}
var testCasesTimeValue = []TimeValueTestCase{
{
in: int64(1501558289000),
outSecs: time.Unix(1501558289, 0),
outMillis: time.Unix(1501558289, 0),
},
{
in: int64(1501558289001),
outSecs: time.Unix(1501558289, 0),
outMillis: time.Unix(1501558289, 1*1000000),
},
}
func TestSecondsTimeValue(t *testing.T) {
for idx, testCase := range testCasesTimeValue {
out := SecondsTimeValue(&testCase.in)
if e, a := testCase.outSecs, out; e != a {
t.Errorf("Unexpected value for time value at %d", idx)
}
}
}
func TestMillisecondsTimeValue(t *testing.T) {
for idx, testCase := range testCasesTimeValue {
out := MillisecondsTimeValue(&testCase.in)
if e, a := testCase.outMillis, out; e != a {
t.Errorf("Unexpected value for time value at %d", idx)
}
}
}

View file

@ -1,64 +0,0 @@
// +build go1.8
package corehandlers_test
import (
"crypto/tls"
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/service/s3"
"golang.org/x/net/http2"
)
func TestSendHandler_HEADNoBody(t *testing.T) {
TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile, err := awstesting.CreateTLSBundleFiles()
if err != nil {
panic(err)
}
defer awstesting.CleanupTLSBundleFiles(TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile)
endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
transport := http.DefaultTransport.(*http.Transport)
// test server's certificate is self-signed certificate
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
http2.ConfigureTransport(transport)
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{
HTTPClient: &http.Client{},
Endpoint: aws.String(endpoint),
Region: aws.String("mock-region"),
Credentials: credentials.AnonymousCredentials,
S3ForcePathStyle: aws.Bool(true),
},
})
svc := s3.New(sess)
req, _ := svc.HeadObjectRequest(&s3.HeadObjectInput{
Bucket: aws.String("bucketname"),
Key: aws.String("keyname"),
})
if e, a := request.NoBody, req.HTTPRequest.Body; e != a {
t.Fatalf("expect %T request body, got %T", e, a)
}
err = req.Send()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := http.StatusOK, req.HTTPResponse.StatusCode; e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
}

View file

@ -1,398 +0,0 @@
package corehandlers_test
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := awstesting.NewClient(aws.NewConfig().WithRegion("us-west-2"))
svc.Handlers.Clear()
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := awstesting.NewClient()
svc.Handlers.Clear()
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
if err == nil {
t.Errorf("expect error, got none")
}
if e, a := aws.ErrMissingRegion, err; e != a {
t.Errorf("expect %v to be %v", e, a)
}
}
type mockCredsProvider struct {
expired bool
retrieveCalled bool
}
func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
m.retrieveCalled = true
return credentials.Value{ProviderName: "mockCredsProvider"}, nil
}
func (m *mockCredsProvider) IsExpired() bool {
return m.expired
}
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
credProvider := &mockCredsProvider{}
svc := awstesting.NewClient(&aws.Config{
Credentials: credentials.NewCredentials(credProvider),
MaxRetries: aws.Int(1),
})
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *request.Request) {
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400, Body: ioutil.NopCloser(bytes.NewBuffer([]byte{}))}
})
svc.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
if !svc.Config.Credentials.IsExpired() {
t.Errorf("Expect to start out expired")
}
if credProvider.retrieveCalled {
t.Errorf("expect not called")
}
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.Send()
if !svc.Config.Credentials.IsExpired() {
t.Errorf("Expect to start out expired")
}
if credProvider.retrieveCalled {
t.Errorf("expect not called")
}
_, err := svc.Config.Credentials.Get()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if !credProvider.retrieveCalled {
t.Errorf("expect not called")
}
}
func TestAfterRetryWithContextCanceled(t *testing.T) {
c := awstesting.NewClient()
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{}, 0)}
req.SetContext(ctx)
req.Error = fmt.Errorf("some error")
req.Retryable = aws.Bool(true)
req.HTTPResponse = &http.Response{
StatusCode: 500,
}
close(ctx.DoneCh)
ctx.Error = fmt.Errorf("context canceled")
corehandlers.AfterRetryHandler.Fn(req)
if req.Error == nil {
t.Fatalf("expect error but didn't receive one")
}
aerr := req.Error.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q, error code got %q", e, a)
}
}
func TestAfterRetryWithContext(t *testing.T) {
c := awstesting.NewClient()
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{}, 0)}
req.SetContext(ctx)
req.Error = fmt.Errorf("some error")
req.Retryable = aws.Bool(true)
req.HTTPResponse = &http.Response{
StatusCode: 500,
}
corehandlers.AfterRetryHandler.Fn(req)
if req.Error != nil {
t.Fatalf("expect no error, got %v", req.Error)
}
if e, a := 1, req.RetryCount; e != a {
t.Errorf("expect retry count to be %d, got %d", e, a)
}
}
func TestSendWithContextCanceled(t *testing.T) {
c := awstesting.NewClient(&aws.Config{
SleepDelay: func(dur time.Duration) {
t.Errorf("SleepDelay should not be called")
},
})
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{}, 0)}
req.SetContext(ctx)
req.Error = fmt.Errorf("some error")
req.Retryable = aws.Bool(true)
req.HTTPResponse = &http.Response{
StatusCode: 500,
}
close(ctx.DoneCh)
ctx.Error = fmt.Errorf("context canceled")
corehandlers.SendHandler.Fn(req)
if req.Error == nil {
t.Fatalf("expect error but didn't receive one")
}
aerr := req.Error.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q, error code got %q", e, a)
}
}
type testSendHandlerTransport struct{}
func (t *testSendHandlerTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("mock error")
}
func TestSendHandlerError(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{
HTTPClient: &http.Client{
Transport: &testSendHandlerTransport{},
},
})
svc.Handlers.Clear()
svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
r := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
r.Send()
if r.Error == nil {
t.Errorf("expect error, got none")
}
if r.HTTPResponse == nil {
t.Errorf("expect response, got none")
}
}
func TestSendWithoutFollowRedirects(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/original":
w.Header().Set("Location", "/redirected")
w.WriteHeader(301)
case "/redirected":
t.Fatalf("expect not to redirect, but was")
}
}))
svc := awstesting.NewClient(&aws.Config{
DisableSSL: aws.Bool(true),
Endpoint: aws.String(server.URL),
})
svc.Handlers.Clear()
svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
r := svc.NewRequest(&request.Operation{
Name: "Operation",
HTTPPath: "/original",
}, nil, nil)
r.DisableFollowRedirects = true
err := r.Send()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := 301, r.HTTPResponse.StatusCode; e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
}
func TestValidateReqSigHandler(t *testing.T) {
cases := []struct {
Req *request.Request
Resign bool
}{
{
Req: &request.Request{
Config: aws.Config{Credentials: credentials.AnonymousCredentials},
Time: time.Now().Add(-15 * time.Minute),
},
Resign: false,
},
{
Req: &request.Request{
Time: time.Now().Add(-15 * time.Minute),
},
Resign: true,
},
{
Req: &request.Request{
Time: time.Now().Add(-1 * time.Minute),
},
Resign: false,
},
}
for i, c := range cases {
resigned := false
c.Req.Handlers.Sign.PushBack(func(r *request.Request) {
resigned = true
})
corehandlers.ValidateReqSigHandler.Fn(c.Req)
if c.Req.Error != nil {
t.Errorf("expect no error, got %v", c.Req.Error)
}
if e, a := c.Resign, resigned; e != a {
t.Errorf("%d, expect %v to be %v", i, e, a)
}
}
}
func setupContentLengthTestServer(t *testing.T, hasContentLength bool, contentLength int64) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := r.Header["Content-Length"]
if e, a := hasContentLength, ok; e != a {
t.Errorf("expect %v to be %v", e, a)
}
if hasContentLength {
if e, a := contentLength, r.ContentLength; e != a {
t.Errorf("expect %v to be %v", e, a)
}
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
r.Body.Close()
authHeader := r.Header.Get("Authorization")
if hasContentLength {
if e, a := "content-length", authHeader; !strings.Contains(a, e) {
t.Errorf("expect %v to be in %v", e, a)
}
} else {
if e, a := "content-length", authHeader; strings.Contains(a, e) {
t.Errorf("expect %v to not be in %v", e, a)
}
}
if e, a := contentLength, int64(len(b)); e != a {
t.Errorf("expect %v to be %v", e, a)
}
}))
return server
}
func TestBuildContentLength_ZeroBody(t *testing.T) {
server := setupContentLengthTestServer(t, false, 0)
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
S3ForcePathStyle: aws.Bool(true),
DisableSSL: aws.Bool(true),
})
_, err := svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String("bucketname"),
Key: aws.String("keyname"),
})
if err != nil {
t.Errorf("expect no error, got %v", err)
}
}
func TestBuildContentLength_NegativeBody(t *testing.T) {
server := setupContentLengthTestServer(t, false, 0)
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
S3ForcePathStyle: aws.Bool(true),
DisableSSL: aws.Bool(true),
})
req, _ := svc.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String("bucketname"),
Key: aws.String("keyname"),
})
req.HTTPRequest.Header.Set("Content-Length", "-1")
if req.Error != nil {
t.Errorf("expect no error, got %v", req.Error)
}
}
func TestBuildContentLength_WithBody(t *testing.T) {
server := setupContentLengthTestServer(t, true, 1024)
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
S3ForcePathStyle: aws.Bool(true),
DisableSSL: aws.Bool(true),
})
_, err := svc.PutObject(&s3.PutObjectInput{
Bucket: aws.String("bucketname"),
Key: aws.String("keyname"),
Body: bytes.NewReader(make([]byte, 1024)),
})
if err != nil {
t.Errorf("expect no error, got %v", err)
}
}

View file

@ -1,286 +0,0 @@
package corehandlers_test
import (
"fmt"
"testing"
"reflect"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/kinesis"
)
var testSvc = func() *client.Client {
s := &client.Client{
Config: aws.Config{},
ClientInfo: metadata.ClientInfo{
ServiceName: "mock-service",
APIVersion: "2015-01-01",
},
}
return s
}()
type StructShape struct {
_ struct{} `type:"structure"`
RequiredList []*ConditionalStructShape `required:"true"`
RequiredMap map[string]*ConditionalStructShape `required:"true"`
RequiredBool *bool `required:"true"`
OptionalStruct *ConditionalStructShape
hiddenParameter *string
}
func (s *StructShape) Validate() error {
invalidParams := request.ErrInvalidParams{Context: "StructShape"}
if s.RequiredList == nil {
invalidParams.Add(request.NewErrParamRequired("RequiredList"))
}
if s.RequiredMap == nil {
invalidParams.Add(request.NewErrParamRequired("RequiredMap"))
}
if s.RequiredBool == nil {
invalidParams.Add(request.NewErrParamRequired("RequiredBool"))
}
if s.RequiredList != nil {
for i, v := range s.RequiredList {
if v == nil {
continue
}
if err := v.Validate(); err != nil {
invalidParams.AddNested(fmt.Sprintf("%s[%v]", "RequiredList", i), err.(request.ErrInvalidParams))
}
}
}
if s.RequiredMap != nil {
for i, v := range s.RequiredMap {
if v == nil {
continue
}
if err := v.Validate(); err != nil {
invalidParams.AddNested(fmt.Sprintf("%s[%v]", "RequiredMap", i), err.(request.ErrInvalidParams))
}
}
}
if s.OptionalStruct != nil {
if err := s.OptionalStruct.Validate(); err != nil {
invalidParams.AddNested("OptionalStruct", err.(request.ErrInvalidParams))
}
}
if invalidParams.Len() > 0 {
return invalidParams
}
return nil
}
type ConditionalStructShape struct {
_ struct{} `type:"structure"`
Name *string `required:"true"`
}
func (s *ConditionalStructShape) Validate() error {
invalidParams := request.ErrInvalidParams{Context: "ConditionalStructShape"}
if s.Name == nil {
invalidParams.Add(request.NewErrParamRequired("Name"))
}
if invalidParams.Len() > 0 {
return invalidParams
}
return nil
}
func TestNoErrors(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {Name: aws.String("Name")},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
if req.Error != nil {
t.Fatalf("expect no error, got %v", req.Error)
}
}
func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
if req.Error == nil {
t.Fatalf("expect error")
}
if e, a := "InvalidParameter", req.Error.(awserr.Error).Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "3 validation error(s) found.", req.Error.(awserr.Error).Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
errs := req.Error.(awserr.BatchedErrors).OrigErrs()
if e, a := 3, len(errs); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredList.", errs[0].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredMap.", errs[1].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredBool.", errs[2].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "InvalidParameter: 3 validation error(s) found.\n- missing required field, StructShape.RequiredList.\n- missing required field, StructShape.RequiredMap.\n- missing required field, StructShape.RequiredBool.\n", req.Error.Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestNestedMissingRequiredParameters(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{{}},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{},
}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
if req.Error == nil {
t.Fatalf("expect error")
}
if e, a := "InvalidParameter", req.Error.(awserr.Error).Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "3 validation error(s) found.", req.Error.(awserr.Error).Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
errs := req.Error.(awserr.BatchedErrors).OrigErrs()
if e, a := 3, len(errs); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredList[0].Name.", errs[0].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredMap[key2].Name.", errs[1].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.OptionalStruct.Name.", errs[2].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
type testInput struct {
StringField *string `min:"5"`
ListField []string `min:"3"`
MapField map[string]string `min:"4"`
}
func (s testInput) Validate() error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
if s.StringField != nil && len(*s.StringField) < 5 {
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
}
if s.ListField != nil && len(s.ListField) < 3 {
invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
}
if s.MapField != nil && len(s.MapField) < 4 {
invalidParams.Add(request.NewErrParamMinLen("MapField", 4))
}
if invalidParams.Len() > 0 {
return invalidParams
}
return nil
}
var testsFieldMin = []struct {
err awserr.Error
in testInput
}{
{
err: func() awserr.Error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
return invalidParams
}(),
in: testInput{StringField: aws.String("abcd")},
},
{
err: func() awserr.Error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
return invalidParams
}(),
in: testInput{StringField: aws.String("abcd"), ListField: []string{"a", "b"}},
},
{
err: func() awserr.Error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
invalidParams.Add(request.NewErrParamMinLen("MapField", 4))
return invalidParams
}(),
in: testInput{StringField: aws.String("abcd"), ListField: []string{"a", "b"}, MapField: map[string]string{"a": "a", "b": "b"}},
},
{
err: nil,
in: testInput{StringField: aws.String("abcde"),
ListField: []string{"a", "b", "c"}, MapField: map[string]string{"a": "a", "b": "b", "c": "c", "d": "d"}},
},
}
func TestValidateFieldMinParameter(t *testing.T) {
for i, c := range testsFieldMin {
req := testSvc.NewRequest(&request.Operation{}, &c.in, nil)
corehandlers.ValidateParametersHandler.Fn(req)
if e, a := c.err, req.Error; !reflect.DeepEqual(e,a) {
t.Errorf("%d, expect %v, got %v", i, e, a)
}
}
}
func BenchmarkValidateAny(b *testing.B) {
input := &kinesis.PutRecordsInput{
StreamName: aws.String("stream"),
}
for i := 0; i < 100; i++ {
record := &kinesis.PutRecordsRequestEntry{
Data: make([]byte, 10000),
PartitionKey: aws.String("partition"),
}
input.Records = append(input.Records, record)
}
req, _ := kinesis.New(unit.Session).PutRecordsRequest(input)
b.ResetTimer()
for i := 0; i < b.N; i++ {
corehandlers.ValidateParametersHandler.Fn(req)
if err := req.Error; err != nil {
b.Fatalf("validation failed: %v", err)
}
}
}

View file

@ -1,40 +0,0 @@
package corehandlers
import (
"net/http"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestAddHostExecEnvUserAgentHander(t *testing.T) {
cases := []struct {
ExecEnv string
Expect string
}{
{ExecEnv: "Lambda", Expect: "exec_env/Lambda"},
{ExecEnv: "", Expect: ""},
{ExecEnv: "someThingCool", Expect: "exec_env/someThingCool"},
}
for i, c := range cases {
os.Clearenv()
os.Setenv(execEnvVar, c.ExecEnv)
req := &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
}
AddHostExecEnvUserAgentHander.Fn(req)
if err := req.Error; err != nil {
t.Fatalf("%d, expect no error, got %v", i, err)
}
if e, a := c.Expect, req.HTTPRequest.Header.Get("User-Agent"); e != a {
t.Errorf("%d, expect %v user agent, got %v", i, e, a)
}
}
}

View file

@ -1,154 +0,0 @@
package credentials
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
type secondStubProvider struct {
creds Value
expired bool
err error
}
func (s *secondStubProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "secondStubProvider"
return s.creds, s.err
}
func (s *secondStubProvider) IsExpired() bool {
return s.expired
}
func TestChainProviderWithNames(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&secondStubProvider{
creds: Value{
AccessKeyID: "AKIF",
SecretAccessKey: "NOSECRET",
SessionToken: "",
},
},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "secondStubProvider", creds.ProviderName, "Expect provider name to match")
// Also check credentials
assert.Equal(t, "AKIF", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "NOSECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestChainProviderGet(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestChainProviderIsExpired(t *testing.T) {
stubProvider := &stubProvider{expired: true}
p := &ChainProvider{
Providers: []Provider{
stubProvider,
},
}
assert.True(t, p.IsExpired(), "Expect expired to be true before any Retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
stubProvider.expired = true
assert.True(t, p.IsExpired(), "Expect return of expired provider")
_, err = p.Retrieve()
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
}
func TestChainProviderWithNoProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t,
ErrNoValidProvidersFoundInChain,
err,
"Expect no providers error returned")
}
func TestChainProviderWithNoValidProvider(t *testing.T) {
errs := []error{
awserr.New("FirstError", "first provider error", nil),
awserr.New("SecondError", "second provider error", nil),
}
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: errs[0]},
&stubProvider{err: errs[1]},
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t,
ErrNoValidProvidersFoundInChain,
err,
"Expect no providers error returned")
}
func TestChainProviderWithNoValidProviderWithVerboseEnabled(t *testing.T) {
errs := []error{
awserr.New("FirstError", "first provider error", nil),
awserr.New("SecondError", "second provider error", nil),
}
p := &ChainProvider{
VerboseErrors: true,
Providers: []Provider{
&stubProvider{err: errs[0]},
&stubProvider{err: errs[1]},
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t,
awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs),
err,
"Expect no providers error returned")
}

View file

@ -1,90 +0,0 @@
// +build go1.9
package credentials
import (
"fmt"
"strconv"
"sync"
"testing"
"time"
)
func BenchmarkCredentials_Get(b *testing.B) {
stub := &stubProvider{}
cases := []int{1, 10, 100, 500, 1000, 10000}
for _, c := range cases {
b.Run(strconv.Itoa(c), func(b *testing.B) {
creds := NewCredentials(stub)
var wg sync.WaitGroup
wg.Add(c)
for i := 0; i < c; i++ {
go func() {
for j := 0; j < b.N; j++ {
v, err := creds.Get()
if err != nil {
b.Fatalf("expect no error %v, %v", v, err)
}
}
wg.Done()
}()
}
b.ResetTimer()
wg.Wait()
})
}
}
func BenchmarkCredentials_Get_Expire(b *testing.B) {
p := &blockProvider{}
expRates := []int{10000, 1000, 100}
cases := []int{1, 10, 100, 500, 1000, 10000}
for _, expRate := range expRates {
for _, c := range cases {
b.Run(fmt.Sprintf("%d-%d", expRate, c), func(b *testing.B) {
creds := NewCredentials(p)
var wg sync.WaitGroup
wg.Add(c)
for i := 0; i < c; i++ {
go func(id int) {
for j := 0; j < b.N; j++ {
v, err := creds.Get()
if err != nil {
b.Fatalf("expect no error %v, %v", v, err)
}
// periodically expire creds to cause rwlock
if id == 0 && j%expRate == 0 {
creds.Expire()
}
}
wg.Done()
}(i)
}
b.ResetTimer()
wg.Wait()
})
}
}
}
type blockProvider struct {
creds Value
expired bool
err error
}
func (s *blockProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "blockProvider"
time.Sleep(time.Millisecond)
return s.creds, s.err
}
func (s *blockProvider) IsExpired() bool {
return s.expired
}

View file

@ -1,73 +0,0 @@
package credentials
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
type stubProvider struct {
creds Value
expired bool
err error
}
func (s *stubProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "stubProvider"
return s.creds, s.err
}
func (s *stubProvider) IsExpired() bool {
return s.expired
}
func TestCredentialsGet(t *testing.T) {
c := NewCredentials(&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
expired: true,
})
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
_, err := c.Get()
assert.Equal(t, "provider error", err.(awserr.Error).Code(), "Expected provider error")
}
func TestCredentialsExpire(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
stub.expired = false
assert.True(t, c.IsExpired(), "Expected to start out expired")
c.Expire()
assert.True(t, c.IsExpired(), "Expected to be expired")
c.forceRefresh = false
assert.False(t, c.IsExpired(), "Expected not to be expired")
stub.expired = true
assert.True(t, c.IsExpired(), "Expected to be expired")
}
func TestCredentialsGetWithProviderName(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, creds.ProviderName, "stubProvider", "Expected provider name to match")
}

View file

@ -1,159 +0,0 @@
package ec2rolecreds_test
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
const credsRespTmpl = `{
"Code": "Success",
"Type": "AWS-HMAC",
"AccessKeyId" : "accessKey",
"SecretAccessKey" : "secret",
"Token" : "token",
"Expiration" : "%s",
"LastUpdated" : "2009-11-23T0:00:00Z"
}`
const credsFailRespTmpl = `{
"Code": "ErrorCode",
"Message": "ErrorMsg",
"LastUpdated": "2009-11-23T0:00:00Z"
}`
func initTestServer(expireOn string, failAssume bool) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/latest/meta-data/iam/security-credentials" {
fmt.Fprintln(w, "RoleName")
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
if failAssume {
fmt.Fprintf(w, credsFailRespTmpl)
} else {
fmt.Fprintf(w, credsRespTmpl, expireOn)
}
} else {
http.Error(w, "bad request", http.StatusBadRequest)
}
}))
return server
}
func TestEC2RoleProvider(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderFailAssume(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", true)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
creds, err := p.Retrieve()
assert.Error(t, err, "Expect error")
e := err.(awserr.Error)
assert.Equal(t, "ErrorCode", e.Code())
assert.Equal(t, "ErrorMsg", e.Message())
assert.Nil(t, e.OrigErr())
assert.Equal(t, "", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
ExpiryWindow: time.Hour * 1,
}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func BenchmarkEC3RoleProvider(b *testing.B) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := p.Retrieve(); err != nil {
b.Fatal(err)
}
}
}

View file

@ -1,111 +0,0 @@
package endpointcreds_test
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/stretchr/testify/assert"
)
func TestRetrieveRefreshableCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/path/to/endpoint", r.URL.Path)
assert.Equal(t, "application/json", r.Header.Get("Accept"))
assert.Equal(t, "else", r.URL.Query().Get("something"))
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET",
"Token": "TOKEN",
"Expiration": time.Now().Add(1 * time.Hour),
})
if err != nil {
fmt.Println("failed to write out creds", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config,
unit.Session.Handlers,
server.URL+"/path/to/endpoint?something=else",
)
creds, err := client.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Equal(t, "TOKEN", creds.SessionToken)
assert.False(t, client.IsExpired())
client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}
assert.True(t, client.IsExpired())
}
func TestRetrieveStaticCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET",
})
if err != nil {
fmt.Println("failed to write out creds", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
creds, err := client.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.False(t, client.IsExpired())
}
func TestFailedRetrieveCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(400)
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"Code": "Error",
"Message": "Message",
})
if err != nil {
fmt.Println("failed to write error", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
creds, err := client.Retrieve()
assert.Error(t, err)
aerr := err.(awserr.Error)
assert.Equal(t, "CredentialsEndpointError", aerr.Code())
assert.Equal(t, "failed to load credentials", aerr.Message())
aerr = aerr.OrigErr().(awserr.Error)
assert.Equal(t, "Error", aerr.Code())
assert.Equal(t, "Message", aerr.Message())
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.True(t, client.IsExpired())
}

View file

@ -1,70 +0,0 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestEnvProviderRetrieve(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "access", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEnvProviderIsExpired(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
assert.True(t, e.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, e.IsExpired(), "Expect creds to not be expired after retrieve.")
}
func TestEnvProviderNoAccessKeyID(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrAccessKeyIDNotFound, err, "ErrAccessKeyIDNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderNoSecretAccessKey(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrSecretAccessKeyNotFound, err, "ErrSecretAccessKeyNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderAlternateNames(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY", "access")
os.Setenv("AWS_SECRET_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "access", creds.AccessKeyID, "Expected access key ID")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expected secret access key")
assert.Empty(t, creds.SessionToken, "Expected no token")
}

View file

@ -1,136 +0,0 @@
package credentials
import (
"os"
"path/filepath"
"testing"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
"github.com/stretchr/testify/assert"
)
func TestSharedCredentialsProvider(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderIsExpired(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve")
}
func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "example.ini")
p := SharedCredentialsProvider{}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILEAbsPath(t *testing.T) {
os.Clearenv()
wd, err := os.Getwd()
assert.NoError(t, err)
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join(wd, "example.ini"))
p := SharedCredentialsProvider{}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "no_token")
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func TestSharedCredentialsProviderWithoutTokenFromProfile(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "no_token"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func TestSharedCredentialsProviderColonInCredFile(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "with_colon"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func TestSharedCredentialsProvider_DefaultFilename(t *testing.T) {
os.Clearenv()
os.Setenv("USERPROFILE", "profile_dir")
os.Setenv("HOME", "home_dir")
// default filename and profile
p := SharedCredentialsProvider{}
filename, err := p.filename()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := shareddefaults.SharedCredentialsFilename(), filename; e != a {
t.Errorf("expect %q filename, got %q", e, a)
}
}
func BenchmarkSharedCredentialsProvider(b *testing.B) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
}

View file

@ -1,34 +0,0 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestStaticProviderGet(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
creds, err := s.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no session token")
}
func TestStaticProviderIsExpired(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
assert.False(t, s.IsExpired(), "Expect static credentials to never expire")
}

View file

@ -1,150 +0,0 @@
package stscreds
import (
"fmt"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
)
type stubSTS struct {
TestInput func(*sts.AssumeRoleInput)
}
func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
if s.TestInput != nil {
s.TestInput(input)
}
expiry := time.Now().Add(60 * time.Minute)
return &sts.AssumeRoleOutput{
Credentials: &sts.Credentials{
// Just reflect the role arn to the provider.
AccessKeyId: input.RoleArn,
SecretAccessKey: aws.String("assumedSecretAccessKey"),
SessionToken: aws.String("assumedSessionToken"),
Expiration: &expiry,
},
}, nil
}
func TestAssumeRoleProvider(t *testing.T) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func TestAssumeRoleProvider_WithTokenCode(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Equal(t, "0123456789", *in.SerialNumber)
assert.Equal(t, "code", *in.TokenCode)
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
TokenCode: aws.String("code"),
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Equal(t, "0123456789", *in.SerialNumber)
assert.Equal(t, "code", *in.TokenCode)
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
TokenProvider: func() (string, error) {
return "code", nil
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func TestAssumeRoleProvider_WithTokenProviderError(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Fail(t, "API request should not of been called")
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
TokenProvider: func() (string, error) {
return "", fmt.Errorf("error occurred")
},
}
creds, err := p.Retrieve()
assert.Error(t, err)
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
}
func TestAssumeRoleProvider_MFAWithNoToken(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Fail(t, "API request should not of been called")
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
}
creds, err := p.Retrieve()
assert.Error(t, err)
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
}
func BenchmarkAssumeRoleProvider(b *testing.B) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := p.Retrieve(); err != nil {
b.Fatal(err)
}
}
}

View file

@ -1,74 +0,0 @@
package csm
import (
"encoding/json"
"fmt"
"net"
"testing"
)
func startUDPServer(done chan struct{}, fn func([]byte)) (string, error) {
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
return "", err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return "", err
}
buf := make([]byte, 1024)
go func() {
defer conn.Close()
for {
select {
case <-done:
return
default:
}
n, _, err := conn.ReadFromUDP(buf)
fn(buf[:n])
if err != nil {
panic(err)
}
}
}()
return conn.LocalAddr().String(), nil
}
func TestDifferentParams(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("expected panic with different parameters")
}
}()
Start("clientID2", ":0")
}
var MetricsCh = make(chan map[string]interface{}, 1)
var Done = make(chan struct{})
func init() {
url, err := startUDPServer(Done, func(b []byte) {
m := map[string]interface{}{}
if err := json.Unmarshal(b, &m); err != nil {
panic(fmt.Sprintf("expected no error, but received %v", err))
}
MetricsCh <- m
})
if err != nil {
panic(err)
}
_, err = Start("clientID", url)
if err != nil {
panic(err)
}
}

View file

@ -1,40 +0,0 @@
package csm_test
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
)
func ExampleStart() {
r, err := csm.Start("clientID", ":31000")
if err != nil {
panic(fmt.Errorf("failed starting CSM: %v", err))
}
sess, err := session.NewSession(&aws.Config{})
if err != nil {
panic(fmt.Errorf("failed loading session: %v", err))
}
r.InjectHandlers(&sess.Handlers)
client := s3.New(sess)
client.GetObject(&s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
// Pauses monitoring
r.Pause()
client.GetObject(&s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
// Resume monitoring
r.Continue()
}

View file

@ -1,72 +0,0 @@
package csm
import (
"testing"
)
func TestMetricChanPush(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
pushed := ch.Push(metric{})
if !pushed {
t.Errorf("expected metrics to be pushed")
}
if e, a := 1, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanPauseContinue(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
ch.Pause()
if !ch.IsPaused() {
t.Errorf("expected to be paused, but did not pause properly")
}
ch.Continue()
if ch.IsPaused() {
t.Errorf("expected to be not paused, but did not continue properly")
}
pushed := ch.Push(metric{})
if !pushed {
t.Errorf("expected metrics to be pushed")
}
if e, a := 1, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanPushWhenPaused(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
ch.Pause()
pushed := ch.Push(metric{})
if pushed {
t.Errorf("expected metrics to not be pushed")
}
if e, a := 0, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanNonBlocking(t *testing.T) {
ch := newMetricChan(0)
defer close(ch.ch)
pushed := ch.Push(metric{})
if pushed {
t.Errorf("expected metrics to be not pushed")
}
if e, a := 0, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}

View file

@ -1,249 +0,0 @@
package csm_test
import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
)
func startUDPServer(done chan struct{}, fn func([]byte)) (string, error) {
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
return "", err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return "", err
}
buf := make([]byte, 1024)
i := 0
go func() {
defer conn.Close()
for {
i++
select {
case <-done:
return
default:
}
n, _, err := conn.ReadFromUDP(buf)
fn(buf[:n])
if err != nil {
panic(err)
}
}
}()
return conn.LocalAddr().String(), nil
}
func TestReportingMetrics(t *testing.T) {
reporter := csm.Get()
if reporter == nil {
t.Errorf("expected non-nil reporter")
}
sess := session.New()
sess.Handlers.Clear()
reporter.InjectHandlers(&sess.Handlers)
md := metadata.ClientInfo{}
op := &request.Operation{}
r := request.New(*sess.Config, md, sess.Handlers, client.DefaultRetryer{NumMaxRetries: 0}, op, nil, nil)
sess.Handlers.Complete.Run(r)
foundAttempt := false
foundCall := false
expectedMetrics := 2
for i := 0; i < expectedMetrics; i++ {
m := <-csm.MetricsCh
for k, v := range m {
switch k {
case "Type":
a := v.(string)
foundCall = foundCall || a == "ApiCall"
foundAttempt = foundAttempt || a == "ApiCallAttempt"
if prefix := "ApiCall"; !strings.HasPrefix(a, prefix) {
t.Errorf("expected 'APICall' prefix, but received %q", a)
}
}
}
}
if !foundAttempt {
t.Errorf("expected attempt event to have occurred")
}
if !foundCall {
t.Errorf("expected call event to have occurred")
}
}
type mockService struct {
*client.Client
}
type input struct{}
type output struct{}
func (s *mockService) Request(i input) *request.Request {
op := &request.Operation{
Name: "foo",
HTTPMethod: "POST",
HTTPPath: "/",
}
o := output{}
req := s.NewRequest(op, &i, &o)
return req
}
func BenchmarkWithCSM(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := session.New(&cfg)
r := csm.Get()
r.InjectHandlers(&sess.Handlers)
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}
func BenchmarkWithCSMNoUDPConnection(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := session.New(&cfg)
r := csm.Get()
r.Pause()
r.InjectHandlers(&sess.Handlers)
defer r.Pause()
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}
func BenchmarkWithoutCSM(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := session.New(&cfg)
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}

View file

@ -1,116 +0,0 @@
package defaults
import (
"fmt"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestHTTPCredProvider(t *testing.T) {
origFn := lookupHostFn
defer func() { lookupHostFn = origFn }()
lookupHostFn = func(host string) ([]string, error) {
m := map[string]struct {
Addrs []string
Err error
}{
"localhost": {Addrs: []string{"::1", "127.0.0.1"}},
"actuallylocal": {Addrs: []string{"127.0.0.2"}},
"notlocal": {Addrs: []string{"::1", "127.0.0.1", "192.168.1.10"}},
"www.example.com": {Addrs: []string{"10.10.10.10"}},
}
h, ok := m[host]
if !ok {
t.Fatalf("unknown host in test, %v", host)
return nil, fmt.Errorf("unknown host")
}
return h.Addrs, h.Err
}
cases := []struct {
Host string
Fail bool
}{
{"localhost", false},
{"actuallylocal", false},
{"127.0.0.1", false},
{"127.1.1.1", false},
{"[::1]", false},
{"www.example.com", true},
{"169.254.170.2", true},
}
defer os.Clearenv()
for i, c := range cases {
u := fmt.Sprintf("http://%s/abc/123", c.Host)
os.Setenv(httpProviderEnvVar, u)
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
t.Fatalf("%d, expect provider not to be nil, but was", i)
}
if c.Fail {
creds, err := provider.Retrieve()
if err == nil {
t.Fatalf("%d, expect error but got none", i)
} else {
aerr := err.(awserr.Error)
if e, a := "CredentialsEndpointError", aerr.Code(); e != a {
t.Errorf("%d, expect %s error code, got %s", i, e, a)
}
}
if e, a := endpointcreds.ProviderName, creds.ProviderName; e != a {
t.Errorf("%d, expect %s provider name got %s", i, e, a)
}
} else {
httpProvider := provider.(*endpointcreds.Provider)
if e, a := u, httpProvider.Client.Endpoint; e != a {
t.Errorf("%d, expect %q endpoint, got %q", i, e, a)
}
}
}
}
func TestECSCredProvider(t *testing.T) {
defer os.Clearenv()
os.Setenv(ecsCredsProviderEnvVar, "/abc/123")
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
httpProvider := provider.(*endpointcreds.Provider)
if httpProvider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
if e, a := "http://169.254.170.2/abc/123", httpProvider.Client.Endpoint; e != a {
t.Errorf("expect %q endpoint, got %q", e, a)
}
}
func TestDefaultEC2RoleProvider(t *testing.T) {
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
ec2Provider := provider.(*ec2rolecreds.EC2RoleProvider)
if ec2Provider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
if e, a := "http://169.254.169.254/latest", ec2Provider.Client.Endpoint; e != a {
t.Errorf("expect %q endpoint, got %q", e, a)
}
}

View file

@ -1,289 +0,0 @@
package ec2metadata_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"path"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
const instanceIdentityDocument = `{
"devpayProductCodes" : null,
"availabilityZone" : "us-east-1d",
"privateIp" : "10.158.112.84",
"version" : "2010-08-31",
"region" : "us-east-1",
"instanceId" : "i-1234567890abcdef0",
"billingProducts" : null,
"instanceType" : "t1.micro",
"accountId" : "123456789012",
"pendingTime" : "2015-11-19T16:32:11Z",
"imageId" : "ami-5fb8c835",
"kernelId" : "aki-919dcaf8",
"ramdiskId" : null,
"architecture" : "x86_64"
}`
const validIamInfo = `{
"Code" : "Success",
"LastUpdated" : "2016-03-17T12:27:32Z",
"InstanceProfileArn" : "arn:aws:iam::123456789012:instance-profile/my-instance-profile",
"InstanceProfileId" : "AIPAABCDEFGHIJKLMN123"
}`
const unsuccessfulIamInfo = `{
"Code" : "Failed",
"LastUpdated" : "2016-03-17T12:27:32Z",
"InstanceProfileArn" : "arn:aws:iam::123456789012:instance-profile/my-instance-profile",
"InstanceProfileId" : "AIPAABCDEFGHIJKLMN123"
}`
func initTestServer(path string, resp string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI != path {
http.Error(w, "not found", http.StatusNotFound)
return
}
w.Write([]byte(resp))
}))
}
func TestEndpoint(t *testing.T) {
c := ec2metadata.New(unit.Session)
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", "testpath"),
}
req := c.NewRequest(op, nil, nil)
if e, a := "http://169.254.169.254/latest", req.ClientInfo.Endpoint; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetMetadata(t *testing.T) {
server := initTestServer(
"/latest/meta-data/some/path",
"success", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetMetadata("some/path")
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "success", resp; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetUserData(t *testing.T) {
server := initTestServer(
"/latest/user-data",
"success", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetUserData()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "success", resp; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetUserData_Error(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reader := strings.NewReader(`<?xml version="1.0" encoding="iso-8859-1"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en" lang="en">
<head>
<title>404 - Not Found</title>
</head>
<body>
<h1>404 - Not Found</h1>
</body>
</html>`)
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Length", fmt.Sprintf("%d", reader.Len()))
w.WriteHeader(http.StatusNotFound)
io.Copy(w, reader)
}))
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetUserData()
if err == nil {
t.Errorf("expect error")
}
if len(resp) != 0 {
t.Errorf("expect empty, got %v", resp)
}
aerr := err.(awserr.Error)
if e, a := "NotFoundError", aerr.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetRegion(t *testing.T) {
server := initTestServer(
"/latest/meta-data/placement/availability-zone",
"us-west-2a", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
region, err := c.Region()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "us-west-2", region; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestMetadataAvailable(t *testing.T) {
server := initTestServer(
"/latest/meta-data/instance-id",
"instance-id",
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
if !c.Available() {
t.Errorf("expect available")
}
}
func TestMetadataIAMInfo_success(t *testing.T) {
server := initTestServer(
"/latest/meta-data/iam/info",
validIamInfo,
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
iamInfo, err := c.IAMInfo()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "Success", iamInfo.Code; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "arn:aws:iam::123456789012:instance-profile/my-instance-profile", iamInfo.InstanceProfileArn; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "AIPAABCDEFGHIJKLMN123", iamInfo.InstanceProfileID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestMetadataIAMInfo_failure(t *testing.T) {
server := initTestServer(
"/latest/meta-data/iam/info",
unsuccessfulIamInfo,
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
iamInfo, err := c.IAMInfo()
if err == nil {
t.Errorf("expect error")
}
if e, a := "", iamInfo.Code; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "", iamInfo.InstanceProfileArn; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "", iamInfo.InstanceProfileID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestMetadataNotAvailable(t *testing.T) {
c := ec2metadata.New(unit.Session)
c.Handlers.Send.Clear()
c.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: int(0),
Status: http.StatusText(int(0)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
r.Error = awserr.New("RequestError", "send request failed", nil)
r.Retryable = aws.Bool(true) // network errors are retryable
})
if c.Available() {
t.Errorf("expect not available")
}
}
func TestMetadataErrorResponse(t *testing.T) {
c := ec2metadata.New(unit.Session)
c.Handlers.Send.Clear()
c.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest),
Body: ioutil.NopCloser(strings.NewReader("error message text")),
}
r.Retryable = aws.Bool(false) // network errors are retryable
})
data, err := c.GetMetadata("uri/path")
if len(data) != 0 {
t.Errorf("expect empty, got %v", data)
}
if e, a := "error message text", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %v to be in %v", e, a)
}
}
func TestEC2RoleProviderInstanceIdentity(t *testing.T) {
server := initTestServer(
"/latest/dynamic/instance-identity/document",
instanceIdentityDocument,
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
doc, err := c.GetInstanceIdentityDocument()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := doc.AccountID, "123456789012"; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := doc.AvailabilityZone, "us-east-1d"; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := doc.Region, "us-east-1"; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}

View file

@ -1,120 +0,0 @@
package ec2metadata_test
import (
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
func TestClientOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e == a {
t.Errorf("expect %v, not to equal %v", e, a)
}
if e, a := 5*time.Second, svc.Config.HTTPClient.Timeout; e != a {
t.Errorf("expect %v to be %v", e, a)
}
}
func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {
http.DefaultClient.Transport = &http.Transport{}
defer func() {
http.DefaultClient.Transport = nil
}()
svc := ec2metadata.New(unit.Session)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
t.Errorf("expect %v, got %v", e, a)
}
tr := svc.Config.HTTPClient.Transport.(*http.Transport)
if tr == nil {
t.Fatalf("expect transport not to be nil")
}
if tr.Dial != nil {
t.Errorf("expect dial to be nil, was not")
}
}
func TestClientDisableOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session, aws.NewConfig().WithEC2MetadataDisableTimeoutOverride(true))
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestClientOverrideDefaultHTTPClientTimeoutRace(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("us-east-1a"))
}))
cfg := aws.NewConfig().WithEndpoint(server.URL)
runEC2MetadataClients(t, cfg, 100)
}
func TestClientOverrideDefaultHTTPClientTimeoutRaceWithTransport(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("us-east-1a"))
}))
cfg := aws.NewConfig().WithEndpoint(server.URL).WithHTTPClient(&http.Client{
Transport: http.DefaultTransport,
})
runEC2MetadataClients(t, cfg, 100)
}
func TestClientDisableIMDS(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
os.Setenv("AWS_EC2_METADATA_DISABLED", "true")
svc := ec2metadata.New(unit.Session)
resp, err := svc.Region()
if err == nil {
t.Fatalf("expect error, got none")
}
if len(resp) != 0 {
t.Errorf("expect no response, got %v", resp)
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %v error code, got %v", e, a)
}
if e, a := "AWS_EC2_METADATA_DISABLED", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expect %v in error message, got %v", e, a)
}
}
func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
var wg sync.WaitGroup
wg.Add(atOnce)
for i := 0; i < atOnce; i++ {
go func() {
svc := ec2metadata.New(unit.Session, cfg)
_, err := svc.Region()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
wg.Done()
}()
}
wg.Wait()
}

View file

@ -1,117 +0,0 @@
package endpoints
import (
"strings"
"testing"
)
func TestDecodeEndpoints_V3(t *testing.T) {
const v3Doc = `
{
"version": 3,
"partitions": [
{
"defaults": {
"hostname": "{service}.{region}.{dnsSuffix}",
"protocols": [
"https"
],
"signatureVersions": [
"v4"
]
},
"dnsSuffix": "amazonaws.com",
"partition": "aws",
"partitionName": "AWS Standard",
"regionRegex": "^(us|eu|ap|sa|ca)\\-\\w+\\-\\d+$",
"regions": {
"ap-northeast-1": {
"description": "Asia Pacific (Tokyo)"
}
},
"services": {
"acm": {
"endpoints": {
"ap-northeast-1": {}
}
},
"s3": {
"endpoints": {
"ap-northeast-1": {}
}
}
}
}
]
}`
resolver, err := DecodeModel(strings.NewReader(v3Doc))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
endpoint, err := resolver.EndpointFor("acm", "ap-northeast-1")
if err != nil {
t.Fatalf("failed to resolve endpoint, %v", err)
}
if a, e := endpoint.URL, "https://acm.ap-northeast-1.amazonaws.com"; a != e {
t.Errorf("expected %q URL got %q", e, a)
}
p := resolver.(partitions)[0]
s3Defaults := p.Services["s3"].Defaults
if a, e := s3Defaults.HasDualStack, boxedTrue; a != e {
t.Errorf("expect s3 service to have dualstack enabled")
}
if a, e := s3Defaults.DualStackHostname, "{service}.dualstack.{region}.{dnsSuffix}"; a != e {
t.Errorf("expect s3 dualstack host pattern to be %q, got %q", e, a)
}
ec2metaEndpoint := p.Services["ec2metadata"].Endpoints["aws-global"]
if a, e := ec2metaEndpoint.Hostname, "169.254.169.254/latest"; a != e {
t.Errorf("expect ec2metadata host to be %q, got %q", e, a)
}
}
func TestDecodeEndpoints_NoPartitions(t *testing.T) {
const doc = `{ "version": 3 }`
resolver, err := DecodeModel(strings.NewReader(doc))
if err == nil {
t.Fatalf("expected error")
}
if resolver != nil {
t.Errorf("expect resolver to be nil")
}
}
func TestDecodeEndpoints_UnsupportedVersion(t *testing.T) {
const doc = `{ "version": 2 }`
resolver, err := DecodeModel(strings.NewReader(doc))
if err == nil {
t.Fatalf("expected error decoding model")
}
if resolver != nil {
t.Errorf("expect resolver to be nil")
}
}
func TestDecodeModelOptionsSet(t *testing.T) {
var actual DecodeModelOptions
actual.Set(func(o *DecodeModelOptions) {
o.SkipCustomizations = true
})
expect := DecodeModelOptions{
SkipCustomizations: true,
}
if actual != expect {
t.Errorf("expect %v options got %v", expect, actual)
}
}

View file

@ -1,342 +0,0 @@
package endpoints
import "testing"
func TestEnumDefaultPartitions(t *testing.T) {
resolver := DefaultResolver()
enum, ok := resolver.(EnumPartitions)
if ok != true {
t.Fatalf("resolver must satisfy EnumPartition interface")
}
ps := enum.Partitions()
if a, e := len(ps), len(defaultPartitions); a != e {
t.Errorf("expected %d partitions, got %d", e, a)
}
}
func TestEnumDefaultRegions(t *testing.T) {
expectPart := defaultPartitions[0]
partEnum := defaultPartitions[0].Partition()
regEnum := partEnum.Regions()
if a, e := len(regEnum), len(expectPart.Regions); a != e {
t.Errorf("expected %d regions, got %d", e, a)
}
}
func TestEnumPartitionServices(t *testing.T) {
expectPart := testPartitions[0]
partEnum := testPartitions[0].Partition()
if a, e := partEnum.ID(), "part-id"; a != e {
t.Errorf("expect %q partition ID, got %q", e, a)
}
svcEnum := partEnum.Services()
if a, e := len(svcEnum), len(expectPart.Services); a != e {
t.Errorf("expected %d regions, got %d", e, a)
}
}
func TestEnumRegionServices(t *testing.T) {
p := testPartitions[0].Partition()
rs := p.Regions()
if a, e := len(rs), 2; a != e {
t.Errorf("expect %d regions got %d", e, a)
}
if _, ok := rs["us-east-1"]; !ok {
t.Errorf("expect us-east-1 region to be found, was not")
}
if _, ok := rs["us-west-2"]; !ok {
t.Errorf("expect us-west-2 region to be found, was not")
}
r := rs["us-east-1"]
if a, e := r.ID(), "us-east-1"; a != e {
t.Errorf("expect %q region ID, got %q", e, a)
}
if a, e := r.Description(), "region description"; a != e {
t.Errorf("expect %q region Description, got %q", e, a)
}
ss := r.Services()
if a, e := len(ss), 1; a != e {
t.Errorf("expect %d services for us-east-1, got %d", e, a)
}
if _, ok := ss["service1"]; !ok {
t.Errorf("expect service1 service to be found, was not")
}
resolved, err := r.ResolveEndpoint("service1")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service1.us-east-1.amazonaws.com"; a != e {
t.Errorf("expect %q resolved URL, got %q", e, a)
}
}
func TestEnumServiceRegions(t *testing.T) {
p := testPartitions[0].Partition()
rs := p.Services()["service1"].Regions()
if e, a := 2, len(rs); e != a {
t.Errorf("expect %d regions, got %d", e, a)
}
if _, ok := rs["us-east-1"]; !ok {
t.Errorf("expect region to be found")
}
if _, ok := rs["us-west-2"]; !ok {
t.Errorf("expect region to be found")
}
}
func TestEnumServicesEndpoints(t *testing.T) {
p := testPartitions[0].Partition()
ss := p.Services()
if a, e := len(ss), 5; a != e {
t.Errorf("expect %d regions got %d", e, a)
}
if _, ok := ss["service1"]; !ok {
t.Errorf("expect service1 region to be found, was not")
}
if _, ok := ss["service2"]; !ok {
t.Errorf("expect service2 region to be found, was not")
}
s := ss["service1"]
if a, e := s.ID(), "service1"; a != e {
t.Errorf("expect %q service ID, got %q", e, a)
}
resolved, err := s.ResolveEndpoint("us-west-2")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service1.us-west-2.amazonaws.com"; a != e {
t.Errorf("expect %q resolved URL, got %q", e, a)
}
}
func TestEnumEndpoints(t *testing.T) {
p := testPartitions[0].Partition()
s := p.Services()["service1"]
es := s.Endpoints()
if a, e := len(es), 2; a != e {
t.Errorf("expect %d endpoints for service2, got %d", e, a)
}
if _, ok := es["us-east-1"]; !ok {
t.Errorf("expect us-east-1 to be found, was not")
}
e := es["us-east-1"]
if a, e := e.ID(), "us-east-1"; a != e {
t.Errorf("expect %q endpoint ID, got %q", e, a)
}
if a, e := e.ServiceID(), "service1"; a != e {
t.Errorf("expect %q service ID, got %q", e, a)
}
resolved, err := e.ResolveEndpoint()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service1.us-east-1.amazonaws.com"; a != e {
t.Errorf("expect %q resolved URL, got %q", e, a)
}
}
func TestResolveEndpointForPartition(t *testing.T) {
enum := testPartitions.Partitions()[0]
expected, err := testPartitions.EndpointFor("service1", "us-east-1")
actual, err := enum.EndpointFor("service1", "us-east-1")
if err != nil {
t.Fatalf("unexpected error, %v", err)
}
if expected != actual {
t.Errorf("expect resolved endpoint to be %v, but got %v", expected, actual)
}
}
func TestAddScheme(t *testing.T) {
cases := []struct {
In string
Expect string
DisableSSL bool
}{
{
In: "https://example.com",
Expect: "https://example.com",
},
{
In: "example.com",
Expect: "https://example.com",
},
{
In: "http://example.com",
Expect: "http://example.com",
},
{
In: "example.com",
Expect: "http://example.com",
DisableSSL: true,
},
{
In: "https://example.com",
Expect: "https://example.com",
DisableSSL: true,
},
}
for i, c := range cases {
actual := AddScheme(c.In, c.DisableSSL)
if actual != c.Expect {
t.Errorf("%d, expect URL to be %q, got %q", i, c.Expect, actual)
}
}
}
func TestResolverFunc(t *testing.T) {
var resolver Resolver
resolver = ResolverFunc(func(s, r string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return ResolvedEndpoint{
URL: "https://service.region.dnssuffix.com",
SigningRegion: "region",
SigningName: "service",
}, nil
})
resolved, err := resolver.EndpointFor("service", "region", func(o *Options) {
o.DisableSSL = true
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service.region.dnssuffix.com"; a != e {
t.Errorf("expect %q endpoint URL, got %q", e, a)
}
if a, e := resolved.SigningRegion, "region"; a != e {
t.Errorf("expect %q region, got %q", e, a)
}
if a, e := resolved.SigningName, "service"; a != e {
t.Errorf("expect %q signing name, got %q", e, a)
}
}
func TestOptionsSet(t *testing.T) {
var actual Options
actual.Set(DisableSSLOption, UseDualStackOption, StrictMatchingOption)
expect := Options{
DisableSSL: true,
UseDualStack: true,
StrictMatching: true,
}
if actual != expect {
t.Errorf("expect %v options got %v", expect, actual)
}
}
func TestRegionsForService(t *testing.T) {
ps := DefaultPartitions()
var expect map[string]Region
var serviceID string
for _, s := range ps[0].Services() {
expect = s.Regions()
serviceID = s.ID()
if len(expect) > 0 {
break
}
}
actual, ok := RegionsForService(ps, ps[0].ID(), serviceID)
if !ok {
t.Fatalf("expect regions to be found, was not")
}
if len(actual) == 0 {
t.Fatalf("expect service %s to have regions", serviceID)
}
if e, a := len(expect), len(actual); e != a {
t.Fatalf("expect %d regions, got %d", e, a)
}
for id, r := range actual {
if e, a := id, r.ID(); e != a {
t.Errorf("expect %s region id, got %s", e, a)
}
if _, ok := expect[id]; !ok {
t.Errorf("expect %s region to be found", id)
}
if a, e := r.Description(), expect[id].desc; a != e {
t.Errorf("expect %q region Description, got %q", e, a)
}
}
}
func TestRegionsForService_NotFound(t *testing.T) {
ps := testPartitions.Partitions()
actual, ok := RegionsForService(ps, ps[0].ID(), "service-not-exists")
if ok {
t.Fatalf("expect no regions to be found, but were")
}
if len(actual) != 0 {
t.Errorf("expect no regions, got %v", actual)
}
}
func TestPartitionForRegion(t *testing.T) {
ps := DefaultPartitions()
expect := ps[len(ps)%2]
var regionID string
for id := range expect.Regions() {
regionID = id
break
}
actual, ok := PartitionForRegion(ps, regionID)
if !ok {
t.Fatalf("expect partition to be found")
}
if e, a := expect.ID(), actual.ID(); e != a {
t.Errorf("expect %s partition, got %s", e, a)
}
}
func TestPartitionForRegion_NotFound(t *testing.T) {
ps := DefaultPartitions()
actual, ok := PartitionForRegion(ps, "regionNotExists")
if ok {
t.Errorf("expect no partition to be found, got %v", actual)
}
}

View file

@ -1,66 +0,0 @@
package endpoints_test
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sqs"
)
func ExampleEnumPartitions() {
resolver := endpoints.DefaultResolver()
partitions := resolver.(endpoints.EnumPartitions).Partitions()
for _, p := range partitions {
fmt.Println("Regions for", p.ID())
for id := range p.Regions() {
fmt.Println("*", id)
}
fmt.Println("Services for", p.ID())
for id := range p.Services() {
fmt.Println("*", id)
}
}
}
func ExampleResolverFunc() {
myCustomResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if service == endpoints.S3ServiceID {
return endpoints.ResolvedEndpoint{
URL: "s3.custom.endpoint.com",
SigningRegion: "custom-signing-region",
}, nil
}
return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
}
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String("us-west-2"),
EndpointResolver: endpoints.ResolverFunc(myCustomResolver),
}))
// Create the S3 service client with the shared session. This will
// automatically use the S3 custom endpoint configured in the custom
// endpoint resolver wrapping the default endpoint resolver.
s3Svc := s3.New(sess)
// Operation calls will be made to the custom endpoint.
s3Svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String("myBucket"),
Key: aws.String("myObjectKey"),
})
// Create the SQS service client with the shared session. This will
// fallback to the default endpoint resolver because the customization
// passes any non S3 service endpoint resolve to the default resolver.
sqsSvc := sqs.New(sess)
// Operation calls will be made to the default endpoint for SQS for the
// region configured.
sqsSvc.ReceiveMessage(&sqs.ReceiveMessageInput{
QueueUrl: aws.String("my-queue-url"),
})
}

View file

@ -1,541 +0,0 @@
package endpoints
import (
"encoding/json"
"reflect"
"regexp"
"testing"
)
func TestUnmarshalRegionRegex(t *testing.T) {
var input = []byte(`
{
"regionRegex": "^(us|eu|ap|sa|ca)\\-\\w+\\-\\d+$"
}`)
p := partition{}
err := json.Unmarshal(input, &p)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
expectRegexp, err := regexp.Compile(`^(us|eu|ap|sa|ca)\-\w+\-\d+$`)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := expectRegexp.String(), p.RegionRegex.Regexp.String(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestUnmarshalRegion(t *testing.T) {
var input = []byte(`
{
"aws-global": {
"description": "AWS partition-global endpoint"
},
"us-east-1": {
"description": "US East (N. Virginia)"
}
}`)
rs := regions{}
err := json.Unmarshal(input, &rs)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := 2, len(rs); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
r, ok := rs["aws-global"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := "AWS partition-global endpoint", r.Description; e != a {
t.Errorf("expect %v, got %v", e, a)
}
r, ok = rs["us-east-1"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := "US East (N. Virginia)", r.Description; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestUnmarshalServices(t *testing.T) {
var input = []byte(`
{
"acm": {
"endpoints": {
"us-east-1": {}
}
},
"apigateway": {
"isRegionalized": true,
"endpoints": {
"us-east-1": {},
"us-west-2": {}
}
},
"notRegionalized": {
"isRegionalized": false,
"endpoints": {
"us-east-1": {},
"us-west-2": {}
}
}
}`)
ss := services{}
err := json.Unmarshal(input, &ss)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := 3, len(ss); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
s, ok := ss["acm"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := 1, len(s.Endpoints); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
if e, a := boxedBoolUnset, s.IsRegionalized; e != a {
t.Errorf("expect %v, got %v", e, a)
}
s, ok = ss["apigateway"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := 2, len(s.Endpoints); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
if e, a := boxedTrue, s.IsRegionalized; e != a {
t.Errorf("expect %v, got %v", e, a)
}
s, ok = ss["notRegionalized"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := 2, len(s.Endpoints); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
if e, a := boxedFalse, s.IsRegionalized; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestUnmarshalEndpoints(t *testing.T) {
var inputs = []byte(`
{
"aws-global": {
"hostname": "cloudfront.amazonaws.com",
"protocols": [
"http",
"https"
],
"signatureVersions": [ "v4" ],
"credentialScope": {
"region": "us-east-1",
"service": "serviceName"
},
"sslCommonName": "commonName"
},
"us-east-1": {}
}`)
es := endpoints{}
err := json.Unmarshal(inputs, &es)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := 2, len(es); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
s, ok := es["aws-global"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := "cloudfront.amazonaws.com", s.Hostname; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := []string{"http", "https"}, s.Protocols; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := []string{"v4"}, s.SignatureVersions; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := (credentialScope{"us-east-1", "serviceName"}), s.CredentialScope; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "commonName", s.SSLCommonName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEndpointResolve(t *testing.T) {
defs := []endpoint{
{
Hostname: "{service}.{region}.{dnsSuffix}",
SignatureVersions: []string{"v2"},
SSLCommonName: "sslCommonName",
},
{
Hostname: "other-hostname",
Protocols: []string{"http"},
CredentialScope: credentialScope{
Region: "signing_region",
Service: "signing_service",
},
},
}
e := endpoint{
Hostname: "{service}.{region}.{dnsSuffix}",
Protocols: []string{"http", "https"},
SignatureVersions: []string{"v4"},
SSLCommonName: "new sslCommonName",
}
resolved := e.resolve("service", "region", "dnsSuffix",
defs, Options{},
)
if e, a := "https://service.region.dnsSuffix", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "signing_service", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "signing_region", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "v4", resolved.SigningMethod; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEndpointMergeIn(t *testing.T) {
expected := endpoint{
Hostname: "other hostname",
Protocols: []string{"http"},
SignatureVersions: []string{"v4"},
SSLCommonName: "ssl common name",
CredentialScope: credentialScope{
Region: "region",
Service: "service",
},
}
actual := endpoint{}
actual.mergeIn(endpoint{
Hostname: "other hostname",
Protocols: []string{"http"},
SignatureVersions: []string{"v4"},
SSLCommonName: "ssl common name",
CredentialScope: credentialScope{
Region: "region",
Service: "service",
},
})
if e, a := expected, actual; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
}
var testPartitions = partitions{
partition{
ID: "part-id",
Name: "partitionName",
DNSSuffix: "amazonaws.com",
RegionRegex: regionRegex{
Regexp: func() *regexp.Regexp {
reg, _ := regexp.Compile("^(us|eu|ap|sa|ca)\\-\\w+\\-\\d+$")
return reg
}(),
},
Defaults: endpoint{
Hostname: "{service}.{region}.{dnsSuffix}",
Protocols: []string{"https"},
SignatureVersions: []string{"v4"},
},
Regions: regions{
"us-east-1": region{
Description: "region description",
},
"us-west-2": region{},
},
Services: services{
"s3": service{},
"service1": service{
Defaults: endpoint{
CredentialScope: credentialScope{
Service: "service1",
},
},
Endpoints: endpoints{
"us-east-1": {},
"us-west-2": {
HasDualStack: boxedTrue,
DualStackHostname: "{service}.dualstack.{region}.{dnsSuffix}",
},
},
},
"service2": service{
Defaults: endpoint{
CredentialScope: credentialScope{
Service: "service2",
},
},
},
"httpService": service{
Defaults: endpoint{
Protocols: []string{"http"},
},
},
"globalService": service{
IsRegionalized: boxedFalse,
PartitionEndpoint: "aws-global",
Endpoints: endpoints{
"aws-global": endpoint{
CredentialScope: credentialScope{
Region: "us-east-1",
},
Hostname: "globalService.amazonaws.com",
},
},
},
},
},
}
func TestResolveEndpoint(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-west-2")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service2.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_DisableSSL(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-west-2", DisableSSLOption)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "http://service2.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_UseDualStack(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service1", "us-west-2", UseDualStackOption)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service1.dualstack.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service1", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_HTTPProtocol(t *testing.T) {
resolved, err := testPartitions.EndpointFor("httpService", "us-west-2")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "http://httpService.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "httpService", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}
func TestResolveEndpoint_UnknownService(t *testing.T) {
_, err := testPartitions.EndpointFor("unknownservice", "us-west-2")
if err == nil {
t.Errorf("expect error, got none")
}
_, ok := err.(UnknownServiceError)
if !ok {
t.Errorf("expect error to be UnknownServiceError")
}
}
func TestResolveEndpoint_ResolveUnknownService(t *testing.T) {
resolved, err := testPartitions.EndpointFor("unknown-service", "us-region-1",
ResolveUnknownServiceOption)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://unknown-service.us-region-1.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-region-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "unknown-service", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}
func TestResolveEndpoint_UnknownMatchedRegion(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-region-1")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service2.us-region-1.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-region-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_UnknownRegion(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "unknownregion")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service2.unknownregion.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "unknownregion", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_StrictPartitionUnknownEndpoint(t *testing.T) {
_, err := testPartitions[0].EndpointFor("service2", "unknownregion", StrictMatchingOption)
if err == nil {
t.Errorf("expect error, got none")
}
_, ok := err.(UnknownEndpointError)
if !ok {
t.Errorf("expect error to be UnknownEndpointError")
}
}
func TestResolveEndpoint_StrictPartitionsUnknownEndpoint(t *testing.T) {
_, err := testPartitions.EndpointFor("service2", "us-region-1", StrictMatchingOption)
if err == nil {
t.Errorf("expect error, got none")
}
_, ok := err.(UnknownEndpointError)
if !ok {
t.Errorf("expect error to be UnknownEndpointError")
}
}
func TestResolveEndpoint_NotRegionalized(t *testing.T) {
resolved, err := testPartitions.EndpointFor("globalService", "us-west-2")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://globalService.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-east-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "globalService", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}
func TestResolveEndpoint_AwsGlobal(t *testing.T) {
resolved, err := testPartitions.EndpointFor("globalService", "aws-global")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://globalService.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-east-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "globalService", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}

View file

@ -1,9 +0,0 @@
// +build appengine plan9
package request_test
import (
"errors"
)
var stubConnectionResetError = errors.New("connection reset")

View file

@ -1,11 +0,0 @@
// +build !appengine,!plan9
package request_test
import (
"net"
"os"
"syscall"
)
var stubConnectionResetError = &net.OpError{Err: &os.SyscallError{Syscall: "read", Err: syscall.ECONNRESET}}

View file

@ -1,266 +0,0 @@
package request_test
import (
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestHandlerList(t *testing.T) {
s := ""
r := &request.Request{}
l := request.HandlerList{}
l.PushBack(func(r *request.Request) {
s += "a"
r.Data = s
})
l.Run(r)
if e, a := "a", s; e != a {
t.Errorf("expect %q update got %q", e, a)
}
if e, a := "a", r.Data.(string); e != a {
t.Errorf("expect %q data update got %q", e, a)
}
}
func TestMultipleHandlers(t *testing.T) {
r := &request.Request{}
l := request.HandlerList{}
l.PushBack(func(r *request.Request) { r.Data = nil })
l.PushFront(func(r *request.Request) { r.Data = aws.Bool(true) })
l.Run(r)
if r.Data != nil {
t.Error("Expected handler to execute")
}
}
func TestNamedHandlers(t *testing.T) {
l := request.HandlerList{}
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {}}
named2 := request.NamedHandler{Name: "NotName", Fn: func(r *request.Request) {}}
l.PushBackNamed(named)
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.PushBack(func(r *request.Request) {})
if e, a := 4, l.Len(); e != a {
t.Errorf("expect %d list length, got %d", e, a)
}
l.Remove(named)
if e, a := 2, l.Len(); e != a {
t.Errorf("expect %d list length, got %d", e, a)
}
}
func TestSwapHandlers(t *testing.T) {
firstHandlerCalled := 0
swappedOutHandlerCalled := 0
swappedInHandlerCalled := 0
l := request.HandlerList{}
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {
firstHandlerCalled++
}}
named2 := request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
swappedOutHandlerCalled++
}}
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.PushBackNamed(named)
l.SwapNamed(request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
swappedInHandlerCalled++
}})
l.Run(&request.Request{})
if e, a := 2, firstHandlerCalled; e != a {
t.Errorf("expect first handler to be called %d, was called %d times", e, a)
}
if n := swappedOutHandlerCalled; n != 0 {
t.Errorf("expect swapped out handler to not be called, was called %d times", n)
}
if e, a := 1, swappedInHandlerCalled; e != a {
t.Errorf("expect swapped in handler to be called %d, was called %d times", e, a)
}
}
func TestSetBackNamed_Exists(t *testing.T) {
firstHandlerCalled := 0
swappedOutHandlerCalled := 0
swappedInHandlerCalled := 0
l := request.HandlerList{}
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {
firstHandlerCalled++
}}
named2 := request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
swappedOutHandlerCalled++
}}
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.SetBackNamed(request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
swappedInHandlerCalled++
}})
l.Run(&request.Request{})
if e, a := 1, firstHandlerCalled; e != a {
t.Errorf("expect first handler to be called %d, was called %d times", e, a)
}
if n := swappedOutHandlerCalled; n != 0 {
t.Errorf("expect swapped out handler to not be called, was called %d times", n)
}
if e, a := 1, swappedInHandlerCalled; e != a {
t.Errorf("expect swapped in handler to be called %d, was called %d times", e, a)
}
}
func TestSetBackNamed_NotExists(t *testing.T) {
firstHandlerCalled := 0
secondHandlerCalled := 0
swappedInHandlerCalled := 0
l := request.HandlerList{}
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {
firstHandlerCalled++
}}
named2 := request.NamedHandler{Name: "OtherName", Fn: func(r *request.Request) {
secondHandlerCalled++
}}
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.SetBackNamed(request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
swappedInHandlerCalled++
}})
l.Run(&request.Request{})
if e, a := 1, firstHandlerCalled; e != a {
t.Errorf("expect first handler to be called %d, was called %d times", e, a)
}
if e, a := 1, secondHandlerCalled; e != a {
t.Errorf("expect second handler to be called %d, was called %d times", e, a)
}
if e, a := 1, swappedInHandlerCalled; e != a {
t.Errorf("expect swapped in handler to be called %d, was called %d times", e, a)
}
}
func TestLoggedHandlers(t *testing.T) {
expectedHandlers := []string{"name1", "name2"}
l := request.HandlerList{}
loggedHandlers := []string{}
l.AfterEachFn = request.HandlerListLogItem
cfg := aws.Config{Logger: aws.LoggerFunc(func(args ...interface{}) {
loggedHandlers = append(loggedHandlers, args[2].(string))
})}
named1 := request.NamedHandler{Name: "name1", Fn: func(r *request.Request) {}}
named2 := request.NamedHandler{Name: "name2", Fn: func(r *request.Request) {}}
l.PushBackNamed(named1)
l.PushBackNamed(named2)
l.Run(&request.Request{Config: cfg})
if !reflect.DeepEqual(expectedHandlers, loggedHandlers) {
t.Errorf("expect handlers executed %v to match logged handlers, %v",
expectedHandlers, loggedHandlers)
}
}
func TestStopHandlers(t *testing.T) {
l := request.HandlerList{}
stopAt := 1
l.AfterEachFn = func(item request.HandlerListRunItem) bool {
return item.Index != stopAt
}
called := 0
l.PushBackNamed(request.NamedHandler{Name: "name1", Fn: func(r *request.Request) {
called++
}})
l.PushBackNamed(request.NamedHandler{Name: "name2", Fn: func(r *request.Request) {
called++
}})
l.PushBackNamed(request.NamedHandler{Name: "name3", Fn: func(r *request.Request) {
t.Fatalf("third handler should not be called")
}})
l.Run(&request.Request{})
if e, a := 2, called; e != a {
t.Errorf("expect %d handlers called, got %d", e, a)
}
}
func BenchmarkNewRequest(b *testing.B) {
svc := s3.New(unit.Session)
for i := 0; i < b.N; i++ {
r, _ := svc.GetObjectRequest(nil)
if r == nil {
b.Fatal("r should not be nil")
}
}
}
func BenchmarkHandlersCopy(b *testing.B) {
handlers := request.Handlers{}
handlers.Validate.PushBack(func(r *request.Request) {})
handlers.Validate.PushBack(func(r *request.Request) {})
handlers.Build.PushBack(func(r *request.Request) {})
handlers.Build.PushBack(func(r *request.Request) {})
handlers.Send.PushBack(func(r *request.Request) {})
handlers.Send.PushBack(func(r *request.Request) {})
handlers.Unmarshal.PushBack(func(r *request.Request) {})
handlers.Unmarshal.PushBack(func(r *request.Request) {})
for i := 0; i < b.N; i++ {
h := handlers.Copy()
if e, a := handlers.Validate.Len(), h.Validate.Len(); e != a {
b.Fatalf("expected %d handlers got %d", e, a)
}
}
}
func BenchmarkHandlersPushBack(b *testing.B) {
handlers := request.Handlers{}
for i := 0; i < b.N; i++ {
h := handlers.Copy()
h.Validate.PushBack(func(r *request.Request) {})
h.Validate.PushBack(func(r *request.Request) {})
h.Validate.PushBack(func(r *request.Request) {})
h.Validate.PushBack(func(r *request.Request) {})
}
}
func BenchmarkHandlersPushFront(b *testing.B) {
handlers := request.Handlers{}
for i := 0; i < b.N; i++ {
h := handlers.Copy()
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
}
}
func BenchmarkHandlersClear(b *testing.B) {
handlers := request.Handlers{}
for i := 0; i < b.N; i++ {
h := handlers.Copy()
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Clear()
}
}

View file

@ -1,34 +0,0 @@
package request
import (
"bytes"
"io/ioutil"
"net/http"
"net/url"
"sync"
"testing"
)
func TestRequestCopyRace(t *testing.T) {
origReq := &http.Request{URL: &url.URL{}, Header: http.Header{}}
origReq.Header.Set("Header", "OrigValue")
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
req := copyHTTPRequest(origReq, ioutil.NopCloser(&bytes.Buffer{}))
req.Header.Set("Header", "Value")
go func() {
req2 := copyHTTPRequest(req, ioutil.NopCloser(&bytes.Buffer{}))
req2.Header.Add("Header", "Value2")
}()
_ = req.Header.Get("Header")
wg.Done()
}()
_ = origReq.Header.Get("Header")
}
origReq.Header.Get("Header")
wg.Wait()
}

View file

@ -1,37 +0,0 @@
// +build go1.5
package request_test
import (
"errors"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/mock"
"github.com/stretchr/testify/assert"
)
func TestRequestCancelRetry(t *testing.T) {
c := make(chan struct{})
reqNum := 0
s := mock.NewMockClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.Clear()
s.Handlers.UnmarshalMeta.Clear()
s.Handlers.UnmarshalError.Clear()
s.Handlers.Send.PushFront(func(r *request.Request) {
reqNum++
r.Error = errors.New("net/http: request canceled")
})
out := &testData{}
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
r.HTTPRequest.Cancel = c
close(c)
err := r.Send()
assert.True(t, strings.Contains(err.Error(), "canceled"))
assert.Equal(t, 1, reqNum)
}

View file

@ -1,140 +0,0 @@
package request
import (
"bytes"
"io"
"math/rand"
"sync"
"testing"
"time"
"github.com/aws/aws-sdk-go/internal/sdkio"
"github.com/stretchr/testify/assert"
)
func TestOffsetReaderRead(t *testing.T) {
buf := []byte("testData")
reader := &offsetReader{buf: bytes.NewReader(buf)}
tempBuf := make([]byte, len(buf))
n, err := reader.Read(tempBuf)
assert.Equal(t, n, len(buf))
assert.Nil(t, err)
assert.Equal(t, buf, tempBuf)
}
func TestOffsetReaderSeek(t *testing.T) {
buf := []byte("testData")
reader := newOffsetReader(bytes.NewReader(buf), 0)
orig, err := reader.Seek(0, sdkio.SeekCurrent)
assert.NoError(t, err)
assert.Equal(t, int64(0), orig)
n, err := reader.Seek(0, sdkio.SeekEnd)
assert.NoError(t, err)
assert.Equal(t, int64(len(buf)), n)
n, err = reader.Seek(orig, sdkio.SeekStart)
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
}
func TestOffsetReaderClose(t *testing.T) {
buf := []byte("testData")
reader := &offsetReader{buf: bytes.NewReader(buf)}
err := reader.Close()
assert.Nil(t, err)
tempBuf := make([]byte, len(buf))
n, err := reader.Read(tempBuf)
assert.Equal(t, n, 0)
assert.Equal(t, err, io.EOF)
}
func TestOffsetReaderCloseAndCopy(t *testing.T) {
buf := []byte("testData")
tempBuf := make([]byte, len(buf))
reader := &offsetReader{buf: bytes.NewReader(buf)}
newReader := reader.CloseAndCopy(0)
n, err := reader.Read(tempBuf)
assert.Equal(t, n, 0)
assert.Equal(t, err, io.EOF)
n, err = newReader.Read(tempBuf)
assert.Equal(t, n, len(buf))
assert.Nil(t, err)
assert.Equal(t, buf, tempBuf)
}
func TestOffsetReaderCloseAndCopyOffset(t *testing.T) {
buf := []byte("testData")
tempBuf := make([]byte, len(buf))
reader := &offsetReader{buf: bytes.NewReader(buf)}
newReader := reader.CloseAndCopy(4)
n, err := newReader.Read(tempBuf)
assert.Equal(t, n, len(buf)-4)
assert.Nil(t, err)
expected := []byte{'D', 'a', 't', 'a', 0, 0, 0, 0}
assert.Equal(t, expected, tempBuf)
}
func TestOffsetReaderRace(t *testing.T) {
wg := sync.WaitGroup{}
f := func(reader *offsetReader) {
defer wg.Done()
var err error
buf := make([]byte, 1)
_, err = reader.Read(buf)
for err != io.EOF {
_, err = reader.Read(buf)
}
}
closeFn := func(reader *offsetReader) {
defer wg.Done()
time.Sleep(time.Duration(rand.Intn(20)+1) * time.Millisecond)
reader.Close()
}
for i := 0; i < 50; i++ {
reader := &offsetReader{buf: bytes.NewReader(make([]byte, 1024*1024))}
wg.Add(1)
go f(reader)
wg.Add(1)
go closeFn(reader)
}
wg.Wait()
}
func BenchmarkOffsetReader(b *testing.B) {
bufSize := 1024 * 1024 * 100
buf := make([]byte, bufSize)
reader := &offsetReader{buf: bytes.NewReader(buf)}
tempBuf := make([]byte, 1024)
for i := 0; i < b.N; i++ {
reader.Read(tempBuf)
}
}
func BenchmarkBytesReader(b *testing.B) {
bufSize := 1024 * 1024 * 100
buf := make([]byte, bufSize)
reader := bytes.NewReader(buf)
tempBuf := make([]byte, 1024)
for i := 0; i < b.N; i++ {
reader.Read(tempBuf)
}
}

View file

@ -1,11 +0,0 @@
// +build !go1.6
package request_test
import (
"errors"
"github.com/aws/aws-sdk-go/aws/awserr"
)
var errTimeout = awserr.New("foo", "bar", errors.New("net/http: request canceled Timeout"))

View file

@ -1,51 +0,0 @@
// +build go1.6
package request_test
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
)
// go version 1.4 and 1.5 do not return an error. Version 1.5 will url encode
// the uri while 1.4 will not
func TestRequestInvalidEndpoint(t *testing.T) {
endpoint := "http://localhost:90 "
r := request.New(
aws.Config{},
metadata.ClientInfo{Endpoint: endpoint},
defaults.Handlers(),
client.DefaultRetryer{},
&request.Operation{},
nil,
nil,
)
assert.Error(t, r.Error)
}
type timeoutErr struct {
error
}
var errTimeout = awserr.New("foo", "bar", &timeoutErr{
errors.New("net/http: request canceled"),
})
func (e *timeoutErr) Timeout() bool {
return true
}
func (e *timeoutErr) Temporary() bool {
return true
}

View file

@ -1,24 +0,0 @@
// +build !go1.8
package request
import (
"net/http"
"strings"
"testing"
)
func TestResetBody_WithEmptyBody(t *testing.T) {
r := Request{
HTTPRequest: &http.Request{},
}
reader := strings.NewReader("")
r.Body = reader
r.ResetBody()
if a, e := r.HTTPRequest.Body, (noBody{}); a != e {
t.Errorf("expected request body to be set to reader, got %#v", r.HTTPRequest.Body)
}
}

View file

@ -1,85 +0,0 @@
// +build go1.8
package request_test
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
func TestResetBody_WithEmptyBody(t *testing.T) {
r := request.Request{
HTTPRequest: &http.Request{},
}
reader := strings.NewReader("")
r.Body = reader
r.ResetBody()
if a, e := r.HTTPRequest.Body, http.NoBody; a != e {
t.Errorf("expected request body to be set to reader, got %#v",
r.HTTPRequest.Body)
}
}
func TestRequest_FollowPUTRedirects(t *testing.T) {
const bodySize = 1024
redirectHit := 0
endpointHit := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/redirect-me":
u := *r.URL
u.Path = "/endpoint"
w.Header().Set("Location", u.String())
w.WriteHeader(307)
redirectHit++
case "/endpoint":
b := bytes.Buffer{}
io.Copy(&b, r.Body)
r.Body.Close()
if e, a := bodySize, b.Len(); e != a {
t.Fatalf("expect %d body size, got %d", e, a)
}
endpointHit++
default:
t.Fatalf("unexpected endpoint used, %q", r.URL.String())
}
}))
svc := awstesting.NewClient(&aws.Config{
Region: unit.Session.Config.Region,
DisableSSL: aws.Bool(true),
Endpoint: aws.String(server.URL),
})
req := svc.NewRequest(&request.Operation{
Name: "Operation",
HTTPMethod: "PUT",
HTTPPath: "/redirect-me",
}, &struct{}{}, &struct{}{})
req.SetReaderBody(bytes.NewReader(make([]byte, bodySize)))
err := req.Send()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := 1, redirectHit; e != a {
t.Errorf("expect %d redirect hits, got %d", e, a)
}
if e, a := 1, endpointHit; e != a {
t.Errorf("expect %d endpoint hits, got %d", e, a)
}
}

View file

@ -1,46 +0,0 @@
package request_test
import (
"fmt"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestRequest_SetContext(t *testing.T) {
svc := awstesting.NewClient()
svc.Handlers.Clear()
svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
r := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
r.SetContext(ctx)
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
err := r.Send()
if err == nil {
t.Fatalf("expected error, got none")
}
// Only check against canceled because go 1.6 will not use the context's
// Err().
if e, a := "canceled", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %q to be in %q, but was not", e, a)
}
}
func TestRequest_SetContextPanic(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatalf("expect SetContext to panic, did not")
}
}()
r := &request.Request{}
r.SetContext(nil)
}

View file

@ -1,27 +0,0 @@
package request
import (
"testing"
)
func TestCopy(t *testing.T) {
handlers := Handlers{}
op := &Operation{}
op.HTTPMethod = "Foo"
req := &Request{}
req.Operation = op
req.Handlers = handlers
r := req.copy()
if r == req {
t.Fatal("expect request pointer copy to be different")
}
if r.Operation == req.Operation {
t.Errorf("expect request operation pointer to be different")
}
if e, a := req.Operation.HTTPMethod, r.Operation.HTTPMethod; e != a {
t.Errorf("expect %q http method, got %q", e, a)
}
}

View file

@ -1,648 +0,0 @@
package request_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/route53"
"github.com/aws/aws-sdk-go/service/s3"
)
// Use DynamoDB methods for simplicity
func TestPaginationQueryPage(t *testing.T) {
db := dynamodb.New(unit.Session)
tokens, pages, numPages, gotToEnd := []map[string]*dynamodb.AttributeValue{}, []map[string]*dynamodb.AttributeValue{}, 0, false
reqNum := 0
resps := []*dynamodb.QueryOutput{
{
LastEvaluatedKey: map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key1")}},
Count: aws.Int64(1),
Items: []map[string]*dynamodb.AttributeValue{
{
"key": {S: aws.String("key1")},
},
},
},
{
LastEvaluatedKey: map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key2")}},
Count: aws.Int64(1),
Items: []map[string]*dynamodb.AttributeValue{
{
"key": {S: aws.String("key2")},
},
},
},
{
LastEvaluatedKey: map[string]*dynamodb.AttributeValue{},
Count: aws.Int64(1),
Items: []map[string]*dynamodb.AttributeValue{
{
"key": {S: aws.String("key3")},
},
},
},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *request.Request) {
in := r.Params.(*dynamodb.QueryInput)
if in == nil {
tokens = append(tokens, nil)
} else if len(in.ExclusiveStartKey) != 0 {
tokens = append(tokens, in.ExclusiveStartKey)
}
})
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.QueryInput{
Limit: aws.Int64(2),
TableName: aws.String("tablename"),
}
err := db.QueryPages(params, func(p *dynamodb.QueryOutput, last bool) bool {
numPages++
for _, item := range p.Items {
pages = append(pages, item)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Nil(t, err)
assert.Equal(t,
[]map[string]*dynamodb.AttributeValue{
{"key": {S: aws.String("key1")}},
{"key": {S: aws.String("key2")}},
}, tokens)
assert.Equal(t,
[]map[string]*dynamodb.AttributeValue{
{"key": {S: aws.String("key1")}},
{"key": {S: aws.String("key2")}},
{"key": {S: aws.String("key3")}},
}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, params.ExclusiveStartKey)
}
// Use DynamoDB methods for simplicity
func TestPagination(t *testing.T) {
db := dynamodb.New(unit.Session)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *request.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
} else if in.ExclusiveStartTableName != nil {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool {
numPages++
for _, t := range p.TableNames {
pages = append(pages, *t)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, []string{"Table2", "Table4"}, tokens)
assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, err)
assert.Nil(t, params.ExclusiveStartTableName)
}
// Use DynamoDB methods for simplicity
func TestPaginationEachPage(t *testing.T) {
db := dynamodb.New(unit.Session)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *request.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
} else if in.ExclusiveStartTableName != nil {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
req, _ := db.ListTablesRequest(params)
err := req.EachPage(func(p interface{}, last bool) bool {
numPages++
for _, t := range p.(*dynamodb.ListTablesOutput).TableNames {
pages = append(pages, *t)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, []string{"Table2", "Table4"}, tokens)
assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, err)
}
// Use DynamoDB methods for simplicity
func TestPaginationEarlyExit(t *testing.T) {
db := dynamodb.New(unit.Session)
numPages, gotToEnd := 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool {
numPages++
if numPages == 2 {
return false
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, 2, numPages)
assert.False(t, gotToEnd)
assert.Nil(t, err)
}
func TestSkipPagination(t *testing.T) {
client := s3.New(unit.Session)
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = &s3.HeadBucketOutput{}
})
req, _ := client.HeadBucketRequest(&s3.HeadBucketInput{Bucket: aws.String("bucket")})
numPages, gotToEnd := 0, false
req.EachPage(func(p interface{}, last bool) bool {
numPages++
if last {
gotToEnd = true
}
return true
})
assert.Equal(t, 1, numPages)
assert.True(t, gotToEnd)
}
// Use S3 for simplicity
func TestPaginationTruncation(t *testing.T) {
client := s3.New(unit.Session)
reqNum := 0
resps := []*s3.ListObjectsOutput{
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key1")}}},
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key2")}}},
{IsTruncated: aws.Bool(false), Contents: []*s3.Object{{Key: aws.String("Key3")}}},
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key4")}}},
}
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &s3.ListObjectsInput{Bucket: aws.String("bucket")}
results := []string{}
err := client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool {
results = append(results, *p.Contents[0].Key)
return true
})
assert.Equal(t, []string{"Key1", "Key2", "Key3"}, results)
assert.Nil(t, err)
// Try again without truncation token at all
reqNum = 0
resps[1].IsTruncated = nil
resps[2].IsTruncated = aws.Bool(true)
results = []string{}
err = client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool {
results = append(results, *p.Contents[0].Key)
return true
})
assert.Equal(t, []string{"Key1", "Key2"}, results)
assert.Nil(t, err)
}
func TestPaginationNilToken(t *testing.T) {
client := route53.New(unit.Session)
reqNum := 0
resps := []*route53.ListResourceRecordSetsOutput{
{
ResourceRecordSets: []*route53.ResourceRecordSet{
{Name: aws.String("first.example.com.")},
},
IsTruncated: aws.Bool(true),
NextRecordName: aws.String("second.example.com."),
NextRecordType: aws.String("MX"),
NextRecordIdentifier: aws.String("second"),
MaxItems: aws.String("1"),
},
{
ResourceRecordSets: []*route53.ResourceRecordSet{
{Name: aws.String("second.example.com.")},
},
IsTruncated: aws.Bool(true),
NextRecordName: aws.String("third.example.com."),
NextRecordType: aws.String("MX"),
MaxItems: aws.String("1"),
},
{
ResourceRecordSets: []*route53.ResourceRecordSet{
{Name: aws.String("third.example.com.")},
},
IsTruncated: aws.Bool(false),
MaxItems: aws.String("1"),
},
}
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
idents := []string{}
client.Handlers.Build.PushBack(func(r *request.Request) {
p := r.Params.(*route53.ListResourceRecordSetsInput)
idents = append(idents, aws.StringValue(p.StartRecordIdentifier))
})
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &route53.ListResourceRecordSetsInput{
HostedZoneId: aws.String("id-zone"),
}
results := []string{}
err := client.ListResourceRecordSetsPages(params, func(p *route53.ListResourceRecordSetsOutput, last bool) bool {
results = append(results, *p.ResourceRecordSets[0].Name)
return true
})
assert.NoError(t, err)
assert.Equal(t, []string{"", "second", ""}, idents)
assert.Equal(t, []string{"first.example.com.", "second.example.com.", "third.example.com."}, results)
}
func TestPaginationNilInput(t *testing.T) {
// Code generation doesn't have a great way to verify the code is correct
// other than being run via unit tests in the SDK. This should be fixed
// So code generation can be validated independently.
client := s3.New(unit.Session)
client.Handlers.Validate.Clear()
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = &s3.ListObjectsOutput{}
})
gotToEnd := false
numPages := 0
err := client.ListObjectsPages(nil, func(p *s3.ListObjectsOutput, last bool) bool {
numPages++
if last {
gotToEnd = true
}
return true
})
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
if e, a := 1, numPages; e != a {
t.Errorf("expect %d number pages but got %d", e, a)
}
if !gotToEnd {
t.Errorf("expect to of gotten to end, did not")
}
}
func TestPaginationWithContextNilInput(t *testing.T) {
// Code generation doesn't have a great way to verify the code is correct
// other than being run via unit tests in the SDK. This should be fixed
// So code generation can be validated independently.
client := s3.New(unit.Session)
client.Handlers.Validate.Clear()
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = &s3.ListObjectsOutput{}
})
gotToEnd := false
numPages := 0
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
err := client.ListObjectsPagesWithContext(ctx, nil, func(p *s3.ListObjectsOutput, last bool) bool {
numPages++
if last {
gotToEnd = true
}
return true
})
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
if e, a := 1, numPages; e != a {
t.Errorf("expect %d number pages but got %d", e, a)
}
if !gotToEnd {
t.Errorf("expect to of gotten to end, did not")
}
}
func TestPagination_Standalone(t *testing.T) {
type testPageInput struct {
NextToken *string
}
type testPageOutput struct {
Value *string
NextToken *string
}
type testCase struct {
Value, PrevToken, NextToken *string
}
type testCaseList struct {
StopOnSameToken bool
Cases []testCase
}
cases := []testCaseList{
{
Cases: []testCase{
testCase{aws.String("FirstValue"), aws.String("InitalToken"), aws.String("FirstToken")},
testCase{aws.String("SecondValue"), aws.String("FirstToken"), aws.String("SecondToken")},
testCase{aws.String("ThirdValue"), aws.String("SecondToken"), nil},
},
StopOnSameToken: false,
},
{
Cases: []testCase{
testCase{aws.String("FirstValue"), aws.String("InitalToken"), aws.String("FirstToken")},
testCase{aws.String("SecondValue"), aws.String("FirstToken"), aws.String("SecondToken")},
testCase{aws.String("ThirdValue"), aws.String("SecondToken"), aws.String("")},
},
StopOnSameToken: false,
},
{
Cases: []testCase{
testCase{aws.String("FirstValue"), aws.String("InitalToken"), aws.String("FirstToken")},
testCase{aws.String("SecondValue"), aws.String("FirstToken"), aws.String("SecondToken")},
testCase{nil, aws.String("SecondToken"), aws.String("SecondToken")},
},
StopOnSameToken: true,
},
{
Cases: []testCase{
testCase{aws.String("FirstValue"), aws.String("InitalToken"), aws.String("FirstToken")},
testCase{aws.String("SecondValue"), aws.String("FirstToken"), aws.String("SecondToken")},
testCase{aws.String("SecondValue"), aws.String("SecondToken"), aws.String("SecondToken")},
},
StopOnSameToken: true,
},
}
for _, testcase := range cases {
c := testcase.Cases
input := testPageInput{
NextToken: c[0].PrevToken,
}
svc := awstesting.NewClient()
i := 0
p := request.Pagination{
EndPageOnSameToken: testcase.StopOnSameToken,
NewRequest: func() (*request.Request, error) {
r := svc.NewRequest(
&request.Operation{
Name: "Operation",
Paginator: &request.Paginator{
InputTokens: []string{"NextToken"},
OutputTokens: []string{"NextToken"},
},
},
&input, &testPageOutput{},
)
// Setup handlers for testing
r.Handlers.Clear()
r.Handlers.Build.PushBack(func(req *request.Request) {
if e, a := len(c), i+1; a > e {
t.Fatalf("expect no more than %d requests, got %d", e, a)
}
in := req.Params.(*testPageInput)
if e, a := aws.StringValue(c[i].PrevToken), aws.StringValue(in.NextToken); e != a {
t.Errorf("%d, expect NextToken input %q, got %q", i, e, a)
}
})
r.Handlers.Unmarshal.PushBack(func(req *request.Request) {
out := &testPageOutput{
Value: c[i].Value,
}
if c[i].NextToken != nil {
next := *c[i].NextToken
out.NextToken = aws.String(next)
}
req.Data = out
})
return r, nil
},
}
for p.Next() {
data := p.Page().(*testPageOutput)
if e, a := aws.StringValue(c[i].Value), aws.StringValue(data.Value); e != a {
t.Errorf("%d, expect Value to be %q, got %q", i, e, a)
}
if e, a := aws.StringValue(c[i].NextToken), aws.StringValue(data.NextToken); e != a {
t.Errorf("%d, expect NextToken to be %q, got %q", i, e, a)
}
i++
}
if e, a := len(c), i; e != a {
t.Errorf("expected to process %d pages, did %d", e, a)
}
if err := p.Err(); err != nil {
t.Fatalf("%d, expected no error, got %v", i, err)
}
}
}
// Benchmarks
var benchResps = []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE")}},
}
var benchDb = func() *dynamodb.DynamoDB {
db := dynamodb.New(unit.Session)
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
return db
}
func BenchmarkCodegenIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
input := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
iter := func(fn func(*dynamodb.ListTablesOutput, bool) bool) error {
page, _ := db.ListTablesRequest(input)
for ; page != nil; page = page.NextPage() {
page.Send()
out := page.Data.(*dynamodb.ListTablesOutput)
if result := fn(out, !page.HasNextPage()); page.Error != nil || !result {
return page.Error
}
}
return nil
}
for i := 0; i < b.N; i++ {
reqNum = 0
iter(func(p *dynamodb.ListTablesOutput, last bool) bool {
return true
})
}
}
func BenchmarkEachPageIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
input := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
for i := 0; i < b.N; i++ {
reqNum = 0
req, _ := db.ListTablesRequest(input)
req.EachPage(func(p interface{}, last bool) bool {
return true
})
}
}

View file

@ -1,107 +0,0 @@
package request
import (
"bytes"
"io"
"net/http"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
)
func TestResetBody_WithBodyContents(t *testing.T) {
r := Request{
HTTPRequest: &http.Request{},
}
reader := strings.NewReader("abc")
r.Body = reader
r.ResetBody()
if v, ok := r.HTTPRequest.Body.(*offsetReader); !ok || v == nil {
t.Errorf("expected request body to be set to reader, got %#v",
r.HTTPRequest.Body)
}
}
type mockReader struct{}
func (mockReader) Read([]byte) (int, error) {
return 0, io.EOF
}
func TestResetBody_ExcludeEmptyUnseekableBodyByMethod(t *testing.T) {
cases := []struct {
Method string
Body io.ReadSeeker
IsNoBody bool
}{
{
Method: "GET",
IsNoBody: true,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "HEAD",
IsNoBody: true,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "DELETE",
IsNoBody: true,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "PUT",
IsNoBody: false,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "PATCH",
IsNoBody: false,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "POST",
IsNoBody: false,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "GET",
IsNoBody: false,
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc"))),
},
{
Method: "GET",
IsNoBody: true,
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
},
{
Method: "POST",
IsNoBody: false,
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc"))),
},
{
Method: "POST",
IsNoBody: true,
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
},
}
for i, c := range cases {
r := Request{
HTTPRequest: &http.Request{},
Operation: &Operation{
HTTPMethod: c.Method,
},
}
r.SetReaderBody(c.Body)
if a, e := r.HTTPRequest.Body == NoBody, c.IsNoBody; a != e {
t.Errorf("%d, expect body to be set to noBody(%t), but was %t", i, e, a)
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,62 +0,0 @@
package request
import (
"errors"
"fmt"
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
)
func TestRequestThrottling(t *testing.T) {
req := Request{}
req.Error = awserr.New("Throttling", "", nil)
if e, a := true, req.IsErrorThrottle(); e != a {
t.Errorf("expect %t to be throttled, was %t", e, a)
}
}
type mockTempError bool
func (e mockTempError) Error() string {
return fmt.Sprintf("mock temporary error: %t", e.Temporary())
}
func (e mockTempError) Temporary() bool {
return bool(e)
}
func TestIsErrorRetryable(t *testing.T) {
cases := []struct {
Err error
IsTemp bool
}{
{
Err: awserr.New(ErrCodeSerialization, "temporary error", mockTempError(true)),
IsTemp: true,
},
{
Err: awserr.New(ErrCodeSerialization, "temporary error", mockTempError(false)),
IsTemp: false,
},
{
Err: awserr.New(ErrCodeSerialization, "some error", errors.New("blah")),
IsTemp: false,
},
{
Err: awserr.New("SomeError", "some error", nil),
IsTemp: false,
},
{
Err: awserr.New("RequestError", "some error", nil),
IsTemp: true,
},
}
for i, c := range cases {
retryable := IsErrorRetryable(c.Err)
if e, a := c.IsTemp, retryable; e != a {
t.Errorf("%d, expect %t temporary error, got %t", i, e, a)
}
}
}

View file

@ -1,76 +0,0 @@
package request_test
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
)
func BenchmarkTimeoutReadCloser(b *testing.B) {
resp := `
{
"Bar": "qux"
}
`
handlers := request.Handlers{}
handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewBuffer([]byte(resp))),
}
})
handlers.Sign.PushBackNamed(v4.SignRequestHandler)
handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
op := &request.Operation{
Name: "op",
HTTPMethod: "POST",
HTTPPath: "/",
}
meta := metadata.ClientInfo{
ServiceName: "fooService",
SigningName: "foo",
SigningRegion: "foo",
Endpoint: "localhost",
APIVersion: "2001-01-01",
JSONVersion: "1.1",
TargetPrefix: "Foo",
}
req := request.New(
*unit.Session.Config,
meta,
handlers,
client.DefaultRetryer{NumMaxRetries: 5},
op,
&struct {
Foo *string
}{},
&struct {
Bar *string
}{},
)
req.ApplyOptions(request.WithResponseReadTimeout(15 * time.Second))
for i := 0; i < b.N; i++ {
err := req.Send()
if err != nil {
b.Errorf("Expected no error, but received %v", err)
}
}
}

View file

@ -1,118 +0,0 @@
package request
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
type testReader struct {
duration time.Duration
count int
}
func (r *testReader) Read(b []byte) (int, error) {
if r.count > 0 {
r.count--
return len(b), nil
}
time.Sleep(r.duration)
return 0, io.EOF
}
func (r *testReader) Close() error {
return nil
}
func TestTimeoutReadCloser(t *testing.T) {
reader := timeoutReadCloser{
reader: &testReader{
duration: time.Second,
count: 5,
},
duration: time.Millisecond,
}
b := make([]byte, 100)
_, err := reader.Read(b)
if err != nil {
t.Log(err)
}
}
func TestTimeoutReadCloserSameDuration(t *testing.T) {
reader := timeoutReadCloser{
reader: &testReader{
duration: time.Millisecond,
count: 5,
},
duration: time.Millisecond,
}
b := make([]byte, 100)
_, err := reader.Read(b)
if err != nil {
t.Log(err)
}
}
func TestWithResponseReadTimeout(t *testing.T) {
r := Request{
HTTPResponse: &http.Response{
Body: ioutil.NopCloser(bytes.NewReader(nil)),
},
}
r.ApplyOptions(WithResponseReadTimeout(time.Second))
err := r.Send()
if err != nil {
t.Error(err)
}
v, ok := r.HTTPResponse.Body.(*timeoutReadCloser)
if !ok {
t.Error("Expected the body to be a timeoutReadCloser")
}
if v.duration != time.Second {
t.Errorf("Expected %v, but receive %v\n", time.Second, v.duration)
}
}
func TestAdaptToResponseTimeout(t *testing.T) {
testCases := []struct {
childErr error
r Request
expectedRootCode string
}{
{
childErr: awserr.New(ErrCodeResponseTimeout, "timeout!", nil),
r: Request{
Error: awserr.New("ErrTest", "FooBar", awserr.New(ErrCodeResponseTimeout, "timeout!", nil)),
},
expectedRootCode: ErrCodeResponseTimeout,
},
{
childErr: awserr.New(ErrCodeResponseTimeout+"1", "timeout!", nil),
r: Request{
Error: awserr.New("ErrTest", "FooBar", awserr.New(ErrCodeResponseTimeout+"1", "timeout!", nil)),
},
expectedRootCode: "ErrTest",
},
{
r: Request{
Error: awserr.New("ErrTest", "FooBar", nil),
},
expectedRootCode: "ErrTest",
},
}
for i, c := range testCases {
adaptToResponseTimeoutError(&c.r)
if aerr, ok := c.r.Error.(awserr.Error); !ok {
t.Errorf("Case %d: Expected 'awserr', but received %v", i+1, c.r.Error)
} else if aerr.Code() != c.expectedRootCode {
t.Errorf("Case %d: Expected %q, but received %s", i+1, c.expectedRootCode, aerr.Code())
}
}
}

View file

@ -1,654 +0,0 @@
package request_test
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
type mockClient struct {
*client.Client
}
type MockInput struct{}
type MockOutput struct {
States []*MockState
}
type MockState struct {
State *string
}
func (c *mockClient) MockRequest(input *MockInput) (*request.Request, *MockOutput) {
op := &request.Operation{
Name: "Mock",
HTTPMethod: "POST",
HTTPPath: "/",
}
if input == nil {
input = &MockInput{}
}
output := &MockOutput{}
req := c.NewRequest(op, input, output)
req.Data = output
return req, output
}
func BuildNewMockRequest(c *mockClient, in *MockInput) func([]request.Option) (*request.Request, error) {
return func(opts []request.Option) (*request.Request, error) {
req, _ := c.MockRequest(in)
req.ApplyOptions(opts...)
return req, nil
}
}
func TestWaiterPathAll(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 2
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
},
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
},
},
}
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
w := request.Waiter{
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(0),
SleepWithContext: aws.SleepWithContext,
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.PathAllWaiterMatch,
Argument: "States[].State",
Expected: "running",
},
},
NewRequest: BuildNewMockRequest(svc, &MockInput{}),
}
err := w.WaitWithContext(aws.BackgroundContext())
assert.NoError(t, err)
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
}
func TestWaiterPath(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 2
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
},
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
},
},
}
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
w := request.Waiter{
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(0),
SleepWithContext: aws.SleepWithContext,
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.PathWaiterMatch,
Argument: "States[].State",
Expected: "running",
},
},
NewRequest: BuildNewMockRequest(svc, &MockInput{}),
}
err := w.WaitWithContext(aws.BackgroundContext())
assert.NoError(t, err)
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
}
func TestWaiterFailure(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 2
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
},
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("stopping")},
},
},
}
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
w := request.Waiter{
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(0),
SleepWithContext: aws.SleepWithContext,
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.PathAllWaiterMatch,
Argument: "States[].State",
Expected: "running",
},
{
State: request.FailureWaiterState,
Matcher: request.PathAnyWaiterMatch,
Argument: "States[].State",
Expected: "stopping",
},
},
NewRequest: BuildNewMockRequest(svc, &MockInput{}),
}
err := w.WaitWithContext(aws.BackgroundContext()).(awserr.Error)
assert.Error(t, err)
assert.Equal(t, request.WaiterResourceNotReadyErrorCode, err.Code())
assert.Equal(t, "failed waiting for successful resource state", err.Message())
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
}
func TestWaiterError(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.UnmarshalError.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 1, error case retry
},
{ // Request 2, error case failure
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
},
},
}
reqErrs := make([]error, len(resps))
reqErrs[1] = awserr.New("MockException", "mock exception message", nil)
reqErrs[2] = awserr.New("FailureException", "mock failure exception message", nil)
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Send.PushBack(func(r *request.Request) {
code := 200
if reqNum == 1 {
code = 400
}
r.HTTPResponse = &http.Response{
StatusCode: code,
Status: http.StatusText(code),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
svc.Handlers.UnmarshalMeta.PushBack(func(r *request.Request) {
// If there was an error unmarshal error will be called instead of unmarshal
// need to increment count here also
if err := reqErrs[reqNum]; err != nil {
r.Error = err
reqNum++
}
})
w := request.Waiter{
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(0),
SleepWithContext: aws.SleepWithContext,
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.PathAllWaiterMatch,
Argument: "States[].State",
Expected: "running",
},
{
State: request.RetryWaiterState,
Matcher: request.ErrorWaiterMatch,
Argument: "",
Expected: "MockException",
},
{
State: request.FailureWaiterState,
Matcher: request.ErrorWaiterMatch,
Argument: "",
Expected: "FailureException",
},
},
NewRequest: BuildNewMockRequest(svc, &MockInput{}),
}
err := w.WaitWithContext(aws.BackgroundContext())
if err == nil {
t.Fatalf("expected error, but did not get one")
}
aerr := err.(awserr.Error)
if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q error code, got %q", e, a)
}
if e, a := 3, numBuiltReq; e != a {
t.Errorf("expect %d built requests got %d", e, a)
}
if e, a := 3, reqNum; e != a {
t.Errorf("expect %d reqNum got %d", e, a)
}
}
func TestWaiterStatus(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
reqNum++
})
svc.Handlers.Send.PushBack(func(r *request.Request) {
code := 200
if reqNum == 3 {
code = 404
r.Error = awserr.New("NotFound", "Not Found", nil)
}
r.HTTPResponse = &http.Response{
StatusCode: code,
Status: http.StatusText(code),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
})
w := request.Waiter{
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(0),
SleepWithContext: aws.SleepWithContext,
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Argument: "",
Expected: 404,
},
},
NewRequest: BuildNewMockRequest(svc, &MockInput{}),
}
err := w.WaitWithContext(aws.BackgroundContext())
assert.NoError(t, err)
assert.Equal(t, 3, reqNum)
}
func TestWaiter_ApplyOptions(t *testing.T) {
w := request.Waiter{}
logger := aws.NewDefaultLogger()
w.ApplyOptions(
request.WithWaiterLogger(logger),
request.WithWaiterRequestOptions(request.WithLogLevel(aws.LogDebug)),
request.WithWaiterMaxAttempts(2),
request.WithWaiterDelay(request.ConstantWaiterDelay(5*time.Second)),
)
if e, a := logger, w.Logger; e != a {
t.Errorf("expect logger to be set, and match, was not, %v, %v", e, a)
}
if len(w.RequestOptions) != 1 {
t.Fatalf("expect request options to be set to only a single option, %v", w.RequestOptions)
}
r := request.Request{}
r.ApplyOptions(w.RequestOptions...)
if e, a := aws.LogDebug, r.Config.LogLevel.Value(); e != a {
t.Errorf("expect %v loglevel got %v", e, a)
}
if e, a := 2, w.MaxAttempts; e != a {
t.Errorf("expect %d retryer max attempts, got %d", e, a)
}
if e, a := 5*time.Second, w.Delay(0); e != a {
t.Errorf("expect %d retryer delay, got %d", e, a)
}
}
func TestWaiter_WithContextCanceled(t *testing.T) {
c := awstesting.NewClient()
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
reqCount := 0
w := request.Waiter{
Name: "TestWaiter",
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
SleepWithContext: aws.SleepWithContext,
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 200,
},
},
Logger: aws.NewDefaultLogger(),
NewRequest: func(opts []request.Option) (*request.Request, error) {
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound}
req.Handlers.Clear()
req.Data = struct{}{}
req.Handlers.Send.PushBack(func(r *request.Request) {
if reqCount == 1 {
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
}
reqCount++
})
return req, nil
},
}
w.SleepWithContext = func(c aws.Context, delay time.Duration) error {
context := c.(*awstesting.FakeContext)
select {
case <-context.DoneCh:
return context.Err()
default:
return nil
}
}
err := w.WaitWithContext(ctx)
if err == nil {
t.Fatalf("expect waiter to be canceled.")
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q error code, got %q", e, a)
}
if e, a := 2, reqCount; e != a {
t.Errorf("expect %d requests, got %d", e, a)
}
}
func TestWaiter_WithContext(t *testing.T) {
c := awstesting.NewClient()
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
reqCount := 0
statuses := []int{http.StatusNotFound, http.StatusOK}
w := request.Waiter{
Name: "TestWaiter",
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
SleepWithContext: aws.SleepWithContext,
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 200,
},
},
Logger: aws.NewDefaultLogger(),
NewRequest: func(opts []request.Option) (*request.Request, error) {
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.HTTPResponse = &http.Response{StatusCode: statuses[reqCount]}
req.Handlers.Clear()
req.Data = struct{}{}
req.Handlers.Send.PushBack(func(r *request.Request) {
if reqCount == 1 {
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
}
reqCount++
})
return req, nil
},
}
err := w.WaitWithContext(ctx)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := 2, reqCount; e != a {
t.Errorf("expect %d requests, got %d", e, a)
}
}
func TestWaiter_AttemptsExpires(t *testing.T) {
c := awstesting.NewClient()
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
reqCount := 0
w := request.Waiter{
Name: "TestWaiter",
MaxAttempts: 2,
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
SleepWithContext: aws.SleepWithContext,
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 200,
},
},
Logger: aws.NewDefaultLogger(),
NewRequest: func(opts []request.Option) (*request.Request, error) {
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound}
req.Handlers.Clear()
req.Data = struct{}{}
req.Handlers.Send.PushBack(func(r *request.Request) {
reqCount++
})
return req, nil
},
}
err := w.WaitWithContext(ctx)
if err == nil {
t.Fatalf("expect error did not get one")
}
aerr := err.(awserr.Error)
if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q error code, got %q", e, a)
}
if e, a := 2, reqCount; e != a {
t.Errorf("expect %d requests, got %d", e, a)
}
}
func TestWaiterNilInput(t *testing.T) {
// Code generation doesn't have a great way to verify the code is correct
// other than being run via unit tests in the SDK. This should be fixed
// So code generation can be validated independently.
client := s3.New(unit.Session)
client.Handlers.Validate.Clear()
client.Handlers.Send.Clear() // mock sending
client.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
}
})
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Config.SleepDelay = func(dur time.Duration) {}
// Ensure waiters do not panic on nil input. It doesn't make sense to
// call a waiter without an input, Validation will
err := client.WaitUntilBucketExists(nil)
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
}
func TestWaiterWithContextNilInput(t *testing.T) {
// Code generation doesn't have a great way to verify the code is correct
// other than being run via unit tests in the SDK. This should be fixed
// So code generation can be validated independently.
client := s3.New(unit.Session)
client.Handlers.Validate.Clear()
client.Handlers.Send.Clear() // mock sending
client.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
}
})
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
// Ensure waiters do not panic on nil input
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
err := client.WaitUntilBucketExistsWithContext(ctx, nil,
request.WithWaiterDelay(request.ConstantWaiterDelay(0)),
request.WithWaiterMaxAttempts(1),
)
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
}

View file

@ -1,243 +0,0 @@
package session
import (
"bytes"
"fmt"
"net"
"net/http"
"os"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/awstesting"
)
var TLSBundleCertFile string
var TLSBundleKeyFile string
var TLSBundleCAFile string
func TestMain(m *testing.M) {
var err error
TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile, err = awstesting.CreateTLSBundleFiles()
if err != nil {
panic(err)
}
fmt.Println("TestMain", TLSBundleCertFile, TLSBundleKeyFile)
code := m.Run()
err = awstesting.CleanupTLSBundleFiles(TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile)
if err != nil {
panic(err)
}
os.Exit(code)
}
func TestNewSession_WithCustomCABundle_Env(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
os.Setenv("AWS_CA_BUNDLE", TLSBundleCAFile)
s, err := NewSession(&aws.Config{
HTTPClient: &http.Client{},
Endpoint: aws.String(endpoint),
Region: aws.String("mock-region"),
Credentials: credentials.AnonymousCredentials,
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if s == nil {
t.Fatalf("expect session to be created, got none")
}
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
resp, err := s.Config.HTTPClient.Do(req)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := http.StatusOK, resp.StatusCode; e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
}
func TestNewSession_WithCustomCABundle_EnvNotExists(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_CA_BUNDLE", "file-not-exists")
s, err := NewSession()
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := "LoadCustomCABundleError", err.(awserr.Error).Code(); e != a {
t.Errorf("expect %s error code, got %s", e, a)
}
if s != nil {
t.Errorf("expect nil session, got %v", s)
}
}
func TestNewSession_WithCustomCABundle_Option(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
s, err := NewSessionWithOptions(Options{
Config: aws.Config{
HTTPClient: &http.Client{},
Endpoint: aws.String(endpoint),
Region: aws.String("mock-region"),
Credentials: credentials.AnonymousCredentials,
},
CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if s == nil {
t.Fatalf("expect session to be created, got none")
}
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
resp, err := s.Config.HTTPClient.Do(req)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := http.StatusOK, resp.StatusCode; e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
}
func TestNewSession_WithCustomCABundle_OptionPriority(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
os.Setenv("AWS_CA_BUNDLE", "file-not-exists")
s, err := NewSessionWithOptions(Options{
Config: aws.Config{
HTTPClient: &http.Client{},
Endpoint: aws.String(endpoint),
Region: aws.String("mock-region"),
Credentials: credentials.AnonymousCredentials,
},
CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if s == nil {
t.Fatalf("expect session to be created, got none")
}
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
resp, err := s.Config.HTTPClient.Do(req)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := http.StatusOK, resp.StatusCode; e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
}
type mockRoundTripper struct{}
func (m *mockRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
return nil, nil
}
func TestNewSession_WithCustomCABundle_UnsupportedTransport(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
s, err := NewSessionWithOptions(Options{
Config: aws.Config{
HTTPClient: &http.Client{
Transport: &mockRoundTripper{},
},
},
CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
})
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := "LoadCustomCABundleError", err.(awserr.Error).Code(); e != a {
t.Errorf("expect %s error code, got %s", e, a)
}
if s != nil {
t.Errorf("expect nil session, got %v", s)
}
aerrMsg := err.(awserr.Error).Message()
if e, a := "transport unsupported type", aerrMsg; !strings.Contains(a, e) {
t.Errorf("expect %s to be in %s", e, a)
}
}
func TestNewSession_WithCustomCABundle_TransportSet(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
s, err := NewSessionWithOptions(Options{
Config: aws.Config{
Endpoint: aws.String(endpoint),
Region: aws.String("mock-region"),
Credentials: credentials.AnonymousCredentials,
HTTPClient: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).Dial,
TLSHandshakeTimeout: 2 * time.Second,
},
},
},
CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if s == nil {
t.Fatalf("expect session to be created, got none")
}
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
resp, err := s.Config.HTTPClient.Do(req)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := http.StatusOK, resp.StatusCode; e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
}

View file

@ -1,306 +0,0 @@
package session
import (
"os"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
func TestLoadEnvConfig_Creds(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
cases := []struct {
Env map[string]string
Val credentials.Value
}{
{
Env: map[string]string{
"AWS_ACCESS_KEY": "AKID",
},
Val: credentials.Value{},
},
{
Env: map[string]string{
"AWS_ACCESS_KEY_ID": "AKID",
},
Val: credentials.Value{},
},
{
Env: map[string]string{
"AWS_SECRET_KEY": "SECRET",
},
Val: credentials.Value{},
},
{
Env: map[string]string{
"AWS_SECRET_ACCESS_KEY": "SECRET",
},
Val: credentials.Value{},
},
{
Env: map[string]string{
"AWS_ACCESS_KEY_ID": "AKID",
"AWS_SECRET_ACCESS_KEY": "SECRET",
},
Val: credentials.Value{
AccessKeyID: "AKID", SecretAccessKey: "SECRET",
ProviderName: "EnvConfigCredentials",
},
},
{
Env: map[string]string{
"AWS_ACCESS_KEY": "AKID",
"AWS_SECRET_KEY": "SECRET",
},
Val: credentials.Value{
AccessKeyID: "AKID", SecretAccessKey: "SECRET",
ProviderName: "EnvConfigCredentials",
},
},
{
Env: map[string]string{
"AWS_ACCESS_KEY": "AKID",
"AWS_SECRET_KEY": "SECRET",
"AWS_SESSION_TOKEN": "TOKEN",
},
Val: credentials.Value{
AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "TOKEN",
ProviderName: "EnvConfigCredentials",
},
},
}
for _, c := range cases {
os.Clearenv()
for k, v := range c.Env {
os.Setenv(k, v)
}
cfg := loadEnvConfig()
if !reflect.DeepEqual(c.Val, cfg.Creds) {
t.Errorf("expect credentials to match.\n%s",
awstesting.SprintExpectActual(c.Val, cfg.Creds))
}
}
}
func TestLoadEnvConfig(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
cases := []struct {
Env map[string]string
UseSharedConfigCall bool
Config envConfig
}{
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_PROFILE": "profile",
},
Config: envConfig{
Region: "region", Profile: "profile",
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_DEFAULT_REGION": "default_region",
"AWS_PROFILE": "profile",
"AWS_DEFAULT_PROFILE": "default_profile",
},
Config: envConfig{
Region: "region", Profile: "profile",
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_DEFAULT_REGION": "default_region",
"AWS_PROFILE": "profile",
"AWS_DEFAULT_PROFILE": "default_profile",
"AWS_SDK_LOAD_CONFIG": "1",
},
Config: envConfig{
Region: "region", Profile: "profile",
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
Env: map[string]string{
"AWS_DEFAULT_REGION": "default_region",
"AWS_DEFAULT_PROFILE": "default_profile",
},
Config: envConfig{
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
Env: map[string]string{
"AWS_DEFAULT_REGION": "default_region",
"AWS_DEFAULT_PROFILE": "default_profile",
"AWS_SDK_LOAD_CONFIG": "1",
},
Config: envConfig{
Region: "default_region", Profile: "default_profile",
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_PROFILE": "profile",
},
Config: envConfig{
Region: "region", Profile: "profile",
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_DEFAULT_REGION": "default_region",
"AWS_PROFILE": "profile",
"AWS_DEFAULT_PROFILE": "default_profile",
},
Config: envConfig{
Region: "region", Profile: "profile",
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_DEFAULT_REGION": "default_region",
"AWS_PROFILE": "profile",
"AWS_DEFAULT_PROFILE": "default_profile",
"AWS_SDK_LOAD_CONFIG": "1",
},
Config: envConfig{
Region: "region", Profile: "profile",
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
{
Env: map[string]string{
"AWS_DEFAULT_REGION": "default_region",
"AWS_DEFAULT_PROFILE": "default_profile",
},
Config: envConfig{
Region: "default_region", Profile: "default_profile",
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
{
Env: map[string]string{
"AWS_DEFAULT_REGION": "default_region",
"AWS_DEFAULT_PROFILE": "default_profile",
"AWS_SDK_LOAD_CONFIG": "1",
},
Config: envConfig{
Region: "default_region", Profile: "default_profile",
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
{
Env: map[string]string{
"AWS_CA_BUNDLE": "custom_ca_bundle",
},
Config: envConfig{
CustomCABundle: "custom_ca_bundle",
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
Env: map[string]string{
"AWS_CA_BUNDLE": "custom_ca_bundle",
},
Config: envConfig{
CustomCABundle: "custom_ca_bundle",
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
{
Env: map[string]string{
"AWS_SHARED_CREDENTIALS_FILE": "/path/to/credentials/file",
"AWS_CONFIG_FILE": "/path/to/config/file",
},
Config: envConfig{
SharedCredentialsFile: "/path/to/credentials/file",
SharedConfigFile: "/path/to/config/file",
},
},
}
for _, c := range cases {
os.Clearenv()
for k, v := range c.Env {
os.Setenv(k, v)
}
var cfg envConfig
if c.UseSharedConfigCall {
cfg = loadSharedEnvConfig()
} else {
cfg = loadEnvConfig()
}
if !reflect.DeepEqual(c.Config, cfg) {
t.Errorf("expect config to match.\n%s",
awstesting.SprintExpectActual(c.Config, cfg))
}
}
}
func TestSetEnvValue(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
os.Setenv("empty_key", "")
os.Setenv("second_key", "2")
os.Setenv("third_key", "3")
var dst string
setFromEnvVal(&dst, []string{
"empty_key", "first_key", "second_key", "third_key",
})
if e, a := "2", dst; e != a {
t.Errorf("expect %s value from environment, got %s", e, a)
}
}

View file

@ -1,446 +0,0 @@
package session
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestNewDefaultSession(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
s := New(&aws.Config{Region: aws.String("region")})
assert.Equal(t, "region", *s.Config.Region)
assert.Equal(t, http.DefaultClient, s.Config.HTTPClient)
assert.NotNil(t, s.Config.Logger)
assert.Equal(t, aws.LogOff, *s.Config.LogLevel)
}
func TestNew_WithCustomCreds(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
customCreds := credentials.NewStaticCredentials("AKID", "SECRET", "TOKEN")
s := New(&aws.Config{Credentials: customCreds})
assert.Equal(t, customCreds, s.Config.Credentials)
}
type mockLogger struct {
*bytes.Buffer
}
func (w mockLogger) Log(args ...interface{}) {
fmt.Fprintln(w, args...)
}
func TestNew_WithSessionLoadError(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_CONFIG_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_invalid_source_profile")
logger := bytes.Buffer{}
s := New(&aws.Config{Logger: &mockLogger{&logger}})
assert.NotNil(t, s)
svc := s3.New(s)
_, err := svc.ListBuckets(&s3.ListBucketsInput{})
assert.Error(t, err)
assert.Contains(t, logger.String(), "ERROR: failed to create session with AWS_SDK_LOAD_CONFIG enabled")
assert.Contains(t, err.Error(), SharedConfigAssumeRoleError{
RoleARN: "assume_role_invalid_source_profile_role_arn",
}.Error())
}
func TestSessionCopy(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_REGION", "orig_region")
s := Session{
Config: defaults.Config(),
Handlers: defaults.Handlers(),
}
newSess := s.Copy(&aws.Config{Region: aws.String("new_region")})
assert.Equal(t, "orig_region", *s.Config.Region)
assert.Equal(t, "new_region", *newSess.Config.Region)
}
func TestSessionClientConfig(t *testing.T) {
s, err := NewSession(&aws.Config{
Credentials: credentials.AnonymousCredentials,
Region: aws.String("orig_region"),
EndpointResolver: endpoints.ResolverFunc(
func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if e, a := "mock-service", service; e != a {
t.Errorf("expect %q service, got %q", e, a)
}
if e, a := "other-region", region; e != a {
t.Errorf("expect %q region, got %q", e, a)
}
return endpoints.ResolvedEndpoint{
URL: "https://" + service + "." + region + ".amazonaws.com",
SigningRegion: region,
}, nil
},
),
})
assert.NoError(t, err)
cfg := s.ClientConfig("mock-service", &aws.Config{Region: aws.String("other-region")})
assert.Equal(t, "https://mock-service.other-region.amazonaws.com", cfg.Endpoint)
assert.Equal(t, "other-region", cfg.SigningRegion)
assert.Equal(t, "other-region", *cfg.Config.Region)
}
func TestNewSession_NoCredentials(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
s, err := NewSession()
assert.NoError(t, err)
assert.NotNil(t, s.Config.Credentials)
assert.NotEqual(t, credentials.AnonymousCredentials, s.Config.Credentials)
}
func TestNewSessionWithOptions_OverrideProfile(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "other_profile")
s, err := NewSessionWithOptions(Options{
Profile: "full_profile",
})
assert.NoError(t, err)
assert.Equal(t, "full_profile_region", *s.Config.Region)
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, "full_profile_akid", creds.AccessKeyID)
assert.Equal(t, "full_profile_secret", creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.Contains(t, creds.ProviderName, "SharedConfigCredentials")
}
func TestNewSessionWithOptions_OverrideSharedConfigEnable(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "0")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "full_profile")
s, err := NewSessionWithOptions(Options{
SharedConfigState: SharedConfigEnable,
})
assert.NoError(t, err)
assert.Equal(t, "full_profile_region", *s.Config.Region)
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, "full_profile_akid", creds.AccessKeyID)
assert.Equal(t, "full_profile_secret", creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.Contains(t, creds.ProviderName, "SharedConfigCredentials")
}
func TestNewSessionWithOptions_OverrideSharedConfigDisable(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "full_profile")
s, err := NewSessionWithOptions(Options{
SharedConfigState: SharedConfigDisable,
})
assert.NoError(t, err)
assert.Empty(t, *s.Config.Region)
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, "full_profile_akid", creds.AccessKeyID)
assert.Equal(t, "full_profile_secret", creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.Contains(t, creds.ProviderName, "SharedConfigCredentials")
}
func TestNewSessionWithOptions_OverrideSharedConfigFiles(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "config_file_load_order")
s, err := NewSessionWithOptions(Options{
SharedConfigFiles: []string{testConfigOtherFilename},
})
assert.NoError(t, err)
assert.Equal(t, "shared_config_other_region", *s.Config.Region)
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, "shared_config_other_akid", creds.AccessKeyID)
assert.Equal(t, "shared_config_other_secret", creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.Contains(t, creds.ProviderName, "SharedConfigCredentials")
}
func TestNewSessionWithOptions_Overrides(t *testing.T) {
cases := []struct {
InEnvs map[string]string
InProfile string
OutRegion string
OutCreds credentials.Value
}{
{
InEnvs: map[string]string{
"AWS_SDK_LOAD_CONFIG": "0",
"AWS_SHARED_CREDENTIALS_FILE": testConfigFilename,
"AWS_PROFILE": "other_profile",
},
InProfile: "full_profile",
OutRegion: "full_profile_region",
OutCreds: credentials.Value{
AccessKeyID: "full_profile_akid",
SecretAccessKey: "full_profile_secret",
ProviderName: "SharedConfigCredentials",
},
},
{
InEnvs: map[string]string{
"AWS_SDK_LOAD_CONFIG": "0",
"AWS_SHARED_CREDENTIALS_FILE": testConfigFilename,
"AWS_REGION": "env_region",
"AWS_ACCESS_KEY": "env_akid",
"AWS_SECRET_ACCESS_KEY": "env_secret",
"AWS_PROFILE": "other_profile",
},
InProfile: "full_profile",
OutRegion: "env_region",
OutCreds: credentials.Value{
AccessKeyID: "env_akid",
SecretAccessKey: "env_secret",
ProviderName: "EnvConfigCredentials",
},
},
{
InEnvs: map[string]string{
"AWS_SDK_LOAD_CONFIG": "0",
"AWS_SHARED_CREDENTIALS_FILE": testConfigFilename,
"AWS_CONFIG_FILE": testConfigOtherFilename,
"AWS_PROFILE": "shared_profile",
},
InProfile: "config_file_load_order",
OutRegion: "shared_config_region",
OutCreds: credentials.Value{
AccessKeyID: "shared_config_akid",
SecretAccessKey: "shared_config_secret",
ProviderName: "SharedConfigCredentials",
},
},
}
for _, c := range cases {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
for k, v := range c.InEnvs {
os.Setenv(k, v)
}
s, err := NewSessionWithOptions(Options{
Profile: c.InProfile,
SharedConfigState: SharedConfigEnable,
})
assert.NoError(t, err)
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, c.OutRegion, *s.Config.Region)
assert.Equal(t, c.OutCreds.AccessKeyID, creds.AccessKeyID)
assert.Equal(t, c.OutCreds.SecretAccessKey, creds.SecretAccessKey)
assert.Equal(t, c.OutCreds.SessionToken, creds.SessionToken)
assert.Contains(t, creds.ProviderName, c.OutCreds.ProviderName)
}
}
const assumeRoleRespMsg = `
<AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleResult>
<AssumedRoleUser>
<Arn>arn:aws:sts::account_id:assumed-role/role/session_name</Arn>
<AssumedRoleId>AKID:session_name</AssumedRoleId>
</AssumedRoleUser>
<Credentials>
<AccessKeyId>AKID</AccessKeyId>
<SecretAccessKey>SECRET</SecretAccessKey>
<SessionToken>SESSION_TOKEN</SessionToken>
<Expiration>%s</Expiration>
</Credentials>
</AssumeRoleResult>
<ResponseMetadata>
<RequestId>request-id</RequestId>
</ResponseMetadata>
</AssumeRoleResponse>
`
func TestSesisonAssumeRole(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_REGION", "us-east-1")
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_w_creds")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z"))))
}))
s, err := NewSession(&aws.Config{Endpoint: aws.String(server.URL), DisableSSL: aws.Bool(true)})
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Equal(t, "SESSION_TOKEN", creds.SessionToken)
assert.Contains(t, creds.ProviderName, "AssumeRoleProvider")
}
func TestSessionAssumeRole_WithMFA(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_REGION", "us-east-1")
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_w_creds")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.FormValue("SerialNumber"), "0123456789")
assert.Equal(t, r.FormValue("TokenCode"), "tokencode")
w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z"))))
}))
customProviderCalled := false
sess, err := NewSessionWithOptions(Options{
Profile: "assume_role_w_mfa",
Config: aws.Config{
Region: aws.String("us-east-1"),
Endpoint: aws.String(server.URL),
DisableSSL: aws.Bool(true),
},
SharedConfigState: SharedConfigEnable,
AssumeRoleTokenProvider: func() (string, error) {
customProviderCalled = true
return "tokencode", nil
},
})
assert.NoError(t, err)
creds, err := sess.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, customProviderCalled)
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Equal(t, "SESSION_TOKEN", creds.SessionToken)
assert.Contains(t, creds.ProviderName, "AssumeRoleProvider")
}
func TestSessionAssumeRole_WithMFA_NoTokenProvider(t *testing.T) {
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_REGION", "us-east-1")
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_w_creds")
_, err := NewSessionWithOptions(Options{
Profile: "assume_role_w_mfa",
SharedConfigState: SharedConfigEnable,
})
assert.Equal(t, err, AssumeRoleTokenProviderNotSetError{})
}
func TestSessionAssumeRole_DisableSharedConfig(t *testing.T) {
// Backwards compatibility with Shared config disabled
// assume role should not be built into the config.
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "0")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_w_creds")
s, err := NewSession()
assert.NoError(t, err)
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, "assume_role_w_creds_akid", creds.AccessKeyID)
assert.Equal(t, "assume_role_w_creds_secret", creds.SecretAccessKey)
assert.Contains(t, creds.ProviderName, "SharedConfigCredentials")
}
func TestSessionAssumeRole_InvalidSourceProfile(t *testing.T) {
// Backwards compatibility with Shared config disabled
// assume role should not be built into the config.
oldEnv := initSessionTestEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_invalid_source_profile")
s, err := NewSession()
assert.Error(t, err)
assert.Contains(t, err.Error(), "SharedConfigAssumeRoleError: failed to load assume role")
assert.Nil(t, s)
}
func initSessionTestEnv() (oldEnv []string) {
oldEnv = awstesting.StashEnv()
os.Setenv("AWS_CONFIG_FILE", "file_not_exists")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "file_not_exists")
return oldEnv
}

View file

@ -1,274 +0,0 @@
package session
import (
"fmt"
"path/filepath"
"testing"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/go-ini/ini"
"github.com/stretchr/testify/assert"
)
var (
testConfigFilename = filepath.Join("testdata", "shared_config")
testConfigOtherFilename = filepath.Join("testdata", "shared_config_other")
)
func TestLoadSharedConfig(t *testing.T) {
cases := []struct {
Filenames []string
Profile string
Expected sharedConfig
Err error
}{
{
Filenames: []string{"file_not_exists"},
Profile: "default",
},
{
Filenames: []string{testConfigFilename},
Expected: sharedConfig{
Region: "default_region",
},
},
{
Filenames: []string{testConfigOtherFilename, testConfigFilename},
Profile: "config_file_load_order",
Expected: sharedConfig{
Region: "shared_config_region",
Creds: credentials.Value{
AccessKeyID: "shared_config_akid",
SecretAccessKey: "shared_config_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
},
},
{
Filenames: []string{testConfigFilename, testConfigOtherFilename},
Profile: "config_file_load_order",
Expected: sharedConfig{
Region: "shared_config_other_region",
Creds: credentials.Value{
AccessKeyID: "shared_config_other_akid",
SecretAccessKey: "shared_config_other_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigOtherFilename),
},
},
},
{
Filenames: []string{testConfigOtherFilename, testConfigFilename},
Profile: "assume_role",
Expected: sharedConfig{
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_role_arn",
SourceProfile: "complete_creds",
},
AssumeRoleSource: &sharedConfig{
Creds: credentials.Value{
AccessKeyID: "complete_creds_akid",
SecretAccessKey: "complete_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
},
},
},
{
Filenames: []string{testConfigOtherFilename, testConfigFilename},
Profile: "assume_role_invalid_source_profile",
Expected: sharedConfig{
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_invalid_source_profile_role_arn",
SourceProfile: "profile_not_exists",
},
},
Err: SharedConfigAssumeRoleError{RoleARN: "assume_role_invalid_source_profile_role_arn"},
},
{
Filenames: []string{testConfigOtherFilename, testConfigFilename},
Profile: "assume_role_w_creds",
Expected: sharedConfig{
Creds: credentials.Value{
AccessKeyID: "assume_role_w_creds_akid",
SecretAccessKey: "assume_role_w_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_w_creds_role_arn",
SourceProfile: "assume_role_w_creds",
ExternalID: "1234",
RoleSessionName: "assume_role_w_creds_session_name",
},
AssumeRoleSource: &sharedConfig{
Creds: credentials.Value{
AccessKeyID: "assume_role_w_creds_akid",
SecretAccessKey: "assume_role_w_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
},
},
},
{
Filenames: []string{testConfigOtherFilename, testConfigFilename},
Profile: "assume_role_wo_creds",
Expected: sharedConfig{
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_wo_creds_role_arn",
SourceProfile: "assume_role_wo_creds",
},
},
Err: SharedConfigAssumeRoleError{RoleARN: "assume_role_wo_creds_role_arn"},
},
{
Filenames: []string{filepath.Join("testdata", "shared_config_invalid_ini")},
Profile: "profile_name",
Err: SharedConfigLoadError{Filename: filepath.Join("testdata", "shared_config_invalid_ini")},
},
}
for i, c := range cases {
cfg, err := loadSharedConfig(c.Profile, c.Filenames)
if c.Err != nil {
assert.Contains(t, err.Error(), c.Err.Error(), "expected error, %d", i)
continue
}
assert.NoError(t, err, "unexpected error, %d", i)
assert.Equal(t, c.Expected, cfg, "not equal, %d", i)
}
}
func TestLoadSharedConfigFromFile(t *testing.T) {
filename := testConfigFilename
f, err := ini.Load(filename)
if err != nil {
t.Fatalf("failed to load test config file, %s, %v", filename, err)
}
iniFile := sharedConfigFile{IniData: f, Filename: filename}
cases := []struct {
Profile string
Expected sharedConfig
Err error
}{
{
Profile: "default",
Expected: sharedConfig{Region: "default_region"},
},
{
Profile: "alt_profile_name",
Expected: sharedConfig{Region: "alt_profile_name_region"},
},
{
Profile: "short_profile_name_first",
Expected: sharedConfig{Region: "short_profile_name_first_short"},
},
{
Profile: "partial_creds",
Expected: sharedConfig{},
},
{
Profile: "complete_creds",
Expected: sharedConfig{
Creds: credentials.Value{
AccessKeyID: "complete_creds_akid",
SecretAccessKey: "complete_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
},
},
{
Profile: "complete_creds_with_token",
Expected: sharedConfig{
Creds: credentials.Value{
AccessKeyID: "complete_creds_with_token_akid",
SecretAccessKey: "complete_creds_with_token_secret",
SessionToken: "complete_creds_with_token_token",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
},
},
{
Profile: "full_profile",
Expected: sharedConfig{
Creds: credentials.Value{
AccessKeyID: "full_profile_akid",
SecretAccessKey: "full_profile_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
Region: "full_profile_region",
},
},
{
Profile: "partial_assume_role",
Expected: sharedConfig{},
},
{
Profile: "assume_role",
Expected: sharedConfig{
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_role_arn",
SourceProfile: "complete_creds",
},
},
},
{
Profile: "assume_role_w_mfa",
Expected: sharedConfig{
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_role_arn",
SourceProfile: "complete_creds",
MFASerial: "0123456789",
},
},
},
{
Profile: "does_not_exists",
Err: SharedConfigProfileNotExistsError{Profile: "does_not_exists"},
},
}
for i, c := range cases {
cfg := sharedConfig{}
err := cfg.setFromIniFile(c.Profile, iniFile)
if c.Err != nil {
assert.Contains(t, err.Error(), c.Err.Error(), "expected error, %d", i)
continue
}
assert.NoError(t, err, "unexpected error, %d", i)
assert.Equal(t, c.Expected, cfg, "not equal, %d", i)
}
}
func TestLoadSharedConfigIniFiles(t *testing.T) {
cases := []struct {
Filenames []string
Expected []sharedConfigFile
}{
{
Filenames: []string{"not_exists", testConfigFilename},
Expected: []sharedConfigFile{
{Filename: testConfigFilename},
},
},
{
Filenames: []string{testConfigFilename, testConfigOtherFilename},
Expected: []sharedConfigFile{
{Filename: testConfigFilename},
{Filename: testConfigOtherFilename},
},
},
}
for i, c := range cases {
files, err := loadSharedConfigIniFiles(c.Filenames)
assert.NoError(t, err, "unexpected error, %d", i)
assert.Equal(t, len(c.Expected), len(files), "expected num files, %d", i)
for i, expectedFile := range c.Expected {
assert.Equal(t, expectedFile.Filename, files[i].Filename)
}
}
}

View file

@ -1,86 +0,0 @@
// +build go1.5
package v4_test
import (
"fmt"
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
func TestStandaloneSign(t *testing.T) {
creds := unit.Session.Config.Credentials
signer := v4.NewSigner(creds)
for _, c := range standaloneSignCases {
host := fmt.Sprintf("https://%s.%s.%s.amazonaws.com",
c.SubDomain, c.Region, c.Service)
req, err := http.NewRequest("GET", host, nil)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
// URL.EscapedPath() will be used by the signer to get the
// escaped form of the request's URI path.
req.URL.Path = c.OrigURI
req.URL.RawQuery = c.OrigQuery
_, err = signer.Sign(req, nil, c.Service, c.Region, time.Unix(0, 0))
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
actual := req.Header.Get("Authorization")
if e, a := c.ExpSig, actual; e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
if e, a := c.OrigURI, req.URL.Path; e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
if e, a := c.EscapedURI, req.URL.EscapedPath(); e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
}
}
func TestStandaloneSign_RawPath(t *testing.T) {
creds := unit.Session.Config.Credentials
signer := v4.NewSigner(creds)
for _, c := range standaloneSignCases {
host := fmt.Sprintf("https://%s.%s.%s.amazonaws.com",
c.SubDomain, c.Region, c.Service)
req, err := http.NewRequest("GET", host, nil)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
// URL.EscapedPath() will be used by the signer to get the
// escaped form of the request's URI path.
req.URL.Path = c.OrigURI
req.URL.RawPath = c.EscapedURI
req.URL.RawQuery = c.OrigQuery
_, err = signer.Sign(req, nil, c.Service, c.Region, time.Unix(0, 0))
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
actual := req.Header.Get("Authorization")
if e, a := c.ExpSig, actual; e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
if e, a := c.OrigURI, req.URL.Path; e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
if e, a := c.EscapedURI, req.URL.EscapedPath(); e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
}
}

View file

@ -1,254 +0,0 @@
package v4_test
import (
"net/http"
"net/url"
"reflect"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
var standaloneSignCases = []struct {
OrigURI string
OrigQuery string
Region, Service, SubDomain string
ExpSig string
EscapedURI string
}{
{
OrigURI: `/logs-*/_search`,
OrigQuery: `pretty=true`,
Region: "us-west-2", Service: "es", SubDomain: "hostname-clusterkey",
EscapedURI: `/logs-%2A/_search`,
ExpSig: `AWS4-HMAC-SHA256 Credential=AKID/19700101/us-west-2/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=79d0760751907af16f64a537c1242416dacf51204a7dd5284492d15577973b91`,
},
}
func TestPresignHandler(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
ContentDisposition: aws.String("a+b c$d"),
ACL: aws.String("public-read"),
})
req.Time = time.Unix(0, 0)
urlstr, err := req.Presign(5 * time.Minute)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
expectedHost := "bucket.s3.mock-region.amazonaws.com"
expectedDate := "19700101T000000Z"
expectedHeaders := "content-disposition;host;x-amz-acl"
expectedSig := "2d76a414208c0eac2a23ef9c834db9635ecd5a0fbb447a00ad191f82d854f55b"
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
u, _ := url.Parse(urlstr)
urlQ := u.Query()
if e, a := expectedHost, u.Host; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedSig, urlQ.Get("X-Amz-Signature"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedCred, urlQ.Get("X-Amz-Credential"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedDate, urlQ.Get("X-Amz-Date"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "300", urlQ.Get("X-Amz-Expires"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if a := urlQ.Get("X-Amz-Content-Sha256"); len(a) != 0 {
t.Errorf("expect no content sha256 got %v", a)
}
if e, a := "+", urlstr; strings.Contains(a, e) { // + encoded as %20
t.Errorf("expect %v not to be in %v", e, a)
}
}
func TestPresignRequest(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
ContentDisposition: aws.String("a+b c$d"),
ACL: aws.String("public-read"),
})
req.Time = time.Unix(0, 0)
urlstr, headers, err := req.PresignRequest(5 * time.Minute)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
expectedHost := "bucket.s3.mock-region.amazonaws.com"
expectedDate := "19700101T000000Z"
expectedHeaders := "content-disposition;host;x-amz-acl"
expectedSig := "2d76a414208c0eac2a23ef9c834db9635ecd5a0fbb447a00ad191f82d854f55b"
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
expectedHeaderMap := http.Header{
"x-amz-acl": []string{"public-read"},
"content-disposition": []string{"a+b c$d"},
}
u, _ := url.Parse(urlstr)
urlQ := u.Query()
if e, a := expectedHost, u.Host; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedSig, urlQ.Get("X-Amz-Signature"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedCred, urlQ.Get("X-Amz-Credential"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedDate, urlQ.Get("X-Amz-Date"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedHeaderMap, headers; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "300", urlQ.Get("X-Amz-Expires"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if a := urlQ.Get("X-Amz-Content-Sha256"); len(a) != 0 {
t.Errorf("expect no content sha256 got %v", a)
}
if e, a := "+", urlstr; strings.Contains(a, e) { // + encoded as %20
t.Errorf("expect %v not to be in %v", e, a)
}
}
func TestStandaloneSign_CustomURIEscape(t *testing.T) {
var expectSig = `AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=6601e883cc6d23871fd6c2a394c5677ea2b8c82b04a6446786d64cd74f520967`
creds := unit.Session.Config.Credentials
signer := v4.NewSigner(creds, func(s *v4.Signer) {
s.DisableURIPathEscaping = true
})
host := "https://subdomain.us-east-1.es.amazonaws.com"
req, err := http.NewRequest("GET", host, nil)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
req.URL.Path = `/log-*/_search`
req.URL.Opaque = "//subdomain.us-east-1.es.amazonaws.com/log-%2A/_search"
_, err = signer.Sign(req, nil, "es", "us-east-1", time.Unix(0, 0))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
actual := req.Header.Get("Authorization")
if e, a := expectSig, actual; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestStandaloneSign_WithPort(t *testing.T) {
cases := []struct {
description string
url string
expectedSig string
}{
{
"default HTTPS port",
"https://estest.us-east-1.es.amazonaws.com:443/_search",
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=e573fc9aa3a156b720976419319be98fb2824a3abc2ddd895ecb1d1611c6a82d",
},
{
"default HTTP port",
"http://example.com:80/_search",
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=54ebe60c4ae03a40948b849e13c333523235f38002e2807059c64a9a8c7cb951",
},
{
"non-standard HTTP port",
"http://example.com:9200/_search",
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=cd9d926a460f8d3b58b57beadbd87666dc667e014c0afaa4cea37b2867f51b4f",
},
{
"non-standard HTTPS port",
"https://example.com:9200/_search",
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=cd9d926a460f8d3b58b57beadbd87666dc667e014c0afaa4cea37b2867f51b4f",
},
}
for _, c := range cases {
signer := v4.NewSigner(unit.Session.Config.Credentials)
req, _ := http.NewRequest("GET", c.url, nil)
_, err := signer.Sign(req, nil, "es", "us-east-1", time.Unix(0, 0))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
actual := req.Header.Get("Authorization")
if e, a := c.expectedSig, actual; e != a {
t.Errorf("%s, expect %v, got %v", c.description, e, a)
}
}
}
func TestStandalonePresign_WithPort(t *testing.T) {
cases := []struct {
description string
url string
expectedSig string
}{
{
"default HTTPS port",
"https://estest.us-east-1.es.amazonaws.com:443/_search",
"0abcf61a351063441296febf4b485734d780634fba8cf1e7d9769315c35255d6",
},
{
"default HTTP port",
"http://example.com:80/_search",
"fce9976dd6c849c21adfa6d3f3e9eefc651d0e4a2ccd740d43efddcccfdc8179",
},
{
"non-standard HTTP port",
"http://example.com:9200/_search",
"f33c25a81c735e42bef35ed5e9f720c43940562e3e616ff0777bf6dde75249b0",
},
{
"non-standard HTTPS port",
"https://example.com:9200/_search",
"f33c25a81c735e42bef35ed5e9f720c43940562e3e616ff0777bf6dde75249b0",
},
}
for _, c := range cases {
signer := v4.NewSigner(unit.Session.Config.Credentials)
req, _ := http.NewRequest("GET", c.url, nil)
_, err := signer.Presign(req, nil, "es", "us-east-1", 5*time.Minute, time.Unix(0, 0))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
actual := req.URL.Query().Get("X-Amz-Signature")
if e, a := c.expectedSig, actual; e != a {
t.Errorf("%s, expect %v, got %v", c.description, e, a)
}
}
}

View file

@ -1,77 +0,0 @@
package v4
import (
"testing"
)
func TestRuleCheckWhitelist(t *testing.T) {
w := whitelist{
mapRule{
"Cache-Control": struct{}{},
},
}
if !w.IsValid("Cache-Control") {
t.Error("expected true value")
}
if w.IsValid("Cache-") {
t.Error("expected false value")
}
}
func TestRuleCheckBlacklist(t *testing.T) {
b := blacklist{
mapRule{
"Cache-Control": struct{}{},
},
}
if b.IsValid("Cache-Control") {
t.Error("expected false value")
}
if !b.IsValid("Cache-") {
t.Error("expected true value")
}
}
func TestRuleCheckPattern(t *testing.T) {
p := patterns{"X-Amz-Meta-"}
if !p.IsValid("X-Amz-Meta-") {
t.Error("expected true value")
}
if !p.IsValid("X-Amz-Meta-Star") {
t.Error("expected true value")
}
if p.IsValid("Cache-") {
t.Error("expected false value")
}
}
func TestRuleComplexWhitelist(t *testing.T) {
w := rules{
whitelist{
mapRule{
"Cache-Control": struct{}{},
},
},
patterns{"X-Amz-Meta-"},
}
r := rules{
inclusiveRules{patterns{"X-Amz-"}, blacklist{w}},
}
if !r.IsValid("X-Amz-Blah") {
t.Error("expected true value")
}
if r.IsValid("X-Amz-Meta-") {
t.Error("expected false value")
}
if r.IsValid("X-Amz-Meta-Star") {
t.Error("expected false value")
}
if r.IsValid("Cache-Control") {
t.Error("expected false value")
}
}

View file

@ -1,737 +0,0 @@
package v4
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"strconv"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestStripExcessHeaders(t *testing.T) {
vals := []string{
"",
"123",
"1 2 3",
"1 2 3 ",
" 1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2 ",
" 1 2 ",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
expected := []string{
"",
"123",
"1 2 3",
"1 2 3",
"1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2",
"1 2",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
stripExcessSpaces(vals)
for i := 0; i < len(vals); i++ {
if e, a := expected[i], vals[i]; e != a {
t.Errorf("%d, expect %v, got %v", i, e, a)
}
}
}
func buildRequest(serviceName, region, body string) (*http.Request, io.ReadSeeker) {
reader := strings.NewReader(body)
return buildRequestWithBodyReader(serviceName, region, reader)
}
func buildRequestWithBodyReader(serviceName, region string, body io.Reader) (*http.Request, io.ReadSeeker) {
var bodyLen int
type lenner interface {
Len() int
}
if lr, ok := body.(lenner); ok {
bodyLen = lr.Len()
}
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
req, _ := http.NewRequest("POST", endpoint, body)
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
req.Header.Set("X-Amz-Target", "prefix.Operation")
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
if bodyLen > 0 {
req.Header.Set("Content-Length", strconv.Itoa(bodyLen))
}
req.Header.Set("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
var seeker io.ReadSeeker
if sr, ok := body.(io.ReadSeeker); ok {
seeker = sr
} else {
seeker = aws.ReadSeekCloser(body)
}
return req, seeker
}
func buildSigner() Signer {
return Signer{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
}
}
func removeWS(text string) string {
text = strings.Replace(text, " ", "", -1)
text = strings.Replace(text, "\n", "", -1)
text = strings.Replace(text, "\t", "", -1)
return text
}
func assertEqual(t *testing.T, expected, given string) {
if removeWS(expected) != removeWS(given) {
t.Errorf("\nExpected: %s\nGiven: %s", expected, given)
}
}
func TestPresignRequest(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
signer := buildSigner()
signer.Presign(req, body, "dynamodb", "us-east-1", 300*time.Second, time.Unix(0, 0))
expectedDate := "19700101T000000Z"
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
expectedSig := "122f0b9e091e4ba84286097e2b3404a1f1f4c4aad479adda95b7dff0ccbe5581"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
expectedTarget := "prefix.Operation"
q := req.URL.Query()
if e, a := expectedSig, q.Get("X-Amz-Signature"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedCred, q.Get("X-Amz-Credential"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedHeaders, q.Get("X-Amz-SignedHeaders"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if a := q.Get("X-Amz-Meta-Other-Header"); len(a) != 0 {
t.Errorf("expect %v to be empty", a)
}
if e, a := expectedTarget, q.Get("X-Amz-Target"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestPresignBodyWithArrayRequest(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
signer := buildSigner()
signer.Presign(req, body, "dynamodb", "us-east-1", 300*time.Second, time.Unix(0, 0))
expectedDate := "19700101T000000Z"
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
expectedSig := "e3ac55addee8711b76c6d608d762cff285fe8b627a057f8b5ec9268cf82c08b1"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
expectedTarget := "prefix.Operation"
q := req.URL.Query()
if e, a := expectedSig, q.Get("X-Amz-Signature"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedCred, q.Get("X-Amz-Credential"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedHeaders, q.Get("X-Amz-SignedHeaders"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if a := q.Get("X-Amz-Meta-Other-Header"); len(a) != 0 {
t.Errorf("expect %v to be empty, was not", a)
}
if e, a := expectedTarget, q.Get("X-Amz-Target"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignRequest(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
signer := buildSigner()
signer.Sign(req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
expectedDate := "19700101T000000Z"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=a518299330494908a70222cec6899f6f32f297f8595f6df1776d998936652ad9"
q := req.Header
if e, a := expectedSig, q.Get("Authorization"); e != a {
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
}
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
}
}
func TestSignBodyS3(t *testing.T) {
req, body := buildRequest("s3", "us-east-1", "hello")
signer := buildSigner()
signer.Sign(req, body, "s3", "us-east-1", time.Now())
hash := req.Header.Get("X-Amz-Content-Sha256")
if e, a := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignBodyGlacier(t *testing.T) {
req, body := buildRequest("glacier", "us-east-1", "hello")
signer := buildSigner()
signer.Sign(req, body, "glacier", "us-east-1", time.Now())
hash := req.Header.Get("X-Amz-Content-Sha256")
if e, a := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestPresign_SignedPayload(t *testing.T) {
req, body := buildRequest("glacier", "us-east-1", "hello")
signer := buildSigner()
signer.Presign(req, body, "glacier", "us-east-1", 5*time.Minute, time.Now())
hash := req.Header.Get("X-Amz-Content-Sha256")
if e, a := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestPresign_UnsignedPayload(t *testing.T) {
req, body := buildRequest("service-name", "us-east-1", "hello")
signer := buildSigner()
signer.UnsignedPayload = true
signer.Presign(req, body, "service-name", "us-east-1", 5*time.Minute, time.Now())
hash := req.Header.Get("X-Amz-Content-Sha256")
if e, a := "UNSIGNED-PAYLOAD", hash; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestPresign_UnsignedPayload_S3(t *testing.T) {
req, body := buildRequest("s3", "us-east-1", "hello")
signer := buildSigner()
signer.Presign(req, body, "s3", "us-east-1", 5*time.Minute, time.Now())
if a := req.Header.Get("X-Amz-Content-Sha256"); len(a) != 0 {
t.Errorf("expect no content sha256 got %v", a)
}
}
func TestSignUnseekableBody(t *testing.T) {
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
signer := buildSigner()
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
if err == nil {
t.Fatalf("expect error signing request")
}
if e, a := "unseekable request body", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %q to be in %q", e, a)
}
}
func TestSignUnsignedPayloadUnseekableBody(t *testing.T) {
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
signer := buildSigner()
signer.UnsignedPayload = true
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
hash := req.Header.Get("X-Amz-Content-Sha256")
if e, a := "UNSIGNED-PAYLOAD", hash; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignPreComputedHashUnseekableBody(t *testing.T) {
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
signer := buildSigner()
req.Header.Set("X-Amz-Content-Sha256", "some-content-sha256")
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
hash := req.Header.Get("X-Amz-Content-Sha256")
if e, a := "some-content-sha256", hash; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignPrecomputedBodyChecksum(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "hello")
req.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
signer := buildSigner()
signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
hash := req.Header.Get("X-Amz-Content-Sha256")
if e, a := "PRECOMPUTED", hash; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestAnonymousCredentials(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{Credentials: credentials.AnonymousCredentials})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
SignSDKRequest(r)
urlQ := r.HTTPRequest.URL.Query()
if a := urlQ.Get("X-Amz-Signature"); len(a) != 0 {
t.Errorf("expect %v to be empty, was not", a)
}
if a := urlQ.Get("X-Amz-Credential"); len(a) != 0 {
t.Errorf("expect %v to be empty, was not", a)
}
if a := urlQ.Get("X-Amz-SignedHeaders"); len(a) != 0 {
t.Errorf("expect %v to be empty, was not", a)
}
if a := urlQ.Get("X-Amz-Date"); len(a) != 0 {
t.Errorf("expect %v to be empty, was not", a)
}
hQ := r.HTTPRequest.Header
if a := hQ.Get("Authorization"); len(a) != 0 {
t.Errorf("expect %v to be empty, was not", a)
}
if a := hQ.Get("X-Amz-Date"); len(a) != 0 {
t.Errorf("expect %v to be empty, was not", a)
}
}
func TestIgnoreResignRequestWithValidCreds(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
SignSDKRequest(r)
sig := r.HTTPRequest.Header.Get("Authorization")
signSDKRequestWithCurrTime(r, func() time.Time {
// Simulate one second has passed so that signature's date changes
// when it is resigned.
return time.Now().Add(1 * time.Second)
})
if e, a := sig, r.HTTPRequest.Header.Get("Authorization"); e == a {
t.Errorf("expect %v to be %v, but was not", e, a)
}
}
func TestIgnorePreResignRequestWithValidCreds(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
SignSDKRequest(r)
sig := r.HTTPRequest.URL.Query().Get("X-Amz-Signature")
signSDKRequestWithCurrTime(r, func() time.Time {
// Simulate one second has passed so that signature's date changes
// when it is resigned.
return time.Now().Add(1 * time.Second)
})
if e, a := sig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"); e == a {
t.Errorf("expect %v to be %v, but was not", e, a)
}
}
func TestResignRequestExpiredCreds(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
svc := awstesting.NewClient(&aws.Config{Credentials: creds})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
SignSDKRequest(r)
querySig := r.HTTPRequest.Header.Get("Authorization")
var origSignedHeaders string
for _, p := range strings.Split(querySig, ", ") {
if strings.HasPrefix(p, "SignedHeaders=") {
origSignedHeaders = p[len("SignedHeaders="):]
break
}
}
if a := origSignedHeaders; len(a) == 0 {
t.Errorf("expect not to be empty, but was")
}
if e, a := origSignedHeaders, "authorization"; strings.Contains(a, e) {
t.Errorf("expect %v to not be in %v, but was", e, a)
}
origSignedAt := r.LastSignedAt
creds.Expire()
signSDKRequestWithCurrTime(r, func() time.Time {
// Simulate one second has passed so that signature's date changes
// when it is resigned.
return time.Now().Add(1 * time.Second)
})
updatedQuerySig := r.HTTPRequest.Header.Get("Authorization")
if e, a := querySig, updatedQuerySig; e == a {
t.Errorf("expect %v to be %v, was not", e, a)
}
var updatedSignedHeaders string
for _, p := range strings.Split(updatedQuerySig, ", ") {
if strings.HasPrefix(p, "SignedHeaders=") {
updatedSignedHeaders = p[len("SignedHeaders="):]
break
}
}
if a := updatedSignedHeaders; len(a) == 0 {
t.Errorf("expect not to be empty, but was")
}
if e, a := updatedQuerySig, "authorization"; strings.Contains(a, e) {
t.Errorf("expect %v to not be in %v, but was", e, a)
}
if e, a := origSignedAt, r.LastSignedAt; e == a {
t.Errorf("expect %v to be %v, was not", e, a)
}
}
func TestPreResignRequestExpiredCreds(t *testing.T) {
provider := &credentials.StaticProvider{Value: credentials.Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "SESSION",
}}
creds := credentials.NewCredentials(provider)
svc := awstesting.NewClient(&aws.Config{Credentials: creds})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
SignSDKRequest(r)
querySig := r.HTTPRequest.URL.Query().Get("X-Amz-Signature")
signedHeaders := r.HTTPRequest.URL.Query().Get("X-Amz-SignedHeaders")
if a := signedHeaders; len(a) == 0 {
t.Errorf("expect not to be empty, but was")
}
origSignedAt := r.LastSignedAt
creds.Expire()
signSDKRequestWithCurrTime(r, func() time.Time {
// Simulate the request occurred 15 minutes in the past
return time.Now().Add(-48 * time.Hour)
})
if e, a := querySig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"); e == a {
t.Errorf("expect %v to be %v, was not", e, a)
}
resignedHeaders := r.HTTPRequest.URL.Query().Get("X-Amz-SignedHeaders")
if e, a := signedHeaders, resignedHeaders; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := signedHeaders, "x-amz-signedHeaders"; strings.Contains(a, e) {
t.Errorf("expect %v to not be in %v, but was", e, a)
}
if e, a := origSignedAt, r.LastSignedAt; e == a {
t.Errorf("expect %v to be %v, was not", e, a)
}
}
func TestResignRequestExpiredRequest(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
svc := awstesting.NewClient(&aws.Config{Credentials: creds})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
SignSDKRequest(r)
querySig := r.HTTPRequest.Header.Get("Authorization")
origSignedAt := r.LastSignedAt
signSDKRequestWithCurrTime(r, func() time.Time {
// Simulate the request occurred 15 minutes in the past
return time.Now().Add(15 * time.Minute)
})
if e, a := querySig, r.HTTPRequest.Header.Get("Authorization"); e == a {
t.Errorf("expect %v to be %v, was not", e, a)
}
if e, a := origSignedAt, r.LastSignedAt; e == a {
t.Errorf("expect %v to be %v, was not", e, a)
}
}
func TestSignWithRequestBody(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
signer := NewSigner(creds)
expectBody := []byte("abc123")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := expectBody, b; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
w.WriteHeader(http.StatusOK)
}))
req, err := http.NewRequest("POST", server.URL, nil)
_, err = signer.Sign(req, bytes.NewReader(expectBody), "service", "region", time.Now())
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
if e, a := http.StatusOK, resp.StatusCode; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignWithRequestBody_Overwrite(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
signer := NewSigner(creds)
var expectBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
if e, a := len(expectBody), len(b); e != a {
t.Errorf("expect %v, got %v", e, a)
}
w.WriteHeader(http.StatusOK)
}))
req, err := http.NewRequest("GET", server.URL, strings.NewReader("invalid body"))
_, err = signer.Sign(req, nil, "service", "region", time.Now())
req.ContentLength = 0
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
if e, a := http.StatusOK, resp.StatusCode; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestBuildCanonicalRequest(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
ctx := &signingCtx{
ServiceName: "dynamodb",
Region: "us-east-1",
Request: req,
Body: body,
Query: req.URL.Query(),
Time: time.Now(),
ExpireTime: 5 * time.Second,
}
ctx.buildCanonicalString()
expected := "https://example.org/bucket/key-._~,!@#$%^&*()?Foo=z&Foo=o&Foo=m&Foo=a"
if e, a := expected, ctx.Request.URL.String(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignWithBody_ReplaceRequestBody(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
req, seekerBody := buildRequest("dynamodb", "us-east-1", "{}")
req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
s := NewSigner(creds)
origBody := req.Body
_, err := s.Sign(req, seekerBody, "dynamodb", "us-east-1", time.Now())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if req.Body == origBody {
t.Errorf("expeect request body to not be origBody")
}
if req.Body == nil {
t.Errorf("expect request body to be changed but was nil")
}
}
func TestSignWithBody_NoReplaceRequestBody(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
req, seekerBody := buildRequest("dynamodb", "us-east-1", "{}")
req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
s := NewSigner(creds, func(signer *Signer) {
signer.DisableRequestBodyOverwrite = true
})
origBody := req.Body
_, err := s.Sign(req, seekerBody, "dynamodb", "us-east-1", time.Now())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if req.Body != origBody {
t.Errorf("expect request body to not be chagned")
}
}
func TestRequestHost(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
req.Host = "myhost"
ctx := &signingCtx{
ServiceName: "dynamodb",
Region: "us-east-1",
Request: req,
Body: body,
Query: req.URL.Query(),
Time: time.Now(),
ExpireTime: 5 * time.Second,
}
ctx.buildCanonicalHeaders(ignoredHeaders, ctx.Request.Header)
if !strings.Contains(ctx.canonicalHeaders, "host:"+req.Host) {
t.Errorf("canonical host header invalid")
}
}
func BenchmarkPresignRequest(b *testing.B) {
signer := buildSigner()
req, body := buildRequest("dynamodb", "us-east-1", "{}")
for i := 0; i < b.N; i++ {
signer.Presign(req, body, "dynamodb", "us-east-1", 300*time.Second, time.Now())
}
}
func BenchmarkSignRequest(b *testing.B) {
signer := buildSigner()
req, body := buildRequest("dynamodb", "us-east-1", "{}")
for i := 0; i < b.N; i++ {
signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
}
}
var stripExcessSpaceCases = []string{
`AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`,
`123 321 123 321`,
` 123 321 123 321 `,
` 123 321 123 321 `,
"123",
"1 2 3",
" 1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2 ",
" 1 2 ",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
func BenchmarkStripExcessSpaces(b *testing.B) {
for i := 0; i < b.N; i++ {
// Make sure to start with a copy of the cases
cases := append([]string{}, stripExcessSpaceCases...)
stripExcessSpaces(cases)
}
}

View file

@ -1,92 +0,0 @@
package aws
import (
"bytes"
"math/rand"
"testing"
)
func TestWriteAtBuffer(t *testing.T) {
b := &WriteAtBuffer{}
n, err := b.WriteAt([]byte{1}, 0)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 1, n; e != a {
t.Errorf("expected %d, but recieved %d", e, a)
}
n, err = b.WriteAt([]byte{1, 1, 1}, 5)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 3, n; e != a {
t.Errorf("expected %d, but recieved %d", e, a)
}
n, err = b.WriteAt([]byte{2}, 1)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 1, n; e != a {
t.Errorf("expected %d, but recieved %d", e, a)
}
n, err = b.WriteAt([]byte{3}, 2)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 1, n; e != a {
t.Errorf("expected %d, but received %d", e, a)
}
if !bytes.Equal([]byte{1, 2, 3, 0, 0, 1, 1, 1}, b.Bytes()) {
t.Errorf("expected %v, but received %v", []byte{1, 2, 3, 0, 0, 1, 1, 1}, b.Bytes())
}
}
func BenchmarkWriteAtBuffer(b *testing.B) {
buf := &WriteAtBuffer{}
r := rand.New(rand.NewSource(1))
b.ResetTimer()
for i := 0; i < b.N; i++ {
to := r.Intn(10) * 4096
bs := make([]byte, to)
buf.WriteAt(bs, r.Int63n(10)*4096)
}
}
func BenchmarkWriteAtBufferOrderedWrites(b *testing.B) {
// test the performance of a WriteAtBuffer when written in an
// ordered fashion. This is similar to the behavior of the
// s3.Downloader, since downloads the first chunk of the file, then
// the second, and so on.
//
// This test simulates a 150MB file being written in 30 ordered 5MB chunks.
chunk := int64(5e6)
max := chunk * 30
// we'll write the same 5MB chunk every time
tmp := make([]byte, chunk)
for i := 0; i < b.N; i++ {
buf := &WriteAtBuffer{}
for i := int64(0); i < max; i += chunk {
buf.WriteAt(tmp, i)
}
}
}
func BenchmarkWriteAtBufferParallel(b *testing.B) {
buf := &WriteAtBuffer{}
r := rand.New(rand.NewSource(1))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
to := r.Intn(10) * 4096
bs := make([]byte, to)
buf.WriteAt(bs, r.Int63n(10)*4096)
}
})
}

View file

@ -1,40 +0,0 @@
package shareddefaults_test
import (
"os"
"path/filepath"
"testing"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
func TestSharedCredsFilename(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
os.Setenv("HOME", "home_dir")
os.Setenv("USERPROFILE", "profile_dir")
expect := filepath.Join("home_dir", ".aws", "credentials")
name := shareddefaults.SharedCredentialsFilename()
if e, a := expect, name; e != a {
t.Errorf("expect %q shared creds filename, got %q", e, a)
}
}
func TestSharedConfigFilename(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
os.Setenv("HOME", "home_dir")
os.Setenv("USERPROFILE", "profile_dir")
expect := filepath.Join("home_dir", ".aws", "config")
name := shareddefaults.SharedConfigFilename()
if e, a := expect, name; e != a {
t.Errorf("expect %q shared config filename, got %q", e, a)
}
}

View file

@ -1,40 +0,0 @@
package shareddefaults_test
import (
"os"
"path/filepath"
"testing"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
func TestSharedCredsFilename(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
os.Setenv("HOME", "home_dir")
os.Setenv("USERPROFILE", "profile_dir")
expect := filepath.Join("profile_dir", ".aws", "credentials")
name := shareddefaults.SharedCredentialsFilename()
if e, a := expect, name; e != a {
t.Errorf("expect %q shared creds filename, got %q", e, a)
}
}
func TestSharedConfigFilename(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
os.Setenv("HOME", "home_dir")
os.Setenv("USERPROFILE", "profile_dir")
expect := filepath.Join("profile_dir", ".aws", "config")
name := shareddefaults.SharedConfigFilename()
if e, a := expect, name; e != a {
t.Errorf("expect %q shared config filename, got %q", e, a)
}
}

View file

@ -1,106 +0,0 @@
package protocol_test
import (
"reflect"
"testing"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/stretchr/testify/assert"
)
func TestCanSetIdempotencyToken(t *testing.T) {
cases := []struct {
CanSet bool
Case interface{}
}{
{
true,
struct {
Field *string `idempotencyToken:"true"`
}{},
},
{
true,
struct {
Field string `idempotencyToken:"true"`
}{},
},
{
false,
struct {
Field *string `idempotencyToken:"true"`
}{Field: new(string)},
},
{
false,
struct {
Field string `idempotencyToken:"true"`
}{Field: "value"},
},
{
false,
struct {
Field *int `idempotencyToken:"true"`
}{},
},
{
false,
struct {
Field *string
}{},
},
}
for i, c := range cases {
v := reflect.Indirect(reflect.ValueOf(c.Case))
ty := v.Type()
canSet := protocol.CanSetIdempotencyToken(v.Field(0), ty.Field(0))
assert.Equal(t, c.CanSet, canSet, "Expect case %d can set to match", i)
}
}
func TestSetIdempotencyToken(t *testing.T) {
cases := []struct {
Case interface{}
}{
{
&struct {
Field *string `idempotencyToken:"true"`
}{},
},
{
&struct {
Field string `idempotencyToken:"true"`
}{},
},
{
&struct {
Field *string `idempotencyToken:"true"`
}{Field: new(string)},
},
{
&struct {
Field string `idempotencyToken:"true"`
}{Field: ""},
},
}
for i, c := range cases {
v := reflect.Indirect(reflect.ValueOf(c.Case))
protocol.SetIdempotencyToken(v.Field(0))
assert.NotEmpty(t, v.Field(0).Interface(), "Expect case %d to be set", i)
}
}
func TestUUIDVersion4(t *testing.T) {
uuid := protocol.UUIDVersion4(make([]byte, 16))
assert.Equal(t, `00000000-0000-4000-8000-000000000000`, uuid)
b := make([]byte, 16)
for i := 0; i < len(b); i++ {
b[i] = 1
}
uuid = protocol.UUIDVersion4(b)
assert.Equal(t, `01010101-0101-4101-8101-010101010101`, uuid)
}

View file

@ -1,93 +0,0 @@
package protocol
import (
"fmt"
"reflect"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
)
var testJSONValueCases = []struct {
Value aws.JSONValue
Mode EscapeMode
String string
}{
{
Value: aws.JSONValue{
"abc": 123.,
},
Mode: NoEscape,
String: `{"abc":123}`,
},
{
Value: aws.JSONValue{
"abc": 123.,
},
Mode: Base64Escape,
String: `eyJhYmMiOjEyM30=`,
},
{
Value: aws.JSONValue{
"abc": 123.,
},
Mode: QuotedEscape,
String: `"{\"abc\":123}"`,
},
}
func TestEncodeJSONValue(t *testing.T) {
for i, c := range testJSONValueCases {
str, err := EncodeJSONValue(c.Value, c.Mode)
if err != nil {
t.Fatalf("%d, expect no error, got %v", i, err)
}
if e, a := c.String, str; e != a {
t.Errorf("%d, expect %v encoded value, got %v", i, e, a)
}
}
}
func TestDecodeJSONValue(t *testing.T) {
for i, c := range testJSONValueCases {
val, err := DecodeJSONValue(c.String, c.Mode)
if err != nil {
t.Fatalf("%d, expect no error, got %v", i, err)
}
if e, a := c.Value, val; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %v encoded value, got %v", i, e, a)
}
}
}
func TestEncodeJSONValue_PanicUnkownMode(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("expect panic, got none")
} else {
reason := fmt.Sprintf("%v", r)
if e, a := "unknown EscapeMode", reason; !strings.Contains(a, e) {
t.Errorf("expect %q to be in %v", e, a)
}
}
}()
val := aws.JSONValue{}
EncodeJSONValue(val, 123456)
}
func TestDecodeJSONValue_PanicUnkownMode(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("expect panic, got none")
} else {
reason := fmt.Sprintf("%v", r)
if e, a := "unknown EscapeMode", reason; !strings.Contains(a, e) {
t.Errorf("expect %q to be in %v", e, a)
}
}
}()
DecodeJSONValue(`{"abc":123}`, 123456)
}

View file

@ -1,203 +0,0 @@
package protocol_test
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/private/protocol/ec2query"
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
"github.com/aws/aws-sdk-go/private/protocol/query"
"github.com/aws/aws-sdk-go/private/protocol/rest"
"github.com/aws/aws-sdk-go/private/protocol/restjson"
"github.com/aws/aws-sdk-go/private/protocol/restxml"
)
func xmlData(set bool, b []byte, size, delta int) {
const openingTags = "<B><A>"
const closingTags = "</A></B>"
if !set {
copy(b, []byte(openingTags))
}
if size == 0 {
copy(b[delta-len(closingTags):], []byte(closingTags))
}
}
func jsonData(set bool, b []byte, size, delta int) {
if !set {
copy(b, []byte("{\"A\": \""))
}
if size == 0 {
copy(b[delta-len("\"}"):], []byte("\"}"))
}
}
func buildNewRequest(data interface{}) *request.Request {
v := url.Values{}
v.Set("test", "TEST")
v.Add("test1", "TEST1")
req := &request.Request{
HTTPRequest: &http.Request{
Header: make(http.Header),
Body: &awstesting.ReadCloser{Size: 2048},
URL: &url.URL{
RawQuery: v.Encode(),
},
},
Params: &struct {
LocationName string `locationName:"test"`
}{
"Test",
},
ClientInfo: metadata.ClientInfo{
ServiceName: "test",
TargetPrefix: "test",
JSONVersion: "test",
APIVersion: "test",
Endpoint: "test",
SigningName: "test",
SigningRegion: "test",
},
Operation: &request.Operation{
Name: "test",
},
}
req.HTTPResponse = &http.Response{
Body: &awstesting.ReadCloser{Size: 2048},
Header: http.Header{
"X-Amzn-Requestid": []string{"1"},
},
StatusCode: http.StatusOK,
}
if data == nil {
data = &struct {
_ struct{} `type:"structure"`
LocationName *string `locationName:"testName"`
Location *string `location:"statusCode"`
A *string `type:"string"`
}{}
}
req.Data = data
return req
}
type expected struct {
dataType int
closed bool
size int
errExists bool
}
const (
jsonType = iota
xmlType
)
func checkForLeak(data interface{}, build, fn func(*request.Request), t *testing.T, result expected) {
req := buildNewRequest(data)
reader := req.HTTPResponse.Body.(*awstesting.ReadCloser)
switch result.dataType {
case jsonType:
reader.FillData = jsonData
case xmlType:
reader.FillData = xmlData
}
build(req)
fn(req)
if result.errExists {
assert.NotNil(t, req.Error)
} else {
assert.Nil(t, req.Error)
}
assert.Equal(t, reader.Closed, result.closed)
assert.Equal(t, reader.Size, result.size)
}
func TestJSONRpc(t *testing.T) {
checkForLeak(nil, jsonrpc.Build, jsonrpc.Unmarshal, t, expected{jsonType, true, 0, false})
checkForLeak(nil, jsonrpc.Build, jsonrpc.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
checkForLeak(nil, jsonrpc.Build, jsonrpc.UnmarshalError, t, expected{jsonType, true, 0, true})
}
func TestQuery(t *testing.T) {
checkForLeak(nil, query.Build, query.Unmarshal, t, expected{jsonType, true, 0, false})
checkForLeak(nil, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
checkForLeak(nil, query.Build, query.UnmarshalError, t, expected{jsonType, true, 0, true})
}
func TestRest(t *testing.T) {
// case 1: Payload io.ReadSeeker
checkForLeak(nil, rest.Build, rest.Unmarshal, t, expected{jsonType, false, 2048, false})
checkForLeak(nil, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
// case 2: Payload *string
// should close the body
dataStr := struct {
_ struct{} `type:"structure" payload:"Payload"`
LocationName *string `locationName:"testName"`
Location *string `location:"statusCode"`
A *string `type:"string"`
Payload *string `locationName:"payload" type:"blob" required:"true"`
}{}
checkForLeak(&dataStr, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, false})
checkForLeak(&dataStr, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
// case 3: Payload []byte
// should close the body
dataBytes := struct {
_ struct{} `type:"structure" payload:"Payload"`
LocationName *string `locationName:"testName"`
Location *string `location:"statusCode"`
A *string `type:"string"`
Payload []byte `locationName:"payload" type:"blob" required:"true"`
}{}
checkForLeak(&dataBytes, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, false})
checkForLeak(&dataBytes, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
// case 4: Payload unsupported type
// should close the body
dataUnsupported := struct {
_ struct{} `type:"structure" payload:"Payload"`
LocationName *string `locationName:"testName"`
Location *string `location:"statusCode"`
A *string `type:"string"`
Payload string `locationName:"payload" type:"blob" required:"true"`
}{}
checkForLeak(&dataUnsupported, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, true})
checkForLeak(&dataUnsupported, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
}
func TestRestJSON(t *testing.T) {
checkForLeak(nil, restjson.Build, restjson.Unmarshal, t, expected{jsonType, true, 0, false})
checkForLeak(nil, restjson.Build, restjson.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
checkForLeak(nil, restjson.Build, restjson.UnmarshalError, t, expected{jsonType, true, 0, true})
}
func TestRestXML(t *testing.T) {
checkForLeak(nil, restxml.Build, restxml.Unmarshal, t, expected{xmlType, true, 0, false})
checkForLeak(nil, restxml.Build, restxml.UnmarshalMeta, t, expected{xmlType, false, 2048, false})
checkForLeak(nil, restxml.Build, restxml.UnmarshalError, t, expected{xmlType, true, 0, true})
}
func TestXML(t *testing.T) {
checkForLeak(nil, ec2query.Build, ec2query.Unmarshal, t, expected{jsonType, true, 0, false})
checkForLeak(nil, ec2query.Build, ec2query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
checkForLeak(nil, ec2query.Build, ec2query.UnmarshalError, t, expected{jsonType, true, 0, true})
}
func TestProtocol(t *testing.T) {
checkForLeak(nil, restxml.Build, protocol.UnmarshalDiscardBody, t, expected{xmlType, true, 0, false})
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,63 +0,0 @@
package rest
import (
"net/http"
"net/url"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestCleanPath(t *testing.T) {
uri := &url.URL{
Path: "//foo//bar",
Scheme: "https",
Host: "host",
}
cleanPath(uri)
expected := "https://host/foo/bar"
if a, e := uri.String(), expected; a != e {
t.Errorf("expect %q URI, got %q", e, a)
}
}
func TestMarshalPath(t *testing.T) {
in := struct {
Bucket *string `location:"uri" locationName:"bucket"`
Key *string `location:"uri" locationName:"key"`
}{
Bucket: aws.String("mybucket"),
Key: aws.String("my/cool+thing space/object世界"),
}
expectURL := `/mybucket/my/cool+thing space/object世界`
expectEscapedURL := `/mybucket/my/cool%2Bthing%20space/object%E4%B8%96%E7%95%8C`
req := &request.Request{
HTTPRequest: &http.Request{
URL: &url.URL{Scheme: "https", Host: "exmaple.com", Path: "/{bucket}/{key+}"},
},
Params: &in,
}
Build(req)
if req.Error != nil {
t.Fatalf("unexpected error, %v", req.Error)
}
if a, e := req.HTTPRequest.URL.Path, expectURL; a != e {
t.Errorf("expect %q URI, got %q", e, a)
}
if a, e := req.HTTPRequest.URL.RawPath, expectEscapedURL; a != e {
t.Errorf("expect %q escaped URI, got %q", e, a)
}
if a, e := req.HTTPRequest.URL.EscapedPath(), expectEscapedURL; a != e {
t.Errorf("expect %q escaped URI, got %q", e, a)
}
}

View file

@ -1,63 +0,0 @@
package rest_test
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/private/protocol/rest"
)
func TestUnsetHeaders(t *testing.T) {
cfg := &aws.Config{Region: aws.String("us-west-2")}
c := unit.Session.ClientConfig("testService", cfg)
svc := client.New(
*cfg,
metadata.ClientInfo{
ServiceName: "testService",
SigningName: c.SigningName,
SigningRegion: c.SigningRegion,
Endpoint: c.Endpoint,
APIVersion: "",
},
c.Handlers,
)
// Handlers
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(rest.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(rest.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(rest.UnmarshalMetaHandler)
op := &request.Operation{
Name: "test-operation",
HTTPPath: "/",
}
input := &struct {
Foo aws.JSONValue `location:"header" locationName:"x-amz-foo" type:"jsonvalue"`
Bar aws.JSONValue `location:"header" locationName:"x-amz-bar" type:"jsonvalue"`
}{}
output := &struct {
Foo aws.JSONValue `location:"header" locationName:"x-amz-foo" type:"jsonvalue"`
Bar aws.JSONValue `location:"header" locationName:"x-amz-bar" type:"jsonvalue"`
}{}
req := svc.NewRequest(op, input, output)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(bytes.NewBuffer(nil)), Header: http.Header{}}
req.HTTPResponse.Header.Set("X-Amz-Foo", "e30=")
// unmarshal response
rest.UnmarshalMeta(req)
rest.Unmarshal(req)
if req.Error != nil {
t.Fatal(req.Error)
}
}

View file

@ -1,366 +0,0 @@
// +build bench
package restxml_test
import (
"net/http"
"net/http/httptest"
"os"
"testing"
"bytes"
"encoding/xml"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/private/protocol/restxml"
"github.com/aws/aws-sdk-go/service/cloudfront"
"github.com/aws/aws-sdk-go/service/s3"
)
var (
cloudfrontSvc *cloudfront.CloudFront
s3Svc *s3.S3
)
func TestMain(m *testing.M) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
sess := session.Must(session.NewSession(&aws.Config{
Credentials: credentials.NewStaticCredentials("Key", "Secret", "Token"),
Endpoint: aws.String(server.URL),
S3ForcePathStyle: aws.Bool(true),
DisableSSL: aws.Bool(true),
Region: aws.String(endpoints.UsWest2RegionID),
}))
cloudfrontSvc = cloudfront.New(sess)
s3Svc = s3.New(sess)
c := m.Run()
server.Close()
os.Exit(c)
}
func BenchmarkRESTXMLBuild_Complex_CFCreateDistro(b *testing.B) {
params := cloudfrontCreateDistributionInput()
benchRESTXMLBuild(b, func() *request.Request {
req, _ := cloudfrontSvc.CreateDistributionRequest(params)
return req
})
}
func BenchmarkRESTXMLBuild_Simple_CFDeleteDistro(b *testing.B) {
params := cloudfrontDeleteDistributionInput()
benchRESTXMLBuild(b, func() *request.Request {
req, _ := cloudfrontSvc.DeleteDistributionRequest(params)
return req
})
}
func BenchmarkRESTXMLBuild_REST_S3HeadObject(b *testing.B) {
params := s3HeadObjectInput()
benchRESTXMLBuild(b, func() *request.Request {
req, _ := s3Svc.HeadObjectRequest(params)
return req
})
}
func BenchmarkRESTXMLBuild_XML_S3PutObjectAcl(b *testing.B) {
params := s3PutObjectAclInput()
benchRESTXMLBuild(b, func() *request.Request {
req, _ := s3Svc.PutObjectAclRequest(params)
return req
})
}
func BenchmarkRESTXMLRequest_Complex_CFCreateDistro(b *testing.B) {
benchRESTXMLRequest(b, func() *request.Request {
req, _ := cloudfrontSvc.CreateDistributionRequest(cloudfrontCreateDistributionInput())
return req
})
}
func BenchmarkRESTXMLRequest_Simple_CFDeleteDistro(b *testing.B) {
benchRESTXMLRequest(b, func() *request.Request {
req, _ := cloudfrontSvc.DeleteDistributionRequest(cloudfrontDeleteDistributionInput())
return req
})
}
func BenchmarkRESTXMLRequest_REST_S3HeadObject(b *testing.B) {
benchRESTXMLRequest(b, func() *request.Request {
req, _ := s3Svc.HeadObjectRequest(s3HeadObjectInput())
return req
})
}
func BenchmarkRESTXMLRequest_XML_S3PutObjectAcl(b *testing.B) {
benchRESTXMLRequest(b, func() *request.Request {
req, _ := s3Svc.PutObjectAclRequest(s3PutObjectAclInput())
return req
})
}
func BenchmarkEncodingXML_Simple(b *testing.B) {
params := cloudfrontDeleteDistributionInput()
for i := 0; i < b.N; i++ {
buf := &bytes.Buffer{}
encoder := xml.NewEncoder(buf)
if err := encoder.Encode(params); err != nil {
b.Fatal("Unexpected error", err)
}
}
}
func benchRESTXMLBuild(b *testing.B, reqFn func() *request.Request) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := reqFn()
restxml.Build(req)
if req.Error != nil {
b.Fatal("Unexpected error", req.Error)
}
}
}
func benchRESTXMLRequest(b *testing.B, reqFn func() *request.Request) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := reqFn().Send()
if err != nil {
b.Fatal("Unexpected error", err)
}
}
}
func cloudfrontCreateDistributionInput() *cloudfront.CreateDistributionInput {
return &cloudfront.CreateDistributionInput{
DistributionConfig: &cloudfront.DistributionConfig{ // Required
CallerReference: aws.String("string"), // Required
Comment: aws.String("string"), // Required
DefaultCacheBehavior: &cloudfront.DefaultCacheBehavior{ // Required
ForwardedValues: &cloudfront.ForwardedValues{ // Required
Cookies: &cloudfront.CookiePreference{ // Required
Forward: aws.String("ItemSelection"), // Required
WhitelistedNames: &cloudfront.CookieNames{
Quantity: aws.Int64(1), // Required
Items: []*string{
aws.String("string"), // Required
// More values...
},
},
},
QueryString: aws.Bool(true), // Required
Headers: &cloudfront.Headers{
Quantity: aws.Int64(1), // Required
Items: []*string{
aws.String("string"), // Required
// More values...
},
},
},
MinTTL: aws.Int64(1), // Required
TargetOriginId: aws.String("string"), // Required
TrustedSigners: &cloudfront.TrustedSigners{ // Required
Enabled: aws.Bool(true), // Required
Quantity: aws.Int64(1), // Required
Items: []*string{
aws.String("string"), // Required
// More values...
},
},
ViewerProtocolPolicy: aws.String("ViewerProtocolPolicy"), // Required
AllowedMethods: &cloudfront.AllowedMethods{
Items: []*string{ // Required
aws.String("Method"), // Required
// More values...
},
Quantity: aws.Int64(1), // Required
CachedMethods: &cloudfront.CachedMethods{
Items: []*string{ // Required
aws.String("Method"), // Required
// More values...
},
Quantity: aws.Int64(1), // Required
},
},
DefaultTTL: aws.Int64(1),
MaxTTL: aws.Int64(1),
SmoothStreaming: aws.Bool(true),
},
Enabled: aws.Bool(true), // Required
Origins: &cloudfront.Origins{ // Required
Quantity: aws.Int64(1), // Required
Items: []*cloudfront.Origin{
{ // Required
DomainName: aws.String("string"), // Required
Id: aws.String("string"), // Required
CustomOriginConfig: &cloudfront.CustomOriginConfig{
HTTPPort: aws.Int64(1), // Required
HTTPSPort: aws.Int64(1), // Required
OriginProtocolPolicy: aws.String("OriginProtocolPolicy"), // Required
},
OriginPath: aws.String("string"),
S3OriginConfig: &cloudfront.S3OriginConfig{
OriginAccessIdentity: aws.String("string"), // Required
},
},
// More values...
},
},
Aliases: &cloudfront.Aliases{
Quantity: aws.Int64(1), // Required
Items: []*string{
aws.String("string"), // Required
// More values...
},
},
CacheBehaviors: &cloudfront.CacheBehaviors{
Quantity: aws.Int64(1), // Required
Items: []*cloudfront.CacheBehavior{
{ // Required
ForwardedValues: &cloudfront.ForwardedValues{ // Required
Cookies: &cloudfront.CookiePreference{ // Required
Forward: aws.String("ItemSelection"), // Required
WhitelistedNames: &cloudfront.CookieNames{
Quantity: aws.Int64(1), // Required
Items: []*string{
aws.String("string"), // Required
// More values...
},
},
},
QueryString: aws.Bool(true), // Required
Headers: &cloudfront.Headers{
Quantity: aws.Int64(1), // Required
Items: []*string{
aws.String("string"), // Required
// More values...
},
},
},
MinTTL: aws.Int64(1), // Required
PathPattern: aws.String("string"), // Required
TargetOriginId: aws.String("string"), // Required
TrustedSigners: &cloudfront.TrustedSigners{ // Required
Enabled: aws.Bool(true), // Required
Quantity: aws.Int64(1), // Required
Items: []*string{
aws.String("string"), // Required
// More values...
},
},
ViewerProtocolPolicy: aws.String("ViewerProtocolPolicy"), // Required
AllowedMethods: &cloudfront.AllowedMethods{
Items: []*string{ // Required
aws.String("Method"), // Required
// More values...
},
Quantity: aws.Int64(1), // Required
CachedMethods: &cloudfront.CachedMethods{
Items: []*string{ // Required
aws.String("Method"), // Required
// More values...
},
Quantity: aws.Int64(1), // Required
},
},
DefaultTTL: aws.Int64(1),
MaxTTL: aws.Int64(1),
SmoothStreaming: aws.Bool(true),
},
// More values...
},
},
CustomErrorResponses: &cloudfront.CustomErrorResponses{
Quantity: aws.Int64(1), // Required
Items: []*cloudfront.CustomErrorResponse{
{ // Required
ErrorCode: aws.Int64(1), // Required
ErrorCachingMinTTL: aws.Int64(1),
ResponseCode: aws.String("string"),
ResponsePagePath: aws.String("string"),
},
// More values...
},
},
DefaultRootObject: aws.String("string"),
Logging: &cloudfront.LoggingConfig{
Bucket: aws.String("string"), // Required
Enabled: aws.Bool(true), // Required
IncludeCookies: aws.Bool(true), // Required
Prefix: aws.String("string"), // Required
},
PriceClass: aws.String("PriceClass"),
Restrictions: &cloudfront.Restrictions{
GeoRestriction: &cloudfront.GeoRestriction{ // Required
Quantity: aws.Int64(1), // Required
RestrictionType: aws.String("GeoRestrictionType"), // Required
Items: []*string{
aws.String("string"), // Required
// More values...
},
},
},
ViewerCertificate: &cloudfront.ViewerCertificate{
CloudFrontDefaultCertificate: aws.Bool(true),
IAMCertificateId: aws.String("string"),
MinimumProtocolVersion: aws.String("MinimumProtocolVersion"),
SSLSupportMethod: aws.String("SSLSupportMethod"),
},
},
}
}
func cloudfrontDeleteDistributionInput() *cloudfront.DeleteDistributionInput {
return &cloudfront.DeleteDistributionInput{
Id: aws.String("string"), // Required
IfMatch: aws.String("string"),
}
}
func s3HeadObjectInput() *s3.HeadObjectInput {
return &s3.HeadObjectInput{
Bucket: aws.String("somebucketname"),
Key: aws.String("keyname"),
VersionId: aws.String("someVersion"),
IfMatch: aws.String("IfMatch"),
}
}
func s3PutObjectAclInput() *s3.PutObjectAclInput {
return &s3.PutObjectAclInput{
Bucket: aws.String("somebucketname"),
Key: aws.String("keyname"),
AccessControlPolicy: &s3.AccessControlPolicy{
Grants: []*s3.Grant{
{
Grantee: &s3.Grantee{
DisplayName: aws.String("someName"),
EmailAddress: aws.String("someAddr"),
ID: aws.String("someID"),
Type: aws.String(s3.TypeCanonicalUser),
URI: aws.String("someURI"),
},
Permission: aws.String(s3.PermissionWrite),
},
},
Owner: &s3.Owner{
DisplayName: aws.String("howdy"),
ID: aws.String("someID"),
},
},
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,40 +0,0 @@
package protocol_test
import (
"net/http"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/stretchr/testify/assert"
)
type mockCloser struct {
*strings.Reader
Closed bool
}
func (m *mockCloser) Close() error {
m.Closed = true
return nil
}
func TestUnmarshalDrainBody(t *testing.T) {
b := &mockCloser{Reader: strings.NewReader("example body")}
r := &request.Request{HTTPResponse: &http.Response{
Body: b,
}}
protocol.UnmarshalDiscardBody(r)
assert.NoError(t, r.Error)
assert.Equal(t, 0, b.Len())
assert.True(t, b.Closed)
}
func TestUnmarshalDrainBodyNoBody(t *testing.T) {
r := &request.Request{HTTPResponse: &http.Response{}}
protocol.UnmarshalDiscardBody(r)
assert.NoError(t, r.Error)
}

View file

@ -1,142 +0,0 @@
package xmlutil
import (
"encoding/xml"
"fmt"
"io"
"reflect"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
)
type mockBody struct {
DoneErr error
Body io.Reader
}
func (m *mockBody) Read(p []byte) (int, error) {
n, err := m.Body.Read(p)
if (n == 0 || err == io.EOF) && m.DoneErr != nil {
return n, m.DoneErr
}
return n, err
}
type mockOutput struct {
_ struct{} `type:"structure"`
String *string `type:"string"`
Integer *int64 `type:"integer"`
Nested *mockNestedStruct `type:"structure"`
List []*mockListElem `locationName:"List" locationNameList:"Elem" type:"list"`
Closed *mockClosedTags `type:"structure"`
}
type mockNestedStruct struct {
_ struct{} `type:"structure"`
NestedString *string `type:"string"`
NestedInt *int64 `type:"integer"`
}
type mockClosedTags struct {
_ struct{} `type:"structure" xmlPrefix:"xsi" xmlURI:"http://www.w3.org/2001/XMLSchema-instance"`
Attr *string `locationName:"xsi:attrval" type:"string" xmlAttribute:"true"`
}
type mockListElem struct {
_ struct{} `type:"structure" xmlPrefix:"xsi" xmlURI:"http://www.w3.org/2001/XMLSchema-instance"`
String *string `type:"string"`
NestedElem *mockNestedListElem `type:"structure"`
}
type mockNestedListElem struct {
_ struct{} `type:"structure" xmlPrefix:"xsi" xmlURI:"http://www.w3.org/2001/XMLSchema-instance"`
String *string `type:"string"`
Type *string `locationName:"xsi:type" type:"string" xmlAttribute:"true"`
}
func TestUnmarshal(t *testing.T) {
const xmlBodyStr = `<?xml version="1.0" encoding="UTF-8"?>
<MockResponse xmlns="http://xmlns.example.com">
<String>string value</String>
<Integer>123</Integer>
<Closed xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:attrval="attr value"/>
<Nested>
<NestedString>nested string value</NestedString>
<NestedInt>321</NestedInt>
</Nested>
<List>
<Elem>
<NestedElem xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="type">
<String>nested elem string value</String>
</NestedElem>
<String>elem string value</String>
</Elem>
</List>
</MockResponse>`
expect := mockOutput{
String: aws.String("string value"),
Integer: aws.Int64(123),
Closed: &mockClosedTags{
Attr: aws.String("attr value"),
},
Nested: &mockNestedStruct{
NestedString: aws.String("nested string value"),
NestedInt: aws.Int64(321),
},
List: []*mockListElem{
{
String: aws.String("elem string value"),
NestedElem: &mockNestedListElem{
String: aws.String("nested elem string value"),
Type: aws.String("type"),
},
},
},
}
actual := mockOutput{}
decoder := xml.NewDecoder(strings.NewReader(xmlBodyStr))
err := UnmarshalXML(&actual, decoder, "")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if !reflect.DeepEqual(expect, actual) {
t.Errorf("expect unmarshal to match\nExpect: %s\nActual: %s",
awsutil.Prettify(expect), awsutil.Prettify(actual))
}
}
func TestUnmarshal_UnexpectedEOF(t *testing.T) {
const partialXMLBody = `<?xml version="1.0" encoding="UTF-8"?>
<First>first value</First>
<Second>Second val`
out := struct {
First *string `locationName:"First" type:"string"`
Second *string `locationName:"Second" type:"string"`
}{}
expect := out
expect.First = aws.String("first")
expect.Second = aws.String("second")
expectErr := fmt.Errorf("expected read error")
body := &mockBody{
DoneErr: expectErr,
Body: strings.NewReader(partialXMLBody),
}
decoder := xml.NewDecoder(body)
err := UnmarshalXML(&out, decoder, "")
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := expectErr, err; e != a {
t.Errorf("expect %v error in %v, but was not", e, a)
}
}

View file

@ -1,35 +0,0 @@
package route53_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/route53"
)
func TestBuildCorrectURI(t *testing.T) {
const expectPath = "/2013-04-01/hostedzone/ABCDEFG"
svc := route53.New(unit.Session)
svc.Handlers.Validate.Clear()
req, _ := svc.GetHostedZoneRequest(&route53.GetHostedZoneInput{
Id: aws.String("/hostedzone/ABCDEFG"),
})
req.HTTPRequest.URL.RawQuery = "abc=123"
req.Build()
if a, e := req.HTTPRequest.URL.Path, expectPath; a != e {
t.Errorf("expect path %q, got %q", e, a)
}
if a, e := req.HTTPRequest.URL.RawPath, expectPath; a != e {
t.Errorf("expect raw path %q, got %q", e, a)
}
if a, e := req.HTTPRequest.URL.RawQuery, "abc=123"; a != e {
t.Errorf("expect query to be %q, got %q", e, a)
}
}

View file

@ -1,936 +0,0 @@
// Code generated by private/model/cli/gen-api/main.go. DO NOT EDIT.
package route53_test
import (
"fmt"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/route53"
)
var _ time.Duration
var _ strings.Reader
var _ aws.Config
func parseTime(layout, value string) *time.Time {
t, err := time.Parse(layout, value)
if err != nil {
panic(err)
}
return &t
}
// To associate a VPC with a hosted zone
//
// The following example associates the VPC with ID vpc-1a2b3c4d with the hosted zone
// with ID Z3M3LMPEXAMPLE.
func ExampleRoute53_AssociateVPCWithHostedZone_shared00() {
svc := route53.New(session.New())
input := &route53.AssociateVPCWithHostedZoneInput{
Comment: aws.String(""),
HostedZoneId: aws.String("Z3M3LMPEXAMPLE"),
VPC: &route53.VPC{
VPCId: aws.String("vpc-1a2b3c4d"),
VPCRegion: aws.String("us-east-2"),
},
}
result, err := svc.AssociateVPCWithHostedZone(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeNotAuthorizedException:
fmt.Println(route53.ErrCodeNotAuthorizedException, aerr.Error())
case route53.ErrCodeInvalidVPCId:
fmt.Println(route53.ErrCodeInvalidVPCId, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodePublicZoneVPCAssociation:
fmt.Println(route53.ErrCodePublicZoneVPCAssociation, aerr.Error())
case route53.ErrCodeConflictingDomainExists:
fmt.Println(route53.ErrCodeConflictingDomainExists, aerr.Error())
case route53.ErrCodeLimitsExceeded:
fmt.Println(route53.ErrCodeLimitsExceeded, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To create a basic resource record set
//
// The following example creates a resource record set that routes Internet traffic
// to a resource with an IP address of 192.0.2.44.
func ExampleRoute53_ChangeResourceRecordSets_shared00() {
svc := route53.New(session.New())
input := &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
Name: aws.String("example.com"),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("192.0.2.44"),
},
},
TTL: aws.Int64(60),
Type: aws.String("A"),
},
},
},
Comment: aws.String("Web server for example.com"),
},
HostedZoneId: aws.String("Z3M3LMPEXAMPLE"),
}
result, err := svc.ChangeResourceRecordSets(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeNoSuchHealthCheck:
fmt.Println(route53.ErrCodeNoSuchHealthCheck, aerr.Error())
case route53.ErrCodeInvalidChangeBatch:
fmt.Println(route53.ErrCodeInvalidChangeBatch, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodePriorRequestNotComplete:
fmt.Println(route53.ErrCodePriorRequestNotComplete, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To create weighted resource record sets
//
// The following example creates two weighted resource record sets. The resource with
// a Weight of 100 will get 1/3rd of traffic (100/100+200), and the other resource will
// get the rest of the traffic for example.com.
func ExampleRoute53_ChangeResourceRecordSets_shared01() {
svc := route53.New(session.New())
input := &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
HealthCheckId: aws.String("abcdef11-2222-3333-4444-555555fedcba"),
Name: aws.String("example.com"),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("192.0.2.44"),
},
},
SetIdentifier: aws.String("Seattle data center"),
TTL: aws.Int64(60),
Type: aws.String("A"),
Weight: aws.Int64(100),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
HealthCheckId: aws.String("abcdef66-7777-8888-9999-000000fedcba"),
Name: aws.String("example.com"),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("192.0.2.45"),
},
},
SetIdentifier: aws.String("Portland data center"),
TTL: aws.Int64(60),
Type: aws.String("A"),
Weight: aws.Int64(200),
},
},
},
Comment: aws.String("Web servers for example.com"),
},
HostedZoneId: aws.String("Z3M3LMPEXAMPLE"),
}
result, err := svc.ChangeResourceRecordSets(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeNoSuchHealthCheck:
fmt.Println(route53.ErrCodeNoSuchHealthCheck, aerr.Error())
case route53.ErrCodeInvalidChangeBatch:
fmt.Println(route53.ErrCodeInvalidChangeBatch, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodePriorRequestNotComplete:
fmt.Println(route53.ErrCodePriorRequestNotComplete, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To create an alias resource record set
//
// The following example creates an alias resource record set that routes traffic to
// a CloudFront distribution.
func ExampleRoute53_ChangeResourceRecordSets_shared02() {
svc := route53.New(session.New())
input := &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: &route53.AliasTarget{
DNSName: aws.String("d123rk29d0stfj.cloudfront.net"),
EvaluateTargetHealth: aws.Bool(false),
HostedZoneId: aws.String("Z2FDTNDATAQYW2"),
},
Name: aws.String("example.com"),
Type: aws.String("A"),
},
},
},
Comment: aws.String("CloudFront distribution for example.com"),
},
HostedZoneId: aws.String("Z3M3LMPEXAMPLE"),
}
result, err := svc.ChangeResourceRecordSets(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeNoSuchHealthCheck:
fmt.Println(route53.ErrCodeNoSuchHealthCheck, aerr.Error())
case route53.ErrCodeInvalidChangeBatch:
fmt.Println(route53.ErrCodeInvalidChangeBatch, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodePriorRequestNotComplete:
fmt.Println(route53.ErrCodePriorRequestNotComplete, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To create weighted alias resource record sets
//
// The following example creates two weighted alias resource record sets that route
// traffic to ELB load balancers. The resource with a Weight of 100 will get 1/3rd of
// traffic (100/100+200), and the other resource will get the rest of the traffic for
// example.com.
func ExampleRoute53_ChangeResourceRecordSets_shared03() {
svc := route53.New(session.New())
input := &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: &route53.AliasTarget{
DNSName: aws.String("example-com-123456789.us-east-2.elb.amazonaws.com "),
EvaluateTargetHealth: aws.Bool(true),
HostedZoneId: aws.String("Z3AADJGX6KTTL2"),
},
Name: aws.String("example.com"),
SetIdentifier: aws.String("Ohio region"),
Type: aws.String("A"),
Weight: aws.Int64(100),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: &route53.AliasTarget{
DNSName: aws.String("example-com-987654321.us-west-2.elb.amazonaws.com "),
EvaluateTargetHealth: aws.Bool(true),
HostedZoneId: aws.String("Z1H1FL5HABSF5"),
},
Name: aws.String("example.com"),
SetIdentifier: aws.String("Oregon region"),
Type: aws.String("A"),
Weight: aws.Int64(200),
},
},
},
Comment: aws.String("ELB load balancers for example.com"),
},
HostedZoneId: aws.String("Z3M3LMPEXAMPLE"),
}
result, err := svc.ChangeResourceRecordSets(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeNoSuchHealthCheck:
fmt.Println(route53.ErrCodeNoSuchHealthCheck, aerr.Error())
case route53.ErrCodeInvalidChangeBatch:
fmt.Println(route53.ErrCodeInvalidChangeBatch, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodePriorRequestNotComplete:
fmt.Println(route53.ErrCodePriorRequestNotComplete, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To create latency resource record sets
//
// The following example creates two latency resource record sets that route traffic
// to EC2 instances. Traffic for example.com is routed either to the Ohio region or
// the Oregon region, depending on the latency between the user and those regions.
func ExampleRoute53_ChangeResourceRecordSets_shared04() {
svc := route53.New(session.New())
input := &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
HealthCheckId: aws.String("abcdef11-2222-3333-4444-555555fedcba"),
Name: aws.String("example.com"),
Region: aws.String("us-east-2"),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("192.0.2.44"),
},
},
SetIdentifier: aws.String("Ohio region"),
TTL: aws.Int64(60),
Type: aws.String("A"),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
HealthCheckId: aws.String("abcdef66-7777-8888-9999-000000fedcba"),
Name: aws.String("example.com"),
Region: aws.String("us-west-2"),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("192.0.2.45"),
},
},
SetIdentifier: aws.String("Oregon region"),
TTL: aws.Int64(60),
Type: aws.String("A"),
},
},
},
Comment: aws.String("EC2 instances for example.com"),
},
HostedZoneId: aws.String("Z3M3LMPEXAMPLE"),
}
result, err := svc.ChangeResourceRecordSets(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeNoSuchHealthCheck:
fmt.Println(route53.ErrCodeNoSuchHealthCheck, aerr.Error())
case route53.ErrCodeInvalidChangeBatch:
fmt.Println(route53.ErrCodeInvalidChangeBatch, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodePriorRequestNotComplete:
fmt.Println(route53.ErrCodePriorRequestNotComplete, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To create latency alias resource record sets
//
// The following example creates two latency alias resource record sets that route traffic
// for example.com to ELB load balancers. Requests are routed either to the Ohio region
// or the Oregon region, depending on the latency between the user and those regions.
func ExampleRoute53_ChangeResourceRecordSets_shared05() {
svc := route53.New(session.New())
input := &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: &route53.AliasTarget{
DNSName: aws.String("example-com-123456789.us-east-2.elb.amazonaws.com "),
EvaluateTargetHealth: aws.Bool(true),
HostedZoneId: aws.String("Z3AADJGX6KTTL2"),
},
Name: aws.String("example.com"),
Region: aws.String("us-east-2"),
SetIdentifier: aws.String("Ohio region"),
Type: aws.String("A"),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: &route53.AliasTarget{
DNSName: aws.String("example-com-987654321.us-west-2.elb.amazonaws.com "),
EvaluateTargetHealth: aws.Bool(true),
HostedZoneId: aws.String("Z1H1FL5HABSF5"),
},
Name: aws.String("example.com"),
Region: aws.String("us-west-2"),
SetIdentifier: aws.String("Oregon region"),
Type: aws.String("A"),
},
},
},
Comment: aws.String("ELB load balancers for example.com"),
},
HostedZoneId: aws.String("Z3M3LMPEXAMPLE"),
}
result, err := svc.ChangeResourceRecordSets(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeNoSuchHealthCheck:
fmt.Println(route53.ErrCodeNoSuchHealthCheck, aerr.Error())
case route53.ErrCodeInvalidChangeBatch:
fmt.Println(route53.ErrCodeInvalidChangeBatch, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodePriorRequestNotComplete:
fmt.Println(route53.ErrCodePriorRequestNotComplete, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To create failover resource record sets
//
// The following example creates primary and secondary failover resource record sets
// that route traffic to EC2 instances. Traffic is generally routed to the primary resource,
// in the Ohio region. If that resource is unavailable, traffic is routed to the secondary
// resource, in the Oregon region.
func ExampleRoute53_ChangeResourceRecordSets_shared06() {
svc := route53.New(session.New())
input := &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
Failover: aws.String("PRIMARY"),
HealthCheckId: aws.String("abcdef11-2222-3333-4444-555555fedcba"),
Name: aws.String("example.com"),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("192.0.2.44"),
},
},
SetIdentifier: aws.String("Ohio region"),
TTL: aws.Int64(60),
Type: aws.String("A"),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
Failover: aws.String("SECONDARY"),
HealthCheckId: aws.String("abcdef66-7777-8888-9999-000000fedcba"),
Name: aws.String("example.com"),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("192.0.2.45"),
},
},
SetIdentifier: aws.String("Oregon region"),
TTL: aws.Int64(60),
Type: aws.String("A"),
},
},
},
Comment: aws.String("Failover configuration for example.com"),
},
HostedZoneId: aws.String("Z3M3LMPEXAMPLE"),
}
result, err := svc.ChangeResourceRecordSets(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeNoSuchHealthCheck:
fmt.Println(route53.ErrCodeNoSuchHealthCheck, aerr.Error())
case route53.ErrCodeInvalidChangeBatch:
fmt.Println(route53.ErrCodeInvalidChangeBatch, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodePriorRequestNotComplete:
fmt.Println(route53.ErrCodePriorRequestNotComplete, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To create failover alias resource record sets
//
// The following example creates primary and secondary failover alias resource record
// sets that route traffic to ELB load balancers. Traffic is generally routed to the
// primary resource, in the Ohio region. If that resource is unavailable, traffic is
// routed to the secondary resource, in the Oregon region.
func ExampleRoute53_ChangeResourceRecordSets_shared07() {
svc := route53.New(session.New())
input := &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: &route53.AliasTarget{
DNSName: aws.String("example-com-123456789.us-east-2.elb.amazonaws.com "),
EvaluateTargetHealth: aws.Bool(true),
HostedZoneId: aws.String("Z3AADJGX6KTTL2"),
},
Failover: aws.String("PRIMARY"),
Name: aws.String("example.com"),
SetIdentifier: aws.String("Ohio region"),
Type: aws.String("A"),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: &route53.AliasTarget{
DNSName: aws.String("example-com-987654321.us-west-2.elb.amazonaws.com "),
EvaluateTargetHealth: aws.Bool(true),
HostedZoneId: aws.String("Z1H1FL5HABSF5"),
},
Failover: aws.String("SECONDARY"),
Name: aws.String("example.com"),
SetIdentifier: aws.String("Oregon region"),
Type: aws.String("A"),
},
},
},
Comment: aws.String("Failover alias configuration for example.com"),
},
HostedZoneId: aws.String("Z3M3LMPEXAMPLE"),
}
result, err := svc.ChangeResourceRecordSets(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeNoSuchHealthCheck:
fmt.Println(route53.ErrCodeNoSuchHealthCheck, aerr.Error())
case route53.ErrCodeInvalidChangeBatch:
fmt.Println(route53.ErrCodeInvalidChangeBatch, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodePriorRequestNotComplete:
fmt.Println(route53.ErrCodePriorRequestNotComplete, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To create geolocation resource record sets
//
// The following example creates four geolocation resource record sets that use IPv4
// addresses to route traffic to resources such as web servers running on EC2 instances.
// Traffic is routed to one of four IP addresses, for North America (NA), for South
// America (SA), for Europe (EU), and for all other locations (*).
func ExampleRoute53_ChangeResourceRecordSets_shared08() {
svc := route53.New(session.New())
input := &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
GeoLocation: &route53.GeoLocation{
ContinentCode: aws.String("NA"),
},
Name: aws.String("example.com"),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("192.0.2.44"),
},
},
SetIdentifier: aws.String("North America"),
TTL: aws.Int64(60),
Type: aws.String("A"),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
GeoLocation: &route53.GeoLocation{
ContinentCode: aws.String("SA"),
},
Name: aws.String("example.com"),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("192.0.2.45"),
},
},
SetIdentifier: aws.String("South America"),
TTL: aws.Int64(60),
Type: aws.String("A"),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
GeoLocation: &route53.GeoLocation{
ContinentCode: aws.String("EU"),
},
Name: aws.String("example.com"),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("192.0.2.46"),
},
},
SetIdentifier: aws.String("Europe"),
TTL: aws.Int64(60),
Type: aws.String("A"),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
GeoLocation: &route53.GeoLocation{
CountryCode: aws.String("*"),
},
Name: aws.String("example.com"),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("192.0.2.47"),
},
},
SetIdentifier: aws.String("Other locations"),
TTL: aws.Int64(60),
Type: aws.String("A"),
},
},
},
Comment: aws.String("Geolocation configuration for example.com"),
},
HostedZoneId: aws.String("Z3M3LMPEXAMPLE"),
}
result, err := svc.ChangeResourceRecordSets(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeNoSuchHealthCheck:
fmt.Println(route53.ErrCodeNoSuchHealthCheck, aerr.Error())
case route53.ErrCodeInvalidChangeBatch:
fmt.Println(route53.ErrCodeInvalidChangeBatch, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodePriorRequestNotComplete:
fmt.Println(route53.ErrCodePriorRequestNotComplete, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To create geolocation alias resource record sets
//
// The following example creates four geolocation alias resource record sets that route
// traffic to ELB load balancers. Traffic is routed to one of four IP addresses, for
// North America (NA), for South America (SA), for Europe (EU), and for all other locations
// (*).
func ExampleRoute53_ChangeResourceRecordSets_shared09() {
svc := route53.New(session.New())
input := &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: &route53.AliasTarget{
DNSName: aws.String("example-com-123456789.us-east-2.elb.amazonaws.com "),
EvaluateTargetHealth: aws.Bool(true),
HostedZoneId: aws.String("Z3AADJGX6KTTL2"),
},
GeoLocation: &route53.GeoLocation{
ContinentCode: aws.String("NA"),
},
Name: aws.String("example.com"),
SetIdentifier: aws.String("North America"),
Type: aws.String("A"),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: &route53.AliasTarget{
DNSName: aws.String("example-com-234567890.sa-east-1.elb.amazonaws.com "),
EvaluateTargetHealth: aws.Bool(true),
HostedZoneId: aws.String("Z2P70J7HTTTPLU"),
},
GeoLocation: &route53.GeoLocation{
ContinentCode: aws.String("SA"),
},
Name: aws.String("example.com"),
SetIdentifier: aws.String("South America"),
Type: aws.String("A"),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: &route53.AliasTarget{
DNSName: aws.String("example-com-234567890.eu-central-1.elb.amazonaws.com "),
EvaluateTargetHealth: aws.Bool(true),
HostedZoneId: aws.String("Z215JYRZR1TBD5"),
},
GeoLocation: &route53.GeoLocation{
ContinentCode: aws.String("EU"),
},
Name: aws.String("example.com"),
SetIdentifier: aws.String("Europe"),
Type: aws.String("A"),
},
},
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: &route53.AliasTarget{
DNSName: aws.String("example-com-234567890.ap-southeast-1.elb.amazonaws.com "),
EvaluateTargetHealth: aws.Bool(true),
HostedZoneId: aws.String("Z1LMS91P8CMLE5"),
},
GeoLocation: &route53.GeoLocation{
CountryCode: aws.String("*"),
},
Name: aws.String("example.com"),
SetIdentifier: aws.String("Other locations"),
Type: aws.String("A"),
},
},
},
Comment: aws.String("Geolocation alias configuration for example.com"),
},
HostedZoneId: aws.String("Z3M3LMPEXAMPLE"),
}
result, err := svc.ChangeResourceRecordSets(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeNoSuchHealthCheck:
fmt.Println(route53.ErrCodeNoSuchHealthCheck, aerr.Error())
case route53.ErrCodeInvalidChangeBatch:
fmt.Println(route53.ErrCodeInvalidChangeBatch, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodePriorRequestNotComplete:
fmt.Println(route53.ErrCodePriorRequestNotComplete, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To add or remove tags from a hosted zone or health check
//
// The following example adds two tags and removes one tag from the hosted zone with
// ID Z3M3LMPEXAMPLE.
func ExampleRoute53_ChangeTagsForResource_shared00() {
svc := route53.New(session.New())
input := &route53.ChangeTagsForResourceInput{
AddTags: []*route53.Tag{
{
Key: aws.String("apex"),
Value: aws.String("3874"),
},
{
Key: aws.String("acme"),
Value: aws.String("4938"),
},
},
RemoveTagKeys: []*string{
aws.String("Nadir"),
},
ResourceId: aws.String("Z3M3LMPEXAMPLE"),
ResourceType: aws.String("hostedzone"),
}
result, err := svc.ChangeTagsForResource(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
case route53.ErrCodeNoSuchHealthCheck:
fmt.Println(route53.ErrCodeNoSuchHealthCheck, aerr.Error())
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodePriorRequestNotComplete:
fmt.Println(route53.ErrCodePriorRequestNotComplete, aerr.Error())
case route53.ErrCodeThrottlingException:
fmt.Println(route53.ErrCodeThrottlingException, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To get information about a hosted zone
//
// The following example gets information about the Z3M3LMPEXAMPLE hosted zone.
func ExampleRoute53_GetHostedZone_shared00() {
svc := route53.New(session.New())
input := &route53.GetHostedZoneInput{
Id: aws.String("Z3M3LMPEXAMPLE"),
}
result, err := svc.GetHostedZone(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case route53.ErrCodeNoSuchHostedZone:
fmt.Println(route53.ErrCodeNoSuchHostedZone, aerr.Error())
case route53.ErrCodeInvalidInput:
fmt.Println(route53.ErrCodeInvalidInput, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}

View file

@ -1,43 +0,0 @@
package route53
import (
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestUnmarhsalErrorLeak(t *testing.T) {
req := &request.Request{
Operation: &request.Operation{
Name: opChangeResourceRecordSets,
},
HTTPRequest: &http.Request{
Header: make(http.Header),
Body: &awstesting.ReadCloser{Size: 2048},
},
}
req.HTTPResponse = &http.Response{
Body: &awstesting.ReadCloser{Size: 2048},
Header: http.Header{
"X-Amzn-Requestid": []string{"1"},
},
StatusCode: http.StatusOK,
}
reader := req.HTTPResponse.Body.(*awstesting.ReadCloser)
unmarshalChangeResourceRecordSetsError(req)
if req.Error == nil {
t.Error("expected an error, but received none")
}
if !reader.Closed {
t.Error("expected reader to be closed")
}
if e, a := 0, reader.Size; e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}

View file

@ -1,130 +0,0 @@
package route53_test
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/route53"
)
func makeClientWithResponse(response string) *route53.Route53 {
r := route53.New(unit.Session)
r.Handlers.Send.Clear()
r.Handlers.Send.PushBack(func(r *request.Request) {
body := ioutil.NopCloser(bytes.NewReader([]byte(response)))
r.HTTPResponse = &http.Response{
ContentLength: int64(len(response)),
StatusCode: 400,
Status: "Bad Request",
Body: body,
}
})
return r
}
func TestUnmarshalStandardError(t *testing.T) {
const errorResponse = `<?xml version="1.0" encoding="UTF-8"?>
<ErrorResponse xmlns="https://route53.amazonaws.com/doc/2013-04-01/">
<Error>
<Code>InvalidDomainName</Code>
<Message>The domain name is invalid</Message>
</Error>
<RequestId>12345</RequestId>
</ErrorResponse>
`
r := makeClientWithResponse(errorResponse)
_, err := r.CreateHostedZone(&route53.CreateHostedZoneInput{
CallerReference: aws.String("test"),
Name: aws.String("test_zone"),
})
if err == nil {
t.Error("expected error, but received none")
}
if e, a := "InvalidDomainName", err.(awserr.Error).Code(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "The domain name is invalid", err.(awserr.Error).Message(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
}
func TestUnmarshalInvalidChangeBatch(t *testing.T) {
const errorMessage = `
Tried to create resource record set duplicate.example.com. type A,
but it already exists
`
const errorResponse = `<?xml version="1.0" encoding="UTF-8"?>
<InvalidChangeBatch xmlns="https://route53.amazonaws.com/doc/2013-04-01/">
<Messages>
<Message>` + errorMessage + `</Message>
</Messages>
</InvalidChangeBatch>
`
r := makeClientWithResponse(errorResponse)
req := &route53.ChangeResourceRecordSetsInput{
HostedZoneId: aws.String("zoneId"),
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("CREATE"),
ResourceRecordSet: &route53.ResourceRecordSet{
Name: aws.String("domain"),
Type: aws.String("CNAME"),
TTL: aws.Int64(120),
ResourceRecords: []*route53.ResourceRecord{
{
Value: aws.String("cname"),
},
},
},
},
},
},
}
_, err := r.ChangeResourceRecordSets(req)
if err == nil {
t.Error("expected error, but received none")
}
if reqErr, ok := err.(awserr.RequestFailure); ok {
if reqErr == nil {
t.Error("expected error, but received none")
}
if e, a := 400, reqErr.StatusCode(); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
} else {
t.Fatal("returned error is not a RequestFailure")
}
if batchErr, ok := err.(awserr.BatchedErrors); ok {
errs := batchErr.OrigErrs()
if e, a := 1, len(errs); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
if e, a := "InvalidChangeBatch", errs[0].(awserr.Error).Code(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := errorMessage, errs[0].(awserr.Error).Message(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
} else {
t.Fatal("returned error is not a BatchedErrors")
}
}

View file

@ -1,45 +0,0 @@
package sts_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/sts"
)
var svc = sts.New(unit.Session, &aws.Config{
Region: aws.String("mock-region"),
})
func TestUnsignedRequest_AssumeRoleWithSAML(t *testing.T) {
req, _ := svc.AssumeRoleWithSAMLRequest(&sts.AssumeRoleWithSAMLInput{
PrincipalArn: aws.String("ARN01234567890123456789"),
RoleArn: aws.String("ARN01234567890123456789"),
SAMLAssertion: aws.String("ASSERT"),
})
err := req.Sign()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "", req.HTTPRequest.Header.Get("Authorization"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestUnsignedRequest_AssumeRoleWithWebIdentity(t *testing.T) {
req, _ := svc.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{
RoleArn: aws.String("ARN01234567890123456789"),
RoleSessionName: aws.String("SESSION"),
WebIdentityToken: aws.String("TOKEN"),
})
err := req.Sign()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "", req.HTTPRequest.Header.Get("Authorization"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}

View file

@ -1,282 +0,0 @@
// Code generated by private/model/cli/gen-api/main.go. DO NOT EDIT.
package sts_test
import (
"fmt"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
)
var _ time.Duration
var _ strings.Reader
var _ aws.Config
func parseTime(layout, value string) *time.Time {
t, err := time.Parse(layout, value)
if err != nil {
panic(err)
}
return &t
}
// To assume a role
//
func ExampleSTS_AssumeRole_shared00() {
svc := sts.New(session.New())
input := &sts.AssumeRoleInput{
DurationSeconds: aws.Int64(3600),
ExternalId: aws.String("123ABC"),
Policy: aws.String("{\"Version\":\"2012-10-17\",\"Statement\":[{\"Sid\":\"Stmt1\",\"Effect\":\"Allow\",\"Action\":\"s3:*\",\"Resource\":\"*\"}]}"),
RoleArn: aws.String("arn:aws:iam::123456789012:role/demo"),
RoleSessionName: aws.String("Bob"),
}
result, err := svc.AssumeRole(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case sts.ErrCodeMalformedPolicyDocumentException:
fmt.Println(sts.ErrCodeMalformedPolicyDocumentException, aerr.Error())
case sts.ErrCodePackedPolicyTooLargeException:
fmt.Println(sts.ErrCodePackedPolicyTooLargeException, aerr.Error())
case sts.ErrCodeRegionDisabledException:
fmt.Println(sts.ErrCodeRegionDisabledException, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To assume a role as an OpenID Connect-federated user
//
func ExampleSTS_AssumeRoleWithWebIdentity_shared00() {
svc := sts.New(session.New())
input := &sts.AssumeRoleWithWebIdentityInput{
DurationSeconds: aws.Int64(3600),
ProviderId: aws.String("www.amazon.com"),
RoleArn: aws.String("arn:aws:iam::123456789012:role/FederatedWebIdentityRole"),
RoleSessionName: aws.String("app1"),
WebIdentityToken: aws.String("Atza%7CIQEBLjAsAhRFiXuWpUXuRvQ9PZL3GMFcYevydwIUFAHZwXZXXXXXXXXJnrulxKDHwy87oGKPznh0D6bEQZTSCzyoCtL_8S07pLpr0zMbn6w1lfVZKNTBdDansFBmtGnIsIapjI6xKR02Yc_2bQ8LZbUXSGm6Ry6_BG7PrtLZtj_dfCTj92xNGed-CrKqjG7nPBjNIL016GGvuS5gSvPRUxWES3VYfm1wl7WTI7jn-Pcb6M-buCgHhFOzTQxod27L9CqnOLio7N3gZAGpsp6n1-AJBOCJckcyXe2c6uD0srOJeZlKUm2eTDVMf8IehDVI0r1QOnTV6KzzAI3OY87Vd_cVMQ"),
}
result, err := svc.AssumeRoleWithWebIdentity(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case sts.ErrCodeMalformedPolicyDocumentException:
fmt.Println(sts.ErrCodeMalformedPolicyDocumentException, aerr.Error())
case sts.ErrCodePackedPolicyTooLargeException:
fmt.Println(sts.ErrCodePackedPolicyTooLargeException, aerr.Error())
case sts.ErrCodeIDPRejectedClaimException:
fmt.Println(sts.ErrCodeIDPRejectedClaimException, aerr.Error())
case sts.ErrCodeIDPCommunicationErrorException:
fmt.Println(sts.ErrCodeIDPCommunicationErrorException, aerr.Error())
case sts.ErrCodeInvalidIdentityTokenException:
fmt.Println(sts.ErrCodeInvalidIdentityTokenException, aerr.Error())
case sts.ErrCodeExpiredTokenException:
fmt.Println(sts.ErrCodeExpiredTokenException, aerr.Error())
case sts.ErrCodeRegionDisabledException:
fmt.Println(sts.ErrCodeRegionDisabledException, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To decode information about an authorization status of a request
//
func ExampleSTS_DecodeAuthorizationMessage_shared00() {
svc := sts.New(session.New())
input := &sts.DecodeAuthorizationMessageInput{
EncodedMessage: aws.String("<encoded-message>"),
}
result, err := svc.DecodeAuthorizationMessage(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case sts.ErrCodeInvalidAuthorizationMessageException:
fmt.Println(sts.ErrCodeInvalidAuthorizationMessageException, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To get details about a calling IAM user
//
// This example shows a request and response made with the credentials for a user named
// Alice in the AWS account 123456789012.
func ExampleSTS_GetCallerIdentity_shared00() {
svc := sts.New(session.New())
input := &sts.GetCallerIdentityInput{}
result, err := svc.GetCallerIdentity(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To get details about a calling user federated with AssumeRole
//
// This example shows a request and response made with temporary credentials created
// by AssumeRole. The name of the assumed role is my-role-name, and the RoleSessionName
// is set to my-role-session-name.
func ExampleSTS_GetCallerIdentity_shared01() {
svc := sts.New(session.New())
input := &sts.GetCallerIdentityInput{}
result, err := svc.GetCallerIdentity(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To get details about a calling user federated with GetFederationToken
//
// This example shows a request and response made with temporary credentials created
// by using GetFederationToken. The Name parameter is set to my-federated-user-name.
func ExampleSTS_GetCallerIdentity_shared02() {
svc := sts.New(session.New())
input := &sts.GetCallerIdentityInput{}
result, err := svc.GetCallerIdentity(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To get temporary credentials for a role by using GetFederationToken
//
func ExampleSTS_GetFederationToken_shared00() {
svc := sts.New(session.New())
input := &sts.GetFederationTokenInput{
DurationSeconds: aws.Int64(3600),
Name: aws.String("Bob"),
Policy: aws.String("{\"Version\":\"2012-10-17\",\"Statement\":[{\"Sid\":\"Stmt1\",\"Effect\":\"Allow\",\"Action\":\"s3:*\",\"Resource\":\"*\"}]}"),
}
result, err := svc.GetFederationToken(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case sts.ErrCodeMalformedPolicyDocumentException:
fmt.Println(sts.ErrCodeMalformedPolicyDocumentException, aerr.Error())
case sts.ErrCodePackedPolicyTooLargeException:
fmt.Println(sts.ErrCodePackedPolicyTooLargeException, aerr.Error())
case sts.ErrCodeRegionDisabledException:
fmt.Println(sts.ErrCodeRegionDisabledException, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}
// To get temporary credentials for an IAM user or an AWS account
//
func ExampleSTS_GetSessionToken_shared00() {
svc := sts.New(session.New())
input := &sts.GetSessionTokenInput{
DurationSeconds: aws.Int64(3600),
SerialNumber: aws.String("YourMFASerialNumber"),
TokenCode: aws.String("123456"),
}
result, err := svc.GetSessionToken(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case sts.ErrCodeRegionDisabledException:
fmt.Println(sts.ErrCodeRegionDisabledException, aerr.Error())
default:
fmt.Println(aerr.Error())
}
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
}
return
}
fmt.Println(result)
}