Update github.com/kurin/blazer

This commit is contained in:
Alexander Neumann 2017-10-16 19:43:03 +02:00
parent faadbd734b
commit c87f2420a6
9 changed files with 508 additions and 182 deletions

6
Gopkg.lock generated
View file

@ -88,8 +88,8 @@
[[projects]]
name = "github.com/kurin/blazer"
packages = ["b2","base","internal/b2types","internal/blog"]
revision = "cad56a04490fe20c43548d70a5a9af2be53ff14e"
version = "v0.2.0"
revision = "e269a1a17bb6aec278c06a57cb7e8f8d0d333e04"
version = "v0.2.1"
[[projects]]
branch = "master"
@ -214,6 +214,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "ea711bd1a9bfc8902b973a4de3a840f42536b9091fd8558980f44d6ca1622227"
inputs-digest = "abc33af201086afac21e33a2a7987a473daa6a229c3699ca13761f4d4fd7f52e"
solver-name = "gps-cdcl"
solver-version = 1

View file

@ -72,3 +72,7 @@
[[constraint]]
branch = "master"
name = "golang.org/x/sys"
[[constraint]]
name = "github.com/kurin/blazer"
branch = "master"

View file

@ -8,4 +8,6 @@ branches:
- master
before_script: go run internal/bin/cleanup/cleanup.go
script: go test -v ./base ./b2 ./x/...
script:
- go test -v ./base ./b2 ./x/...
- go vet -v ./base ./b2 ./x/...

View file

@ -27,6 +27,8 @@ import (
"sync/atomic"
"testing"
"time"
"github.com/kurin/blazer/x/transport"
)
const (
@ -744,10 +746,10 @@ func (rt *rtCounter) RoundTrip(r *http.Request) (*http.Response, error) {
}
func TestAttrsNoRoundtrip(t *testing.T) {
rt := &rtCounter{rt: transport}
transport = rt
rt := &rtCounter{rt: defaultTransport}
defaultTransport = rt
defer func() {
transport = rt.rt
defaultTransport = rt.rt
}()
ctx := context.Background()
@ -767,7 +769,7 @@ func TestAttrsNoRoundtrip(t *testing.T) {
t.Fatal(err)
}
if len(objs) != 1 {
t.Fatal("unexpected objects: got %d, want 1", len(objs))
t.Fatalf("unexpected objects: got %d, want 1", len(objs))
}
trips := rt.trips
@ -842,7 +844,7 @@ func listObjects(ctx context.Context, f func(context.Context, int, *Cursor) ([]*
return ch
}
var transport = http.DefaultTransport
var defaultTransport = http.DefaultTransport
type eofTripper struct {
rt http.RoundTripper
@ -919,9 +921,10 @@ func startLiveTest(ctx context.Context, t *testing.T) (*Bucket, func()) {
t.Skipf("B2_ACCOUNT_ID or B2_SECRET_KEY unset; skipping integration tests")
return nil, nil
}
ccport := &ccTripper{rt: transport, t: t}
ccport := &ccTripper{rt: defaultTransport, t: t}
tport := eofTripper{rt: ccport, t: t}
client, err := NewClient(ctx, id, key, FailSomeUploads(), ExpireSomeAuthTokens(), Transport(tport), UserAgent("b2-test"), UserAgent("integration-test"))
errport := transport.WithFailures(tport, transport.FailureRate(.25), transport.MatchPathSubstring("/b2_get_upload_url"), transport.Response(503))
client, err := NewClient(ctx, id, key, FailSomeUploads(), ExpireSomeAuthTokens(), Transport(errport), UserAgent("b2-test"), UserAgent("integration-test"))
if err != nil {
t.Fatal(err)
return nil, nil

View file

@ -43,7 +43,7 @@ import (
const (
APIBase = "https://api.backblazeb2.com"
DefaultUserAgent = "blazer/0.1.1"
DefaultUserAgent = "blazer/0.2.1"
)
type b2err struct {
@ -131,25 +131,32 @@ const (
func mkErr(resp *http.Response) error {
data, err := ioutil.ReadAll(resp.Body)
var msgBody string
if err != nil {
return err
msgBody = fmt.Sprintf("couldn't read message body: %v", err)
}
logResponse(resp, data)
msg := &b2types.ErrorMessage{}
if err := json.Unmarshal(data, msg); err != nil {
return err
if msgBody != "" {
msgBody = fmt.Sprintf("couldn't read message body: %v", err)
}
}
if msgBody == "" {
msgBody = msg.Msg
}
var retryAfter int
retry := resp.Header.Get("Retry-After")
if retry != "" {
r, err := strconv.ParseInt(retry, 10, 64)
if err != nil {
return err
r = 0
blog.V(1).Infof("couldn't parse retry-after header %q: %v", retry, err)
}
retryAfter = int(r)
}
return b2err{
msg: msg.Msg,
msg: msgBody,
retry: retryAfter,
code: resp.StatusCode,
method: resp.Request.Header.Get("X-Blazer-Method"),
@ -222,6 +229,19 @@ type b2Options struct {
userAgent string
}
func (o *b2Options) addHeaders(req *http.Request) {
if o.failSomeUploads {
req.Header.Add("X-Bz-Test-Mode", "fail_some_uploads")
}
if o.expireTokens {
req.Header.Add("X-Bz-Test-Mode", "expire_some_account_authorization_tokens")
}
if o.capExceeded {
req.Header.Add("X-Bz-Test-Mode", "force_cap_exceeded")
}
req.Header.Set("User-Agent", o.getUserAgent())
}
func (o *b2Options) getAPIBase() string {
if o.apiBase != "" {
return o.apiBase
@ -268,14 +288,22 @@ type httpReply struct {
err error
}
func makeNetRequest(req *http.Request, rt http.RoundTripper) <-chan httpReply {
ch := make(chan httpReply)
go func() {
func makeNetRequest(ctx context.Context, req *http.Request, rt http.RoundTripper) (*http.Response, error) {
req = req.WithContext(ctx)
resp, err := rt.RoundTrip(req)
ch <- httpReply{resp, err}
close(ch)
}()
return ch
switch err {
case nil:
return resp, nil
case context.Canceled, context.DeadlineExceeded:
return nil, err
default:
method := req.Header.Get("X-Blazer-Method")
blog.V(2).Infof(">> %s uri: %v err: %v", method, req.URL, err)
return nil, b2err{
msg: err.Error(),
retry: 1,
}
}
}
type requestBody struct {
@ -351,38 +379,14 @@ func (o *b2Options) makeRequest(ctx context.Context, method, verb, uri string, b
}
req.Header.Set(k, v)
}
req.Header.Set("User-Agent", o.getUserAgent())
req.Header.Set("X-Blazer-Request-ID", fmt.Sprintf("%d", atomic.AddInt64(&reqID, 1)))
req.Header.Set("X-Blazer-Method", method)
if o.failSomeUploads {
req.Header.Add("X-Bz-Test-Mode", "fail_some_uploads")
}
if o.expireTokens {
req.Header.Add("X-Bz-Test-Mode", "expire_some_account_authorization_tokens")
}
if o.capExceeded {
req.Header.Add("X-Bz-Test-Mode", "force_cap_exceeded")
}
cancel := make(chan struct{})
req.Cancel = cancel
o.addHeaders(req)
logRequest(req, args)
ch := makeNetRequest(req, o.getTransport())
var reply httpReply
select {
case reply = <-ch:
case <-ctx.Done():
close(cancel)
return ctx.Err()
resp, err := makeNetRequest(ctx, req, o.getTransport())
if err != nil {
return err
}
if reply.err != nil {
// Connection errors are retryable.
blog.V(2).Infof(">> %s uri: %v err: %v", method, req.URL, reply.err)
return b2err{
msg: reply.err.Error(),
retry: 1,
}
}
resp := reply.resp
defer resp.Body.Close()
if resp.StatusCode != 200 {
return mkErr(resp)
@ -397,10 +401,11 @@ func (o *b2Options) makeRequest(ctx context.Context, method, verb, uri string, b
}
replyArgs = rbuf.Bytes()
} else {
replyArgs, err = ioutil.ReadAll(resp.Body)
ra, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
blog.V(1).Infof("%s: couldn't read response: %v", method, err)
}
replyArgs = ra
}
logResponse(resp, replyArgs)
return nil
@ -1038,7 +1043,7 @@ func mkRange(offset, size int64) string {
// DownloadFileByName wraps b2_download_file_by_name.
func (b *Bucket) DownloadFileByName(ctx context.Context, name string, offset, size int64) (*FileReader, error) {
uri := fmt.Sprintf("%s/file/%s/%s", b.b2.downloadURI, b.Name, name)
uri := fmt.Sprintf("%s/file/%s/%s", b.b2.downloadURI, b.Name, escape(name))
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
return nil, err
@ -1046,25 +1051,16 @@ func (b *Bucket) DownloadFileByName(ctx context.Context, name string, offset, si
req.Header.Set("Authorization", b.b2.authToken)
req.Header.Set("X-Blazer-Request-ID", fmt.Sprintf("%d", atomic.AddInt64(&reqID, 1)))
req.Header.Set("X-Blazer-Method", "b2_download_file_by_name")
b.b2.opts.addHeaders(req)
rng := mkRange(offset, size)
if rng != "" {
req.Header.Set("Range", rng)
}
cancel := make(chan struct{})
req.Cancel = cancel
logRequest(req, nil)
ch := makeNetRequest(req, b.b2.opts.getTransport())
var reply httpReply
select {
case reply = <-ch:
case <-ctx.Done():
close(cancel)
return nil, ctx.Err()
resp, err := makeNetRequest(ctx, req, b.b2.opts.getTransport())
if err != nil {
return nil, err
}
if reply.err != nil {
return nil, reply.err
}
resp := reply.resp
logResponse(resp, nil)
if resp.StatusCode != 200 && resp.StatusCode != 206 {
defer resp.Body.Close()

View file

@ -20,14 +20,14 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"reflect"
"strings"
"testing"
"time"
"github.com/kurin/blazer/x/transport"
"context"
)
@ -270,48 +270,7 @@ func TestStorage(t *testing.T) {
}
}
// This slow motion train wreck of a type exists to axe a net connection after
// N bytes have been written. Because of the specific bug it's built to test,
// it can't just *close* the connection, so it just sleeps forever.
type wonkyNetConn struct {
net.Conn
ctx context.Context // implode once cancelled
die *bool // only implode once
n int // bytes to allow before imploding, roughly
i int // bytes written
}
func (w *wonkyNetConn) Write(b []byte) (int, error) {
if w.i > w.n && w.ctx.Err() != nil && *w.die {
*w.die = false
select {}
}
n, err := w.Conn.Write(b)
w.i += n
return n, err
}
func newWonkyNetConn(ctx context.Context, die *bool, n int, netw, addr string) (net.Conn, error) {
conn, err := net.Dial(netw, addr)
if err != nil {
return nil, err
}
return &wonkyNetConn{
Conn: conn,
ctx: ctx,
n: n,
die: die,
}, nil
}
func makeBadDialContext(ctx context.Context) func(context.Context, string, string) (net.Conn, error) {
die := true
return func(noCtx context.Context, network, addr string) (net.Conn, error) {
return newWonkyNetConn(ctx, &die, 10000, network, addr)
}
}
func TestBadUpload(t *testing.T) {
func TestUploadAuthAfterConnectionHang(t *testing.T) {
id := os.Getenv(apiID)
key := os.Getenv(apiKey)
if id == "" || key == "" {
@ -319,19 +278,16 @@ func TestBadUpload(t *testing.T) {
}
ctx := context.Background()
octx, ocancel := context.WithCancel(ctx)
defer ocancel()
hung := make(chan struct{})
badTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: makeBadDialContext(octx),
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
// An http.RoundTripper that dies after sending ~10k bytes.
hang := func() {
close(hung)
select {}
}
tport := transport.WithFailures(nil, transport.AfterNBytes(10000, hang))
b2, err := AuthorizeAccount(ctx, id, key, Transport(badTransport))
b2, err := AuthorizeAccount(ctx, id, key, Transport(tport))
if err != nil {
t.Fatal(err)
}
@ -358,13 +314,13 @@ func TestBadUpload(t *testing.T) {
t.Error(err)
}
smallSHA1 := fmt.Sprintf("%x", hash.Sum(nil))
ocancel()
go func() {
ue.UploadFile(ctx, buf, buf.Len(), smallFileName, "application/octet-stream", smallSHA1, nil)
t.Fatal("this ought not to be reachable")
}()
time.Sleep(time.Second) // give this a chance to hang
<-hung
// Do the whole thing again with the same upload auth, before the remote end
// notices we're gone.
@ -381,7 +337,97 @@ func TestBadUpload(t *testing.T) {
}
}
if Action(err) != AttemptNewUpload {
t.Error("Action(%v): got %v, want AttemptNewUpload", err, Action(err))
t.Errorf("Action(%v): got %v, want AttemptNewUpload", err, Action(err))
}
}
func TestCancelledContextCancelsHTTPRequest(t *testing.T) {
id := os.Getenv(apiID)
key := os.Getenv(apiKey)
if id == "" || key == "" {
t.Skipf("B2_ACCOUNT_ID or B2_SECRET_KEY unset; skipping integration tests")
}
ctx := context.Background()
tport := transport.WithFailures(nil, transport.MatchPathSubstring("b2_upload_file"), transport.FailureRate(1), transport.Stall(2*time.Second))
b2, err := AuthorizeAccount(ctx, id, key, Transport(tport))
if err != nil {
t.Fatal(err)
}
bname := id + "-" + bucketName
bucket, err := b2.CreateBucket(ctx, bname, "", nil, nil)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := bucket.DeleteBucket(ctx); err != nil {
t.Error(err)
}
}()
ue, err := bucket.GetUploadURL(ctx)
if err != nil {
t.Fatal(err)
}
smallFile := io.LimitReader(zReader{}, 1024*50) // 50k
hash := sha1.New()
buf := &bytes.Buffer{}
w := io.MultiWriter(hash, buf)
if _, err := io.Copy(w, smallFile); err != nil {
t.Error(err)
}
smallSHA1 := fmt.Sprintf("%x", hash.Sum(nil))
cctx, cancel := context.WithCancel(ctx)
go func() {
time.Sleep(1)
cancel()
}()
if _, err := ue.UploadFile(cctx, buf, buf.Len(), smallFileName, "application/octet-stream", smallSHA1, nil); err != context.Canceled {
t.Errorf("expected canceled context, but got %v", err)
}
}
func TestDeadlineExceededContextCancelsHTTPRequest(t *testing.T) {
id := os.Getenv(apiID)
key := os.Getenv(apiKey)
if id == "" || key == "" {
t.Skipf("B2_ACCOUNT_ID or B2_SECRET_KEY unset; skipping integration tests")
}
ctx := context.Background()
tport := transport.WithFailures(nil, transport.MatchPathSubstring("b2_upload_file"), transport.FailureRate(1), transport.Stall(2*time.Second))
b2, err := AuthorizeAccount(ctx, id, key, Transport(tport))
if err != nil {
t.Fatal(err)
}
bname := id + "-" + bucketName
bucket, err := b2.CreateBucket(ctx, bname, "", nil, nil)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := bucket.DeleteBucket(ctx); err != nil {
t.Error(err)
}
}()
ue, err := bucket.GetUploadURL(ctx)
if err != nil {
t.Fatal(err)
}
smallFile := io.LimitReader(zReader{}, 1024*50) // 50k
hash := sha1.New()
buf := &bytes.Buffer{}
w := io.MultiWriter(hash, buf)
if _, err := io.Copy(w, smallFile); err != nil {
t.Error(err)
}
smallSHA1 := fmt.Sprintf("%x", hash.Sum(nil))
cctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
if _, err := ue.UploadFile(cctx, buf, buf.Len(), smallFileName, "application/octet-stream", smallSHA1, nil); err != context.DeadlineExceeded {
t.Errorf("expected deadline exceeded error, but got %v", err)
}
}
@ -533,3 +579,72 @@ func TestEscapes(t *testing.T) {
}
}
}
func TestUploadDownloadFilenameEscaping(t *testing.T) {
filename := "file%foo.txt"
id := os.Getenv(apiID)
key := os.Getenv(apiKey)
if id == "" || key == "" {
t.Skipf("B2_ACCOUNT_ID or B2_SECRET_KEY unset; skipping integration tests")
}
ctx := context.Background()
// b2_authorize_account
b2, err := AuthorizeAccount(ctx, id, key, UserAgent("blazer-base-test"))
if err != nil {
t.Fatal(err)
}
// b2_create_bucket
bname := id + "-" + bucketName
bucket, err := b2.CreateBucket(ctx, bname, "", nil, nil)
if err != nil {
t.Fatal(err)
}
defer func() {
// b2_delete_bucket
if err := bucket.DeleteBucket(ctx); err != nil {
t.Error(err)
}
}()
// b2_get_upload_url
ue, err := bucket.GetUploadURL(ctx)
if err != nil {
t.Fatal(err)
}
// b2_upload_file
smallFile := io.LimitReader(zReader{}, 128)
hash := sha1.New()
buf := &bytes.Buffer{}
w := io.MultiWriter(hash, buf)
if _, err := io.Copy(w, smallFile); err != nil {
t.Error(err)
}
smallSHA1 := fmt.Sprintf("%x", hash.Sum(nil))
file, err := ue.UploadFile(ctx, buf, buf.Len(), filename, "application/octet-stream", smallSHA1, nil)
if err != nil {
t.Fatal(err)
}
defer func() {
// b2_delete_file_version
if err := file.DeleteFileVersion(ctx); err != nil {
t.Error(err)
}
}()
// b2_download_file_by_name
fr, err := bucket.DownloadFileByName(ctx, filename, 0, 0)
if err != nil {
t.Fatal(err)
}
lbuf := &bytes.Buffer{}
if _, err := io.Copy(lbuf, fr); err != nil {
t.Fatal(err)
}
}

View file

@ -15,67 +15,14 @@
package base
import (
"bytes"
"errors"
"fmt"
"net/url"
"strings"
)
func noEscape(c byte) bool {
switch c {
case '.', '_', '-', '/', '~', '!', '$', '\'', '(', ')', '*', ';', '=', ':', '@':
return true
}
return false
}
func escape(s string) string {
// cribbed from url.go, kinda
b := &bytes.Buffer{}
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case c == '/':
b.WriteByte(c)
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9':
b.WriteByte(c)
case noEscape(c):
b.WriteByte(c)
default:
fmt.Fprintf(b, "%%%X", c)
}
}
return b.String()
return strings.Replace(url.QueryEscape(s), "%2F", "/", -1)
}
func unescape(s string) (string, error) {
b := &bytes.Buffer{}
for i := 0; i < len(s); i++ {
c := s[i]
switch c {
case '/':
b.WriteString("/")
case '+':
b.WriteString(" ")
case '%':
if len(s)-i < 3 {
return "", errors.New("unescape: bad encoding")
}
b.WriteByte(unhex(s[i+1])<<4 | unhex(s[i+2]))
i += 2
default:
b.WriteByte(c)
}
}
return b.String(), nil
}
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
return url.QueryUnescape(s)
}

52
vendor/github.com/kurin/blazer/base/strings_test.go generated vendored Normal file
View file

@ -0,0 +1,52 @@
package base
import (
"fmt"
"testing"
)
func TestEncodeDecode(t *testing.T) {
// crashes identified by go-fuzz
origs := []string{
"&\x020000",
"&\x020000\x9c",
"&\x020\x9c0",
"&\x0230j",
"&\x02\x98000",
"&\x02\x983\xc8j00",
"00\x000",
"00\x0000",
"00\x0000000000000",
"\x11\x030",
}
for _, orig := range origs {
escaped := escape(orig)
unescaped, err := unescape(escaped)
if err != nil {
t.Errorf("%s: orig: %#v, escaped: %#v, unescaped: %#v\n", err.Error(), orig, escaped, unescaped)
continue
}
if unescaped != orig {
t.Errorf("expected: %#v, got: %#v", orig, unescaped)
}
}
}
// hook for go-fuzz: https://github.com/dvyukov/go-fuzz
func Fuzz(data []byte) int {
orig := string(data)
escaped := escape(orig)
unescaped, err := unescape(escaped)
if err != nil {
return 0
}
if unescaped != orig {
panic(fmt.Sprintf("unescaped: \"%#v\", != orig: \"%#v\"", unescaped, orig))
}
return 1
}

207
vendor/github.com/kurin/blazer/x/transport/transport.go generated vendored Normal file
View file

@ -0,0 +1,207 @@
// Copyright 2017, Google
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package transport provides http.RoundTrippers that may be useful to clients
// of Blazer.
package transport
import (
"context"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"strings"
"sync/atomic"
"time"
)
// WithFailures returns an http.RoundTripper that wraps an existing
// RoundTripper, causing failures according to the options given. If rt is
// nil, the http.DefaultTransport is wrapped.
func WithFailures(rt http.RoundTripper, opts ...FailureOption) http.RoundTripper {
if rt == nil {
rt = http.DefaultTransport
}
o := &options{
rt: rt,
}
for _, opt := range opts {
opt(o)
}
return o
}
type options struct {
pathSubstrings []string
failureRate float64
status int
stall time.Duration
rt http.RoundTripper
msg string
trg *triggerReaderGroup
}
func (o *options) doRequest(req *http.Request) (*http.Response, error) {
if o.trg != nil && req.Body != nil {
req.Body = o.trg.new(req.Body)
}
resp, err := o.rt.RoundTrip(req)
if resp != nil && o.trg != nil {
resp.Body = o.trg.new(resp.Body)
}
return resp, err
}
func (o *options) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: fix triggering conditions
if rand.Float64() > o.failureRate {
return o.doRequest(req)
}
var match bool
if len(o.pathSubstrings) == 0 {
match = true
}
for _, ss := range o.pathSubstrings {
if strings.Contains(req.URL.Path, ss) {
match = true
break
}
}
if !match {
return o.doRequest(req)
}
if o.status > 0 {
resp := &http.Response{
Status: fmt.Sprintf("%d %s", o.status, http.StatusText(o.status)),
StatusCode: o.status,
Body: ioutil.NopCloser(strings.NewReader(o.msg)),
Request: req,
}
return resp, nil
}
if o.stall > 0 {
ctx := req.Context()
select {
case <-time.After(o.stall):
case <-ctx.Done():
}
}
return o.doRequest(req)
}
// A FailureOption specifies the kind of failure that the RoundTripper should
// display.
type FailureOption func(*options)
// MatchPathSubstring restricts the RoundTripper to URLs whose paths contain
// the given string. The default behavior is to match all paths.
func MatchPathSubstring(s string) FailureOption {
return func(o *options) {
o.pathSubstrings = append(o.pathSubstrings, s)
}
}
// FailureRate causes the RoundTripper to fail a certain percentage of the
// time. rate should be a number between 0 and 1, where 0 will never fail and
// 1 will always fail. The default is never to fail.
func FailureRate(rate float64) FailureOption {
return func(o *options) {
o.failureRate = rate
}
}
// Response simulates a given status code. The returned http.Response will
// have its Status, StatusCode, and Body (with any predefined message) set.
func Response(status int) FailureOption {
return func(o *options) {
o.status = status
}
}
// Stall simulates a network connection failure by stalling for the given
// duration.
func Stall(dur time.Duration) FailureOption {
return func(o *options) {
o.stall = dur
}
}
// If a specific Response is requested, the body will have the given message
// set.
func Body(msg string) FailureOption {
return func(o *options) {
o.msg = msg
}
}
// Trigger will raise the RoundTripper's failure rate to 100% when the given
// context is closed.
func Trigger(ctx context.Context) FailureOption {
return func(o *options) {
go func() {
<-ctx.Done()
o.failureRate = 1
}()
}
}
// AfterNBytes will call effect once (roughly) n bytes have gone over the wire.
// Both sent and received bytes are counted against the total. Only bytes in
// the body of an HTTP request are currently counted; this may change in the
// future. effect will only be called once, and it will block (allowing
// callers to simulate connection hangs).
func AfterNBytes(n int, effect func()) FailureOption {
return func(o *options) {
o.trg = &triggerReaderGroup{
bytes: int64(n),
trigger: effect,
}
}
}
type triggerReaderGroup struct {
bytes int64
trigger func()
triggered int64
}
func (rg *triggerReaderGroup) new(rc io.ReadCloser) io.ReadCloser {
return &triggerReader{
ReadCloser: rc,
bytes: &rg.bytes,
trigger: rg.trigger,
triggered: &rg.triggered,
}
}
type triggerReader struct {
io.ReadCloser
bytes *int64
trigger func()
triggered *int64
}
func (r *triggerReader) Read(p []byte) (int, error) {
n, err := r.ReadCloser.Read(p)
if atomic.AddInt64(r.bytes, -int64(n)) < 0 && atomic.CompareAndSwapInt64(r.triggered, 0, 1) {
// Can't use sync.Once because it blocks for *all* callers until Do returns.
r.trigger()
}
return n, err
}