restic/vendor/github.com/pkg/sftp/request_test.go

183 lines
4.1 KiB
Go
Raw Normal View History

2017-07-23 12:24:45 +00:00
package sftp
import (
"sync"
"github.com/stretchr/testify/assert"
"bytes"
"errors"
"io"
"os"
"testing"
)
type testHandler struct {
filecontents []byte // dummy contents
output io.WriterAt // dummy file out
err error // dummy error, should be file related
}
func (t *testHandler) Fileread(r Request) (io.ReaderAt, error) {
if t.err != nil {
return nil, t.err
}
return bytes.NewReader(t.filecontents), nil
}
func (t *testHandler) Filewrite(r Request) (io.WriterAt, error) {
if t.err != nil {
return nil, t.err
}
return io.WriterAt(t.output), nil
}
func (t *testHandler) Filecmd(r Request) error {
if t.err != nil {
return t.err
}
return nil
}
func (t *testHandler) Fileinfo(r Request) ([]os.FileInfo, error) {
if t.err != nil {
return nil, t.err
}
f, err := os.Open(r.Filepath)
if err != nil {
return nil, err
}
fi, err := f.Stat()
if err != nil {
return nil, err
}
return []os.FileInfo{fi}, nil
}
// make sure len(fakefile) == len(filecontents)
type fakefile [10]byte
var filecontents = []byte("file-data.")
func testRequest(method string) Request {
request := Request{
Filepath: "./request_test.go",
Method: method,
Attrs: []byte("foo"),
Target: "foo",
packets: make(chan packet_data, sftpServerWorkerCount),
state: &state{},
stateLock: &sync.RWMutex{},
}
for _, p := range []packet_data{
packet_data{id: 1, data: filecontents[:5], length: 5},
packet_data{id: 2, data: filecontents[5:], length: 5, offset: 5}} {
request.packets <- p
}
return request
}
func (ff *fakefile) WriteAt(p []byte, off int64) (int, error) {
n := copy(ff[off:], p)
return n, nil
}
func (ff fakefile) string() string {
b := make([]byte, len(ff))
copy(b, ff[:])
return string(b)
}
func newTestHandlers() Handlers {
handler := &testHandler{
filecontents: filecontents,
output: &fakefile{},
err: nil,
}
return Handlers{
FileGet: handler,
FilePut: handler,
FileCmd: handler,
FileInfo: handler,
}
}
func (h Handlers) getOutString() string {
handler := h.FilePut.(*testHandler)
return handler.output.(*fakefile).string()
}
var errTest = errors.New("test error")
func (h *Handlers) returnError() {
handler := h.FilePut.(*testHandler)
handler.err = errTest
}
func statusOk(t *testing.T, p interface{}) {
if pkt, ok := p.(*sshFxpStatusPacket); ok {
assert.Equal(t, pkt.StatusError.Code, uint32(ssh_FX_OK))
}
}
func TestRequestGet(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Get")
// req.length is 5, so we test reads in 5 byte chunks
for i, txt := range []string{"file-", "data."} {
pkt, err := request.handle(handlers)
assert.Nil(t, err)
dpkt := pkt.(*sshFxpDataPacket)
assert.Equal(t, dpkt.id(), uint32(i+1))
assert.Equal(t, string(dpkt.Data), txt)
}
}
func TestRequestPut(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Put")
pkt, err := request.handle(handlers)
assert.Nil(t, err)
statusOk(t, pkt)
pkt, err = request.handle(handlers)
assert.Nil(t, err)
statusOk(t, pkt)
assert.Equal(t, "file-data.", handlers.getOutString())
}
func TestRequestCmdr(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Mkdir")
pkt, err := request.handle(handlers)
assert.Nil(t, err)
statusOk(t, pkt)
handlers.returnError()
pkt, err = request.handle(handlers)
assert.Nil(t, pkt)
assert.Equal(t, err, errTest)
}
func TestRequestInfoList(t *testing.T) { testInfoMethod(t, "List") }
func TestRequestInfoReadlink(t *testing.T) { testInfoMethod(t, "Readlink") }
func TestRequestInfoStat(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Stat")
pkt, err := request.handle(handlers)
assert.Nil(t, err)
spkt, ok := pkt.(*sshFxpStatResponse)
assert.True(t, ok)
assert.Equal(t, spkt.info.Name(), "request_test.go")
}
func testInfoMethod(t *testing.T, method string) {
handlers := newTestHandlers()
request := testRequest(method)
pkt, err := request.handle(handlers)
assert.Nil(t, err)
npkt, ok := pkt.(*sshFxpNamePacket)
assert.True(t, ok)
assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")
}