diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 45c1c5ac2..2a2299c5c 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -22,6 +22,11 @@ "ImportPath": "github.com/AdRoll/goamz/s3", "Rev": "d3664b76d90508cdda5a6c92042f26eab5db3103" }, + { + "ImportPath": "github.com/MSOpenTech/azure-sdk-for-go/clients/storage", + "Comment": "v1.1-119-g0fbd371", + "Rev": "0fbd37144de3adc2aef74db867c0e15e41c7f74a" + }, { "ImportPath": "github.com/Sirupsen/logrus", "Comment": "v0.6.4-12-g467d9d5", diff --git a/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/blob.go b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/blob.go new file mode 100644 index 000000000..f451e51fe --- /dev/null +++ b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/blob.go @@ -0,0 +1,884 @@ +package storage + +import ( + "bytes" + "encoding/base64" + "encoding/xml" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +type BlobStorageClient struct { + client StorageClient +} + +// A Container is an entry in ContainerListResponse. +type Container struct { + Name string `xml:"Name"` + Properties ContainerProperties `xml:"Properties"` + // TODO (ahmetalpbalkan) Metadata +} + +// ContainerProperties contains various properties of a +// container returned from various endpoints like ListContainers. +type ContainerProperties struct { + LastModified string `xml:"Last-Modified"` + Etag string `xml:"Etag"` + LeaseStatus string `xml:"LeaseStatus"` + LeaseState string `xml:"LeaseState"` + LeaseDuration string `xml:"LeaseDuration"` + // TODO (ahmetalpbalkan) remaining fields +} + +// ContainerListResponse contains the response fields from +// ListContainers call. https://msdn.microsoft.com/en-us/library/azure/dd179352.aspx +type ContainerListResponse struct { + XMLName xml.Name `xml:"EnumerationResults"` + Xmlns string `xml:"xmlns,attr"` + Prefix string `xml:"Prefix"` + Marker string `xml:"Marker"` + NextMarker string `xml:"NextMarker"` + MaxResults int64 `xml:"MaxResults"` + Containers []Container `xml:"Containers>Container"` +} + +// A Blob is an entry in BlobListResponse. +type Blob struct { + Name string `xml:"Name"` + Properties BlobProperties `xml:"Properties"` + // TODO (ahmetalpbalkan) Metadata +} + +// BlobProperties contains various properties of a blob +// returned in various endpoints like ListBlobs or GetBlobProperties. +type BlobProperties struct { + LastModified string `xml:"Last-Modified"` + Etag string `xml:"Etag"` + ContentMD5 string `xml:"Content-MD5"` + ContentLength int64 `xml:"Content-Length"` + ContentType string `xml:"Content-Type"` + ContentEncoding string `xml:"Content-Encoding"` + BlobType BlobType `xml:"x-ms-blob-blob-type"` + SequenceNumber int64 `xml:"x-ms-blob-sequence-number"` + CopyId string `xml:"CopyId"` + CopyStatus string `xml:"CopyStatus"` + CopySource string `xml:"CopySource"` + CopyProgress string `xml:"CopyProgress"` + CopyCompletionTime string `xml:"CopyCompletionTime"` + CopyStatusDescription string `xml:"CopyStatusDescription"` +} + +// BlobListResponse contains the response fields from +// ListBlobs call. https://msdn.microsoft.com/en-us/library/azure/dd135734.aspx +type BlobListResponse struct { + XMLName xml.Name `xml:"EnumerationResults"` + Xmlns string `xml:"xmlns,attr"` + Prefix string `xml:"Prefix"` + Marker string `xml:"Marker"` + NextMarker string `xml:"NextMarker"` + MaxResults int64 `xml:"MaxResults"` + Blobs []Blob `xml:"Blobs>Blob"` +} + +// ListContainersParameters defines the set of customizable +// parameters to make a List Containers call. https://msdn.microsoft.com/en-us/library/azure/dd179352.aspx +type ListContainersParameters struct { + Prefix string + Marker string + Include string + MaxResults uint + Timeout uint +} + +func (p ListContainersParameters) getParameters() url.Values { + out := url.Values{} + + if p.Prefix != "" { + out.Set("prefix", p.Prefix) + } + if p.Marker != "" { + out.Set("marker", p.Marker) + } + if p.Include != "" { + out.Set("include", p.Include) + } + if p.MaxResults != 0 { + out.Set("maxresults", fmt.Sprintf("%v", p.MaxResults)) + } + if p.Timeout != 0 { + out.Set("timeout", fmt.Sprintf("%v", p.Timeout)) + } + + return out +} + +// ListBlobsParameters defines the set of customizable +// parameters to make a List Blobs call. https://msdn.microsoft.com/en-us/library/azure/dd135734.aspx +type ListBlobsParameters struct { + Prefix string + Delimiter string + Marker string + Include string + MaxResults uint + Timeout uint +} + +func (p ListBlobsParameters) getParameters() url.Values { + out := url.Values{} + + if p.Prefix != "" { + out.Set("prefix", p.Prefix) + } + if p.Delimiter != "" { + out.Set("delimiter", p.Delimiter) + } + if p.Marker != "" { + out.Set("marker", p.Marker) + } + if p.Include != "" { + out.Set("include", p.Include) + } + if p.MaxResults != 0 { + out.Set("maxresults", fmt.Sprintf("%v", p.MaxResults)) + } + if p.Timeout != 0 { + out.Set("timeout", fmt.Sprintf("%v", p.Timeout)) + } + + return out +} + +// BlobType defines the type of the Azure Blob. +type BlobType string + +const ( + BlobTypeBlock BlobType = "BlockBlob" + BlobTypePage BlobType = "PageBlob" +) + +// PageWriteType defines the type updates that are going to be +// done on the page blob. +type PageWriteType string + +const ( + PageWriteTypeUpdate PageWriteType = "update" + PageWriteTypeClear PageWriteType = "clear" +) + +const ( + blobCopyStatusPending = "pending" + blobCopyStatusSuccess = "success" + blobCopyStatusAborted = "aborted" + blobCopyStatusFailed = "failed" +) + +// BlockListType is used to filter out types of blocks +// in a Get Blocks List call for a block blob. See +// https://msdn.microsoft.com/en-us/library/azure/dd179400.aspx +// for all block types. +type BlockListType string + +const ( + BlockListTypeAll BlockListType = "all" + BlockListTypeCommitted BlockListType = "committed" + BlockListTypeUncommitted BlockListType = "uncommitted" +) + +// ContainerAccessType defines the access level to the container +// from a public request. See https://msdn.microsoft.com/en-us/library/azure/dd179468.aspx +// and "x-ms-blob-public-access" header. +type ContainerAccessType string + +const ( + ContainerAccessTypePrivate ContainerAccessType = "" + ContainerAccessTypeBlob ContainerAccessType = "blob" + ContainerAccessTypeContainer ContainerAccessType = "container" +) + +const ( + MaxBlobBlockSize = 4 * 1024 * 1024 + MaxBlobPageSize = 4 * 1024 * 1024 +) + +// BlockStatus defines states a block for a block blob can +// be in. +type BlockStatus string + +const ( + BlockStatusUncommitted BlockStatus = "Uncommitted" + BlockStatusCommitted BlockStatus = "Committed" + BlockStatusLatest BlockStatus = "Latest" +) + +// Block is used to create Block entities for Put Block List +// call. +type Block struct { + Id string + Status BlockStatus +} + +// BlockListResponse contains the response fields from +// Get Block List call. https://msdn.microsoft.com/en-us/library/azure/dd179400.aspx +type BlockListResponse struct { + XMLName xml.Name `xml:"BlockList"` + CommittedBlocks []BlockResponse `xml:"CommittedBlocks>Block"` + UncommittedBlocks []BlockResponse `xml:"UncommittedBlocks>Block"` +} + +// BlockResponse contains the block information returned +// in the GetBlockListCall. +type BlockResponse struct { + Name string `xml:"Name"` + Size int64 `xml:"Size"` +} + +// GetPageRangesResponse contains the reponse fields from +// Get Page Ranges call. https://msdn.microsoft.com/en-us/library/azure/ee691973.aspx +type GetPageRangesResponse struct { + XMLName xml.Name `xml:"PageList"` + PageList []PageRange `xml:"PageRange"` +} + +// PageRange contains information about a page of a page blob from +// Get Pages Range call. https://msdn.microsoft.com/en-us/library/azure/ee691973.aspx +type PageRange struct { + Start int64 `xml:"Start"` + End int64 `xml:"End"` +} + +var ( + ErrNotCreated = errors.New("storage: operation has returned a successful error code other than 201 Created.") + ErrNotAccepted = errors.New("storage: operation has returned a successful error code other than 202 Accepted.") + + errBlobCopyAborted = errors.New("storage: blob copy is aborted") + errBlobCopyIdMismatch = errors.New("storage: blob copy id is a mismatch") +) + +const errUnexpectedStatus = "storage: was expecting status code: %d, got: %d" + +// ListContainers returns the list of containers in a storage account along with +// pagination token and other response details. See https://msdn.microsoft.com/en-us/library/azure/dd179352.aspx +func (b BlobStorageClient) ListContainers(params ListContainersParameters) (ContainerListResponse, error) { + q := mergeParams(params.getParameters(), url.Values{"comp": {"list"}}) + uri := b.client.getEndpoint(blobServiceName, "", q) + headers := b.client.getStandardHeaders() + + var out ContainerListResponse + resp, err := b.client.exec("GET", uri, headers, nil) + if err != nil { + return out, err + } + + err = xmlUnmarshal(resp.body, &out) + return out, err +} + +// CreateContainer creates a blob container within the storage account +// with given name and access level. See https://msdn.microsoft.com/en-us/library/azure/dd179468.aspx +// Returns error if container already exists. +func (b BlobStorageClient) CreateContainer(name string, access ContainerAccessType) error { + resp, err := b.createContainer(name, access) + if err != nil { + return err + } + if resp.statusCode != http.StatusCreated { + return ErrNotCreated + } + return nil +} + +// CreateContainerIfNotExists creates a blob container if it does not exist. Returns +// true if container is newly created or false if container already exists. +func (b BlobStorageClient) CreateContainerIfNotExists(name string, access ContainerAccessType) (bool, error) { + resp, err := b.createContainer(name, access) + if resp != nil && (resp.statusCode == http.StatusCreated || resp.statusCode == http.StatusConflict) { + return resp.statusCode == http.StatusCreated, nil + } + return false, err +} + +func (b BlobStorageClient) createContainer(name string, access ContainerAccessType) (*storageResponse, error) { + verb := "PUT" + uri := b.client.getEndpoint(blobServiceName, pathForContainer(name), url.Values{"restype": {"container"}}) + + headers := b.client.getStandardHeaders() + headers["Content-Length"] = "0" + if access != "" { + headers["x-ms-blob-public-access"] = string(access) + } + return b.client.exec(verb, uri, headers, nil) +} + +// ContainerExists returns true if a container with given name exists +// on the storage account, otherwise returns false. +func (b BlobStorageClient) ContainerExists(name string) (bool, error) { + verb := "HEAD" + uri := b.client.getEndpoint(blobServiceName, pathForContainer(name), url.Values{"restype": {"container"}}) + headers := b.client.getStandardHeaders() + + resp, err := b.client.exec(verb, uri, headers, nil) + if resp != nil && (resp.statusCode == http.StatusOK || resp.statusCode == http.StatusNotFound) { + return resp.statusCode == http.StatusOK, nil + } + return false, err +} + +// DeleteContainer deletes the container with given name on the storage +// account. See https://msdn.microsoft.com/en-us/library/azure/dd179408.aspx +// If the container does not exist returns error. +func (b BlobStorageClient) DeleteContainer(name string) error { + resp, err := b.deleteContainer(name) + if err != nil { + return err + } + if resp.statusCode != http.StatusAccepted { + return ErrNotAccepted + } + return nil +} + +// DeleteContainer deletes the container with given name on the storage +// account if it exists. See https://msdn.microsoft.com/en-us/library/azure/dd179408.aspx +// Returns true if container is deleted with this call, or false +// if the container did not exist at the time of the Delete Container operation. +func (b BlobStorageClient) DeleteContainerIfExists(name string) (bool, error) { + resp, err := b.deleteContainer(name) + if resp != nil && (resp.statusCode == http.StatusAccepted || resp.statusCode == http.StatusNotFound) { + return resp.statusCode == http.StatusAccepted, nil + } + return false, err +} + +func (b BlobStorageClient) deleteContainer(name string) (*storageResponse, error) { + verb := "DELETE" + uri := b.client.getEndpoint(blobServiceName, pathForContainer(name), url.Values{"restype": {"container"}}) + + headers := b.client.getStandardHeaders() + return b.client.exec(verb, uri, headers, nil) +} + +// ListBlobs returns an object that contains list of blobs in the container, +// pagination token and other information in the response of List Blobs call. +// See https://msdn.microsoft.com/en-us/library/azure/dd135734.aspx +func (b BlobStorageClient) ListBlobs(container string, params ListBlobsParameters) (BlobListResponse, error) { + q := mergeParams(params.getParameters(), url.Values{ + "restype": {"container"}, + "comp": {"list"}}) + uri := b.client.getEndpoint(blobServiceName, pathForContainer(container), q) + headers := b.client.getStandardHeaders() + + var out BlobListResponse + resp, err := b.client.exec("GET", uri, headers, nil) + if err != nil { + return out, err + } + + err = xmlUnmarshal(resp.body, &out) + return out, err +} + +// BlobExists returns true if a blob with given name exists on the +// specified container of the storage account. +func (b BlobStorageClient) BlobExists(container, name string) (bool, error) { + verb := "HEAD" + uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{}) + + headers := b.client.getStandardHeaders() + resp, err := b.client.exec(verb, uri, headers, nil) + if resp != nil && (resp.statusCode == http.StatusOK || resp.statusCode == http.StatusNotFound) { + return resp.statusCode == http.StatusOK, nil + } + return false, err +} + +// GetBlobUrl gets the canonical URL to the blob with the specified +// name in the specified container. This method does not create a +// publicly accessible URL if the blob or container is private and this +// method does not check if the blob exists. +func (b BlobStorageClient) GetBlobUrl(container, name string) string { + if container == "" { + container = "$root" + } + return b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{}) +} + +// GetBlob downloads a blob to a stream. See https://msdn.microsoft.com/en-us/library/azure/dd179440.aspx +func (b BlobStorageClient) GetBlob(container, name string) (io.ReadCloser, error) { + resp, err := b.getBlobRange(container, name, "") + if err != nil { + return nil, err + } + + if resp.statusCode != http.StatusOK { + return nil, fmt.Errorf(errUnexpectedStatus, http.StatusOK, resp.statusCode) + } + return resp.body, nil +} + +// GetBlobRange reads the specified range of a blob to a stream. +// The bytesRange string must be in a format like "0-", "10-100" +// as defined in HTTP 1.1 spec. See https://msdn.microsoft.com/en-us/library/azure/dd179440.aspx +func (b BlobStorageClient) GetBlobRange(container, name, bytesRange string) (io.ReadCloser, error) { + resp, err := b.getBlobRange(container, name, bytesRange) + if err != nil { + return nil, err + } + + if resp.statusCode != http.StatusPartialContent { + return nil, fmt.Errorf(errUnexpectedStatus, http.StatusPartialContent, resp.statusCode) + } + return resp.body, nil +} + +func (b BlobStorageClient) getBlobRange(container, name, bytesRange string) (*storageResponse, error) { + verb := "GET" + uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{}) + + headers := b.client.getStandardHeaders() + if bytesRange != "" { + headers["Range"] = fmt.Sprintf("bytes=%s", bytesRange) + } + resp, err := b.client.exec(verb, uri, headers, nil) + if err != nil { + return nil, err + } + return resp, err +} + +// GetBlobProperties provides various information about the specified +// blob. See https://msdn.microsoft.com/en-us/library/azure/dd179394.aspx +func (b BlobStorageClient) GetBlobProperties(container, name string) (*BlobProperties, error) { + verb := "HEAD" + uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{}) + + headers := b.client.getStandardHeaders() + resp, err := b.client.exec(verb, uri, headers, nil) + if err != nil { + return nil, err + } + + if resp.statusCode != http.StatusOK { + return nil, fmt.Errorf(errUnexpectedStatus, http.StatusOK, resp.statusCode) + } + + var contentLength int64 + contentLengthStr := resp.headers.Get("Content-Length") + if contentLengthStr != "" { + contentLength, err = strconv.ParseInt(contentLengthStr, 0, 64) + if err != nil { + return nil, err + } + } + + var sequenceNum int64 + sequenceNumStr := resp.headers.Get("x-ms-blob-sequence-number") + if sequenceNumStr != "" { + sequenceNum, err = strconv.ParseInt(sequenceNumStr, 0, 64) + if err != nil { + return nil, err + } + } + + return &BlobProperties{ + LastModified: resp.headers.Get("Last-Modified"), + Etag: resp.headers.Get("Etag"), + ContentMD5: resp.headers.Get("Content-MD5"), + ContentLength: contentLength, + ContentEncoding: resp.headers.Get("Content-Encoding"), + SequenceNumber: sequenceNum, + CopyCompletionTime: resp.headers.Get("x-ms-copy-completion-time"), + CopyStatusDescription: resp.headers.Get("x-ms-copy-status-description"), + CopyId: resp.headers.Get("x-ms-copy-id"), + CopyProgress: resp.headers.Get("x-ms-copy-progress"), + CopySource: resp.headers.Get("x-ms-copy-source"), + CopyStatus: resp.headers.Get("x-ms-copy-status"), + BlobType: BlobType(resp.headers.Get("x-ms-blob-type")), + }, nil +} + +// CreateBlockBlob initializes an empty block blob with no blocks. +// See https://msdn.microsoft.com/en-us/library/azure/dd179451.aspx +func (b BlobStorageClient) CreateBlockBlob(container, name string) error { + path := fmt.Sprintf("%s/%s", container, name) + uri := b.client.getEndpoint(blobServiceName, path, url.Values{}) + headers := b.client.getStandardHeaders() + headers["x-ms-blob-type"] = string(BlobTypeBlock) + headers["Content-Length"] = fmt.Sprintf("%v", 0) + + resp, err := b.client.exec("PUT", uri, headers, nil) + if err != nil { + return err + } + if resp.statusCode != http.StatusCreated { + return ErrNotCreated + } + return nil +} + +// PutBlockBlob uploads given stream into a block blob by splitting +// data stream into chunks and uploading as blocks. Commits the block +// list at the end. This is a helper method built on top of PutBlock +// and PutBlockList methods with sequential block ID counting logic. +func (b BlobStorageClient) PutBlockBlob(container, name string, blob io.Reader) error { // TODO (ahmetalpbalkan) consider ReadCloser and closing + return b.putBlockBlob(container, name, blob, MaxBlobBlockSize) +} + +func (b BlobStorageClient) putBlockBlob(container, name string, blob io.Reader, chunkSize int) error { + if chunkSize <= 0 || chunkSize > MaxBlobBlockSize { + chunkSize = MaxBlobBlockSize + } + + chunk := make([]byte, chunkSize) + n, err := blob.Read(chunk) + if err != nil && err != io.EOF { + return err + } + + if err == io.EOF { + // Fits into one block + return b.putSingleBlockBlob(container, name, chunk[:n]) + } else { + // Does not fit into one block. Upload block by block then commit the block list + blockList := []Block{} + + // Put blocks + for blockNum := 0; ; blockNum++ { + id := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%011d", blockNum))) + data := chunk[:n] + err = b.PutBlock(container, name, id, data) + if err != nil { + return err + } + blockList = append(blockList, Block{id, BlockStatusLatest}) + + // Read next block + n, err = blob.Read(chunk) + if err != nil && err != io.EOF { + return err + } + if err == io.EOF { + break + } + } + + // Commit block list + return b.PutBlockList(container, name, blockList) + } +} + +func (b BlobStorageClient) putSingleBlockBlob(container, name string, chunk []byte) error { + if len(chunk) > MaxBlobBlockSize { + return fmt.Errorf("storage: provided chunk (%d bytes) cannot fit into single-block blob (max %d bytes)", len(chunk), MaxBlobBlockSize) + } + + uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{}) + headers := b.client.getStandardHeaders() + headers["x-ms-blob-type"] = string(BlobTypeBlock) + headers["Content-Length"] = fmt.Sprintf("%v", len(chunk)) + + resp, err := b.client.exec("PUT", uri, headers, bytes.NewReader(chunk)) + if err != nil { + return err + } + if resp.statusCode != http.StatusCreated { + return ErrNotCreated + } + + return nil +} + +// PutBlock saves the given data chunk to the specified block blob with +// given ID. See https://msdn.microsoft.com/en-us/library/azure/dd135726.aspx +func (b BlobStorageClient) PutBlock(container, name, blockId string, chunk []byte) error { + return b.PutBlockWithLength(container, name, blockId, uint64(len(chunk)), bytes.NewReader(chunk)) +} + +// PutBlockWithLength saves the given data stream of exactly specified size to the block blob +// with given ID. See https://msdn.microsoft.com/en-us/library/azure/dd135726.aspx +// It is an alternative to PutBlocks where data comes as stream but the length is +// known in advance. +func (b BlobStorageClient) PutBlockWithLength(container, name, blockId string, size uint64, blob io.Reader) error { + uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{"comp": {"block"}, "blockid": {blockId}}) + headers := b.client.getStandardHeaders() + headers["x-ms-blob-type"] = string(BlobTypeBlock) + headers["Content-Length"] = fmt.Sprintf("%v", size) + + resp, err := b.client.exec("PUT", uri, headers, blob) + if err != nil { + return err + } + if resp.statusCode != http.StatusCreated { + return ErrNotCreated + } + + return nil +} + +// PutBlockList saves list of blocks to the specified block blob. See +// https://msdn.microsoft.com/en-us/library/azure/dd179467.aspx +func (b BlobStorageClient) PutBlockList(container, name string, blocks []Block) error { + blockListXml := prepareBlockListRequest(blocks) + + uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{"comp": {"blocklist"}}) + headers := b.client.getStandardHeaders() + headers["Content-Length"] = fmt.Sprintf("%v", len(blockListXml)) + + resp, err := b.client.exec("PUT", uri, headers, strings.NewReader(blockListXml)) + if err != nil { + return err + } + if resp.statusCode != http.StatusCreated { + return ErrNotCreated + } + return nil +} + +// GetBlockList retrieves list of blocks in the specified block blob. See +// https://msdn.microsoft.com/en-us/library/azure/dd179400.aspx +func (b BlobStorageClient) GetBlockList(container, name string, blockType BlockListType) (BlockListResponse, error) { + params := url.Values{"comp": {"blocklist"}, "blocklisttype": {string(blockType)}} + uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), params) + headers := b.client.getStandardHeaders() + + var out BlockListResponse + resp, err := b.client.exec("GET", uri, headers, nil) + if err != nil { + return out, err + } + + err = xmlUnmarshal(resp.body, &out) + return out, err +} + +// PutPageBlob initializes an empty page blob with specified name and maximum +// size in bytes (size must be aligned to a 512-byte boundary). A page blob must +// be created using this method before writing pages. +// See https://msdn.microsoft.com/en-us/library/azure/dd179451.aspx +func (b BlobStorageClient) PutPageBlob(container, name string, size int64) error { + path := fmt.Sprintf("%s/%s", container, name) + uri := b.client.getEndpoint(blobServiceName, path, url.Values{}) + headers := b.client.getStandardHeaders() + headers["x-ms-blob-type"] = string(BlobTypePage) + headers["x-ms-blob-content-length"] = fmt.Sprintf("%v", size) + headers["Content-Length"] = fmt.Sprintf("%v", 0) + + resp, err := b.client.exec("PUT", uri, headers, nil) + if err != nil { + return err + } + if resp.statusCode != http.StatusCreated { + return ErrNotCreated + } + return nil +} + +// PutPage writes a range of pages to a page blob or clears the given range. +// In case of 'clear' writes, given chunk is discarded. Ranges must be aligned +// with 512-byte boundaries and chunk must be of size multiplies by 512. +// See https://msdn.microsoft.com/en-us/library/ee691975.aspx +func (b BlobStorageClient) PutPage(container, name string, startByte, endByte int64, writeType PageWriteType, chunk []byte) error { + path := fmt.Sprintf("%s/%s", container, name) + uri := b.client.getEndpoint(blobServiceName, path, url.Values{"comp": {"page"}}) + headers := b.client.getStandardHeaders() + headers["x-ms-blob-type"] = string(BlobTypePage) + headers["x-ms-page-write"] = string(writeType) + headers["x-ms-range"] = fmt.Sprintf("bytes=%v-%v", startByte, endByte) + + var contentLength int64 + var data io.Reader + if writeType == PageWriteTypeClear { + contentLength = 0 + data = bytes.NewReader([]byte{}) + } else { + contentLength = int64(len(chunk)) + data = bytes.NewReader(chunk) + } + headers["Content-Length"] = fmt.Sprintf("%v", contentLength) + + resp, err := b.client.exec("PUT", uri, headers, data) + if err != nil { + return err + } + if resp.statusCode != http.StatusCreated { + return ErrNotCreated + } + return nil +} + +// GetPageRanges returns the list of valid page ranges for a page blob. +// See https://msdn.microsoft.com/en-us/library/azure/ee691973.aspx +func (b BlobStorageClient) GetPageRanges(container, name string) (GetPageRangesResponse, error) { + path := fmt.Sprintf("%s/%s", container, name) + uri := b.client.getEndpoint(blobServiceName, path, url.Values{"comp": {"pagelist"}}) + headers := b.client.getStandardHeaders() + + var out GetPageRangesResponse + resp, err := b.client.exec("GET", uri, headers, nil) + if err != nil { + return out, err + } + + if resp.statusCode != http.StatusOK { + return out, fmt.Errorf(errUnexpectedStatus, http.StatusOK, resp.statusCode) + } + + err = xmlUnmarshal(resp.body, &out) + return out, err +} + +// CopyBlob starts a blob copy operation and waits for the operation to complete. +// sourceBlob parameter must be a canonical URL to the blob (can be obtained using +// GetBlobURL method.) There is no SLA on blob copy and therefore this helper +// method works faster on smaller files. See https://msdn.microsoft.com/en-us/library/azure/dd894037.aspx +func (b BlobStorageClient) CopyBlob(container, name, sourceBlob string) error { + copyId, err := b.startBlobCopy(container, name, sourceBlob) + if err != nil { + return err + } + + return b.waitForBlobCopy(container, name, copyId) +} + +func (b BlobStorageClient) startBlobCopy(container, name, sourceBlob string) (string, error) { + uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{}) + + headers := b.client.getStandardHeaders() + headers["Content-Length"] = "0" + headers["x-ms-copy-source"] = sourceBlob + + resp, err := b.client.exec("PUT", uri, headers, nil) + if err != nil { + return "", err + } + if resp.statusCode != http.StatusAccepted && resp.statusCode != http.StatusCreated { + return "", fmt.Errorf(errUnexpectedStatus, []int{http.StatusAccepted, http.StatusCreated}, resp.statusCode) + } + + copyId := resp.headers.Get("x-ms-copy-id") + if copyId == "" { + return "", errors.New("Got empty copy id header") + } + return copyId, nil +} + +func (b BlobStorageClient) waitForBlobCopy(container, name, copyId string) error { + for { + props, err := b.GetBlobProperties(container, name) + if err != nil { + return err + } + + if props.CopyId != copyId { + return errBlobCopyIdMismatch + } + + switch props.CopyStatus { + case blobCopyStatusSuccess: + return nil + case blobCopyStatusPending: + continue + case blobCopyStatusAborted: + return errBlobCopyAborted + case blobCopyStatusFailed: + return fmt.Errorf("storage: blob copy failed. Id=%s Description=%s", props.CopyId, props.CopyStatusDescription) + default: + return fmt.Errorf("storage: unhandled blob copy status: '%s'", props.CopyStatus) + } + } +} + +// DeleteBlob deletes the given blob from the specified container. +// If the blob does not exists at the time of the Delete Blob operation, it +// returns error. See https://msdn.microsoft.com/en-us/library/azure/dd179413.aspx +func (b BlobStorageClient) DeleteBlob(container, name string) error { + resp, err := b.deleteBlob(container, name) + if err != nil { + return err + } + if resp.statusCode != http.StatusAccepted { + return ErrNotAccepted + } + return nil +} + +// DeleteBlobIfExists deletes the given blob from the specified container +// If the blob is deleted with this call, returns true. Otherwise returns +// false. See https://msdn.microsoft.com/en-us/library/azure/dd179413.aspx +func (b BlobStorageClient) DeleteBlobIfExists(container, name string) (bool, error) { + resp, err := b.deleteBlob(container, name) + if resp != nil && (resp.statusCode == http.StatusAccepted || resp.statusCode == http.StatusNotFound) { + return resp.statusCode == http.StatusAccepted, nil + } + return false, err +} + +func (b BlobStorageClient) deleteBlob(container, name string) (*storageResponse, error) { + verb := "DELETE" + uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{}) + headers := b.client.getStandardHeaders() + + return b.client.exec(verb, uri, headers, nil) +} + +// helper method to construct the path to a container given its name +func pathForContainer(name string) string { + return fmt.Sprintf("/%s", name) +} + +// helper method to construct the path to a blob given its container and blob name +func pathForBlob(container, name string) string { + return fmt.Sprintf("/%s/%s", container, name) +} + +// GetBlobSASURI creates an URL to the specified blob which contains the Shared Access Signature +// with specified permissions and expiration time. See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx +func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Time, permissions string) (string, error) { + var ( + signedPermissions = permissions + blobUrl = b.GetBlobUrl(container, name) + ) + canonicalizedResource, err := b.client.buildCanonicalizedResource(blobUrl) + if err != nil { + return "", err + } + signedExpiry := expiry.Format(time.RFC3339) + signedResource := "b" + + stringToSign, err := blobSASStringToSign(b.client.apiVersion, canonicalizedResource, signedExpiry, signedPermissions) + if err != nil { + return "", err + } + + sig := b.client.computeHmac256(stringToSign) + sasParams := url.Values{ + "sv": {b.client.apiVersion}, + "se": {signedExpiry}, + "sr": {signedResource}, + "sp": {signedPermissions}, + "sig": {sig}, + } + + sasUrl, err := url.Parse(blobUrl) + if err != nil { + return "", err + } + sasUrl.RawQuery = sasParams.Encode() + return sasUrl.String(), nil +} + +func blobSASStringToSign(signedVersion, canonicalizedResource, signedExpiry, signedPermissions string) (string, error) { + var signedStart, signedIdentifier, rscc, rscd, rsce, rscl, rsct string + + // reference: http://msdn.microsoft.com/en-us/library/azure/dn140255.aspx + if signedVersion >= "2013-08-15" { + return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s", signedPermissions, signedStart, signedExpiry, canonicalizedResource, signedIdentifier, signedVersion, rscc, rscd, rsce, rscl, rsct), nil + } else { + return "", errors.New("storage: not implemented SAS for versions earlier than 2013-08-15") + } +} diff --git a/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/blob_test.go b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/blob_test.go new file mode 100644 index 000000000..33e3d173d --- /dev/null +++ b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/blob_test.go @@ -0,0 +1,1123 @@ +package storage + +import ( + "bytes" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "reflect" + "sort" + "strings" + "sync" + "testing" + "time" +) + +const testContainerPrefix = "zzzztest-" + +func Test_pathForContainer(t *testing.T) { + out := pathForContainer("foo") + if expected := "/foo"; out != expected { + t.Errorf("Wrong pathForContainer. Expected: '%s', got: '%s'", expected, out) + } +} + +func Test_pathForBlob(t *testing.T) { + out := pathForBlob("foo", "blob") + if expected := "/foo/blob"; out != expected { + t.Errorf("Wrong pathForBlob. Expected: '%s', got: '%s'", expected, out) + } +} + +func Test_blobSASStringToSign(t *testing.T) { + _, err := blobSASStringToSign("2012-02-12", "CS", "SE", "SP") + if err == nil { + t.Fatal("Expected error, got nil") + } + + out, err := blobSASStringToSign("2013-08-15", "CS", "SE", "SP") + if err != nil { + t.Fatal(err) + } + if expected := "SP\n\nSE\nCS\n\n2013-08-15\n\n\n\n\n"; out != expected { + t.Errorf("Wrong stringToSign. Expected: '%s', got: '%s'", expected, out) + } +} + +func TestGetBlobSASURI(t *testing.T) { + api, err := NewClient("foo", "YmFy", DefaultBaseUrl, "2013-08-15", true) + if err != nil { + t.Fatal(err) + } + cli := api.GetBlobService() + expiry := time.Time{} + + expectedParts := url.URL{ + Scheme: "https", + Host: "foo.blob.core.windows.net", + Path: "container/name", + RawQuery: url.Values{ + "sv": {"2013-08-15"}, + "sig": {"/OXG7rWh08jYwtU03GzJM0DHZtidRGpC6g69rSGm3I0="}, + "sr": {"b"}, + "sp": {"r"}, + "se": {"0001-01-01T00:00:00Z"}, + }.Encode()} + + u, err := cli.GetBlobSASURI("container", "name", expiry, "r") + if err != nil { + t.Fatal(err) + } + sasParts, err := url.Parse(u) + if err != nil { + t.Fatal(err) + } + + expectedQuery := expectedParts.Query() + sasQuery := sasParts.Query() + + expectedParts.RawQuery = "" // reset + sasParts.RawQuery = "" + + if expectedParts.String() != sasParts.String() { + t.Fatalf("Base URL wrong for SAS. Expected: '%s', got: '%s'", expectedParts, sasParts) + } + + if len(expectedQuery) != len(sasQuery) { + t.Fatalf("Query string wrong for SAS URL. Expected: '%d keys', got: '%d keys'", len(expectedQuery), len(sasQuery)) + } + + for k, v := range expectedQuery { + out, ok := sasQuery[k] + if !ok { + t.Fatalf("Query parameter '%s' not found in generated SAS query. Expected: '%s'", k, v) + } + if !reflect.DeepEqual(v, out) { + t.Fatalf("Wrong value for query parameter '%s'. Expected: '%s', got: '%s'", k, v, out) + } + } +} + +func TestBlobSASURICorrectness(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + cnt := randContainer() + blob := randString(20) + body := []byte(randString(100)) + expiry := time.Now().UTC().Add(time.Hour) + permissions := "r" + + err = cli.CreateContainer(cnt, ContainerAccessTypePrivate) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteContainer(cnt) + + err = cli.PutBlockBlob(cnt, blob, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + + sasUri, err := cli.GetBlobSASURI(cnt, blob, expiry, permissions) + if err != nil { + t.Fatal(err) + } + + resp, err := http.Get(sasUri) + if err != nil { + t.Logf("SAS URI: %s", sasUri) + t.Fatal(err) + } + + blobResp, err := ioutil.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Non-ok status code: %s", resp.Status) + } + + if len(blobResp) != len(body) { + t.Fatalf("Wrong blob size on SAS URI. Expected: %d, Got: %d", len(body), len(blobResp)) + } +} + +func TestListContainersPagination(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + err = deleteTestContainers(cli) + if err != nil { + t.Fatal(err) + } + + const n = 5 + const pageSize = 2 + + // Create test containers + created := []string{} + for i := 0; i < n; i++ { + name := randContainer() + err := cli.CreateContainer(name, ContainerAccessTypePrivate) + if err != nil { + t.Fatalf("Error creating test container: %s", err) + } + created = append(created, name) + } + sort.Strings(created) + + // Defer test container deletions + defer func() { + var wg sync.WaitGroup + for _, cnt := range created { + wg.Add(1) + go func(name string) { + err := cli.DeleteContainer(name) + if err != nil { + t.Logf("Error while deleting test container: %s", err) + } + wg.Done() + }(cnt) + } + wg.Wait() + }() + + // Paginate results + seen := []string{} + marker := "" + for { + resp, err := cli.ListContainers(ListContainersParameters{ + Prefix: testContainerPrefix, + MaxResults: pageSize, + Marker: marker}) + + if err != nil { + t.Fatal(err) + } + + containers := resp.Containers + + if len(containers) > pageSize { + t.Fatalf("Got a bigger page. Expected: %d, got: %d", pageSize, len(containers)) + } + + for _, c := range containers { + seen = append(seen, c.Name) + } + + marker = resp.NextMarker + if marker == "" || len(containers) == 0 { + break + } + } + + // Compare + if !reflect.DeepEqual(created, seen) { + t.Fatalf("Wrong pagination results:\nExpected:\t\t%v\nGot:\t\t%v", created, seen) + } +} + +func TestContainerExists(t *testing.T) { + cnt := randContainer() + + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + ok, err := cli.ContainerExists(cnt) + if err != nil { + t.Fatal(err) + } + if ok { + t.Fatalf("Non-existing container returned as existing: %s", cnt) + } + + err = cli.CreateContainer(cnt, ContainerAccessTypeBlob) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteContainer(cnt) + + ok, err = cli.ContainerExists(cnt) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatalf("Existing container returned as non-existing: %s", cnt) + } +} + +func TestCreateDeleteContainer(t *testing.T) { + cnt := randContainer() + + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + err = cli.CreateContainer(cnt, ContainerAccessTypePrivate) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteContainer(cnt) + + err = cli.DeleteContainer(cnt) + if err != nil { + t.Fatal(err) + } +} + +func TestCreateContainerIfNotExists(t *testing.T) { + cnt := randContainer() + + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + // First create + ok, err := cli.CreateContainerIfNotExists(cnt, ContainerAccessTypePrivate) + if err != nil { + t.Fatal(err) + } + if expected := true; ok != expected { + t.Fatalf("Wrong creation status. Expected: %v; Got: %v", expected, ok) + } + + // Second create, should not give errors + ok, err = cli.CreateContainerIfNotExists(cnt, ContainerAccessTypePrivate) + if err != nil { + t.Fatal(err) + } + if expected := false; ok != expected { + t.Fatalf("Wrong creation status. Expected: %v; Got: %v", expected, ok) + } + + defer cli.DeleteContainer(cnt) +} + +func TestDeleteContainerIfExists(t *testing.T) { + cnt := randContainer() + + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + // Nonexisting container + err = cli.DeleteContainer(cnt) + if err == nil { + t.Fatal("Expected error, got nil") + } + + ok, err := cli.DeleteContainerIfExists(cnt) + if err != nil { + t.Fatalf("Not supposed to return error, got: %s", err) + } + if expected := false; ok != expected { + t.Fatalf("Wrong deletion status. Expected: %v; Got: %v", expected, ok) + } + + // Existing container + err = cli.CreateContainer(cnt, ContainerAccessTypePrivate) + if err != nil { + t.Fatal(err) + } + ok, err = cli.DeleteContainerIfExists(cnt) + if err != nil { + t.Fatalf("Not supposed to return error, got: %s", err) + } + if expected := true; ok != expected { + t.Fatalf("Wrong deletion status. Expected: %v; Got: %v", expected, ok) + } +} + +func TestBlobExists(t *testing.T) { + cnt := randContainer() + blob := randString(20) + + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + err = cli.CreateContainer(cnt, ContainerAccessTypeBlob) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteContainer(cnt) + err = cli.PutBlockBlob(cnt, blob, strings.NewReader("Hello!")) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteBlob(cnt, blob) + + ok, err := cli.BlobExists(cnt, blob+".foo") + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("Non-existing blob returned as existing: %s/%s", cnt, blob) + } + + ok, err = cli.BlobExists(cnt, blob) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("Existing blob returned as non-existing: %s/%s", cnt, blob) + } +} + +func TestGetBlobUrl(t *testing.T) { + api, err := NewBasicClient("foo", "YmFy") + if err != nil { + t.Fatal(err) + } + cli := api.GetBlobService() + + out := cli.GetBlobUrl("c", "nested/blob") + if expected := "https://foo.blob.core.windows.net/c/nested/blob"; out != expected { + t.Fatalf("Wrong blob URL. Expected: '%s', got:'%s'", expected, out) + } + + out = cli.GetBlobUrl("", "blob") + if expected := "https://foo.blob.core.windows.net/$root/blob"; out != expected { + t.Fatalf("Wrong blob URL. Expected: '%s', got:'%s'", expected, out) + } + + out = cli.GetBlobUrl("", "nested/blob") + if expected := "https://foo.blob.core.windows.net/$root/nested/blob"; out != expected { + t.Fatalf("Wrong blob URL. Expected: '%s', got:'%s'", expected, out) + } +} + +func TestBlobCopy(t *testing.T) { + if testing.Short() { + t.Skip("skipping blob copy in short mode, no SLA on async operation") + } + + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + cnt := randContainer() + src := randString(20) + dst := randString(20) + body := []byte(randString(1024)) + + err = cli.CreateContainer(cnt, ContainerAccessTypePrivate) + if err != nil { + t.Fatal(err) + } + defer cli.deleteContainer(cnt) + + err = cli.PutBlockBlob(cnt, src, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteBlob(cnt, src) + + err = cli.CopyBlob(cnt, dst, cli.GetBlobUrl(cnt, src)) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteBlob(cnt, dst) + + blobBody, err := cli.GetBlob(cnt, dst) + if err != nil { + t.Fatal(err) + } + + b, err := ioutil.ReadAll(blobBody) + defer blobBody.Close() + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(body, b) { + t.Fatalf("Copied blob is wrong. Expected: %d bytes, got: %d bytes\n%s\n%s", len(body), len(b), body, b) + } +} + +func TestDeleteBlobIfExists(t *testing.T) { + cnt := randContainer() + blob := randString(20) + + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + err = cli.DeleteBlob(cnt, blob) + if err == nil { + t.Fatal("Nonexisting blob did not return error") + } + + ok, err := cli.DeleteBlobIfExists(cnt, blob) + if err != nil { + t.Fatalf("Not supposed to return error: %s", err) + } + if expected := false; ok != expected { + t.Fatalf("Wrong deletion status. Expected: %v; Got: %v", expected, ok) + } +} + +func TestGetBlobProperties(t *testing.T) { + cnt := randContainer() + blob := randString(20) + contents := randString(64) + + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + err = cli.CreateContainer(cnt, ContainerAccessTypePrivate) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteContainer(cnt) + + // Nonexisting blob + _, err = cli.GetBlobProperties(cnt, blob) + if err == nil { + t.Fatal("Did not return error for non-existing blob") + } + + // Put the blob + err = cli.PutBlockBlob(cnt, blob, strings.NewReader(contents)) + if err != nil { + t.Fatal(err) + } + + // Get blob properties + props, err := cli.GetBlobProperties(cnt, blob) + if err != nil { + t.Fatal(err) + } + + if props.ContentLength != int64(len(contents)) { + t.Fatalf("Got wrong Content-Length: '%d', expected: %d", props.ContentLength, len(contents)) + } + if props.BlobType != BlobTypeBlock { + t.Fatalf("Got wrong BlobType. Expected:'%s', got:'%s'", BlobTypeBlock, props.BlobType) + } +} + +func TestListBlobsPagination(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + cnt := randContainer() + err = cli.CreateContainer(cnt, ContainerAccessTypePrivate) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteContainer(cnt) + + blobs := []string{} + const n = 5 + const pageSize = 2 + for i := 0; i < n; i++ { + name := randString(20) + err := cli.PutBlockBlob(cnt, name, strings.NewReader("Hello, world!")) + if err != nil { + t.Fatal(err) + } + blobs = append(blobs, name) + } + sort.Strings(blobs) + + // Paginate + seen := []string{} + marker := "" + for { + resp, err := cli.ListBlobs(cnt, ListBlobsParameters{ + MaxResults: pageSize, + Marker: marker}) + if err != nil { + t.Fatal(err) + } + + for _, v := range resp.Blobs { + seen = append(seen, v.Name) + } + + marker = resp.NextMarker + if marker == "" || len(resp.Blobs) == 0 { + break + } + } + + // Compare + if !reflect.DeepEqual(blobs, seen) { + t.Fatalf("Got wrong list of blobs. Expected: %s, Got: %s", blobs, seen) + } + + err = cli.DeleteContainer(cnt) + if err != nil { + t.Fatal(err) + } +} + +func TestPutEmptyBlockBlob(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + cnt := randContainer() + if err := cli.CreateContainer(cnt, ContainerAccessTypePrivate); err != nil { + t.Fatal(err) + } + defer cli.deleteContainer(cnt) + + blob := randString(20) + err = cli.PutBlockBlob(cnt, blob, bytes.NewReader([]byte{})) + if err != nil { + t.Fatal(err) + } + + props, err := cli.GetBlobProperties(cnt, blob) + if err != nil { + t.Fatal(err) + } + if props.ContentLength != 0 { + t.Fatalf("Wrong content length for empty blob: %d", props.ContentLength) + } +} + +func TestPutSingleBlockBlob(t *testing.T) { + cnt := randContainer() + blob := randString(20) + body := []byte(randString(1024)) + + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + err = cli.CreateContainer(cnt, ContainerAccessTypeBlob) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteContainer(cnt) + + err = cli.PutBlockBlob(cnt, blob, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteBlob(cnt, blob) + + resp, err := cli.GetBlob(cnt, blob) + if err != nil { + t.Fatal(err) + } + + // Verify contents + respBody, err := ioutil.ReadAll(resp) + defer resp.Close() + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(body, respBody) { + t.Fatalf("Wrong blob contents.\nExpected: %d bytes, Got: %d byes", len(body), len(respBody)) + } + + // Verify block list + blocks, err := cli.GetBlockList(cnt, blob, BlockListTypeAll) + if err != nil { + t.Fatal(err) + } + if expected := 1; len(blocks.CommittedBlocks) != expected { + t.Fatalf("Wrong committed block count. Expected: %d, Got: %d", expected, len(blocks.CommittedBlocks)) + } + if expected := 0; len(blocks.UncommittedBlocks) != expected { + t.Fatalf("Wrong unccommitted block count. Expected: %d, Got: %d", expected, len(blocks.UncommittedBlocks)) + } + thatBlock := blocks.CommittedBlocks[0] + if expected := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%011d", 0))); thatBlock.Name != expected { + t.Fatalf("Wrong block name. Expected: %s, Got: %s", expected, thatBlock.Name) + } +} + +func TestGetBlobRange(t *testing.T) { + cnt := randContainer() + blob := randString(20) + body := "0123456789" + + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + err = cli.CreateContainer(cnt, ContainerAccessTypeBlob) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteContainer(cnt) + + err = cli.PutBlockBlob(cnt, blob, strings.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteBlob(cnt, blob) + + // Read 1-3 + for _, r := range []struct { + rangeStr string + expected string + }{ + {"0-", body}, + {"1-3", body[1 : 3+1]}, + {"3-", body[3:]}, + } { + resp, err := cli.GetBlobRange(cnt, blob, r.rangeStr) + if err != nil { + t.Fatal(err) + } + blobBody, err := ioutil.ReadAll(resp) + if err != nil { + t.Fatal(err) + } + str := string(blobBody) + if str != r.expected { + t.Fatalf("Got wrong range. Expected: '%s'; Got:'%s'", r.expected, str) + } + } +} + +func TestPutBlock(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + cnt := randContainer() + if err := cli.CreateContainer(cnt, ContainerAccessTypePrivate); err != nil { + t.Fatal(err) + } + defer cli.deleteContainer(cnt) + + blob := randString(20) + chunk := []byte(randString(1024)) + blockId := base64.StdEncoding.EncodeToString([]byte("foo")) + err = cli.PutBlock(cnt, blob, blockId, chunk) + if err != nil { + t.Fatal(err) + } +} + +func TestPutMultiBlockBlob(t *testing.T) { + var ( + cnt = randContainer() + blob = randString(20) + blockSize = 32 * 1024 // 32 KB + body = []byte(randString(blockSize*2 + blockSize/2)) // 3 blocks + ) + + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + err = cli.CreateContainer(cnt, ContainerAccessTypeBlob) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteContainer(cnt) + + err = cli.putBlockBlob(cnt, blob, bytes.NewReader(body), blockSize) + if err != nil { + t.Fatal(err) + } + defer cli.DeleteBlob(cnt, blob) + + resp, err := cli.GetBlob(cnt, blob) + if err != nil { + t.Fatal(err) + } + + // Verify contents + respBody, err := ioutil.ReadAll(resp) + defer resp.Close() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(body, respBody) { + t.Fatalf("Wrong blob contents.\nExpected: %d bytes, Got: %d byes", len(body), len(respBody)) + } + + err = cli.DeleteBlob(cnt, blob) + if err != nil { + t.Fatal(err) + } + + err = cli.DeleteContainer(cnt) + if err != nil { + t.Fatal(err) + } +} + +func TestGetBlockList_PutBlockList(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + cnt := randContainer() + if err := cli.CreateContainer(cnt, ContainerAccessTypePrivate); err != nil { + t.Fatal(err) + } + defer cli.deleteContainer(cnt) + + blob := randString(20) + chunk := []byte(randString(1024)) + blockId := base64.StdEncoding.EncodeToString([]byte("foo")) + + // Put one block + err = cli.PutBlock(cnt, blob, blockId, chunk) + if err != nil { + t.Fatal(err) + } + defer cli.deleteBlob(cnt, blob) + + // Get committed blocks + committed, err := cli.GetBlockList(cnt, blob, BlockListTypeCommitted) + if err != nil { + t.Fatal(err) + } + + if len(committed.CommittedBlocks) > 0 { + t.Fatal("There are committed blocks") + } + + // Get uncommitted blocks + uncommitted, err := cli.GetBlockList(cnt, blob, BlockListTypeUncommitted) + if err != nil { + t.Fatal(err) + } + + if expected := 1; len(uncommitted.UncommittedBlocks) != expected { + t.Fatalf("Uncommitted blocks wrong. Expected: %d, got: %d", expected, len(uncommitted.UncommittedBlocks)) + } + + // Commit block list + err = cli.PutBlockList(cnt, blob, []Block{{blockId, BlockStatusUncommitted}}) + if err != nil { + t.Fatal(err) + } + + // Get all blocks + all, err := cli.GetBlockList(cnt, blob, BlockListTypeAll) + if err != nil { + t.Fatal(err) + } + + if expected := 1; len(all.CommittedBlocks) != expected { + t.Fatalf("Uncommitted blocks wrong. Expected: %d, got: %d", expected, len(uncommitted.CommittedBlocks)) + } + if expected := 0; len(all.UncommittedBlocks) != expected { + t.Fatalf("Uncommitted blocks wrong. Expected: %d, got: %d", expected, len(uncommitted.UncommittedBlocks)) + } + + // Verify the block + thatBlock := all.CommittedBlocks[0] + if expected := blockId; expected != thatBlock.Name { + t.Fatalf("Wrong block name. Expected: %s, got: %s", expected, thatBlock.Name) + } + if expected := int64(len(chunk)); expected != thatBlock.Size { + t.Fatalf("Wrong block name. Expected: %d, got: %d", expected, thatBlock.Size) + } +} + +func TestCreateBlockBlob(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + cnt := randContainer() + if err := cli.CreateContainer(cnt, ContainerAccessTypePrivate); err != nil { + t.Fatal(err) + } + defer cli.deleteContainer(cnt) + + blob := randString(20) + if err := cli.CreateBlockBlob(cnt, blob); err != nil { + t.Fatal(err) + } + + // Verify + blocks, err := cli.GetBlockList(cnt, blob, BlockListTypeAll) + if err != nil { + t.Fatal(err) + } + if expected, got := 0, len(blocks.CommittedBlocks); expected != got { + t.Fatalf("Got wrong committed block count. Expected: %v, Got:%v ", expected, got) + } + if expected, got := 0, len(blocks.UncommittedBlocks); expected != got { + t.Fatalf("Got wrong uncommitted block count. Expected: %v, Got:%v ", expected, got) + } +} + +func TestPutPageBlob(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + cnt := randContainer() + if err := cli.CreateContainer(cnt, ContainerAccessTypePrivate); err != nil { + t.Fatal(err) + } + defer cli.deleteContainer(cnt) + + blob := randString(20) + size := int64(10 * 1024 * 1024) + if err := cli.PutPageBlob(cnt, blob, size); err != nil { + t.Fatal(err) + } + + // Verify + props, err := cli.GetBlobProperties(cnt, blob) + if err != nil { + t.Fatal(err) + } + if expected := size; expected != props.ContentLength { + t.Fatalf("Got wrong Content-Length. Expected: %v, Got:%v ", expected, props.ContentLength) + } + if expected := BlobTypePage; expected != props.BlobType { + t.Fatalf("Got wrong x-ms-blob-type. Expected: %v, Got:%v ", expected, props.BlobType) + } +} + +func TestPutPagesUpdate(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + cnt := randContainer() + if err := cli.CreateContainer(cnt, ContainerAccessTypePrivate); err != nil { + t.Fatal(err) + } + defer cli.deleteContainer(cnt) + + blob := randString(20) + size := int64(10 * 1024 * 1024) // larger than we'll use + if err := cli.PutPageBlob(cnt, blob, size); err != nil { + t.Fatal(err) + } + + chunk1 := []byte(randString(1024)) + chunk2 := []byte(randString(512)) + // Append chunks + if err := cli.PutPage(cnt, blob, 0, int64(len(chunk1)-1), PageWriteTypeUpdate, chunk1); err != nil { + t.Fatal(err) + } + if err := cli.PutPage(cnt, blob, int64(len(chunk1)), int64(len(chunk1)+len(chunk2)-1), PageWriteTypeUpdate, chunk2); err != nil { + t.Fatal(err) + } + + // Verify contents + out, err := cli.GetBlobRange(cnt, blob, fmt.Sprintf("%v-%v", 0, len(chunk1)+len(chunk2))) + if err != nil { + t.Fatal(err) + } + blobContents, err := ioutil.ReadAll(out) + defer out.Close() + if err != nil { + t.Fatal(err) + } + if expected := append(chunk1, chunk2...); reflect.DeepEqual(blobContents, expected) { + t.Fatalf("Got wrong blob.\nGot:%d bytes, Expected:%d bytes", len(blobContents), len(expected)) + } + out.Close() + + // Overwrite first half of chunk1 + chunk0 := []byte(randString(512)) + if err := cli.PutPage(cnt, blob, 0, int64(len(chunk0)-1), PageWriteTypeUpdate, chunk0); err != nil { + t.Fatal(err) + } + + // Verify contents + out, err = cli.GetBlobRange(cnt, blob, fmt.Sprintf("%v-%v", 0, len(chunk1)+len(chunk2))) + if err != nil { + t.Fatal(err) + } + blobContents, err = ioutil.ReadAll(out) + defer out.Close() + if err != nil { + t.Fatal(err) + } + if expected := append(append(chunk0, chunk1[512:]...), chunk2...); reflect.DeepEqual(blobContents, expected) { + t.Fatalf("Got wrong blob.\nGot:%d bytes, Expected:%d bytes", len(blobContents), len(expected)) + } +} + +func TestPutPagesClear(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + cnt := randContainer() + if err := cli.CreateContainer(cnt, ContainerAccessTypePrivate); err != nil { + t.Fatal(err) + } + defer cli.deleteContainer(cnt) + + blob := randString(20) + size := int64(10 * 1024 * 1024) // larger than we'll use + + if err := cli.PutPageBlob(cnt, blob, size); err != nil { + t.Fatal(err) + } + + // Put 0-2047 + chunk := []byte(randString(2048)) + if err := cli.PutPage(cnt, blob, 0, 2047, PageWriteTypeUpdate, chunk); err != nil { + t.Fatal(err) + } + + // Clear 512-1023 + if err := cli.PutPage(cnt, blob, 512, 1023, PageWriteTypeClear, nil); err != nil { + t.Fatal(err) + } + + // Get blob contents + if out, err := cli.GetBlobRange(cnt, blob, "0-2048"); err != nil { + t.Fatal(err) + } else { + contents, err := ioutil.ReadAll(out) + defer out.Close() + if err != nil { + t.Fatal(err) + } + + if expected := append(append(chunk[:512], make([]byte, 512)...), chunk[1024:]...); reflect.DeepEqual(contents, expected) { + t.Fatalf("Cleared blob is not the same. Expected: (%d) %v; got: (%d) %v", len(expected), expected, len(contents), contents) + } + } +} + +func TestGetPageRanges(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + cnt := randContainer() + if err := cli.CreateContainer(cnt, ContainerAccessTypePrivate); err != nil { + t.Fatal(err) + } + defer cli.deleteContainer(cnt) + + blob := randString(20) + size := int64(10 * 1024 * 1024) // larger than we'll use + + if err := cli.PutPageBlob(cnt, blob, size); err != nil { + t.Fatal(err) + } + + // Get page ranges on empty blob + if out, err := cli.GetPageRanges(cnt, blob); err != nil { + t.Fatal(err) + } else if len(out.PageList) != 0 { + t.Fatal("Blob has pages") + } + + // Add 0-512 page + err = cli.PutPage(cnt, blob, 0, 511, PageWriteTypeUpdate, []byte(randString(512))) + if err != nil { + t.Fatal(err) + } + + if out, err := cli.GetPageRanges(cnt, blob); err != nil { + t.Fatal(err) + } else if expected := 1; len(out.PageList) != expected { + t.Fatalf("Expected %d pages, got: %d -- %v", expected, len(out.PageList), out.PageList) + } + + // Add 1024-2048 + err = cli.PutPage(cnt, blob, 1024, 2047, PageWriteTypeUpdate, []byte(randString(1024))) + if err != nil { + t.Fatal(err) + } + + if out, err := cli.GetPageRanges(cnt, blob); err != nil { + t.Fatal(err) + } else if expected := 2; len(out.PageList) != expected { + t.Fatalf("Expected %d pages, got: %d -- %v", expected, len(out.PageList), out.PageList) + } +} + +func deleteTestContainers(cli *BlobStorageClient) error { + for { + resp, err := cli.ListContainers(ListContainersParameters{Prefix: testContainerPrefix}) + if err != nil { + return err + } + if len(resp.Containers) == 0 { + break + } + for _, c := range resp.Containers { + err = cli.DeleteContainer(c.Name) + if err != nil { + return err + } + } + } + return nil +} + +func getBlobClient() (*BlobStorageClient, error) { + name := os.Getenv("ACCOUNT_NAME") + if name == "" { + return nil, errors.New("ACCOUNT_NAME not set, need an empty storage account to test") + } + key := os.Getenv("ACCOUNT_KEY") + if key == "" { + return nil, errors.New("ACCOUNT_KEY not set") + } + cli, err := NewBasicClient(name, key) + if err != nil { + return nil, err + } + return cli.GetBlobService(), nil +} + +func randContainer() string { + return testContainerPrefix + randString(32-len(testContainerPrefix)) +} + +func randString(n int) string { + if n <= 0 { + panic("negative number") + } + const alphanum = "0123456789abcdefghijklmnopqrstuvwxyz" + var bytes = make([]byte, n) + rand.Read(bytes) + for i, b := range bytes { + bytes[i] = alphanum[b%byte(len(alphanum))] + } + return string(bytes) +} diff --git a/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/client.go b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/client.go new file mode 100644 index 000000000..1cbabaee3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/client.go @@ -0,0 +1,315 @@ +package storage + +import ( + "bytes" + "encoding/base64" + "encoding/xml" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "regexp" + "sort" + "strings" +) + +const ( + DefaultBaseUrl = "core.windows.net" + DefaultApiVersion = "2014-02-14" + defaultUseHttps = true + + blobServiceName = "blob" + tableServiceName = "table" + queueServiceName = "queue" +) + +// StorageClient is the object that needs to be constructed +// to perform operations on the storage account. +type StorageClient struct { + accountName string + accountKey []byte + useHttps bool + baseUrl string + apiVersion string +} + +type storageResponse struct { + statusCode int + headers http.Header + body io.ReadCloser +} + +// StorageServiceError contains fields of the error response from +// Azure Storage Service REST API. See https://msdn.microsoft.com/en-us/library/azure/dd179382.aspx +// Some fields might be specific to certain calls. +type StorageServiceError struct { + Code string `xml:"Code"` + Message string `xml:"Message"` + AuthenticationErrorDetail string `xml:"AuthenticationErrorDetail"` + QueryParameterName string `xml:"QueryParameterName"` + QueryParameterValue string `xml:"QueryParameterValue"` + Reason string `xml:"Reason"` + StatusCode int + RequestId string +} + +// NewBasicClient constructs a StorageClient with given storage service name +// and key. +func NewBasicClient(accountName, accountKey string) (*StorageClient, error) { + return NewClient(accountName, accountKey, DefaultBaseUrl, DefaultApiVersion, defaultUseHttps) +} + +// NewClient constructs a StorageClient. This should be used if the caller +// wants to specify whether to use HTTPS, a specific REST API version or a +// custom storage endpoint than Azure Public Cloud. +func NewClient(accountName, accountKey, blobServiceBaseUrl, apiVersion string, useHttps bool) (*StorageClient, error) { + if accountName == "" { + return nil, fmt.Errorf("azure: account name required") + } else if accountKey == "" { + return nil, fmt.Errorf("azure: account key required") + } else if blobServiceBaseUrl == "" { + return nil, fmt.Errorf("azure: base storage service url required") + } + + key, err := base64.StdEncoding.DecodeString(accountKey) + if err != nil { + return nil, err + } + + return &StorageClient{ + accountName: accountName, + accountKey: key, + useHttps: useHttps, + baseUrl: blobServiceBaseUrl, + apiVersion: apiVersion}, nil +} + +func (c StorageClient) getBaseUrl(service string) string { + scheme := "http" + if c.useHttps { + scheme = "https" + } + + host := fmt.Sprintf("%s.%s.%s", c.accountName, service, c.baseUrl) + + u := &url.URL{ + Scheme: scheme, + Host: host} + return u.String() +} + +func (c StorageClient) getEndpoint(service, path string, params url.Values) string { + u, err := url.Parse(c.getBaseUrl(service)) + if err != nil { + // really should not be happening + panic(err) + } + + if path == "" { + path = "/" // API doesn't accept path segments not starting with '/' + } + + u.Path = path + u.RawQuery = params.Encode() + return u.String() +} + +// GetBlobService returns a BlobStorageClient which can operate on the +// blob service of the storage account. +func (c StorageClient) GetBlobService() *BlobStorageClient { + return &BlobStorageClient{c} +} + +func (c StorageClient) createAuthorizationHeader(canonicalizedString string) string { + signature := c.computeHmac256(canonicalizedString) + return fmt.Sprintf("%s %s:%s", "SharedKey", c.accountName, signature) +} + +func (c StorageClient) getAuthorizationHeader(verb, url string, headers map[string]string) (string, error) { + canonicalizedResource, err := c.buildCanonicalizedResource(url) + if err != nil { + return "", err + } + + canonicalizedString := c.buildCanonicalizedString(verb, headers, canonicalizedResource) + return c.createAuthorizationHeader(canonicalizedString), nil +} + +func (c StorageClient) getStandardHeaders() map[string]string { + return map[string]string{ + "x-ms-version": c.apiVersion, + "x-ms-date": currentTimeRfc1123Formatted(), + } +} + +func (c StorageClient) buildCanonicalizedHeader(headers map[string]string) string { + cm := make(map[string]string) + + for k, v := range headers { + headerName := strings.TrimSpace(strings.ToLower(k)) + match, _ := regexp.MatchString("x-ms-", headerName) + if match { + cm[headerName] = v + } + } + + if len(cm) == 0 { + return "" + } + + keys := make([]string, 0, len(cm)) + for key := range cm { + keys = append(keys, key) + } + + sort.Strings(keys) + + ch := "" + + for i, key := range keys { + if i == len(keys)-1 { + ch += fmt.Sprintf("%s:%s", key, cm[key]) + } else { + ch += fmt.Sprintf("%s:%s\n", key, cm[key]) + } + } + return ch +} + +func (c StorageClient) buildCanonicalizedResource(uri string) (string, error) { + errMsg := "buildCanonicalizedResource error: %s" + u, err := url.Parse(uri) + if err != nil { + return "", fmt.Errorf(errMsg, err.Error()) + } + + cr := "/" + c.accountName + if len(u.Path) > 0 { + cr += u.Path + } + + params, err := url.ParseQuery(u.RawQuery) + if err != nil { + return "", fmt.Errorf(errMsg, err.Error()) + } + + if len(params) > 0 { + cr += "\n" + keys := make([]string, 0, len(params)) + for key := range params { + keys = append(keys, key) + } + + sort.Strings(keys) + + for i, key := range keys { + if len(params[key]) > 1 { + sort.Strings(params[key]) + } + + if i == len(keys)-1 { + cr += fmt.Sprintf("%s:%s", key, strings.Join(params[key], ",")) + } else { + cr += fmt.Sprintf("%s:%s\n", key, strings.Join(params[key], ",")) + } + } + } + return cr, nil +} + +func (c StorageClient) buildCanonicalizedString(verb string, headers map[string]string, canonicalizedResource string) string { + canonicalizedString := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s", + verb, + headers["Content-Encoding"], + headers["Content-Language"], + headers["Content-Length"], + headers["Content-MD5"], + headers["Content-Type"], + headers["Date"], + headers["If-Modified-Singe"], + headers["If-Match"], + headers["If-None-Match"], + headers["If-Unmodified-Singe"], + headers["Range"], + c.buildCanonicalizedHeader(headers), + canonicalizedResource) + + return canonicalizedString +} + +func (c StorageClient) exec(verb, url string, headers map[string]string, body io.Reader) (*storageResponse, error) { + authHeader, err := c.getAuthorizationHeader(verb, url, headers) + if err != nil { + return nil, err + } + headers["Authorization"] = authHeader + + if err != nil { + return nil, err + } + + req, err := http.NewRequest(verb, url, body) + for k, v := range headers { + req.Header.Add(k, v) + } + httpClient := http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + + statusCode := resp.StatusCode + if statusCode >= 400 && statusCode <= 505 { + var respBody []byte + respBody, err = readResponseBody(resp) + if err != nil { + return nil, err + } + + if len(respBody) == 0 { + // no error in response body + err = fmt.Errorf("storage: service returned without a response body (%s).", resp.Status) + } else { + // response contains storage service error object, unmarshal + storageErr, errIn := serviceErrFromXml(respBody, resp.StatusCode, resp.Header.Get("x-ms-request-id")) + if err != nil { // error unmarshaling the error response + err = errIn + } + err = storageErr + } + return &storageResponse{ + statusCode: resp.StatusCode, + headers: resp.Header, + body: ioutil.NopCloser(bytes.NewReader(respBody)), /* restore the body */ + }, err + } + + return &storageResponse{ + statusCode: resp.StatusCode, + headers: resp.Header, + body: resp.Body}, nil +} + +func readResponseBody(resp *http.Response) ([]byte, error) { + defer resp.Body.Close() + out, err := ioutil.ReadAll(resp.Body) + if err == io.EOF { + err = nil + } + return out, err +} + +func serviceErrFromXml(body []byte, statusCode int, requestId string) (StorageServiceError, error) { + var storageErr StorageServiceError + if err := xml.Unmarshal(body, &storageErr); err != nil { + return storageErr, err + } + storageErr.StatusCode = statusCode + storageErr.RequestId = requestId + return storageErr, nil +} + +func (e StorageServiceError) Error() string { + return fmt.Sprintf("storage: remote server returned error. StatusCode=%d, ErrorCode=%s, ErrorMessage=%s, RequestId=%s", e.StatusCode, e.Code, e.Message, e.RequestId) +} diff --git a/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/client_test.go b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/client_test.go new file mode 100644 index 000000000..844962a60 --- /dev/null +++ b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/client_test.go @@ -0,0 +1,203 @@ +package storage + +import ( + "encoding/base64" + "net/url" + "testing" +) + +func TestGetBaseUrl_Basic_Https(t *testing.T) { + cli, err := NewBasicClient("foo", "YmFy") + if err != nil { + t.Fatal(err) + } + + if cli.apiVersion != DefaultApiVersion { + t.Fatalf("Wrong api version. Expected: '%s', got: '%s'", DefaultApiVersion, cli.apiVersion) + } + + if err != nil { + t.Fatal(err) + } + output := cli.getBaseUrl("table") + + if expected := "https://foo.table.core.windows.net"; output != expected { + t.Fatalf("Wrong base url. Expected: '%s', got: '%s'", expected, output) + } +} + +func TestGetBaseUrl_Custom_NoHttps(t *testing.T) { + apiVersion := DefaultApiVersion + cli, err := NewClient("foo", "YmFy", "core.chinacloudapi.cn", apiVersion, false) + if err != nil { + t.Fatal(err) + } + + if cli.apiVersion != apiVersion { + t.Fatalf("Wrong api version. Expected: '%s', got: '%s'", apiVersion, cli.apiVersion) + } + + output := cli.getBaseUrl("table") + + if expected := "http://foo.table.core.chinacloudapi.cn"; output != expected { + t.Fatalf("Wrong base url. Expected: '%s', got: '%s'", expected, output) + } +} + +func TestGetEndpoint_None(t *testing.T) { + cli, err := NewBasicClient("foo", "YmFy") + if err != nil { + t.Fatal(err) + } + output := cli.getEndpoint(blobServiceName, "", url.Values{}) + + if expected := "https://foo.blob.core.windows.net/"; output != expected { + t.Fatalf("Wrong endpoint url. Expected: '%s', got: '%s'", expected, output) + } +} + +func TestGetEndpoint_PathOnly(t *testing.T) { + cli, err := NewBasicClient("foo", "YmFy") + if err != nil { + t.Fatal(err) + } + output := cli.getEndpoint(blobServiceName, "path", url.Values{}) + + if expected := "https://foo.blob.core.windows.net/path"; output != expected { + t.Fatalf("Wrong endpoint url. Expected: '%s', got: '%s'", expected, output) + } +} + +func TestGetEndpoint_ParamsOnly(t *testing.T) { + cli, err := NewBasicClient("foo", "YmFy") + if err != nil { + t.Fatal(err) + } + params := url.Values{} + params.Set("a", "b") + params.Set("c", "d") + output := cli.getEndpoint(blobServiceName, "", params) + + if expected := "https://foo.blob.core.windows.net/?a=b&c=d"; output != expected { + t.Fatalf("Wrong endpoint url. Expected: '%s', got: '%s'", expected, output) + } +} + +func TestGetEndpoint_Mixed(t *testing.T) { + cli, err := NewBasicClient("foo", "YmFy") + if err != nil { + t.Fatal(err) + } + params := url.Values{} + params.Set("a", "b") + params.Set("c", "d") + output := cli.getEndpoint(blobServiceName, "path", params) + + if expected := "https://foo.blob.core.windows.net/path?a=b&c=d"; output != expected { + t.Fatalf("Wrong endpoint url. Expected: '%s', got: '%s'", expected, output) + } +} + +func Test_getStandardHeaders(t *testing.T) { + cli, err := NewBasicClient("foo", "YmFy") + if err != nil { + t.Fatal(err) + } + + headers := cli.getStandardHeaders() + if len(headers) != 2 { + t.Fatal("Wrong standard header count") + } + if v, ok := headers["x-ms-version"]; !ok || v != cli.apiVersion { + t.Fatal("Wrong version header") + } + if _, ok := headers["x-ms-date"]; !ok { + t.Fatal("Missing date header") + } +} + +func Test_buildCanonicalizedResource(t *testing.T) { + cli, err := NewBasicClient("foo", "YmFy") + if err != nil { + t.Fatal(err) + } + + type test struct{ url, expected string } + tests := []test{ + {"https://foo.blob.core.windows.net/path?a=b&c=d", "/foo/path\na:b\nc:d"}, + {"https://foo.blob.core.windows.net/?comp=list", "/foo/\ncomp:list"}, + {"https://foo.blob.core.windows.net/cnt/blob", "/foo/cnt/blob"}, + } + + for _, i := range tests { + if out, err := cli.buildCanonicalizedResource(i.url); err != nil { + t.Fatal(err) + } else if out != i.expected { + t.Fatalf("Wrong canonicalized resource. Expected:\n'%s', Got:\n'%s'", i.expected, out) + } + } +} + +func Test_buildCanonicalizedHeader(t *testing.T) { + cli, err := NewBasicClient("foo", "YmFy") + if err != nil { + t.Fatal(err) + } + + type test struct { + headers map[string]string + expected string + } + tests := []test{ + {map[string]string{}, ""}, + {map[string]string{"x-ms-foo": "bar"}, "x-ms-foo:bar"}, + {map[string]string{"foo:": "bar"}, ""}, + {map[string]string{"foo:": "bar", "x-ms-foo": "bar"}, "x-ms-foo:bar"}, + {map[string]string{ + "x-ms-version": "9999-99-99", + "x-ms-blob-type": "BlockBlob"}, "x-ms-blob-type:BlockBlob\nx-ms-version:9999-99-99"}} + + for _, i := range tests { + if out := cli.buildCanonicalizedHeader(i.headers); out != i.expected { + t.Fatalf("Wrong canonicalized resource. Expected:\n'%s', Got:\n'%s'", i.expected, out) + } + } +} + +func TestReturnsStorageServiceError(t *testing.T) { + cli, err := getBlobClient() + if err != nil { + t.Fatal(err) + } + + // attempt to delete a nonexisting container + _, err = cli.deleteContainer(randContainer()) + if err == nil { + t.Fatal("Service has not returned an error") + } + + if v, ok := err.(StorageServiceError); !ok { + t.Fatal("Cannot assert to specific error") + } else if v.StatusCode != 404 { + t.Fatalf("Expected status:%d, got: %d", 404, v.StatusCode) + } else if v.Code != "ContainerNotFound" { + t.Fatalf("Expected code: %s, got: %s", "ContainerNotFound", v.Code) + } else if v.RequestId == "" { + t.Fatalf("RequestId does not exist") + } +} + +func Test_createAuthorizationHeader(t *testing.T) { + key := base64.StdEncoding.EncodeToString([]byte("bar")) + cli, err := NewBasicClient("foo", key) + if err != nil { + t.Fatal(err) + } + + canonicalizedString := `foobarzoo` + expected := `SharedKey foo:h5U0ATVX6SpbFX1H6GNuxIMeXXCILLoIvhflPtuQZ30=` + + if out := cli.createAuthorizationHeader(canonicalizedString); out != expected { + t.Fatalf("Wrong authorization header. Expected: '%s', Got:'%s'", expected, out) + } +} diff --git a/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/util.go b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/util.go new file mode 100644 index 000000000..8a0f7b945 --- /dev/null +++ b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/util.go @@ -0,0 +1,63 @@ +package storage + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/xml" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "time" +) + +func (c StorageClient) computeHmac256(message string) string { + h := hmac.New(sha256.New, c.accountKey) + h.Write([]byte(message)) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func currentTimeRfc1123Formatted() string { + return timeRfc1123Formatted(time.Now().UTC()) +} + +func timeRfc1123Formatted(t time.Time) string { + return t.Format(http.TimeFormat) +} + +func mergeParams(v1, v2 url.Values) url.Values { + out := url.Values{} + for k, v := range v1 { + out[k] = v + } + for k, v := range v2 { + vals, ok := out[k] + if ok { + vals = append(vals, v...) + out[k] = vals + } else { + out[k] = v + } + } + return out +} + +func prepareBlockListRequest(blocks []Block) string { + s := `` + for _, v := range blocks { + s += fmt.Sprintf("<%s>%s", v.Status, v.Id, v.Status) + } + s += `` + return s +} + +func xmlUnmarshal(body io.ReadCloser, v interface{}) error { + data, err := ioutil.ReadAll(body) + if err != nil { + return err + } + defer body.Close() + return xml.Unmarshal(data, v) +} diff --git a/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/util_test.go b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/util_test.go new file mode 100644 index 000000000..d1b2c7949 --- /dev/null +++ b/Godeps/_workspace/src/github.com/MSOpenTech/azure-sdk-for-go/clients/storage/util_test.go @@ -0,0 +1,80 @@ +package storage + +import ( + "io/ioutil" + "net/url" + "reflect" + "strings" + "testing" + "time" +) + +func Test_timeRfc1123Formatted(t *testing.T) { + now := time.Now().UTC() + + expectedLayout := "Mon, 02 Jan 2006 15:04:05 GMT" + expected := now.Format(expectedLayout) + + if output := timeRfc1123Formatted(now); output != expected { + t.Errorf("Expected: %s, got: %s", expected, output) + } +} + +func Test_mergeParams(t *testing.T) { + v1 := url.Values{ + "k1": {"v1"}, + "k2": {"v2"}} + v2 := url.Values{ + "k1": {"v11"}, + "k3": {"v3"}} + + out := mergeParams(v1, v2) + if v := out.Get("k1"); v != "v1" { + t.Errorf("Wrong value for k1: %s", v) + } + + if v := out.Get("k2"); v != "v2" { + t.Errorf("Wrong value for k2: %s", v) + } + + if v := out.Get("k3"); v != "v3" { + t.Errorf("Wrong value for k3: %s", v) + } + + if v := out["k1"]; !reflect.DeepEqual(v, []string{"v1", "v11"}) { + t.Errorf("Wrong multi-value for k1: %s", v) + } +} + +func Test_prepareBlockListRequest(t *testing.T) { + empty := []Block{} + expected := `` + if out := prepareBlockListRequest(empty); expected != out { + t.Errorf("Wrong block list. Expected: '%s', got: '%s'", expected, out) + } + + blocks := []Block{{"foo", BlockStatusLatest}, {"bar", BlockStatusUncommitted}} + expected = `foobar` + if out := prepareBlockListRequest(blocks); expected != out { + t.Errorf("Wrong block list. Expected: '%s', got: '%s'", expected, out) + } +} + +func Test_xmlUnmarshal(t *testing.T) { + xml := ` + + myblob + ` + + body := ioutil.NopCloser(strings.NewReader(xml)) + + var blob Blob + err := xmlUnmarshal(body, &blob) + if err != nil { + t.Fatal(err) + } + + if blob.Name != "myblob" { + t.Fatal("Got wrong value") + } +} diff --git a/cmd/registry-storagedriver-azure/main.go b/cmd/registry-storagedriver-azure/main.go index 17881b50e..71b1faaf6 100644 --- a/cmd/registry-storagedriver-azure/main.go +++ b/cmd/registry-storagedriver-azure/main.go @@ -14,7 +14,7 @@ import ( // An out-of-process Azure Storage driver, intended to be run by ipc.NewDriverClient func main() { parametersBytes := []byte(os.Args[1]) - var parameters map[string]string + var parameters map[string]interface{} err := json.Unmarshal(parametersBytes, ¶meters) if err != nil { panic(err) diff --git a/storagedriver/azure/azure.go b/storagedriver/azure/azure.go index ee3230ff7..b6c42f666 100644 --- a/storagedriver/azure/azure.go +++ b/storagedriver/azure/azure.go @@ -1,19 +1,18 @@ -// +build ignore - // Package azure provides a storagedriver.StorageDriver implementation to // store blobs in Microsoft Azure Blob Storage Service. package azure import ( "bytes" - "encoding/base64" "fmt" "io" "io/ioutil" - "strconv" + "net/http" "strings" + "time" "github.com/docker/distribution/storagedriver" + "github.com/docker/distribution/storagedriver/base" "github.com/docker/distribution/storagedriver/factory" azure "github.com/MSOpenTech/azure-sdk-for-go/clients/storage" @@ -27,41 +26,45 @@ const ( paramContainer = "container" ) -// Driver is a storagedriver.StorageDriver implementation backed by -// Microsoft Azure Blob Storage Service. -type Driver struct { - client *azure.BlobStorageClient +type driver struct { + client azure.BlobStorageClient container string } +type baseEmbed struct{ base.Base } + +// Driver is a storagedriver.StorageDriver implementation backed by +// Microsoft Azure Blob Storage Service. +type Driver struct{ baseEmbed } + func init() { factory.Register(driverName, &azureDriverFactory{}) } type azureDriverFactory struct{} -func (factory *azureDriverFactory) Create(parameters map[string]string) (storagedriver.StorageDriver, error) { +func (factory *azureDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { return FromParameters(parameters) } // FromParameters constructs a new Driver with a given parameters map. -func FromParameters(parameters map[string]string) (*Driver, error) { +func FromParameters(parameters map[string]interface{}) (*Driver, error) { accountName, ok := parameters[paramAccountName] - if !ok { + if !ok || fmt.Sprint(accountName) == "" { return nil, fmt.Errorf("No %s parameter provided", paramAccountName) } accountKey, ok := parameters[paramAccountKey] - if !ok { + if !ok || fmt.Sprint(accountKey) == "" { return nil, fmt.Errorf("No %s parameter provided", paramAccountKey) } container, ok := parameters[paramContainer] - if !ok { + if !ok || fmt.Sprint(container) == "" { return nil, fmt.Errorf("No %s parameter provided", paramContainer) } - return New(accountName, accountKey, container) + return New(fmt.Sprint(accountName), fmt.Sprint(accountKey), fmt.Sprint(container)) } // New constructs a new Driver with the given Azure Storage Account credentials @@ -78,15 +81,16 @@ func New(accountName, accountKey, container string) (*Driver, error) { return nil, err } - return &Driver{ - client: blobClient, - container: container}, nil + d := &driver{ + client: *blobClient, + container: container} + return &Driver{baseEmbed: baseEmbed{Base: base.Base{StorageDriver: d}}}, nil } // Implement the storagedriver.StorageDriver interface. // GetContent retrieves the content stored at "path" as a []byte. -func (d *Driver) GetContent(path string) ([]byte, error) { +func (d *driver) GetContent(path string) ([]byte, error) { blob, err := d.client.GetBlob(d.container, path) if err != nil { if is404(err) { @@ -99,26 +103,27 @@ func (d *Driver) GetContent(path string) ([]byte, error) { } // PutContent stores the []byte content at a location designated by "path". -func (d *Driver) PutContent(path string, contents []byte) error { +func (d *driver) PutContent(path string, contents []byte) error { return d.client.PutBlockBlob(d.container, path, ioutil.NopCloser(bytes.NewReader(contents))) } // ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // given byte offset. -func (d *Driver) ReadStream(path string, offset int64) (io.ReadCloser, error) { +func (d *driver) ReadStream(path string, offset int64) (io.ReadCloser, error) { if ok, err := d.client.BlobExists(d.container, path); err != nil { return nil, err } else if !ok { return nil, storagedriver.PathNotFoundError{Path: path} } - size, err := d.CurrentSize(path) + info, err := d.client.GetBlobProperties(d.container, path) if err != nil { return nil, err } - if offset >= int64(size) { - return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset} + size := int64(info.ContentLength) + if offset >= size { + return ioutil.NopCloser(bytes.NewReader(nil)), nil } bytesRange := fmt.Sprintf("%v-", offset) @@ -131,91 +136,77 @@ func (d *Driver) ReadStream(path string, offset int64) (io.ReadCloser, error) { // WriteStream stores the contents of the provided io.ReadCloser at a location // designated by the given path. -func (d *Driver) WriteStream(path string, offset, size int64, reader io.ReadCloser) error { - var ( - lastBlockNum int - resumableOffset int64 - blocks []azure.Block - ) - +func (d *driver) WriteStream(path string, offset int64, reader io.Reader) (int64, error) { if blobExists, err := d.client.BlobExists(d.container, path); err != nil { - return err - } else if !blobExists { // new blob - lastBlockNum = 0 - resumableOffset = 0 - } else { // append - if parts, err := d.client.GetBlockList(d.container, path, azure.BlockListTypeCommitted); err != nil { - return err - } else if len(parts.CommittedBlocks) == 0 { - lastBlockNum = 0 - resumableOffset = 0 - } else { - lastBlock := parts.CommittedBlocks[len(parts.CommittedBlocks)-1] - if lastBlockNum, err = blockNum(lastBlock.Name); err != nil { - return fmt.Errorf("Cannot parse block name as number '%s': %s", lastBlock.Name, err.Error()) - } - - var totalSize int64 - for _, v := range parts.CommittedBlocks { - blocks = append(blocks, azure.Block{ - Id: v.Name, - Status: azure.BlockStatusCommitted}) - totalSize += int64(v.Size) - } - - // NOTE: Azure driver currently supports only append mode (resumable - // index is exactly where the committed blocks of the blob end). - // In order to support writing to offsets other than last index, - // adjacent blocks overlapping with the [offset:offset+size] area - // must be fetched, splitted and should be overwritten accordingly. - // As the current use of this method is append only, that implementation - // is omitted. - resumableOffset = totalSize + return 0, err + } else if !blobExists { + err := d.client.CreateBlockBlob(d.container, path) + if err != nil { + return 0, err } } - - if offset != resumableOffset { - return storagedriver.InvalidOffsetError{Path: path, Offset: offset} + if offset < 0 { + return 0, storagedriver.InvalidOffsetError{Path: path, Offset: offset} } - // Put content - buf := make([]byte, azure.MaxBlobBlockSize) - for { - // Read chunks of exactly size N except the last chunk to - // maximize block size and minimize block count. - n, err := io.ReadFull(reader, buf) - if err == io.EOF { - break - } - - data := buf[:n] - blockID := toBlockID(lastBlockNum + 1) - if err = d.client.PutBlock(d.container, path, blockID, data); err != nil { - return err - } - blocks = append(blocks, azure.Block{ - Id: blockID, - Status: azure.BlockStatusLatest}) - lastBlockNum++ - } - - // Commit block list - return d.client.PutBlockList(d.container, path, blocks) + bs := newAzureBlockStorage(d.client) + bw := newRandomBlobWriter(&bs, azure.MaxBlobBlockSize) + zw := newZeroFillWriter(&bw) + return zw.Write(d.container, path, offset, reader) } -// CurrentSize retrieves the curernt size in bytes of the object at the given -// path. -func (d *Driver) CurrentSize(path string) (uint64, error) { - props, err := d.client.GetBlobProperties(d.container, path) - if err != nil { - return 0, err +// Stat retrieves the FileInfo for the given path, including the current size +// in bytes and the creation time. +func (d *driver) Stat(path string) (storagedriver.FileInfo, error) { + // Check if the path is a blob + if ok, err := d.client.BlobExists(d.container, path); err != nil { + return nil, err + } else if ok { + blob, err := d.client.GetBlobProperties(d.container, path) + if err != nil { + return nil, err + } + + mtim, err := time.Parse(http.TimeFormat, blob.LastModified) + if err != nil { + return nil, err + } + + return storagedriver.FileInfoInternal{FileInfoFields: storagedriver.FileInfoFields{ + Path: path, + Size: int64(blob.ContentLength), + ModTime: mtim, + IsDir: false, + }}, nil } - return props.ContentLength, nil + + // Check if path is a virtual container + virtContainerPath := path + if !strings.HasSuffix(virtContainerPath, "/") { + virtContainerPath += "/" + } + blobs, err := d.client.ListBlobs(d.container, azure.ListBlobsParameters{ + Prefix: virtContainerPath, + MaxResults: 1, + }) + if err != nil { + return nil, err + } + if len(blobs.Blobs) > 0 { + // path is a virtual container + return storagedriver.FileInfoInternal{FileInfoFields: storagedriver.FileInfoFields{ + Path: path, + IsDir: true, + }}, nil + } + + // path is not a blob or virtual container + return nil, storagedriver.PathNotFoundError{Path: path} } // List returns a list of the objects that are direct descendants of the given // path. -func (d *Driver) List(path string) ([]string, error) { +func (d *driver) List(path string) ([]string, error) { if path == "/" { path = "" } @@ -231,7 +222,7 @@ func (d *Driver) List(path string) ([]string, error) { // Move moves an object stored at sourcePath to destPath, removing the original // object. -func (d *Driver) Move(sourcePath string, destPath string) error { +func (d *driver) Move(sourcePath string, destPath string) error { sourceBlobURL := d.client.GetBlobUrl(d.container, sourcePath) err := d.client.CopyBlob(d.container, destPath, sourceBlobURL) if err != nil { @@ -245,7 +236,7 @@ func (d *Driver) Move(sourcePath string, destPath string) error { } // Delete recursively deletes all objects stored at "path" and its subpaths. -func (d *Driver) Delete(path string) error { +func (d *driver) Delete(path string) error { ok, err := d.client.DeleteBlobIfExists(d.container, path) if err != nil { return err @@ -272,6 +263,21 @@ func (d *Driver) Delete(path string) error { return nil } +// URLFor returns a publicly accessible URL for the blob stored at given path +// for specified duration by making use of Azure Storage Shared Access Signatures (SAS). +// See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx for more info. +func (d *driver) URLFor(path string, options map[string]interface{}) (string, error) { + expiresTime := time.Now().UTC().Add(20 * time.Minute) // default expiration + expires, ok := options["expiry"] + if ok { + t, ok := expires.(time.Time) + if ok { + expiresTime = t + } + } + return d.client.GetBlobSASURI(d.container, path, expiresTime, "r") +} + // directDescendants will find direct descendants (blobs or virtual containers) // of from list of blob paths and will return their full paths. Elements in blobs // list must be prefixed with a "/" and @@ -306,7 +312,7 @@ func directDescendants(blobs []string, prefix string) []string { return keys } -func (d *Driver) listBlobs(container, virtPath string) ([]string, error) { +func (d *driver) listBlobs(container, virtPath string) ([]string, error) { if virtPath != "" && !strings.HasSuffix(virtPath, "/") { // containerify the path virtPath += "/" } @@ -339,16 +345,3 @@ func is404(err error) bool { e, ok := err.(azure.StorageServiceError) return ok && e.StatusCode == 404 } - -func blockNum(b64Name string) (int, error) { - s, err := base64.StdEncoding.DecodeString(b64Name) - if err != nil { - return 0, err - } - - return strconv.Atoi(string(s)) -} - -func toBlockID(i int) string { - return base64.StdEncoding.EncodeToString([]byte(strconv.Itoa(i))) -} diff --git a/storagedriver/azure/azure_test.go b/storagedriver/azure/azure_test.go index 4e8ac59de..170e20f87 100644 --- a/storagedriver/azure/azure_test.go +++ b/storagedriver/azure/azure_test.go @@ -1,5 +1,3 @@ -// +build ignore - package azure import ( @@ -59,9 +57,9 @@ func init() { } testsuites.RegisterInProcessSuite(azureDriverConstructor, skipCheck) - testsuites.RegisterIPCSuite(driverName, map[string]string{ - paramAccountName: accountName, - paramAccountKey: accountKey, - paramContainer: container, - }, skipCheck) + // testsuites.RegisterIPCSuite(driverName, map[string]string{ + // paramAccountName: accountName, + // paramAccountKey: accountKey, + // paramContainer: container, + // }, skipCheck) } diff --git a/storagedriver/azure/blockblob.go b/storagedriver/azure/blockblob.go new file mode 100644 index 000000000..d868453f1 --- /dev/null +++ b/storagedriver/azure/blockblob.go @@ -0,0 +1,24 @@ +package azure + +import ( + "fmt" + "io" + + azure "github.com/MSOpenTech/azure-sdk-for-go/clients/storage" +) + +// azureBlockStorage is adaptor between azure.BlobStorageClient and +// blockStorage interface. +type azureBlockStorage struct { + azure.BlobStorageClient +} + +func (b *azureBlockStorage) GetSectionReader(container, blob string, start, length int64) (io.ReadCloser, error) { + return b.BlobStorageClient.GetBlobRange(container, blob, fmt.Sprintf("%v-%v", start, start+length-1)) +} + +func newAzureBlockStorage(b azure.BlobStorageClient) azureBlockStorage { + a := azureBlockStorage{} + a.BlobStorageClient = b + return a +} diff --git a/storagedriver/azure/blockblob_test.go b/storagedriver/azure/blockblob_test.go new file mode 100644 index 000000000..f1e390277 --- /dev/null +++ b/storagedriver/azure/blockblob_test.go @@ -0,0 +1,155 @@ +package azure + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + + azure "github.com/MSOpenTech/azure-sdk-for-go/clients/storage" +) + +type StorageSimulator struct { + blobs map[string]*BlockBlob +} + +type BlockBlob struct { + blocks map[string]*DataBlock + blockList []string +} + +type DataBlock struct { + data []byte + committed bool +} + +func (s *StorageSimulator) path(container, blob string) string { + return fmt.Sprintf("%s/%s", container, blob) +} + +func (s *StorageSimulator) BlobExists(container, blob string) (bool, error) { + _, ok := s.blobs[s.path(container, blob)] + return ok, nil +} + +func (s *StorageSimulator) GetBlob(container, blob string) (io.ReadCloser, error) { + bb, ok := s.blobs[s.path(container, blob)] + if !ok { + return nil, fmt.Errorf("blob not found") + } + + var readers []io.Reader + for _, bID := range bb.blockList { + readers = append(readers, bytes.NewReader(bb.blocks[bID].data)) + } + return ioutil.NopCloser(io.MultiReader(readers...)), nil +} + +func (s *StorageSimulator) GetSectionReader(container, blob string, start, length int64) (io.ReadCloser, error) { + r, err := s.GetBlob(container, blob) + if err != nil { + return nil, err + } + b, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + return ioutil.NopCloser(bytes.NewReader(b[start : start+length])), nil +} + +func (s *StorageSimulator) CreateBlockBlob(container, blob string) error { + path := s.path(container, blob) + bb := &BlockBlob{ + blocks: make(map[string]*DataBlock), + blockList: []string{}, + } + s.blobs[path] = bb + return nil +} + +func (s *StorageSimulator) PutBlock(container, blob, blockID string, chunk []byte) error { + path := s.path(container, blob) + bb, ok := s.blobs[path] + if !ok { + return fmt.Errorf("blob not found") + } + data := make([]byte, len(chunk)) + copy(data, chunk) + bb.blocks[blockID] = &DataBlock{data: data, committed: false} // add block to blob + return nil +} + +func (s *StorageSimulator) GetBlockList(container, blob string, blockType azure.BlockListType) (azure.BlockListResponse, error) { + resp := azure.BlockListResponse{} + bb, ok := s.blobs[s.path(container, blob)] + if !ok { + return resp, fmt.Errorf("blob not found") + } + + // Iterate committed blocks (in order) + if blockType == azure.BlockListTypeAll || blockType == azure.BlockListTypeCommitted { + for _, blockID := range bb.blockList { + b := bb.blocks[blockID] + block := azure.BlockResponse{ + Name: blockID, + Size: int64(len(b.data)), + } + resp.CommittedBlocks = append(resp.CommittedBlocks, block) + } + + } + + // Iterate uncommitted blocks (in no order) + if blockType == azure.BlockListTypeAll || blockType == azure.BlockListTypeCommitted { + for blockID, b := range bb.blocks { + block := azure.BlockResponse{ + Name: blockID, + Size: int64(len(b.data)), + } + if !b.committed { + resp.UncommittedBlocks = append(resp.UncommittedBlocks, block) + } + } + } + return resp, nil +} + +func (s *StorageSimulator) PutBlockList(container, blob string, blocks []azure.Block) error { + bb, ok := s.blobs[s.path(container, blob)] + if !ok { + return fmt.Errorf("blob not found") + } + + var blockIDs []string + for _, v := range blocks { + bl, ok := bb.blocks[v.Id] + if !ok { // check if block ID exists + return fmt.Errorf("Block id '%s' not found", v.Id) + } + bl.committed = true + blockIDs = append(blockIDs, v.Id) + } + + // Mark all other blocks uncommitted + for k, b := range bb.blocks { + inList := false + for _, v := range blockIDs { + if k == v { + inList = true + break + } + } + if !inList { + b.committed = false + } + } + + bb.blockList = blockIDs + return nil +} + +func NewStorageSimulator() StorageSimulator { + return StorageSimulator{ + blobs: make(map[string]*BlockBlob), + } +} diff --git a/storagedriver/azure/blockid.go b/storagedriver/azure/blockid.go new file mode 100644 index 000000000..61f41ebcf --- /dev/null +++ b/storagedriver/azure/blockid.go @@ -0,0 +1,60 @@ +package azure + +import ( + "encoding/base64" + "fmt" + "math/rand" + "sync" + "time" + + azure "github.com/MSOpenTech/azure-sdk-for-go/clients/storage" +) + +type blockIDGenerator struct { + pool map[string]bool + r *rand.Rand + m sync.Mutex +} + +// Generate returns an unused random block id and adds the generated ID +// to list of used IDs so that the same block name is not used again. +func (b *blockIDGenerator) Generate() string { + b.m.Lock() + defer b.m.Unlock() + + var id string + for { + id = toBlockID(int(b.r.Int())) + if !b.exists(id) { + break + } + } + b.pool[id] = true + return id +} + +func (b *blockIDGenerator) exists(id string) bool { + _, used := b.pool[id] + return used +} + +func (b *blockIDGenerator) Feed(blocks azure.BlockListResponse) { + b.m.Lock() + defer b.m.Unlock() + + for _, bl := range append(blocks.CommittedBlocks, blocks.UncommittedBlocks...) { + b.pool[bl.Name] = true + } +} + +func newBlockIDGenerator() *blockIDGenerator { + return &blockIDGenerator{ + pool: make(map[string]bool), + r: rand.New(rand.NewSource(time.Now().UnixNano()))} +} + +// toBlockId converts given integer to base64-encoded block ID of a fixed length. +func toBlockID(i int) string { + s := fmt.Sprintf("%029d", i) // add zero padding for same length-blobs + return base64.StdEncoding.EncodeToString([]byte(s)) +} diff --git a/storagedriver/azure/blockid_test.go b/storagedriver/azure/blockid_test.go new file mode 100644 index 000000000..46d52a342 --- /dev/null +++ b/storagedriver/azure/blockid_test.go @@ -0,0 +1,74 @@ +package azure + +import ( + "math" + "testing" + + azure "github.com/MSOpenTech/azure-sdk-for-go/clients/storage" +) + +func Test_blockIdGenerator(t *testing.T) { + r := newBlockIDGenerator() + + for i := 1; i <= 10; i++ { + if expected := i - 1; len(r.pool) != expected { + t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected) + } + if id := r.Generate(); id == "" { + t.Fatal("returned empty id") + } + if expected := i; len(r.pool) != expected { + t.Fatalf("rand pool has wrong number of items: %d, expected:%d", len(r.pool), expected) + } + } +} + +func Test_blockIdGenerator_Feed(t *testing.T) { + r := newBlockIDGenerator() + if expected := 0; len(r.pool) != expected { + t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected) + } + + // feed empty list + blocks := azure.BlockListResponse{} + r.Feed(blocks) + if expected := 0; len(r.pool) != expected { + t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected) + } + + // feed blocks + blocks = azure.BlockListResponse{ + CommittedBlocks: []azure.BlockResponse{ + {"1", 1}, + {"2", 2}, + }, + UncommittedBlocks: []azure.BlockResponse{ + {"3", 3}, + }} + r.Feed(blocks) + if expected := 3; len(r.pool) != expected { + t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected) + } + + // feed same block IDs with committed/uncommitted place changed + blocks = azure.BlockListResponse{ + CommittedBlocks: []azure.BlockResponse{ + {"3", 3}, + }, + UncommittedBlocks: []azure.BlockResponse{ + {"1", 1}, + }} + r.Feed(blocks) + if expected := 3; len(r.pool) != expected { + t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected) + } +} + +func Test_toBlockId(t *testing.T) { + min := 0 + max := math.MaxInt64 + + if len(toBlockID(min)) != len(toBlockID(max)) { + t.Fatalf("different-sized blockIDs are returned") + } +} diff --git a/storagedriver/azure/randomwriter.go b/storagedriver/azure/randomwriter.go new file mode 100644 index 000000000..c89dd0a34 --- /dev/null +++ b/storagedriver/azure/randomwriter.go @@ -0,0 +1,208 @@ +package azure + +import ( + "fmt" + "io" + "io/ioutil" + + azure "github.com/MSOpenTech/azure-sdk-for-go/clients/storage" +) + +// blockStorage is the interface required from a block storage service +// client implementation +type blockStorage interface { + CreateBlockBlob(container, blob string) error + GetBlob(container, blob string) (io.ReadCloser, error) + GetSectionReader(container, blob string, start, length int64) (io.ReadCloser, error) + PutBlock(container, blob, blockID string, chunk []byte) error + GetBlockList(container, blob string, blockType azure.BlockListType) (azure.BlockListResponse, error) + PutBlockList(container, blob string, blocks []azure.Block) error +} + +// randomBlobWriter enables random access semantics on Azure block blobs +// by enabling writing arbitrary length of chunks to arbitrary write offsets +// within the blob. Normally, Azure Blob Storage does not support random +// access semantics on block blobs; however, this writer can download, split and +// reupload the overlapping blocks and discards those being overwritten entirely. +type randomBlobWriter struct { + bs blockStorage + blockSize int +} + +func newRandomBlobWriter(bs blockStorage, blockSize int) randomBlobWriter { + return randomBlobWriter{bs: bs, blockSize: blockSize} +} + +// WriteBlobAt writes the given chunk to the specified position of an existing blob. +// The offset must be equals to size of the blob or smaller than it. +func (r *randomBlobWriter) WriteBlobAt(container, blob string, offset int64, chunk io.Reader) (int64, error) { + rand := newBlockIDGenerator() + + blocks, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeCommitted) + if err != nil { + return 0, err + } + rand.Feed(blocks) // load existing block IDs + + // Check for write offset for existing blob + size := getBlobSize(blocks) + if offset < 0 || offset > size { + return 0, fmt.Errorf("wrong offset for Write: %v", offset) + } + + // Upload the new chunk as blocks + blockList, nn, err := r.writeChunkToBlocks(container, blob, chunk, rand) + if err != nil { + return 0, err + } + + // For non-append operations, existing blocks may need to be splitted + if offset != size { + // Split the block on the left end (if any) + leftBlocks, err := r.blocksLeftSide(container, blob, offset, rand) + if err != nil { + return 0, err + } + blockList = append(leftBlocks, blockList...) + + // Split the block on the right end (if any) + rightBlocks, err := r.blocksRightSide(container, blob, offset, nn, rand) + if err != nil { + return 0, err + } + blockList = append(blockList, rightBlocks...) + } else { + // Use existing block list + var existingBlocks []azure.Block + for _, v := range blocks.CommittedBlocks { + existingBlocks = append(existingBlocks, azure.Block{Id: v.Name, Status: azure.BlockStatusCommitted}) + } + blockList = append(existingBlocks, blockList...) + } + // Put block list + return nn, r.bs.PutBlockList(container, blob, blockList) +} + +func (r *randomBlobWriter) GetSize(container, blob string) (int64, error) { + blocks, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeCommitted) + if err != nil { + return 0, err + } + return getBlobSize(blocks), nil +} + +// writeChunkToBlocks writes given chunk to one or multiple blocks within specified +// blob and returns their block representations. Those blocks are not committed, yet +func (r *randomBlobWriter) writeChunkToBlocks(container, blob string, chunk io.Reader, rand *blockIDGenerator) ([]azure.Block, int64, error) { + var newBlocks []azure.Block + var nn int64 + + // Read chunks of at most size N except the last chunk to + // maximize block size and minimize block count. + buf := make([]byte, r.blockSize) + for { + n, err := io.ReadFull(chunk, buf) + if err == io.EOF { + break + } + nn += int64(n) + data := buf[:n] + blockID := rand.Generate() + if err := r.bs.PutBlock(container, blob, blockID, data); err != nil { + return newBlocks, nn, err + } + newBlocks = append(newBlocks, azure.Block{Id: blockID, Status: azure.BlockStatusUncommitted}) + } + return newBlocks, nn, nil +} + +// blocksLeftSide returns the blocks that are going to be at the left side of +// the writeOffset: [0, writeOffset) by identifying blocks that will remain +// the same and splitting blocks and reuploading them as needed. +func (r *randomBlobWriter) blocksLeftSide(container, blob string, writeOffset int64, rand *blockIDGenerator) ([]azure.Block, error) { + var left []azure.Block + bx, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeAll) + if err != nil { + return left, err + } + + o := writeOffset + elapsed := int64(0) + for _, v := range bx.CommittedBlocks { + blkSize := int64(v.Size) + if o >= blkSize { // use existing block + left = append(left, azure.Block{Id: v.Name, Status: azure.BlockStatusCommitted}) + o -= blkSize + elapsed += blkSize + } else if o > 0 { // current block needs to be splitted + start := elapsed + size := o + part, err := r.bs.GetSectionReader(container, blob, start, size) + if err != nil { + return left, err + } + newBlockID := rand.Generate() + + data, err := ioutil.ReadAll(part) + if err != nil { + return left, err + } + if err = r.bs.PutBlock(container, blob, newBlockID, data); err != nil { + return left, err + } + left = append(left, azure.Block{Id: newBlockID, Status: azure.BlockStatusUncommitted}) + break + } + } + return left, nil +} + +// blocksRightSide returns the blocks that are going to be at the right side of +// the written chunk: [writeOffset+size, +inf) by identifying blocks that will remain +// the same and splitting blocks and reuploading them as needed. +func (r *randomBlobWriter) blocksRightSide(container, blob string, writeOffset int64, chunkSize int64, rand *blockIDGenerator) ([]azure.Block, error) { + var right []azure.Block + + bx, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeAll) + if err != nil { + return nil, err + } + + re := writeOffset + chunkSize - 1 // right end of written chunk + var elapsed int64 + for _, v := range bx.CommittedBlocks { + var ( + bs = elapsed // left end of current block + be = elapsed + int64(v.Size) - 1 // right end of current block + ) + + if bs > re { // take the block as is + right = append(right, azure.Block{Id: v.Name, Status: azure.BlockStatusCommitted}) + } else if be > re { // current block needs to be splitted + part, err := r.bs.GetSectionReader(container, blob, re+1, be-(re+1)+1) + if err != nil { + return right, err + } + newBlockID := rand.Generate() + + data, err := ioutil.ReadAll(part) + if err != nil { + return right, err + } + if err = r.bs.PutBlock(container, blob, newBlockID, data); err != nil { + return right, err + } + right = append(right, azure.Block{Id: newBlockID, Status: azure.BlockStatusUncommitted}) + } + elapsed += int64(v.Size) + } + return right, nil +} + +func getBlobSize(blocks azure.BlockListResponse) int64 { + var n int64 + for _, v := range blocks.CommittedBlocks { + n += int64(v.Size) + } + return n +} diff --git a/storagedriver/azure/randomwriter_test.go b/storagedriver/azure/randomwriter_test.go new file mode 100644 index 000000000..5201e3b49 --- /dev/null +++ b/storagedriver/azure/randomwriter_test.go @@ -0,0 +1,339 @@ +package azure + +import ( + "bytes" + "io" + "io/ioutil" + "math/rand" + "reflect" + "strings" + "testing" + + azure "github.com/MSOpenTech/azure-sdk-for-go/clients/storage" +) + +func TestRandomWriter_writeChunkToBlocks(t *testing.T) { + s := NewStorageSimulator() + rw := newRandomBlobWriter(&s, 3) + rand := newBlockIDGenerator() + c := []byte("AAABBBCCCD") + + if err := rw.bs.CreateBlockBlob("a", "b"); err != nil { + t.Fatal(err) + } + bw, nn, err := rw.writeChunkToBlocks("a", "b", bytes.NewReader(c), rand) + if err != nil { + t.Fatal(err) + } + if expected := int64(len(c)); nn != expected { + t.Fatalf("wrong nn:%v, expected:%v", nn, expected) + } + if expected := 4; len(bw) != expected { + t.Fatal("unexpected written block count") + } + + bx, err := s.GetBlockList("a", "b", azure.BlockListTypeAll) + if err != nil { + t.Fatal(err) + } + if expected := 0; len(bx.CommittedBlocks) != expected { + t.Fatal("unexpected committed block count") + } + if expected := 4; len(bx.UncommittedBlocks) != expected { + t.Fatalf("unexpected uncommitted block count: %d -- %#v", len(bx.UncommittedBlocks), bx) + } + + if err := rw.bs.PutBlockList("a", "b", bw); err != nil { + t.Fatal(err) + } + + r, err := rw.bs.GetBlob("a", "b") + if err != nil { + t.Fatal(err) + } + assertBlobContents(t, r, c) +} + +func TestRandomWriter_blocksLeftSide(t *testing.T) { + blob := "AAAAABBBBBCCC" + cases := []struct { + offset int64 + expectedBlob string + expectedPattern []azure.BlockStatus + }{ + {0, "", []azure.BlockStatus{}}, // write to beginning, discard all + {13, blob, []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusCommitted, azure.BlockStatusCommitted}}, // write to end, no change + {1, "A", []azure.BlockStatus{azure.BlockStatusUncommitted}}, // write at 1 + {5, "AAAAA", []azure.BlockStatus{azure.BlockStatusCommitted}}, // write just after first block + {6, "AAAAAB", []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusUncommitted}}, // split the second block + {9, "AAAAABBBB", []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusUncommitted}}, // write just after first block + } + + for _, c := range cases { + s := NewStorageSimulator() + rw := newRandomBlobWriter(&s, 5) + rand := newBlockIDGenerator() + + if err := rw.bs.CreateBlockBlob("a", "b"); err != nil { + t.Fatal(err) + } + bw, _, err := rw.writeChunkToBlocks("a", "b", strings.NewReader(blob), rand) + if err != nil { + t.Fatal(err) + } + if err := rw.bs.PutBlockList("a", "b", bw); err != nil { + t.Fatal(err) + } + bx, err := rw.blocksLeftSide("a", "b", c.offset, rand) + if err != nil { + t.Fatal(err) + } + + bs := []azure.BlockStatus{} + for _, v := range bx { + bs = append(bs, v.Status) + } + + if !reflect.DeepEqual(bs, c.expectedPattern) { + t.Logf("Committed blocks %v", bw) + t.Fatalf("For offset %v: Expected pattern: %v, Got: %v\n(Returned: %v)", c.offset, c.expectedPattern, bs, bx) + } + if rw.bs.PutBlockList("a", "b", bx); err != nil { + t.Fatal(err) + } + r, err := rw.bs.GetBlob("a", "b") + if err != nil { + t.Fatal(err) + } + cout, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + outBlob := string(cout) + if outBlob != c.expectedBlob { + t.Fatalf("wrong blob contents: %v, expected: %v", outBlob, c.expectedBlob) + } + } +} + +func TestRandomWriter_blocksRightSide(t *testing.T) { + blob := "AAAAABBBBBCCC" + cases := []struct { + offset int64 + size int64 + expectedBlob string + expectedPattern []azure.BlockStatus + }{ + {0, 100, "", []azure.BlockStatus{}}, // overwrite the entire blob + {0, 3, "AABBBBBCCC", []azure.BlockStatus{azure.BlockStatusUncommitted, azure.BlockStatusCommitted, azure.BlockStatusCommitted}}, // split first block + {4, 1, "BBBBBCCC", []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusCommitted}}, // write to last char of first block + {1, 6, "BBBCCC", []azure.BlockStatus{azure.BlockStatusUncommitted, azure.BlockStatusCommitted}}, // overwrite splits first and second block, last block remains + {3, 8, "CC", []azure.BlockStatus{azure.BlockStatusUncommitted}}, // overwrite a block in middle block, split end block + {10, 1, "CC", []azure.BlockStatus{azure.BlockStatusUncommitted}}, // overwrite first byte of rightmost block + {11, 2, "", []azure.BlockStatus{}}, // overwrite the rightmost index + {13, 20, "", []azure.BlockStatus{}}, // append to the end + } + + for _, c := range cases { + s := NewStorageSimulator() + rw := newRandomBlobWriter(&s, 5) + rand := newBlockIDGenerator() + + if err := rw.bs.CreateBlockBlob("a", "b"); err != nil { + t.Fatal(err) + } + bw, _, err := rw.writeChunkToBlocks("a", "b", strings.NewReader(blob), rand) + if err != nil { + t.Fatal(err) + } + if err := rw.bs.PutBlockList("a", "b", bw); err != nil { + t.Fatal(err) + } + bx, err := rw.blocksRightSide("a", "b", c.offset, c.size, rand) + if err != nil { + t.Fatal(err) + } + + bs := []azure.BlockStatus{} + for _, v := range bx { + bs = append(bs, v.Status) + } + + if !reflect.DeepEqual(bs, c.expectedPattern) { + t.Logf("Committed blocks %v", bw) + t.Fatalf("For offset %v-size:%v: Expected pattern: %v, Got: %v\n(Returned: %v)", c.offset, c.size, c.expectedPattern, bs, bx) + } + if rw.bs.PutBlockList("a", "b", bx); err != nil { + t.Fatal(err) + } + r, err := rw.bs.GetBlob("a", "b") + if err != nil { + t.Fatal(err) + } + cout, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + outBlob := string(cout) + if outBlob != c.expectedBlob { + t.Fatalf("For offset %v-size:%v: wrong blob contents: %v, expected: %v", c.offset, c.size, outBlob, c.expectedBlob) + } + } +} + +func TestRandomWriter_Write_NewBlob(t *testing.T) { + var ( + s = NewStorageSimulator() + rw = newRandomBlobWriter(&s, 1024*3) // 3 KB blocks + blob = randomContents(1024 * 7) // 7 KB blob + ) + if err := rw.bs.CreateBlockBlob("a", "b"); err != nil { + t.Fatal(err) + } + + if _, err := rw.WriteBlobAt("a", "b", 10, bytes.NewReader(blob)); err == nil { + t.Fatal("expected error, got nil") + } + if _, err := rw.WriteBlobAt("a", "b", 100000, bytes.NewReader(blob)); err == nil { + t.Fatal("expected error, got nil") + } + if nn, err := rw.WriteBlobAt("a", "b", 0, bytes.NewReader(blob)); err != nil { + t.Fatal(err) + } else if expected := int64(len(blob)); expected != nn { + t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected) + } + if out, err := rw.bs.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, blob) + } + if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil { + t.Fatal(err) + } else if len(bx.CommittedBlocks) != 3 { + t.Fatalf("got wrong number of committed blocks: %v", len(bx.CommittedBlocks)) + } + + // Replace first 512 bytes + leftChunk := randomContents(512) + blob = append(leftChunk, blob[512:]...) + if nn, err := rw.WriteBlobAt("a", "b", 0, bytes.NewReader(leftChunk)); err != nil { + t.Fatal(err) + } else if expected := int64(len(leftChunk)); expected != nn { + t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected) + } + if out, err := rw.bs.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, blob) + } + if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil { + t.Fatal(err) + } else if expected := 4; len(bx.CommittedBlocks) != expected { + t.Fatalf("got wrong number of committed blocks: %v, expected: %v", len(bx.CommittedBlocks), expected) + } + + // Replace last 512 bytes with 1024 bytes + rightChunk := randomContents(1024) + offset := int64(len(blob) - 512) + blob = append(blob[:offset], rightChunk...) + if nn, err := rw.WriteBlobAt("a", "b", offset, bytes.NewReader(rightChunk)); err != nil { + t.Fatal(err) + } else if expected := int64(len(rightChunk)); expected != nn { + t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected) + } + if out, err := rw.bs.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, blob) + } + if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil { + t.Fatal(err) + } else if expected := 5; len(bx.CommittedBlocks) != expected { + t.Fatalf("got wrong number of committed blocks: %v, expected: %v", len(bx.CommittedBlocks), expected) + } + + // Replace 2K-4K (overlaps 2 blocks from L/R) + newChunk := randomContents(1024 * 2) + offset = 1024 * 2 + blob = append(append(blob[:offset], newChunk...), blob[offset+int64(len(newChunk)):]...) + if nn, err := rw.WriteBlobAt("a", "b", offset, bytes.NewReader(newChunk)); err != nil { + t.Fatal(err) + } else if expected := int64(len(newChunk)); expected != nn { + t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected) + } + if out, err := rw.bs.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, blob) + } + if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil { + t.Fatal(err) + } else if expected := 6; len(bx.CommittedBlocks) != expected { + t.Fatalf("got wrong number of committed blocks: %v, expected: %v\n%v", len(bx.CommittedBlocks), expected, bx.CommittedBlocks) + } + + // Replace the entire blob + newBlob := randomContents(1024 * 30) + if nn, err := rw.WriteBlobAt("a", "b", 0, bytes.NewReader(newBlob)); err != nil { + t.Fatal(err) + } else if expected := int64(len(newBlob)); expected != nn { + t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected) + } + if out, err := rw.bs.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, newBlob) + } + if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil { + t.Fatal(err) + } else if expected := 10; len(bx.CommittedBlocks) != expected { + t.Fatalf("got wrong number of committed blocks: %v, expected: %v\n%v", len(bx.CommittedBlocks), expected, bx.CommittedBlocks) + } else if expected, size := int64(1024*30), getBlobSize(bx); size != expected { + t.Fatalf("committed block size does not indicate blob size") + } +} + +func Test_getBlobSize(t *testing.T) { + // with some committed blocks + if expected, size := int64(151), getBlobSize(azure.BlockListResponse{ + CommittedBlocks: []azure.BlockResponse{ + {"A", 100}, + {"B", 50}, + {"C", 1}, + }, + UncommittedBlocks: []azure.BlockResponse{ + {"D", 200}, + }}); expected != size { + t.Fatalf("wrong blob size: %v, expected: %v", size, expected) + } + + // with no committed blocks + if expected, size := int64(0), getBlobSize(azure.BlockListResponse{ + UncommittedBlocks: []azure.BlockResponse{ + {"A", 100}, + {"B", 50}, + {"C", 1}, + {"D", 200}, + }}); expected != size { + t.Fatalf("wrong blob size: %v, expected: %v", size, expected) + } +} + +func assertBlobContents(t *testing.T, r io.Reader, expected []byte) { + out, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(out, expected) { + t.Fatalf("wrong blob contents. size: %v, expected: %v", len(out), len(expected)) + } +} + +func randomContents(length int64) []byte { + b := make([]byte, length) + for i := range b { + b[i] = byte(rand.Intn(2 << 8)) + } + return b +} diff --git a/storagedriver/azure/zerofillwriter.go b/storagedriver/azure/zerofillwriter.go new file mode 100644 index 000000000..095489d22 --- /dev/null +++ b/storagedriver/azure/zerofillwriter.go @@ -0,0 +1,49 @@ +package azure + +import ( + "bytes" + "io" +) + +type blockBlobWriter interface { + GetSize(container, blob string) (int64, error) + WriteBlobAt(container, blob string, offset int64, chunk io.Reader) (int64, error) +} + +// zeroFillWriter enables writing to an offset outside a block blob's size +// by offering the chunk to the underlying writer as a contiguous data with +// the gap in between filled with NUL (zero) bytes. +type zeroFillWriter struct { + blockBlobWriter +} + +func newZeroFillWriter(b blockBlobWriter) zeroFillWriter { + w := zeroFillWriter{} + w.blockBlobWriter = b + return w +} + +// Write writes the given chunk to the specified existing blob even though +// offset is out of blob's size. The gaps are filled with zeros. Returned +// written number count does not include zeros written. +func (z *zeroFillWriter) Write(container, blob string, offset int64, chunk io.Reader) (int64, error) { + size, err := z.blockBlobWriter.GetSize(container, blob) + if err != nil { + return 0, err + } + + var reader io.Reader + var zeroPadding int64 + if offset <= size { + reader = chunk + } else { + zeroPadding = offset - size + offset = size // adjust offset to be the append index + zeros := bytes.NewReader(make([]byte, zeroPadding)) + reader = io.MultiReader(zeros, chunk) + } + + nn, err := z.blockBlobWriter.WriteBlobAt(container, blob, offset, reader) + nn -= zeroPadding + return nn, err +} diff --git a/storagedriver/azure/zerofillwriter_test.go b/storagedriver/azure/zerofillwriter_test.go new file mode 100644 index 000000000..49361791a --- /dev/null +++ b/storagedriver/azure/zerofillwriter_test.go @@ -0,0 +1,126 @@ +package azure + +import ( + "bytes" + "testing" +) + +func Test_zeroFillWrite_AppendNoGap(t *testing.T) { + s := NewStorageSimulator() + bw := newRandomBlobWriter(&s, 1024*1) + zw := newZeroFillWriter(&bw) + if err := s.CreateBlockBlob("a", "b"); err != nil { + t.Fatal(err) + } + + firstChunk := randomContents(1024*3 + 512) + if nn, err := zw.Write("a", "b", 0, bytes.NewReader(firstChunk)); err != nil { + t.Fatal(err) + } else if expected := int64(len(firstChunk)); expected != nn { + t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected) + } + if out, err := s.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, firstChunk) + } + + secondChunk := randomContents(256) + if nn, err := zw.Write("a", "b", int64(len(firstChunk)), bytes.NewReader(secondChunk)); err != nil { + t.Fatal(err) + } else if expected := int64(len(secondChunk)); expected != nn { + t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected) + } + if out, err := s.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, append(firstChunk, secondChunk...)) + } + +} + +func Test_zeroFillWrite_StartWithGap(t *testing.T) { + s := NewStorageSimulator() + bw := newRandomBlobWriter(&s, 1024*2) + zw := newZeroFillWriter(&bw) + if err := s.CreateBlockBlob("a", "b"); err != nil { + t.Fatal(err) + } + + chunk := randomContents(1024 * 5) + padding := int64(1024*2 + 256) + if nn, err := zw.Write("a", "b", padding, bytes.NewReader(chunk)); err != nil { + t.Fatal(err) + } else if expected := int64(len(chunk)); expected != nn { + t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected) + } + if out, err := s.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, append(make([]byte, padding), chunk...)) + } +} + +func Test_zeroFillWrite_AppendWithGap(t *testing.T) { + s := NewStorageSimulator() + bw := newRandomBlobWriter(&s, 1024*2) + zw := newZeroFillWriter(&bw) + if err := s.CreateBlockBlob("a", "b"); err != nil { + t.Fatal(err) + } + + firstChunk := randomContents(1024*3 + 512) + if _, err := zw.Write("a", "b", 0, bytes.NewReader(firstChunk)); err != nil { + t.Fatal(err) + } + if out, err := s.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, firstChunk) + } + + secondChunk := randomContents(256) + padding := int64(1024 * 4) + if nn, err := zw.Write("a", "b", int64(len(firstChunk))+padding, bytes.NewReader(secondChunk)); err != nil { + t.Fatal(err) + } else if expected := int64(len(secondChunk)); expected != nn { + t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected) + } + if out, err := s.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, append(firstChunk, append(make([]byte, padding), secondChunk...)...)) + } +} + +func Test_zeroFillWrite_LiesWithinSize(t *testing.T) { + s := NewStorageSimulator() + bw := newRandomBlobWriter(&s, 1024*2) + zw := newZeroFillWriter(&bw) + if err := s.CreateBlockBlob("a", "b"); err != nil { + t.Fatal(err) + } + + firstChunk := randomContents(1024 * 3) + if _, err := zw.Write("a", "b", 0, bytes.NewReader(firstChunk)); err != nil { + t.Fatal(err) + } + if out, err := s.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, firstChunk) + } + + // in this case, zerofill won't be used + secondChunk := randomContents(256) + if nn, err := zw.Write("a", "b", 0, bytes.NewReader(secondChunk)); err != nil { + t.Fatal(err) + } else if expected := int64(len(secondChunk)); expected != nn { + t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected) + } + if out, err := s.GetBlob("a", "b"); err != nil { + t.Fatal(err) + } else { + assertBlobContents(t, out, append(secondChunk, firstChunk[len(secondChunk):]...)) + } +}