Factor CloseNotifier use into a new function

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
This commit is contained in:
Aaron Lehmann 2015-07-29 18:18:50 -07:00
parent 6cb5670ba5
commit 9c58954a6e
3 changed files with 48 additions and 75 deletions

View file

@ -2,7 +2,6 @@ package handlers
import ( import (
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -170,30 +169,8 @@ func (buh *blobUploadHandler) PatchBlobData(w http.ResponseWriter, r *http.Reque
// TODO(dmcgowan): support Content-Range header to seek and write range // TODO(dmcgowan): support Content-Range header to seek and write range
// Get a channel that tells us if the client disconnects if err := copyFullPayload(w, r, buh.Upload, buh, "blob PATCH", &buh.Errors); err != nil {
var clientClosed <-chan bool // copyFullPayload reports the error if necessary
if notifier, ok := w.(http.CloseNotifier); ok {
clientClosed = notifier.CloseNotify()
} else {
panic("the ResponseWriter does not implement CloseNotifier")
}
// Copy the data
copied, err := io.Copy(buh.Upload, r.Body)
if clientClosed != nil && (err != nil || (r.ContentLength > 0 && copied < r.ContentLength)) {
// Didn't recieve as much content as expected. Did the client
// disconnect during the request? If so, avoid returning a 400
// error to keep the logs cleaner.
select {
case <-clientClosed:
ctxu.GetLogger(buh).Error("client disconnected during blob PATCH")
return
default:
}
}
if err != nil {
ctxu.GetLogger(buh).Errorf("unknown error copying into upload: %v", err)
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return return
} }
@ -231,30 +208,8 @@ func (buh *blobUploadHandler) PutBlobUploadComplete(w http.ResponseWriter, r *ht
return return
} }
// Get a channel that tells us if the client disconnects if err := copyFullPayload(w, r, buh.Upload, buh, "blob PUT", &buh.Errors); err != nil {
var clientClosed <-chan bool // copyFullPayload reports the error if necessary
if notifier, ok := w.(http.CloseNotifier); ok {
clientClosed = notifier.CloseNotify()
} else {
panic("the ResponseWriter does not implement CloseNotifier")
}
// Read in the data, if any.
copied, err := io.Copy(buh.Upload, r.Body)
if clientClosed != nil && (err != nil || (r.ContentLength > 0 && copied < r.ContentLength)) {
// Didn't recieve as much content as expected. Did the client
// disconnect during the request? If so, avoid returning a 400
// error to keep the logs cleaner.
select {
case <-clientClosed:
ctxu.GetLogger(buh).Error("client disconnected during blob PUT")
return
default:
}
}
if err != nil {
ctxu.GetLogger(buh).Errorf("unknown error copying into upload: %v", err)
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return return
} }

View file

@ -1,8 +1,12 @@
package handlers package handlers
import ( import (
"errors"
"io" "io"
"net/http" "net/http"
ctxu "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/api/errcode"
) )
// closeResources closes all the provided resources after running the target // closeResources closes all the provided resources after running the target
@ -15,3 +19,38 @@ func closeResources(handler http.Handler, closers ...io.Closer) http.Handler {
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
}) })
} }
// copyFullPayload copies the payload of a HTTP request to destWriter. If it
// receives less content than expected, and the client disconnected during the
// upload, it avoids sending a 400 error to keep the logs cleaner.
func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWriter io.Writer, context ctxu.Context, action string, errSlice *errcode.Errors) error {
// Get a channel that tells us if the client disconnects
var clientClosed <-chan bool
if notifier, ok := responseWriter.(http.CloseNotifier); ok {
clientClosed = notifier.CloseNotify()
} else {
panic("the ResponseWriter does not implement CloseNotifier")
}
// Read in the data, if any.
copied, err := io.Copy(destWriter, r.Body)
if clientClosed != nil && (err != nil || (r.ContentLength > 0 && copied < r.ContentLength)) {
// Didn't recieve as much content as expected. Did the client
// disconnect during the request? If so, avoid returning a 400
// error to keep the logs cleaner.
select {
case <-clientClosed:
ctxu.GetLogger(context).Error("client disconnected during " + action)
return errors.New("client disconnected")
default:
}
}
if err != nil {
ctxu.GetLogger(context).Errorf("unknown error reading request payload: %v", err)
*errSlice = append(*errSlice, errcode.ErrorCodeUnknown.WithDetail(err))
return err
}
return nil
}

View file

@ -1,9 +1,9 @@
package handlers package handlers
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"strings" "strings"
@ -113,35 +113,14 @@ func (imh *imageManifestHandler) PutImageManifest(w http.ResponseWriter, r *http
return return
} }
// Get a channel that tells us if the client disconnects var jsonBuf bytes.Buffer
var clientClosed <-chan bool if err := copyFullPayload(w, r, &jsonBuf, imh, "image manifest PUT", &imh.Errors); err != nil {
if notifier, ok := w.(http.CloseNotifier); ok { // copyFullPayload reports the error if necessary
clientClosed = notifier.CloseNotify()
} else {
panic("the ResponseWriter does not implement CloseNotifier")
}
// Copy the data
jsonBytes, err := ioutil.ReadAll(r.Body)
if clientClosed != nil && (err != nil || (r.ContentLength > 0 && int64(len(jsonBytes)) < r.ContentLength)) {
// Didn't recieve as much content as expected. Did the client
// disconnect during the request? If so, avoid returning a 400
// error to keep the logs cleaner.
select {
case <-clientClosed:
ctxu.GetLogger(imh).Error("client disconnected during image manifest PUT")
return
default:
}
}
if err != nil {
ctxu.GetLogger(imh).Errorf("unknown error reading payload: %v", err)
imh.Errors = append(imh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return return
} }
var manifest manifest.SignedManifest var manifest manifest.SignedManifest
if err := json.Unmarshal(jsonBytes, &manifest); err != nil { if err := json.Unmarshal(jsonBuf.Bytes(), &manifest); err != nil {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestInvalid.WithDetail(err)) imh.Errors = append(imh.Errors, v2.ErrorCodeManifestInvalid.WithDetail(err))
return return
} }