vendor: switch to using go1.11 modules

This commit is contained in:
Nick Craig-Wood 2018-08-28 15:27:07 +01:00
parent 5c75453aba
commit da1682a30e
6142 changed files with 390 additions and 5155875 deletions

View file

@ -1,45 +0,0 @@
package sftp
import (
"bytes"
"os"
"reflect"
"testing"
"time"
)
// ensure that attrs implemenst os.FileInfo
var _ os.FileInfo = new(fileInfo)
var unmarshalAttrsTests = []struct {
b []byte
want *fileInfo
rest []byte
}{
{marshal(nil, struct{ Flags uint32 }{}), &fileInfo{mtime: time.Unix(int64(0), 0)}, nil},
{marshal(nil, struct {
Flags uint32
Size uint64
}{ssh_FILEXFER_ATTR_SIZE, 20}), &fileInfo{size: 20, mtime: time.Unix(int64(0), 0)}, nil},
{marshal(nil, struct {
Flags uint32
Size uint64
Permissions uint32
}{ssh_FILEXFER_ATTR_SIZE | ssh_FILEXFER_ATTR_PERMISSIONS, 20, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil},
{marshal(nil, struct {
Flags uint32
Size uint64
UID, GID, Permissions uint32
}{ssh_FILEXFER_ATTR_SIZE | ssh_FILEXFER_ATTR_UIDGID | ssh_FILEXFER_ATTR_UIDGID | ssh_FILEXFER_ATTR_PERMISSIONS, 20, 1000, 1000, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil},
}
func TestUnmarshalAttrs(t *testing.T) {
for _, tt := range unmarshalAttrsTests {
stat, rest := unmarshalAttrs(tt.b)
got := fileInfoFromStat(stat, "")
tt.want.sys = got.Sys()
if !reflect.DeepEqual(got, tt.want) || !bytes.Equal(tt.rest, rest) {
t.Errorf("unmarshalAttrs(%#v): want %#v, %#v, got: %#v, %#v", tt.b, tt.want, tt.rest, got, rest)
}
}
}

View file

@ -1,42 +0,0 @@
package sftp
import (
"syscall"
"testing"
)
const sftpServer = "/usr/libexec/sftp-server"
func TestClientStatVFS(t *testing.T) {
if *testServerImpl {
t.Skipf("go server does not support FXP_EXTENDED")
}
sftp, cmd := testClient(t, READWRITE, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
vfs, err := sftp.StatVFS("/")
if err != nil {
t.Fatal(err)
}
// get system stats
s := syscall.Statfs_t{}
err = syscall.Statfs("/", &s)
if err != nil {
t.Fatal(err)
}
// check some stats
if vfs.Files != uint64(s.Files) {
t.Fatal("fr_size does not match")
}
if vfs.Bfree != uint64(s.Bfree) {
t.Fatal("f_bsize does not match")
}
if vfs.Favail != uint64(s.Ffree) {
t.Fatal("f_namemax does not match")
}
}

View file

@ -1,42 +0,0 @@
package sftp
import (
"syscall"
"testing"
)
const sftpServer = "/usr/lib/openssh/sftp-server"
func TestClientStatVFS(t *testing.T) {
if *testServerImpl {
t.Skipf("go server does not support FXP_EXTENDED")
}
sftp, cmd := testClient(t, READWRITE, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
vfs, err := sftp.StatVFS("/")
if err != nil {
t.Fatal(err)
}
// get system stats
s := syscall.Statfs_t{}
err = syscall.Statfs("/", &s)
if err != nil {
t.Fatal(err)
}
// check some stats
if vfs.Frsize != uint64(s.Frsize) {
t.Fatalf("fr_size does not match, expected: %v, got: %v", s.Frsize, vfs.Frsize)
}
if vfs.Bsize != uint64(s.Bsize) {
t.Fatalf("f_bsize does not match, expected: %v, got: %v", s.Bsize, vfs.Bsize)
}
if vfs.Namemax != uint64(s.Namelen) {
t.Fatalf("f_namemax does not match, expected: %v, got: %v", s.Namelen, vfs.Namemax)
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,196 +0,0 @@
package sftp
import (
"errors"
"io"
"os"
"reflect"
"testing"
"github.com/kr/fs"
)
// assert that *Client implements fs.FileSystem
var _ fs.FileSystem = new(Client)
// assert that *File implements io.ReadWriteCloser
var _ io.ReadWriteCloser = new(File)
func TestNormaliseError(t *testing.T) {
var (
ok = &StatusError{Code: ssh_FX_OK}
eof = &StatusError{Code: ssh_FX_EOF}
fail = &StatusError{Code: ssh_FX_FAILURE}
noSuchFile = &StatusError{Code: ssh_FX_NO_SUCH_FILE}
foo = errors.New("foo")
)
var tests = []struct {
desc string
err error
want error
}{
{
desc: "nil error",
},
{
desc: "not *StatusError",
err: foo,
want: foo,
},
{
desc: "*StatusError with ssh_FX_EOF",
err: eof,
want: io.EOF,
},
{
desc: "*StatusError with ssh_FX_NO_SUCH_FILE",
err: noSuchFile,
want: os.ErrNotExist,
},
{
desc: "*StatusError with ssh_FX_OK",
err: ok,
},
{
desc: "*StatusError with ssh_FX_FAILURE",
err: fail,
want: fail,
},
}
for _, tt := range tests {
got := normaliseError(tt.err)
if got != tt.want {
t.Errorf("normaliseError(%#v), test %q\n- want: %#v\n- got: %#v",
tt.err, tt.desc, tt.want, got)
}
}
}
var flagsTests = []struct {
flags int
want uint32
}{
{os.O_RDONLY, ssh_FXF_READ},
{os.O_WRONLY, ssh_FXF_WRITE},
{os.O_RDWR, ssh_FXF_READ | ssh_FXF_WRITE},
{os.O_RDWR | os.O_CREATE | os.O_TRUNC, ssh_FXF_READ | ssh_FXF_WRITE | ssh_FXF_CREAT | ssh_FXF_TRUNC},
{os.O_WRONLY | os.O_APPEND, ssh_FXF_WRITE | ssh_FXF_APPEND},
}
func TestFlags(t *testing.T) {
for i, tt := range flagsTests {
got := flags(tt.flags)
if got != tt.want {
t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got)
}
}
}
func TestUnmarshalStatus(t *testing.T) {
requestID := uint32(1)
id := marshalUint32([]byte{}, requestID)
idCode := marshalUint32(id, ssh_FX_FAILURE)
idCodeMsg := marshalString(idCode, "err msg")
idCodeMsgLang := marshalString(idCodeMsg, "lang tag")
var tests = []struct {
desc string
reqID uint32
status []byte
want error
}{
{
desc: "well-formed status",
reqID: 1,
status: idCodeMsgLang,
want: &StatusError{
Code: ssh_FX_FAILURE,
msg: "err msg",
lang: "lang tag",
},
},
{
desc: "missing error message and language tag",
reqID: 1,
status: idCode,
want: &StatusError{
Code: ssh_FX_FAILURE,
},
},
{
desc: "missing language tag",
reqID: 1,
status: idCodeMsg,
want: &StatusError{
Code: ssh_FX_FAILURE,
msg: "err msg",
},
},
{
desc: "request identifier mismatch",
reqID: 2,
status: idCodeMsgLang,
want: &unexpectedIDErr{2, requestID},
},
}
for _, tt := range tests {
got := unmarshalStatus(tt.reqID, tt.status)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("unmarshalStatus(%v, %v), test %q\n- want: %#v\n- got: %#v",
requestID, tt.status, tt.desc, tt.want, got)
}
}
}
type packetSizeTest struct {
size int
valid bool
}
var maxPacketCheckedTests = []packetSizeTest{
{size: 0, valid: false},
{size: 1, valid: true},
{size: 32768, valid: true},
{size: 32769, valid: false},
}
var maxPacketUncheckedTests = []packetSizeTest{
{size: 0, valid: false},
{size: 1, valid: true},
{size: 32768, valid: true},
{size: 32769, valid: true},
}
func TestMaxPacketChecked(t *testing.T) {
for _, tt := range maxPacketCheckedTests {
testMaxPacketOption(t, MaxPacketChecked(tt.size), tt)
}
}
func TestMaxPacketUnchecked(t *testing.T) {
for _, tt := range maxPacketUncheckedTests {
testMaxPacketOption(t, MaxPacketUnchecked(tt.size), tt)
}
}
func TestMaxPacket(t *testing.T) {
for _, tt := range maxPacketCheckedTests {
testMaxPacketOption(t, MaxPacket(tt.size), tt)
}
}
func testMaxPacketOption(t *testing.T, o ClientOption, tt packetSizeTest) {
var c Client
err := o(&c)
if (err == nil) != tt.valid {
t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.valid, err == nil)
}
if c.maxPacket != tt.size && tt.valid {
t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.size, c.maxPacket)
}
}

View file

@ -1,164 +0,0 @@
package sftp_test
import (
"bufio"
"fmt"
"io"
"log"
"os"
"os/exec"
"path"
"strings"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
func Example() {
var conn *ssh.Client
// open an SFTP session over an existing ssh connection.
sftp, err := sftp.NewClient(conn)
if err != nil {
log.Fatal(err)
}
defer sftp.Close()
// walk a directory
w := sftp.Walk("/home/user")
for w.Step() {
if w.Err() != nil {
continue
}
log.Println(w.Path())
}
// leave your mark
f, err := sftp.Create("hello.txt")
if err != nil {
log.Fatal(err)
}
if _, err := f.Write([]byte("Hello world!")); err != nil {
log.Fatal(err)
}
// check it's there
fi, err := sftp.Lstat("hello.txt")
if err != nil {
log.Fatal(err)
}
log.Println(fi)
}
func ExampleNewClientPipe() {
// Connect to a remote host and request the sftp subsystem via the 'ssh'
// command. This assumes that passwordless login is correctly configured.
cmd := exec.Command("ssh", "example.com", "-s", "sftp")
// send errors from ssh to stderr
cmd.Stderr = os.Stderr
// get stdin and stdout
wr, err := cmd.StdinPipe()
if err != nil {
log.Fatal(err)
}
rd, err := cmd.StdoutPipe()
if err != nil {
log.Fatal(err)
}
// start the process
if err := cmd.Start(); err != nil {
log.Fatal(err)
}
defer cmd.Wait()
// open the SFTP session
client, err := sftp.NewClientPipe(rd, wr)
if err != nil {
log.Fatal(err)
}
// read a directory
list, err := client.ReadDir("/")
if err != nil {
log.Fatal(err)
}
// print contents
for _, item := range list {
fmt.Println(item.Name())
}
// close the connection
client.Close()
}
func ExampleClient_Mkdir_parents() {
// Example of mimicing 'mkdir --parents'; I.E. recursively create
// directoryies and don't error if any directories already exists.
var conn *ssh.Client
client, err := sftp.NewClient(conn)
if err != nil {
log.Fatal(err)
}
defer client.Close()
sshFxFailure := uint32(4)
mkdirParents := func(client *sftp.Client, dir string) (err error) {
var parents string
if path.IsAbs(dir) {
// Otherwise, an absolute path given below would be turned in to a relative one
// by splitting on "/"
parents = "/"
}
for _, name := range strings.Split(dir, "/") {
if name == "" {
// Paths with double-/ in them should just move along
// this will also catch the case of the first character being a "/", i.e. an absolute path
continue
}
parents = path.Join(parents, name)
err = client.Mkdir(parents)
if status, ok := err.(*sftp.StatusError); ok {
if status.Code == sshFxFailure {
var fi os.FileInfo
fi, err = client.Stat(parents)
if err == nil {
if !fi.IsDir() {
return fmt.Errorf("File exists: %s", parents)
}
}
}
}
if err != nil {
break
}
}
return err
}
err = mkdirParents(client, "/tmp/foo/bar")
if err != nil {
log.Fatal(err)
}
}
func ExampleFile_ReadFrom_bufio() {
// Using Bufio to buffer writes going to an sftp.File won't buffer as it
// skips buffering if the underlying writer support ReadFrom. The
// workaround is to wrap your writer in a struct that only implements
// io.Writer.
//
// For background see github.com/pkg/sftp/issues/125
var data_source io.Reader
var f *sftp.File
type writerOnly struct{ io.Writer }
bw := bufio.NewWriter(writerOnly{f}) // no ReadFrom()
bw.ReadFrom(data_source)
}

View file

@ -1,78 +0,0 @@
// buffered-read-benchmark benchmarks the peformance of reading
// from /dev/zero on the server to a []byte on the client via io.Copy.
package main
import (
"flag"
"fmt"
"io"
"log"
"net"
"os"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"github.com/pkg/sftp"
)
var (
USER = flag.String("user", os.Getenv("USER"), "ssh username")
HOST = flag.String("host", "localhost", "ssh server hostname")
PORT = flag.Int("port", 22, "ssh server port")
PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password")
SIZE = flag.Int("s", 1<<15, "set max packet size")
)
func init() {
flag.Parse()
}
func main() {
var auths []ssh.AuthMethod
if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers))
}
if *PASS != "" {
auths = append(auths, ssh.Password(*PASS))
}
config := ssh.ClientConfig{
User: *USER,
Auth: auths,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", *HOST, *PORT)
conn, err := ssh.Dial("tcp", addr, &config)
if err != nil {
log.Fatalf("unable to connect to [%s]: %v", addr, err)
}
defer conn.Close()
c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE))
if err != nil {
log.Fatalf("unable to start sftp subsytem: %v", err)
}
defer c.Close()
r, err := c.Open("/dev/zero")
if err != nil {
log.Fatal(err)
}
defer r.Close()
const size = 1e9
log.Printf("reading %v bytes", size)
t1 := time.Now()
n, err := io.ReadFull(r, make([]byte, size))
if err != nil {
log.Fatal(err)
}
if n != size {
log.Fatalf("copy: expected %v bytes, got %d", size, n)
}
log.Printf("read %v bytes in %s", size, time.Since(t1))
}

View file

@ -1,84 +0,0 @@
// buffered-write-benchmark benchmarks the peformance of writing
// a single large []byte on the client to /dev/null on the server via io.Copy.
package main
import (
"flag"
"fmt"
"log"
"net"
"os"
"syscall"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"github.com/pkg/sftp"
)
var (
USER = flag.String("user", os.Getenv("USER"), "ssh username")
HOST = flag.String("host", "localhost", "ssh server hostname")
PORT = flag.Int("port", 22, "ssh server port")
PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password")
SIZE = flag.Int("s", 1<<15, "set max packet size")
)
func init() {
flag.Parse()
}
func main() {
var auths []ssh.AuthMethod
if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers))
}
if *PASS != "" {
auths = append(auths, ssh.Password(*PASS))
}
config := ssh.ClientConfig{
User: *USER,
Auth: auths,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", *HOST, *PORT)
conn, err := ssh.Dial("tcp", addr, &config)
if err != nil {
log.Fatalf("unable to connect to [%s]: %v", addr, err)
}
defer conn.Close()
c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE))
if err != nil {
log.Fatalf("unable to start sftp subsytem: %v", err)
}
defer c.Close()
w, err := c.OpenFile("/dev/null", syscall.O_WRONLY)
if err != nil {
log.Fatal(err)
}
defer w.Close()
f, err := os.Open("/dev/zero")
if err != nil {
log.Fatal(err)
}
defer f.Close()
const size = 1e9
log.Printf("writing %v bytes", size)
t1 := time.Now()
n, err := w.Write(make([]byte, size))
if err != nil {
log.Fatal(err)
}
if n != size {
log.Fatalf("copy: expected %v bytes, got %d", size, n)
}
log.Printf("wrote %v bytes in %s", size, time.Since(t1))
}

View file

@ -1,131 +0,0 @@
// An example SFTP server implementation using the golang SSH package.
// Serves the whole filesystem visible to the user, and has a hard-coded username and password,
// so not for real use!
package main
import (
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
// Based on example server code from golang.org/x/crypto/ssh and server_standalone
func main() {
var (
readOnly bool
debugStderr bool
)
flag.BoolVar(&readOnly, "R", false, "read-only server")
flag.BoolVar(&debugStderr, "e", false, "debug to stderr")
flag.Parse()
debugStream := ioutil.Discard
if debugStderr {
debugStream = os.Stderr
}
// An SSH server is represented by a ServerConfig, which holds
// certificate details and handles authentication of ServerConns.
config := &ssh.ServerConfig{
PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
// Should use constant-time compare (or better, salt+hash) in
// a production setting.
fmt.Fprintf(debugStream, "Login: %s\n", c.User())
if c.User() == "testuser" && string(pass) == "tiger" {
return nil, nil
}
return nil, fmt.Errorf("password rejected for %q", c.User())
},
}
privateBytes, err := ioutil.ReadFile("id_rsa")
if err != nil {
log.Fatal("Failed to load private key", err)
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
log.Fatal("Failed to parse private key", err)
}
config.AddHostKey(private)
// Once a ServerConfig has been configured, connections can be
// accepted.
listener, err := net.Listen("tcp", "0.0.0.0:2022")
if err != nil {
log.Fatal("failed to listen for connection", err)
}
fmt.Printf("Listening on %v\n", listener.Addr())
nConn, err := listener.Accept()
if err != nil {
log.Fatal("failed to accept incoming connection", err)
}
// Before use, a handshake must be performed on the incoming net.Conn.
sconn, chans, reqs, err := ssh.NewServerConn(nConn, config)
if err != nil {
log.Fatal("failed to handshake", err)
}
log.Println("login detected:", sconn.User())
fmt.Fprintf(debugStream, "SSH server established\n")
// The incoming Request channel must be serviced.
go ssh.DiscardRequests(reqs)
// Service the incoming Channel channel.
for newChannel := range chans {
// Channels have a type, depending on the application level
// protocol intended. In the case of an SFTP session, this is "subsystem"
// with a payload string of "<length=4>sftp"
fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType())
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType())
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Fatal("could not accept channel.", err)
}
fmt.Fprintf(debugStream, "Channel accepted\n")
// Sessions have out-of-band requests such as "shell",
// "pty-req" and "env". Here we handle only the
// "subsystem" request.
go func(in <-chan *ssh.Request) {
for req := range in {
fmt.Fprintf(debugStream, "Request: %v\n", req.Type)
ok := false
switch req.Type {
case "subsystem":
fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:])
if string(req.Payload[4:]) == "sftp" {
ok = true
}
}
fmt.Fprintf(debugStream, " - accepted: %v\n", ok)
req.Reply(ok, nil)
}
}(requests)
root := sftp.InMemHandler()
server := sftp.NewRequestServer(channel, root)
if err := server.Serve(); err == io.EOF {
server.Close()
log.Print("sftp client exited session.")
} else if err != nil {
log.Fatal("sftp server completed with error:", err)
}
}
}

View file

@ -1,12 +0,0 @@
Example SFTP server implementation
===
In order to use this example you will need an RSA key.
On linux-like systems with openssh installed, you can use the command:
```
ssh-keygen -t rsa -f id_rsa
```
Then you will be able to run the sftp-server command in the current directory.

View file

@ -1,147 +0,0 @@
// An example SFTP server implementation using the golang SSH package.
// Serves the whole filesystem visible to the user, and has a hard-coded username and password,
// so not for real use!
package main
import (
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
// Based on example server code from golang.org/x/crypto/ssh and server_standalone
func main() {
var (
readOnly bool
debugStderr bool
)
flag.BoolVar(&readOnly, "R", false, "read-only server")
flag.BoolVar(&debugStderr, "e", false, "debug to stderr")
flag.Parse()
debugStream := ioutil.Discard
if debugStderr {
debugStream = os.Stderr
}
// An SSH server is represented by a ServerConfig, which holds
// certificate details and handles authentication of ServerConns.
config := &ssh.ServerConfig{
PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
// Should use constant-time compare (or better, salt+hash) in
// a production setting.
fmt.Fprintf(debugStream, "Login: %s\n", c.User())
if c.User() == "testuser" && string(pass) == "tiger" {
return nil, nil
}
return nil, fmt.Errorf("password rejected for %q", c.User())
},
}
privateBytes, err := ioutil.ReadFile("id_rsa")
if err != nil {
log.Fatal("Failed to load private key", err)
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
log.Fatal("Failed to parse private key", err)
}
config.AddHostKey(private)
// Once a ServerConfig has been configured, connections can be
// accepted.
listener, err := net.Listen("tcp", "0.0.0.0:2022")
if err != nil {
log.Fatal("failed to listen for connection", err)
}
fmt.Printf("Listening on %v\n", listener.Addr())
nConn, err := listener.Accept()
if err != nil {
log.Fatal("failed to accept incoming connection", err)
}
// Before use, a handshake must be performed on the incoming
// net.Conn.
_, chans, reqs, err := ssh.NewServerConn(nConn, config)
if err != nil {
log.Fatal("failed to handshake", err)
}
fmt.Fprintf(debugStream, "SSH server established\n")
// The incoming Request channel must be serviced.
go ssh.DiscardRequests(reqs)
// Service the incoming Channel channel.
for newChannel := range chans {
// Channels have a type, depending on the application level
// protocol intended. In the case of an SFTP session, this is "subsystem"
// with a payload string of "<length=4>sftp"
fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType())
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType())
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Fatal("could not accept channel.", err)
}
fmt.Fprintf(debugStream, "Channel accepted\n")
// Sessions have out-of-band requests such as "shell",
// "pty-req" and "env". Here we handle only the
// "subsystem" request.
go func(in <-chan *ssh.Request) {
for req := range in {
fmt.Fprintf(debugStream, "Request: %v\n", req.Type)
ok := false
switch req.Type {
case "subsystem":
fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:])
if string(req.Payload[4:]) == "sftp" {
ok = true
}
}
fmt.Fprintf(debugStream, " - accepted: %v\n", ok)
req.Reply(ok, nil)
}
}(requests)
serverOptions := []sftp.ServerOption{
sftp.WithDebug(debugStream),
}
if readOnly {
serverOptions = append(serverOptions, sftp.ReadOnly())
fmt.Fprintf(debugStream, "Read-only server\n")
} else {
fmt.Fprintf(debugStream, "Read write server\n")
}
server, err := sftp.NewServer(
channel,
serverOptions...,
)
if err != nil {
log.Fatal(err)
}
if err := server.Serve(); err == io.EOF {
server.Close()
log.Print("sftp client exited session.")
} else if err != nil {
log.Fatal("sftp server completed with error:", err)
}
}
}

View file

@ -1,85 +0,0 @@
// streaming-read-benchmark benchmarks the peformance of reading
// from /dev/zero on the server to /dev/null on the client via io.Copy.
package main
import (
"flag"
"fmt"
"io"
"log"
"net"
"os"
"syscall"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"github.com/pkg/sftp"
)
var (
USER = flag.String("user", os.Getenv("USER"), "ssh username")
HOST = flag.String("host", "localhost", "ssh server hostname")
PORT = flag.Int("port", 22, "ssh server port")
PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password")
SIZE = flag.Int("s", 1<<15, "set max packet size")
)
func init() {
flag.Parse()
}
func main() {
var auths []ssh.AuthMethod
if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers))
}
if *PASS != "" {
auths = append(auths, ssh.Password(*PASS))
}
config := ssh.ClientConfig{
User: *USER,
Auth: auths,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", *HOST, *PORT)
conn, err := ssh.Dial("tcp", addr, &config)
if err != nil {
log.Fatalf("unable to connect to [%s]: %v", addr, err)
}
defer conn.Close()
c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE))
if err != nil {
log.Fatalf("unable to start sftp subsytem: %v", err)
}
defer c.Close()
r, err := c.Open("/dev/zero")
if err != nil {
log.Fatal(err)
}
defer r.Close()
w, err := os.OpenFile("/dev/null", syscall.O_WRONLY, 0600)
if err != nil {
log.Fatal(err)
}
defer w.Close()
const size int64 = 1e9
log.Printf("reading %v bytes", size)
t1 := time.Now()
n, err := io.Copy(w, io.LimitReader(r, size))
if err != nil {
log.Fatal(err)
}
if n != size {
log.Fatalf("copy: expected %v bytes, got %d", size, n)
}
log.Printf("read %v bytes in %s", size, time.Since(t1))
}

View file

@ -1,85 +0,0 @@
// streaming-write-benchmark benchmarks the peformance of writing
// from /dev/zero on the client to /dev/null on the server via io.Copy.
package main
import (
"flag"
"fmt"
"io"
"log"
"net"
"os"
"syscall"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"github.com/pkg/sftp"
)
var (
USER = flag.String("user", os.Getenv("USER"), "ssh username")
HOST = flag.String("host", "localhost", "ssh server hostname")
PORT = flag.Int("port", 22, "ssh server port")
PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password")
SIZE = flag.Int("s", 1<<15, "set max packet size")
)
func init() {
flag.Parse()
}
func main() {
var auths []ssh.AuthMethod
if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers))
}
if *PASS != "" {
auths = append(auths, ssh.Password(*PASS))
}
config := ssh.ClientConfig{
User: *USER,
Auth: auths,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", *HOST, *PORT)
conn, err := ssh.Dial("tcp", addr, &config)
if err != nil {
log.Fatalf("unable to connect to [%s]: %v", addr, err)
}
defer conn.Close()
c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE))
if err != nil {
log.Fatalf("unable to start sftp subsytem: %v", err)
}
defer c.Close()
w, err := c.OpenFile("/dev/null", syscall.O_WRONLY)
if err != nil {
log.Fatal(err)
}
defer w.Close()
f, err := os.Open("/dev/zero")
if err != nil {
log.Fatal(err)
}
defer f.Close()
const size int64 = 1e9
log.Printf("writing %v bytes", size)
t1 := time.Now()
n, err := io.Copy(w, io.LimitReader(f, size))
if err != nil {
log.Fatal(err)
}
if n != size {
log.Fatalf("copy: expected %v bytes, got %d", size, n)
}
log.Printf("wrote %v bytes in %s", size, time.Since(t1))
}

View file

@ -1,5 +0,0 @@
// +build !linux,!darwin
package sftp
const sftpServer = "/usr/bin/false" // unsupported

View file

@ -1,102 +0,0 @@
package sftp
import (
"encoding"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
type _testSender struct {
sent chan encoding.BinaryMarshaler
}
func newTestSender() *_testSender {
return &_testSender{make(chan encoding.BinaryMarshaler)}
}
func (s _testSender) sendPacket(p encoding.BinaryMarshaler) error {
s.sent <- p
return nil
}
type fakepacket uint32
func (fakepacket) MarshalBinary() ([]byte, error) {
return []byte{}, nil
}
func (fakepacket) UnmarshalBinary([]byte) error {
return nil
}
func (f fakepacket) id() uint32 {
return uint32(f)
}
type pair struct {
in fakepacket
out fakepacket
}
// basic test
var ttable1 = []pair{
pair{fakepacket(0), fakepacket(0)},
pair{fakepacket(1), fakepacket(1)},
pair{fakepacket(2), fakepacket(2)},
pair{fakepacket(3), fakepacket(3)},
}
// outgoing packets out of order
var ttable2 = []pair{
pair{fakepacket(0), fakepacket(0)},
pair{fakepacket(1), fakepacket(4)},
pair{fakepacket(2), fakepacket(1)},
pair{fakepacket(3), fakepacket(3)},
pair{fakepacket(4), fakepacket(2)},
}
// incoming packets out of order
var ttable3 = []pair{
pair{fakepacket(2), fakepacket(0)},
pair{fakepacket(1), fakepacket(1)},
pair{fakepacket(3), fakepacket(2)},
pair{fakepacket(0), fakepacket(3)},
}
var tables = [][]pair{ttable1, ttable2, ttable3}
func TestPacketManager(t *testing.T) {
sender := newTestSender()
s := newPktMgr(sender)
for i := range tables {
table := tables[i]
for _, p := range table {
s.incomingPacket(p.in)
}
for _, p := range table {
s.readyPacket(p.out)
}
for i := 0; i < len(table); i++ {
pkt := <-sender.sent
id := pkt.(fakepacket).id()
assert.Equal(t, id, uint32(i))
}
}
s.close()
}
func (p sshFxpRemovePacket) String() string {
return fmt.Sprintf("RmPct:%d", p.ID)
}
func (p sshFxpOpenPacket) String() string {
return fmt.Sprintf("OpPct:%d", p.ID)
}
func (p sshFxpWritePacket) String() string {
return fmt.Sprintf("WrPct:%d", p.ID)
}
func (p sshFxpClosePacket) String() string {
return fmt.Sprintf("ClPct:%d", p.ID)
}

View file

@ -1,345 +0,0 @@
package sftp
import (
"bytes"
"encoding"
"os"
"testing"
)
var marshalUint32Tests = []struct {
v uint32
want []byte
}{
{1, []byte{0, 0, 0, 1}},
{256, []byte{0, 0, 1, 0}},
{^uint32(0), []byte{255, 255, 255, 255}},
}
func TestMarshalUint32(t *testing.T) {
for _, tt := range marshalUint32Tests {
got := marshalUint32(nil, tt.v)
if !bytes.Equal(tt.want, got) {
t.Errorf("marshalUint32(%d): want %v, got %v", tt.v, tt.want, got)
}
}
}
var marshalUint64Tests = []struct {
v uint64
want []byte
}{
{1, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}},
{256, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0}},
{^uint64(0), []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
{1 << 32, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}},
}
func TestMarshalUint64(t *testing.T) {
for _, tt := range marshalUint64Tests {
got := marshalUint64(nil, tt.v)
if !bytes.Equal(tt.want, got) {
t.Errorf("marshalUint64(%d): want %#v, got %#v", tt.v, tt.want, got)
}
}
}
var marshalStringTests = []struct {
v string
want []byte
}{
{"", []byte{0, 0, 0, 0}},
{"/foo", []byte{0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f}},
}
func TestMarshalString(t *testing.T) {
for _, tt := range marshalStringTests {
got := marshalString(nil, tt.v)
if !bytes.Equal(tt.want, got) {
t.Errorf("marshalString(%q): want %#v, got %#v", tt.v, tt.want, got)
}
}
}
var marshalTests = []struct {
v interface{}
want []byte
}{
{uint8(1), []byte{1}},
{byte(1), []byte{1}},
{uint32(1), []byte{0, 0, 0, 1}},
{uint64(1), []byte{0, 0, 0, 0, 0, 0, 0, 1}},
{"foo", []byte{0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f}},
{[]uint32{1, 2, 3, 4}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x4}},
}
func TestMarshal(t *testing.T) {
for _, tt := range marshalTests {
got := marshal(nil, tt.v)
if !bytes.Equal(tt.want, got) {
t.Errorf("marshal(%v): want %#v, got %#v", tt.v, tt.want, got)
}
}
}
var unmarshalUint32Tests = []struct {
b []byte
want uint32
rest []byte
}{
{[]byte{0, 0, 0, 0}, 0, nil},
{[]byte{0, 0, 1, 0}, 256, nil},
{[]byte{255, 0, 0, 255}, 4278190335, nil},
}
func TestUnmarshalUint32(t *testing.T) {
for _, tt := range unmarshalUint32Tests {
got, rest := unmarshalUint32(tt.b)
if got != tt.want || !bytes.Equal(rest, tt.rest) {
t.Errorf("unmarshalUint32(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest)
}
}
}
var unmarshalUint64Tests = []struct {
b []byte
want uint64
rest []byte
}{
{[]byte{0, 0, 0, 0, 0, 0, 0, 0}, 0, nil},
{[]byte{0, 0, 0, 0, 0, 0, 1, 0}, 256, nil},
{[]byte{255, 0, 0, 0, 0, 0, 0, 255}, 18374686479671623935, nil},
}
func TestUnmarshalUint64(t *testing.T) {
for _, tt := range unmarshalUint64Tests {
got, rest := unmarshalUint64(tt.b)
if got != tt.want || !bytes.Equal(rest, tt.rest) {
t.Errorf("unmarshalUint64(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest)
}
}
}
var unmarshalStringTests = []struct {
b []byte
want string
rest []byte
}{
{marshalString(nil, ""), "", nil},
{marshalString(nil, "blah"), "blah", nil},
}
func TestUnmarshalString(t *testing.T) {
for _, tt := range unmarshalStringTests {
got, rest := unmarshalString(tt.b)
if got != tt.want || !bytes.Equal(rest, tt.rest) {
t.Errorf("unmarshalUint64(%v): want %q, %#v, got %q, %#v", tt.b, tt.want, tt.rest, got, rest)
}
}
}
var sendPacketTests = []struct {
p encoding.BinaryMarshaler
want []byte
}{
{sshFxInitPacket{
Version: 3,
Extensions: []extensionPair{
{"posix-rename@openssh.com", "1"},
},
}, []byte{0x0, 0x0, 0x0, 0x26, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}},
{sshFxpOpenPacket{
ID: 1,
Path: "/foo",
Pflags: flags(os.O_RDONLY),
}, []byte{0x0, 0x0, 0x0, 0x15, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}},
{sshFxpWritePacket{
ID: 124,
Handle: "foo",
Offset: 13,
Length: uint32(len([]byte("bar"))),
Data: []byte("bar"),
}, []byte{0x0, 0x0, 0x0, 0x1b, 0x6, 0x0, 0x0, 0x0, 0x7c, 0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd, 0x0, 0x0, 0x0, 0x3, 0x62, 0x61, 0x72}},
{sshFxpSetstatPacket{
ID: 31,
Path: "/bar",
Flags: flags(os.O_WRONLY),
Attrs: struct {
UID uint32
GID uint32
}{1000, 100},
}, []byte{0x0, 0x0, 0x0, 0x19, 0x9, 0x0, 0x0, 0x0, 0x1f, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x62, 0x61, 0x72, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x3, 0xe8, 0x0, 0x0, 0x0, 0x64}},
}
func TestSendPacket(t *testing.T) {
for _, tt := range sendPacketTests {
var w bytes.Buffer
sendPacket(&w, tt.p)
if got := w.Bytes(); !bytes.Equal(tt.want, got) {
t.Errorf("sendPacket(%v): want %#v, got %#v", tt.p, tt.want, got)
}
}
}
func sp(p encoding.BinaryMarshaler) []byte {
var w bytes.Buffer
sendPacket(&w, p)
return w.Bytes()
}
var recvPacketTests = []struct {
b []byte
want uint8
rest []byte
}{
{sp(sshFxInitPacket{
Version: 3,
Extensions: []extensionPair{
{"posix-rename@openssh.com", "1"},
},
}), ssh_FXP_INIT, []byte{0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}},
}
func TestRecvPacket(t *testing.T) {
for _, tt := range recvPacketTests {
r := bytes.NewReader(tt.b)
got, rest, _ := recvPacket(r)
if got != tt.want || !bytes.Equal(rest, tt.rest) {
t.Errorf("recvPacket(%#v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest)
}
}
}
func TestSSHFxpOpenPacketreadonly(t *testing.T) {
var tests = []struct {
pflags uint32
ok bool
}{
{
pflags: ssh_FXF_READ,
ok: true,
},
{
pflags: ssh_FXF_WRITE,
ok: false,
},
{
pflags: ssh_FXF_READ | ssh_FXF_WRITE,
ok: false,
},
}
for _, tt := range tests {
p := &sshFxpOpenPacket{
Pflags: tt.pflags,
}
if want, got := tt.ok, p.readonly(); want != got {
t.Errorf("unexpected value for p.readonly(): want: %v, got: %v",
want, got)
}
}
}
func TestSSHFxpOpenPackethasPflags(t *testing.T) {
var tests = []struct {
desc string
haveFlags uint32
testFlags []uint32
ok bool
}{
{
desc: "have read, test against write",
haveFlags: ssh_FXF_READ,
testFlags: []uint32{ssh_FXF_WRITE},
ok: false,
},
{
desc: "have write, test against read",
haveFlags: ssh_FXF_WRITE,
testFlags: []uint32{ssh_FXF_READ},
ok: false,
},
{
desc: "have read+write, test against read",
haveFlags: ssh_FXF_READ | ssh_FXF_WRITE,
testFlags: []uint32{ssh_FXF_READ},
ok: true,
},
{
desc: "have read+write, test against write",
haveFlags: ssh_FXF_READ | ssh_FXF_WRITE,
testFlags: []uint32{ssh_FXF_WRITE},
ok: true,
},
{
desc: "have read+write, test against read+write",
haveFlags: ssh_FXF_READ | ssh_FXF_WRITE,
testFlags: []uint32{ssh_FXF_READ, ssh_FXF_WRITE},
ok: true,
},
}
for _, tt := range tests {
t.Log(tt.desc)
p := &sshFxpOpenPacket{
Pflags: tt.haveFlags,
}
if want, got := tt.ok, p.hasPflags(tt.testFlags...); want != got {
t.Errorf("unexpected value for p.hasPflags(%#v): want: %v, got: %v",
tt.testFlags, want, got)
}
}
}
func BenchmarkMarshalInit(b *testing.B) {
for i := 0; i < b.N; i++ {
sp(sshFxInitPacket{
Version: 3,
Extensions: []extensionPair{
{"posix-rename@openssh.com", "1"},
},
})
}
}
func BenchmarkMarshalOpen(b *testing.B) {
for i := 0; i < b.N; i++ {
sp(sshFxpOpenPacket{
ID: 1,
Path: "/home/test/some/random/path",
Pflags: flags(os.O_RDONLY),
})
}
}
func BenchmarkMarshalWriteWorstCase(b *testing.B) {
data := make([]byte, 32*1024)
for i := 0; i < b.N; i++ {
sp(sshFxpWritePacket{
ID: 1,
Handle: "someopaquehandle",
Offset: 0,
Length: uint32(len(data)),
Data: data,
})
}
}
func BenchmarkMarshalWrite1k(b *testing.B) {
data := make([]byte, 1024)
for i := 0; i < b.N; i++ {
sp(sshFxpWritePacket{
ID: 1,
Handle: "someopaquehandle",
Offset: 0,
Length: uint32(len(data)),
Data: data,
})
}
}

View file

@ -1,51 +0,0 @@
package sftp
import (
"os"
"github.com/stretchr/testify/assert"
"testing"
)
func TestRequestPflags(t *testing.T) {
pflags := newFileOpenFlags(ssh_FXF_READ | ssh_FXF_WRITE | ssh_FXF_APPEND)
assert.True(t, pflags.Read)
assert.True(t, pflags.Write)
assert.True(t, pflags.Append)
assert.False(t, pflags.Creat)
assert.False(t, pflags.Trunc)
assert.False(t, pflags.Excl)
}
func TestRequestAflags(t *testing.T) {
aflags := newFileAttrFlags(
ssh_FILEXFER_ATTR_SIZE | ssh_FILEXFER_ATTR_UIDGID)
assert.True(t, aflags.Size)
assert.True(t, aflags.UidGid)
assert.False(t, aflags.Acmodtime)
assert.False(t, aflags.Permissions)
}
func TestRequestAttributes(t *testing.T) {
// UID/GID
fa := FileStat{UID: 1, GID: 2}
fl := uint32(ssh_FILEXFER_ATTR_UIDGID)
at := []byte{}
at = marshalUint32(at, 1)
at = marshalUint32(at, 2)
test_fs, _ := getFileStat(fl, at)
assert.Equal(t, fa, *test_fs)
// Size and Mode
fa = FileStat{Mode: 700, Size: 99}
fl = uint32(ssh_FILEXFER_ATTR_SIZE | ssh_FILEXFER_ATTR_PERMISSIONS)
at = []byte{}
at = marshalUint64(at, 99)
at = marshalUint32(at, 700)
test_fs, _ = getFileStat(fl, at)
assert.Equal(t, fa, *test_fs)
// FileMode
assert.True(t, test_fs.FileMode().IsRegular())
assert.False(t, test_fs.FileMode().IsDir())
assert.Equal(t, test_fs.FileMode().Perm(), os.FileMode(700).Perm())
}

View file

@ -1,393 +0,0 @@
package sftp
import (
"context"
"fmt"
"io"
"net"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
var _ = fmt.Print
type csPair struct {
cli *Client
svr *RequestServer
}
// these must be closed in order, else client.Close will hang
func (cs csPair) Close() {
cs.svr.Close()
cs.cli.Close()
os.Remove(sock)
}
func (cs csPair) testHandler() *root {
return cs.svr.Handlers.FileGet.(*root)
}
const sock = "/tmp/rstest.sock"
func clientRequestServerPair(t *testing.T) *csPair {
ready := make(chan bool)
os.Remove(sock) // either this or signal handling
var server *RequestServer
go func() {
l, err := net.Listen("unix", sock)
if err != nil {
// neither assert nor t.Fatal reliably exit before Accept errors
panic(err)
}
ready <- true
fd, err := l.Accept()
assert.Nil(t, err)
handlers := InMemHandler()
server = NewRequestServer(fd, handlers)
server.Serve()
}()
<-ready
defer os.Remove(sock)
c, err := net.Dial("unix", sock)
assert.Nil(t, err)
client, err := NewClientPipe(c, c)
if err != nil {
t.Fatalf("%+v\n", err)
}
return &csPair{client, server}
}
// after adding logging, maybe check log to make sure packet handling
// was split over more than one worker
func TestRequestSplitWrite(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
w, err := p.cli.Create("/foo")
assert.Nil(t, err)
p.cli.maxPacket = 3 // force it to send in small chunks
contents := "one two three four five six seven eight nine ten"
w.Write([]byte(contents))
w.Close()
r := p.testHandler()
f, _ := r.fetch("/foo")
assert.Equal(t, contents, string(f.content))
}
func TestRequestCache(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
foo := NewRequest("", "foo")
foo.ctx, foo.cancelCtx = context.WithCancel(context.Background())
bar := NewRequest("", "bar")
fh := p.svr.nextRequest(foo)
bh := p.svr.nextRequest(bar)
assert.Len(t, p.svr.openRequests, 2)
_foo, ok := p.svr.getRequest(fh, "")
assert.Equal(t, foo.Method, _foo.Method)
assert.Equal(t, foo.Filepath, _foo.Filepath)
assert.Equal(t, foo.Target, _foo.Target)
assert.Equal(t, foo.Flags, _foo.Flags)
assert.Equal(t, foo.Attrs, _foo.Attrs)
assert.Equal(t, foo.state, _foo.state)
assert.NotNil(t, _foo.ctx)
assert.Equal(t, _foo.Context().Err(), nil, "context is still valid")
assert.True(t, ok)
_, ok = p.svr.getRequest("zed", "")
assert.False(t, ok)
p.svr.closeRequest(fh)
assert.Equal(t, _foo.Context().Err(), context.Canceled, "context is now canceled")
p.svr.closeRequest(bh)
assert.Len(t, p.svr.openRequests, 0)
}
func TestRequestCacheState(t *testing.T) {
// test operation that uses open/close
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
assert.Len(t, p.svr.openRequests, 0)
// test operation that doesn't open/close
err = p.cli.Remove("/foo")
assert.Nil(t, err)
assert.Len(t, p.svr.openRequests, 0)
}
func putTestFile(cli *Client, path, content string) (int, error) {
w, err := cli.Create(path)
if err == nil {
defer w.Close()
return w.Write([]byte(content))
}
return 0, err
}
func TestRequestWrite(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
n, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
assert.Equal(t, 5, n)
r := p.testHandler()
f, err := r.fetch("/foo")
assert.Nil(t, err)
assert.False(t, f.isdir)
assert.Equal(t, f.content, []byte("hello"))
}
func TestRequestWriteEmpty(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
n, err := putTestFile(p.cli, "/foo", "")
assert.NoError(t, err)
assert.Equal(t, 0, n)
r := p.testHandler()
f, err := r.fetch("/foo")
if assert.Nil(t, err) {
assert.False(t, f.isdir)
assert.Equal(t, f.content, []byte(""))
}
// lets test with an error
r.returnErr(os.ErrInvalid)
n, err = putTestFile(p.cli, "/bar", "")
assert.Error(t, err)
r.returnErr(nil)
assert.Equal(t, 0, n)
}
func TestRequestFilename(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.NoError(t, err)
r := p.testHandler()
f, err := r.fetch("/foo")
assert.NoError(t, err)
assert.Equal(t, f.Name(), "foo")
_, err = r.fetch("/bar")
assert.Error(t, err)
}
func TestRequestRead(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
rf, err := p.cli.Open("/foo")
assert.Nil(t, err)
defer rf.Close()
contents := make([]byte, 5)
n, err := rf.Read(contents)
if err != nil && err != io.EOF {
t.Fatalf("err: %v", err)
}
assert.Equal(t, 5, n)
assert.Equal(t, "hello", string(contents[0:5]))
}
func TestRequestReadFail(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
rf, err := p.cli.Open("/foo")
assert.Nil(t, err)
contents := make([]byte, 5)
n, err := rf.Read(contents)
assert.Equal(t, n, 0)
assert.Exactly(t, os.ErrNotExist, err)
}
func TestRequestOpen(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
fh, err := p.cli.Open("foo")
assert.Nil(t, err)
err = fh.Close()
assert.Nil(t, err)
}
func TestRequestMkdir(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
err := p.cli.Mkdir("/foo")
assert.Nil(t, err)
r := p.testHandler()
f, err := r.fetch("/foo")
assert.Nil(t, err)
assert.True(t, f.isdir)
}
func TestRequestRemove(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
r := p.testHandler()
_, err = r.fetch("/foo")
assert.Nil(t, err)
err = p.cli.Remove("/foo")
assert.Nil(t, err)
_, err = r.fetch("/foo")
assert.Equal(t, err, os.ErrNotExist)
}
func TestRequestRename(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
r := p.testHandler()
_, err = r.fetch("/foo")
assert.Nil(t, err)
err = p.cli.Rename("/foo", "/bar")
assert.Nil(t, err)
_, err = r.fetch("/bar")
assert.Nil(t, err)
_, err = r.fetch("/foo")
assert.Equal(t, err, os.ErrNotExist)
}
func TestRequestRenameFail(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
_, err = putTestFile(p.cli, "/bar", "goodbye")
assert.Nil(t, err)
err = p.cli.Rename("/foo", "/bar")
assert.IsType(t, &StatusError{}, err)
}
func TestRequestStat(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
fi, err := p.cli.Stat("/foo")
assert.Equal(t, fi.Name(), "foo")
assert.Equal(t, fi.Size(), int64(5))
assert.Equal(t, fi.Mode(), os.FileMode(0644))
assert.NoError(t, testOsSys(fi.Sys()))
assert.NoError(t, err)
}
// NOTE: Setstat is a noop in the request server tests, but we want to test
// that is does nothing without crapping out.
func TestRequestSetstat(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
mode := os.FileMode(0644)
err = p.cli.Chmod("/foo", mode)
assert.Nil(t, err)
fi, err := p.cli.Stat("/foo")
assert.Nil(t, err)
assert.Equal(t, fi.Name(), "foo")
assert.Equal(t, fi.Size(), int64(5))
assert.Equal(t, fi.Mode(), os.FileMode(0644))
assert.NoError(t, testOsSys(fi.Sys()))
}
func TestRequestFstat(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
fp, err := p.cli.Open("/foo")
assert.Nil(t, err)
fi, err := fp.Stat()
if assert.NoError(t, err) {
assert.Equal(t, fi.Name(), "foo")
assert.Equal(t, fi.Size(), int64(5))
assert.Equal(t, fi.Mode(), os.FileMode(0644))
assert.NoError(t, testOsSys(fi.Sys()))
}
}
func TestRequestStatFail(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
fi, err := p.cli.Stat("/foo")
assert.Nil(t, fi)
assert.True(t, os.IsNotExist(err))
}
func TestRequestSymlink(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
err = p.cli.Symlink("/foo", "/bar")
assert.Nil(t, err)
r := p.testHandler()
fi, err := r.fetch("/bar")
assert.Nil(t, err)
assert.True(t, fi.Mode()&os.ModeSymlink == os.ModeSymlink)
}
func TestRequestSymlinkFail(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
err := p.cli.Symlink("/foo", "/bar")
assert.True(t, os.IsNotExist(err))
}
func TestRequestReadlink(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
err = p.cli.Symlink("/foo", "/bar")
assert.Nil(t, err)
rl, err := p.cli.ReadLink("/bar")
assert.Nil(t, err)
assert.Equal(t, "foo", rl)
}
func TestRequestReaddir(t *testing.T) {
p := clientRequestServerPair(t)
MaxFilelist = 22 // make not divisible by our test amount (100)
defer p.Close()
for i := 0; i < 100; i++ {
fname := fmt.Sprintf("/foo_%02d", i)
_, err := putTestFile(p.cli, fname, fname)
assert.Nil(t, err)
}
_, err := p.cli.ReadDir("/foo_01")
assert.Equal(t, &StatusError{Code: ssh_FX_FAILURE,
msg: " /foo_01: not a directory"}, err)
_, err = p.cli.ReadDir("/does_not_exist")
assert.Equal(t, os.ErrNotExist, err)
di, err := p.cli.ReadDir("/")
assert.Nil(t, err)
assert.Len(t, di, 100)
names := []string{di[18].Name(), di[81].Name()}
assert.Equal(t, []string{"foo_18", "foo_81"}, names)
}
func TestCleanPath(t *testing.T) {
assert.Equal(t, "/", cleanPath("/"))
assert.Equal(t, "/", cleanPath("."))
assert.Equal(t, "/", cleanPath("/."))
assert.Equal(t, "/", cleanPath("/a/.."))
assert.Equal(t, "/a/c", cleanPath("/a/b/../c"))
assert.Equal(t, "/a/c", cleanPath("/a/b/../c/"))
assert.Equal(t, "/a", cleanPath("/a/b/.."))
assert.Equal(t, "/a/b/c", cleanPath("/a/b/c"))
assert.Equal(t, "/", cleanPath("//"))
assert.Equal(t, "/a", cleanPath("/a/"))
assert.Equal(t, "/a", cleanPath("a/"))
assert.Equal(t, "/a/b/c", cleanPath("/a//b//c/"))
// filepath.ToSlash does not touch \ as char on unix systems
// so os.PathSeparator is used for windows compatible tests
bslash := string(os.PathSeparator)
assert.Equal(t, "/", cleanPath(bslash))
assert.Equal(t, "/", cleanPath(bslash+bslash))
assert.Equal(t, "/a", cleanPath(bslash+"a"+bslash))
assert.Equal(t, "/a", cleanPath("a"+bslash))
assert.Equal(t, "/a/b/c",
cleanPath(bslash+"a"+bslash+bslash+"b"+bslash+bslash+"c"+bslash))
}

View file

@ -1,213 +0,0 @@
package sftp
import (
"sync"
"github.com/stretchr/testify/assert"
"bytes"
"errors"
"io"
"os"
"testing"
)
type testHandler struct {
filecontents []byte // dummy contents
output io.WriterAt // dummy file out
err error // dummy error, should be file related
}
func (t *testHandler) Fileread(r *Request) (io.ReaderAt, error) {
if t.err != nil {
return nil, t.err
}
return bytes.NewReader(t.filecontents), nil
}
func (t *testHandler) Filewrite(r *Request) (io.WriterAt, error) {
if t.err != nil {
return nil, t.err
}
return io.WriterAt(t.output), nil
}
func (t *testHandler) Filecmd(r *Request) error {
return t.err
}
func (t *testHandler) Filelist(r *Request) (ListerAt, error) {
if t.err != nil {
return nil, t.err
}
f, err := os.Open(r.Filepath)
if err != nil {
return nil, err
}
fi, err := f.Stat()
if err != nil {
return nil, err
}
return listerat([]os.FileInfo{fi}), nil
}
// make sure len(fakefile) == len(filecontents)
type fakefile [10]byte
var filecontents = []byte("file-data.")
func testRequest(method string) *Request {
request := &Request{
Filepath: "./request_test.go",
Method: method,
Attrs: []byte("foo"),
Target: "foo",
state: state{RWMutex: new(sync.RWMutex)},
}
return request
}
func (ff *fakefile) WriteAt(p []byte, off int64) (int, error) {
n := copy(ff[off:], p)
return n, nil
}
func (ff fakefile) string() string {
b := make([]byte, len(ff))
copy(b, ff[:])
return string(b)
}
func newTestHandlers() Handlers {
handler := &testHandler{
filecontents: filecontents,
output: &fakefile{},
err: nil,
}
return Handlers{
FileGet: handler,
FilePut: handler,
FileCmd: handler,
FileList: handler,
}
}
func (h Handlers) getOutString() string {
handler := h.FilePut.(*testHandler)
return handler.output.(*fakefile).string()
}
var errTest = errors.New("test error")
func (h *Handlers) returnError(err error) {
handler := h.FilePut.(*testHandler)
handler.err = err
}
func getStatusMsg(p interface{}) string {
pkt := p.(sshFxpStatusPacket)
return pkt.StatusError.msg
}
func checkOkStatus(t *testing.T, p interface{}) {
pkt := p.(sshFxpStatusPacket)
assert.Equal(t, pkt.StatusError.Code, uint32(ssh_FX_OK),
"sshFxpStatusPacket not OK\n", pkt.StatusError.msg)
}
// fake/test packet
type fakePacket struct {
myid uint32
handle string
}
func (f fakePacket) id() uint32 {
return f.myid
}
func (f fakePacket) getHandle() string {
return f.handle
}
func (fakePacket) UnmarshalBinary(d []byte) error { return nil }
func TestRequestGet(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Get")
// req.length is 5, so we test reads in 5 byte chunks
for i, txt := range []string{"file-", "data."} {
pkt := &sshFxpReadPacket{uint32(i), "a", uint64(i * 5), 5}
rpkt := request.call(handlers, pkt)
dpkt := rpkt.(*sshFxpDataPacket)
assert.Equal(t, dpkt.id(), uint32(i))
assert.Equal(t, string(dpkt.Data), txt)
}
}
func TestRequestCustomError(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Stat")
pkt := fakePacket{myid: 1}
cmdErr := errors.New("stat not supported")
handlers.returnError(cmdErr)
rpkt := request.call(handlers, pkt)
assert.Equal(t, rpkt, statusFromError(rpkt, cmdErr))
}
func TestRequestPut(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Put")
pkt := &sshFxpWritePacket{0, "a", 0, 5, []byte("file-")}
rpkt := request.call(handlers, pkt)
checkOkStatus(t, rpkt)
pkt = &sshFxpWritePacket{1, "a", 5, 5, []byte("data.")}
rpkt = request.call(handlers, pkt)
checkOkStatus(t, rpkt)
assert.Equal(t, "file-data.", handlers.getOutString())
}
func TestRequestCmdr(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Mkdir")
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt)
checkOkStatus(t, rpkt)
handlers.returnError(errTest)
rpkt = request.call(handlers, pkt)
assert.Equal(t, rpkt, statusFromError(rpkt, errTest))
}
func TestRequestInfoStat(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Stat")
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt)
spkt, ok := rpkt.(*sshFxpStatResponse)
assert.True(t, ok)
assert.Equal(t, spkt.info.Name(), "request_test.go")
}
func TestRequestInfoList(t *testing.T) { testInfoMethod(t, "List") }
func TestRequestInfoReadlink(t *testing.T) { testInfoMethod(t, "Readlink") }
func testInfoMethod(t *testing.T, method string) {
handlers := newTestHandlers()
request := testRequest(method)
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt)
npkt, ok := rpkt.(*sshFxpNamePacket)
assert.True(t, ok)
assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")
}
func TestOpendirHandleReuse(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Stat")
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt)
assert.IsType(t, &sshFxpStatResponse{}, rpkt)
request.Method = "List"
pkt = fakePacket{myid: 2}
rpkt = request.call(handlers, pkt)
assert.IsType(t, &sshFxpNamePacket{}, rpkt)
}

View file

@ -1,720 +0,0 @@
package sftp
// sftp server integration tests
// enable with -integration
// example invokation (darwin): gofmt -w `find . -name \*.go` && (cd server_standalone/ ; go build -tags debug) && go test -tags debug github.com/pkg/sftp -integration -v -sftp /usr/libexec/sftp-server -run ServerCompareSubsystems
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
crand "crypto/rand"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"math/rand"
"net"
"os"
"os/exec"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"testing"
"time"
"github.com/kr/fs"
"golang.org/x/crypto/ssh"
)
var testSftpClientBin = flag.String("sftp_client", "/usr/bin/sftp", "location of the sftp client binary")
var sshServerDebugStream = ioutil.Discard
var sftpServerDebugStream = ioutil.Discard
var sftpClientDebugStream = ioutil.Discard
const (
GOLANG_SFTP = true
OPENSSH_SFTP = false
)
var (
hostPrivateKeySigner ssh.Signer
privKey = []byte(`
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEArhp7SqFnXVZAgWREL9Ogs+miy4IU/m0vmdkoK6M97G9NX/Pj
wf8I/3/ynxmcArbt8Rc4JgkjT2uxx/NqR0yN42N1PjO5Czu0dms1PSqcKIJdeUBV
7gdrKSm9Co4d2vwfQp5mg47eG4w63pz7Drk9+VIyi9YiYH4bve7WnGDswn4ycvYZ
slV5kKnjlfCdPig+g5P7yQYud0cDWVwyA0+kxvL6H3Ip+Fu8rLDZn4/P1WlFAIuc
PAf4uEKDGGmC2URowi5eesYR7f6GN/HnBs2776laNlAVXZUmYTUfOGagwLsEkx8x
XdNqntfbs2MOOoK+myJrNtcB9pCrM0H6um19uQIDAQABAoIBABkWr9WdVKvalgkP
TdQmhu3mKRNyd1wCl+1voZ5IM9Ayac/98UAvZDiNU4Uhx52MhtVLJ0gz4Oa8+i16
IkKMAZZW6ro/8dZwkBzQbieWUFJ2Fso2PyvB3etcnGU8/Yhk9IxBDzy+BbuqhYE2
1ebVQtz+v1HvVZzaD11bYYm/Xd7Y28QREVfFen30Q/v3dv7dOteDE/RgDS8Czz7w
jMW32Q8JL5grz7zPkMK39BLXsTcSYcaasT2ParROhGJZDmbgd3l33zKCVc1zcj9B
SA47QljGd09Tys958WWHgtj2o7bp9v1Ufs4LnyKgzrB80WX1ovaSQKvd5THTLchO
kLIhUAECgYEA2doGXy9wMBmTn/hjiVvggR1aKiBwUpnB87Hn5xCMgoECVhFZlT6l
WmZe7R2klbtG1aYlw+y+uzHhoVDAJW9AUSV8qoDUwbRXvBVlp+In5wIqJ+VjfivK
zgIfzomL5NvDz37cvPmzqIeySTowEfbQyq7CUQSoDtE9H97E2wWZhDkCgYEAzJdJ
k+NSFoTkHhfD3L0xCDHpRV3gvaOeew8524fVtVUq53X8m91ng4AX1r74dCUYwwiF
gqTtSSJfx2iH1xKnNq28M9uKg7wOrCKrRqNPnYUO3LehZEC7rwUr26z4iJDHjjoB
uBcS7nw0LJ+0Zeg1IF+aIdZGV3MrAKnrzWPixYECgYBsffX6ZWebrMEmQ89eUtFF
u9ZxcGI/4K8ErC7vlgBD5ffB4TYZ627xzFWuBLs4jmHCeNIJ9tct5rOVYN+wRO1k
/CRPzYUnSqb+1jEgILL6istvvv+DkE+ZtNkeRMXUndWwel94BWsBnUKe0UmrSJ3G
sq23J3iCmJW2T3z+DpXbkQKBgQCK+LUVDNPE0i42NsRnm+fDfkvLP7Kafpr3Umdl
tMY474o+QYn+wg0/aPJIf9463rwMNyyhirBX/k57IIktUdFdtfPicd2MEGETElWv
nN1GzYxD50Rs2f/jKisZhEwqT9YNyV9DkgDdGGdEbJNYqbv0qpwDIg8T9foe8E1p
bdErgQKBgAt290I3L316cdxIQTkJh1DlScN/unFffITwu127WMr28Jt3mq3cZpuM
Aecey/eEKCj+Rlas5NDYKsB18QIuAw+qqWyq0LAKLiAvP1965Rkc4PLScl3MgJtO
QYa37FK0p8NcDeUuF86zXBVutwS5nJLchHhKfd590ks57OROtm29
-----END RSA PRIVATE KEY-----
`)
)
func init() {
var err error
hostPrivateKeySigner, err = ssh.ParsePrivateKey(privKey)
if err != nil {
panic(err)
}
}
func keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
permissions := &ssh.Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{},
}
return permissions, nil
}
func pwAuth(conn ssh.ConnMetadata, pw []byte) (*ssh.Permissions, error) {
permissions := &ssh.Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{},
}
return permissions, nil
}
func basicServerConfig() *ssh.ServerConfig {
config := ssh.ServerConfig{
Config: ssh.Config{
MACs: []string{"hmac-sha1"},
},
PasswordCallback: pwAuth,
PublicKeyCallback: keyAuth,
}
config.AddHostKey(hostPrivateKeySigner)
return &config
}
type sshServer struct {
useSubsystem bool
conn net.Conn
config *ssh.ServerConfig
sshConn *ssh.ServerConn
newChans <-chan ssh.NewChannel
newReqs <-chan *ssh.Request
}
func sshServerFromConn(conn net.Conn, useSubsystem bool, config *ssh.ServerConfig) (*sshServer, error) {
// From a standard TCP connection to an encrypted SSH connection
sshConn, newChans, newReqs, err := ssh.NewServerConn(conn, config)
if err != nil {
return nil, err
}
svr := &sshServer{useSubsystem, conn, config, sshConn, newChans, newReqs}
svr.listenChannels()
return svr, nil
}
func (svr *sshServer) Wait() error {
return svr.sshConn.Wait()
}
func (svr *sshServer) Close() error {
return svr.sshConn.Close()
}
func (svr *sshServer) listenChannels() {
go func() {
for chanReq := range svr.newChans {
go svr.handleChanReq(chanReq)
}
}()
go func() {
for req := range svr.newReqs {
go svr.handleReq(req)
}
}()
}
func (svr *sshServer) handleReq(req *ssh.Request) {
switch req.Type {
default:
rejectRequest(req)
}
}
type sshChannelServer struct {
svr *sshServer
chanReq ssh.NewChannel
ch ssh.Channel
newReqs <-chan *ssh.Request
}
type sshSessionChannelServer struct {
*sshChannelServer
env []string
}
func (svr *sshServer) handleChanReq(chanReq ssh.NewChannel) {
fmt.Fprintf(sshServerDebugStream, "channel request: %v, extra: '%v'\n", chanReq.ChannelType(), hex.EncodeToString(chanReq.ExtraData()))
switch chanReq.ChannelType() {
case "session":
if ch, reqs, err := chanReq.Accept(); err != nil {
fmt.Fprintf(sshServerDebugStream, "fail to accept channel request: %v\n", err)
chanReq.Reject(ssh.ResourceShortage, "channel accept failure")
} else {
chsvr := &sshSessionChannelServer{
sshChannelServer: &sshChannelServer{svr, chanReq, ch, reqs},
env: append([]string{}, os.Environ()...),
}
chsvr.handle()
}
default:
chanReq.Reject(ssh.UnknownChannelType, "channel type is not a session")
}
}
func (chsvr *sshSessionChannelServer) handle() {
// should maybe do something here...
go chsvr.handleReqs()
}
func (chsvr *sshSessionChannelServer) handleReqs() {
for req := range chsvr.newReqs {
chsvr.handleReq(req)
}
fmt.Fprintf(sshServerDebugStream, "ssh server session channel complete\n")
}
func (chsvr *sshSessionChannelServer) handleReq(req *ssh.Request) {
switch req.Type {
case "env":
chsvr.handleEnv(req)
case "subsystem":
chsvr.handleSubsystem(req)
default:
rejectRequest(req)
}
}
func rejectRequest(req *ssh.Request) error {
fmt.Fprintf(sshServerDebugStream, "ssh rejecting request, type: %s\n", req.Type)
err := req.Reply(false, []byte{})
if err != nil {
fmt.Fprintf(sshServerDebugStream, "ssh request reply had error: %v\n", err)
}
return err
}
func rejectRequestUnmarshalError(req *ssh.Request, s interface{}, err error) error {
fmt.Fprintf(sshServerDebugStream, "ssh request unmarshaling error, type '%T': %v\n", s, err)
rejectRequest(req)
return err
}
// env request form:
type sshEnvRequest struct {
Envvar string
Value string
}
func (chsvr *sshSessionChannelServer) handleEnv(req *ssh.Request) error {
envReq := &sshEnvRequest{}
if err := ssh.Unmarshal(req.Payload, envReq); err != nil {
return rejectRequestUnmarshalError(req, envReq, err)
}
req.Reply(true, nil)
found := false
for i, envstr := range chsvr.env {
if strings.HasPrefix(envstr, envReq.Envvar+"=") {
found = true
chsvr.env[i] = envReq.Envvar + "=" + envReq.Value
}
}
if !found {
chsvr.env = append(chsvr.env, envReq.Envvar+"="+envReq.Value)
}
return nil
}
// Payload: int: command size, string: command
type sshSubsystemRequest struct {
Name string
}
type sshSubsystemExitStatus struct {
Status uint32
}
func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error {
defer func() {
err1 := chsvr.ch.CloseWrite()
err2 := chsvr.ch.Close()
fmt.Fprintf(sshServerDebugStream, "ssh server subsystem request complete, err: %v %v\n", err1, err2)
}()
subsystemReq := &sshSubsystemRequest{}
if err := ssh.Unmarshal(req.Payload, subsystemReq); err != nil {
return rejectRequestUnmarshalError(req, subsystemReq, err)
}
// reply to the ssh client
// no idea if this is actually correct spec-wise.
// just enough for an sftp server to start.
if subsystemReq.Name != "sftp" {
return req.Reply(false, nil)
}
req.Reply(true, nil)
if !chsvr.svr.useSubsystem {
// use the openssh sftp server backend; this is to test the ssh code, not the sftp code,
// or is used for comparison between our sftp subsystem and the openssh sftp subsystem
cmd := exec.Command(*testSftp, "-e", "-l", "DEBUG") // log to stderr
cmd.Stdin = chsvr.ch
cmd.Stdout = chsvr.ch
cmd.Stderr = sftpServerDebugStream
if err := cmd.Start(); err != nil {
return err
}
return cmd.Wait()
}
sftpServer, err := NewServer(
chsvr.ch,
WithDebug(sftpServerDebugStream),
)
if err != nil {
return err
}
// wait for the session to close
runErr := sftpServer.Serve()
exitStatus := uint32(1)
if runErr == nil {
exitStatus = uint32(0)
}
_, exitStatusErr := chsvr.ch.SendRequest("exit-status", false, ssh.Marshal(sshSubsystemExitStatus{exitStatus}))
return exitStatusErr
}
// starts an ssh server to test. returns: host string and port
func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, string, int) {
if !*testIntegration {
t.Skip("skipping integration test")
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
host, portStr, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
t.Fatal(err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
t.Fatal(err)
}
go func() {
for {
conn, err := listener.Accept()
if err != nil {
fmt.Fprintf(sshServerDebugStream, "ssh server socket closed: %v\n", err)
break
}
go func() {
defer conn.Close()
sshSvr, err := sshServerFromConn(conn, useSubsystem, basicServerConfig())
if err != nil {
t.Error(err)
return
}
err = sshSvr.Wait()
fmt.Fprintf(sshServerDebugStream, "ssh server finished, err: %v\n", err)
}()
}
}()
return listener, host, port
}
func makeDummyKey() (string, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
if err != nil {
return "", fmt.Errorf("cannot generate key: %v", err)
}
der, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return "", fmt.Errorf("cannot marshal key: %v", err)
}
block := &pem.Block{Type: "EC PRIVATE KEY", Bytes: der}
f, err := ioutil.TempFile("", "sftp-test-key-")
if err != nil {
return "", fmt.Errorf("cannot create temp file: %v", err)
}
defer func() {
if f != nil {
_ = f.Close()
_ = os.Remove(f.Name())
}
}()
if err := pem.Encode(f, block); err != nil {
return "", fmt.Errorf("cannot write key: %v", err)
}
if err := f.Close(); err != nil {
return "", fmt.Errorf("error closing key file: %v", err)
}
path := f.Name()
f = nil
return path, nil
}
func runSftpClient(t *testing.T, script string, path string, host string, port int) (string, error) {
// if sftp client binary is unavailable, skip test
if _, err := os.Stat(*testSftpClientBin); err != nil {
t.Skip("sftp client binary unavailable")
}
// make a dummy key so we don't rely on ssh-agent
dummyKey, err := makeDummyKey()
if err != nil {
return "", err
}
defer os.Remove(dummyKey)
args := []string{
// "-vvvv",
"-b", "-",
"-o", "StrictHostKeyChecking=no",
"-o", "LogLevel=ERROR",
"-o", "UserKnownHostsFile /dev/null",
// do not trigger ssh-agent prompting
"-o", "IdentityFile=" + dummyKey,
"-o", "IdentitiesOnly=yes",
"-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path),
}
cmd := exec.Command(*testSftpClientBin, args...)
var stdout bytes.Buffer
cmd.Stdin = bytes.NewBufferString(script)
cmd.Stdout = &stdout
cmd.Stderr = sftpClientDebugStream
if err := cmd.Start(); err != nil {
return "", err
}
err = cmd.Wait()
return stdout.String(), err
}
func TestServerCompareSubsystems(t *testing.T) {
listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY)
listenerOp, hostOp, portOp := testServer(t, OPENSSH_SFTP, READONLY)
defer listenerGo.Close()
defer listenerOp.Close()
script := `
ls /
ls -l /
ls /dev/
ls -l /dev/
ls -l /etc/
ls -l /bin/
ls -l /usr/bin/
`
outputGo, err := runSftpClient(t, script, "/", hostGo, portGo)
if err != nil {
t.Fatal(err)
}
outputOp, err := runSftpClient(t, script, "/", hostOp, portOp)
if err != nil {
t.Fatal(err)
}
newlineRegex := regexp.MustCompile(`\r*\n`)
spaceRegex := regexp.MustCompile(`\s+`)
outputGoLines := newlineRegex.Split(outputGo, -1)
outputOpLines := newlineRegex.Split(outputOp, -1)
for i, goLine := range outputGoLines {
if i > len(outputOpLines) {
t.Fatalf("output line count differs")
}
opLine := outputOpLines[i]
bad := false
if goLine != opLine {
goWords := spaceRegex.Split(goLine, -1)
opWords := spaceRegex.Split(opLine, -1)
// some fields are allowed to be different..
// words[2] and [3] as these are users & groups
// words[1] as the link count for directories like proc is unstable
// during testing as processes are created/destroyed.
// words[7] as timestamp on dirs can very for things like /tmp
for j, goWord := range goWords {
if j > len(opWords) {
bad = true
}
opWord := opWords[j]
if goWord != opWord && j != 1 && j != 2 && j != 3 && j != 7 {
bad = true
}
}
}
if bad {
t.Errorf("outputs differ, go:\n%v\nopenssh:\n%v\n", goLine, opLine)
}
}
}
var rng = rand.New(rand.NewSource(time.Now().Unix()))
func randData(length int) []byte {
data := make([]byte, length)
for i := 0; i < length; i++ {
data[i] = byte(rng.Uint32())
}
return data
}
func randName() string {
return "sftp." + hex.EncodeToString(randData(16))
}
func TestServerMkdirRmdir(t *testing.T) {
listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY)
defer listenerGo.Close()
tmpDir := "/tmp/" + randName()
defer os.RemoveAll(tmpDir)
// mkdir remote
if _, err := runSftpClient(t, "mkdir "+tmpDir, "/", hostGo, portGo); err != nil {
t.Fatal(err)
}
// directory should now exist
if _, err := os.Stat(tmpDir); err != nil {
t.Fatal(err)
}
// now remove the directory
if _, err := runSftpClient(t, "rmdir "+tmpDir, "/", hostGo, portGo); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(tmpDir); err == nil {
t.Fatal("should have error after deleting the directory")
}
}
func TestServerSymlink(t *testing.T) {
listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY)
defer listenerGo.Close()
link := "/tmp/" + randName()
defer os.RemoveAll(link)
// now create a symbolic link within the new directory
if output, err := runSftpClient(t, "symlink /bin/sh "+link, "/", hostGo, portGo); err != nil {
t.Fatalf("failed: %v %v", err, string(output))
}
// symlink should now exist
if stat, err := os.Lstat(link); err != nil {
t.Fatal(err)
} else if (stat.Mode() & os.ModeSymlink) != os.ModeSymlink {
t.Fatalf("is not a symlink: %v", stat.Mode())
}
}
func TestServerPut(t *testing.T) {
listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY)
defer listenerGo.Close()
tmpFileLocal := "/tmp/" + randName()
tmpFileRemote := "/tmp/" + randName()
defer os.RemoveAll(tmpFileLocal)
defer os.RemoveAll(tmpFileRemote)
t.Logf("put: local %v remote %v", tmpFileLocal, tmpFileRemote)
// create a file with random contents. This will be the local file pushed to the server
tmpFileLocalData := randData(10 * 1024 * 1024)
if err := ioutil.WriteFile(tmpFileLocal, tmpFileLocalData, 0644); err != nil {
t.Fatal(err)
}
// sftp the file to the server
if output, err := runSftpClient(t, "put "+tmpFileLocal+" "+tmpFileRemote, "/", hostGo, portGo); err != nil {
t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
}
// tmpFile2 should now exist, with the same contents
if tmpFileRemoteData, err := ioutil.ReadFile(tmpFileRemote); err != nil {
t.Fatal(err)
} else if string(tmpFileLocalData) != string(tmpFileRemoteData) {
t.Fatal("contents of file incorrect after put")
}
}
func TestServerGet(t *testing.T) {
listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY)
defer listenerGo.Close()
tmpFileLocal := "/tmp/" + randName()
tmpFileRemote := "/tmp/" + randName()
defer os.RemoveAll(tmpFileLocal)
defer os.RemoveAll(tmpFileRemote)
t.Logf("get: local %v remote %v", tmpFileLocal, tmpFileRemote)
// create a file with random contents. This will be the remote file pulled from the server
tmpFileRemoteData := randData(10 * 1024 * 1024)
if err := ioutil.WriteFile(tmpFileRemote, tmpFileRemoteData, 0644); err != nil {
t.Fatal(err)
}
// sftp the file to the server
if output, err := runSftpClient(t, "get "+tmpFileRemote+" "+tmpFileLocal, "/", hostGo, portGo); err != nil {
t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
}
// tmpFile2 should now exist, with the same contents
if tmpFileLocalData, err := ioutil.ReadFile(tmpFileLocal); err != nil {
t.Fatal(err)
} else if string(tmpFileLocalData) != string(tmpFileRemoteData) {
t.Fatal("contents of file incorrect after put")
}
}
func compareDirectoriesRecursive(t *testing.T, aroot, broot string) {
walker := fs.Walk(aroot)
for walker.Step() {
if err := walker.Err(); err != nil {
t.Fatal(err)
}
// find paths
aPath := walker.Path()
aRel, err := filepath.Rel(aroot, aPath)
if err != nil {
t.Fatalf("could not find relative path for %v: %v", aPath, err)
}
bPath := path.Join(broot, aRel)
if aRel == "." {
continue
}
//t.Logf("comparing: %v a: %v b %v", aRel, aPath, bPath)
// if a is a link, the sftp recursive copy won't have copied it. ignore
aLink, err := os.Lstat(aPath)
if err != nil {
t.Fatalf("could not lstat %v: %v", aPath, err)
}
if aLink.Mode()&os.ModeSymlink != 0 {
continue
}
// stat the files
aFile, err := os.Stat(aPath)
if err != nil {
t.Fatalf("could not stat %v: %v", aPath, err)
}
bFile, err := os.Stat(bPath)
if err != nil {
t.Fatalf("could not stat %v: %v", bPath, err)
}
// compare stats, with some leniency for the timestamp
if aFile.Mode() != bFile.Mode() {
t.Fatalf("modes different for %v: %v vs %v", aRel, aFile.Mode(), bFile.Mode())
}
if !aFile.IsDir() {
if aFile.Size() != bFile.Size() {
t.Fatalf("sizes different for %v: %v vs %v", aRel, aFile.Size(), bFile.Size())
}
}
timeDiff := aFile.ModTime().Sub(bFile.ModTime())
if timeDiff > time.Second || timeDiff < -time.Second {
t.Fatalf("mtimes different for %v: %v vs %v", aRel, aFile.ModTime(), bFile.ModTime())
}
// compare contents
if !aFile.IsDir() {
if aContents, err := ioutil.ReadFile(aPath); err != nil {
t.Fatal(err)
} else if bContents, err := ioutil.ReadFile(bPath); err != nil {
t.Fatal(err)
} else if string(aContents) != string(bContents) {
t.Fatalf("contents different for %v", aRel)
}
}
}
}
func TestServerPutRecursive(t *testing.T) {
listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY)
defer listenerGo.Close()
dirLocal, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
tmpDirRemote := "/tmp/" + randName()
defer os.RemoveAll(tmpDirRemote)
t.Logf("put recursive: local %v remote %v", dirLocal, tmpDirRemote)
// push this directory (source code etc) recursively to the server
if output, err := runSftpClient(t, "mkdir "+tmpDirRemote+"\r\nput -r -P "+dirLocal+"/ "+tmpDirRemote+"/", "/", hostGo, portGo); err != nil {
t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
}
compareDirectoriesRecursive(t, dirLocal, path.Join(tmpDirRemote, path.Base(dirLocal)))
}
func TestServerGetRecursive(t *testing.T) {
listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY)
defer listenerGo.Close()
dirRemote, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
tmpDirLocal := "/tmp/" + randName()
defer os.RemoveAll(tmpDirLocal)
t.Logf("get recursive: local %v remote %v", tmpDirLocal, dirRemote)
// pull this directory (source code etc) recursively from the server
if output, err := runSftpClient(t, "lmkdir "+tmpDirLocal+"\r\nget -r -P "+dirRemote+"/ "+tmpDirLocal+"/", "/", hostGo, portGo); err != nil {
t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
}
compareDirectoriesRecursive(t, dirRemote, path.Join(tmpDirLocal, path.Base(dirRemote)))
}

View file

@ -1,52 +0,0 @@
package main
// small wrapper around sftp server that allows it to be used as a separate process subsystem call by the ssh server.
// in practice this will statically link; however this allows unit testing from the sftp client.
import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"github.com/pkg/sftp"
)
func main() {
var (
readOnly bool
debugStderr bool
debugLevel string
options []sftp.ServerOption
)
flag.BoolVar(&readOnly, "R", false, "read-only server")
flag.BoolVar(&debugStderr, "e", false, "debug to stderr")
flag.StringVar(&debugLevel, "l", "none", "debug level (ignored)")
flag.Parse()
debugStream := ioutil.Discard
if debugStderr {
debugStream = os.Stderr
}
options = append(options, sftp.WithDebug(debugStream))
if readOnly {
options = append(options, sftp.ReadOnly())
}
svr, _ := sftp.NewServer(
struct {
io.Reader
io.WriteCloser
}{os.Stdin,
os.Stdout,
},
options...,
)
if err := svr.Serve(); err != nil {
fmt.Fprintf(debugStream, "sftp server completed with error: %v", err)
os.Exit(1)
}
}

View file

@ -1,280 +0,0 @@
package sftp
import (
"io"
"os"
"regexp"
"sync"
"syscall"
"testing"
"time"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
const (
typeDirectory = "d"
typeFile = "[^d]"
)
func TestRunLsWithExamplesDirectory(t *testing.T) {
path := "examples"
item, _ := os.Stat(path)
result := runLs(path, item)
runLsTestHelper(t, result, typeDirectory, path)
}
func TestRunLsWithLicensesFile(t *testing.T) {
path := "LICENSE"
item, _ := os.Stat(path)
result := runLs(path, item)
runLsTestHelper(t, result, typeFile, path)
}
/*
The format of the `longname' field is unspecified by this protocol.
It MUST be suitable for use in the output of a directory listing
command (in fact, the recommended operation for a directory listing
command is to simply display this data). However, clients SHOULD NOT
attempt to parse the longname field for file attributes; they SHOULD
use the attrs field instead.
The recommended format for the longname field is as follows:
-rwxr-xr-x 1 mjos staff 348911 Mar 25 14:29 t-filexfer
1234567890 123 12345678 12345678 12345678 123456789012
Here, the first line is sample output, and the second field indicates
widths of the various fields. Fields are separated by spaces. The
first field lists file permissions for user, group, and others; the
second field is link count; the third field is the name of the user
who owns the file; the fourth field is the name of the group that
owns the file; the fifth field is the size of the file in bytes; the
sixth field (which actually may contain spaces, but is fixed to 12
characters) is the file modification time, and the seventh field is
the file name. Each field is specified to be a minimum of certain
number of character positions (indicated by the second line above),
but may also be longer if the data does not fit in the specified
length.
The SSH_FXP_ATTRS response has the following format:
uint32 id
ATTRS attrs
where `id' is the request identifier, and `attrs' is the returned
file attributes as described in Section ``File Attributes''.
*/
func runLsTestHelper(t *testing.T, result, expectedType, path string) {
// using regular expressions to make tests work on all systems
// a virtual file system (like afero) would be needed to mock valid filesystem checks
// expected layout is:
// drwxr-xr-x 8 501 20 272 Aug 9 19:46 examples
// permissions (len 10, "drwxr-xr-x")
got := result[0:10]
if ok, err := regexp.MatchString("^"+expectedType+"[rwx-]{9}$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): permission field mismatch, expected dir, got: %#v, err: %#v", path, got, err)
}
// space
got = result[10:11]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 1 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
}
// link count (len 3, number)
got = result[12:15]
if ok, err := regexp.MatchString("^\\s*[0-9]+$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): link count field mismatch, got: %#v, err: %#v", path, got, err)
}
// spacer
got = result[15:16]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 2 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
}
// username / uid (len 8, number or string)
got = result[16:24]
if ok, err := regexp.MatchString("^[^\\s]{1,8}\\s*$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): username / uid mismatch, expected user, got: %#v, err: %#v", path, got, err)
}
// spacer
got = result[24:25]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 3 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
}
// groupname / gid (len 8, number or string)
got = result[25:33]
if ok, err := regexp.MatchString("^[^\\s]{1,8}\\s*$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): groupname / gid mismatch, expected group, got: %#v, err: %#v", path, got, err)
}
// spacer
got = result[33:34]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 4 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
}
// filesize (len 8)
got = result[34:42]
if ok, err := regexp.MatchString("^\\s*[0-9]+$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): filesize field mismatch, expected size in bytes, got: %#v, err: %#v", path, got, err)
}
// spacer
got = result[42:43]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 5 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
}
// mod time (len 12, e.g. Aug 9 19:46)
got = result[43:55]
layout := "Jan 2 15:04"
_, err := time.Parse(layout, got)
if err != nil {
layout = "Jan 2 2006"
_, err = time.Parse(layout, got)
}
if err != nil {
t.Errorf("runLs(%#v, *FileInfo): mod time field mismatch, expected date layout %s, got: %#v, err: %#v", path, layout, got, err)
}
// spacer
got = result[55:56]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 6 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
}
// filename
got = result[56:]
if ok, err := regexp.MatchString("^"+path+"$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): name field mismatch, expected examples, got: %#v, err: %#v", path, got, err)
}
}
func clientServerPair(t *testing.T) (*Client, *Server) {
cr, sw := io.Pipe()
sr, cw := io.Pipe()
server, err := NewServer(struct {
io.Reader
io.WriteCloser
}{sr, sw})
if err != nil {
t.Fatal(err)
}
go server.Serve()
client, err := NewClientPipe(cr, cw)
if err != nil {
t.Fatalf("%+v\n", err)
}
return client, server
}
type sshFxpTestBadExtendedPacket struct {
ID uint32
Extension string
Data string
}
func (p sshFxpTestBadExtendedPacket) id() uint32 { return p.ID }
func (p sshFxpTestBadExtendedPacket) MarshalBinary() ([]byte, error) {
l := 1 + 4 + 4 + // type(byte) + uint32 + uint32
len(p.Extension) +
len(p.Data)
b := make([]byte, 0, l)
b = append(b, ssh_FXP_EXTENDED)
b = marshalUint32(b, p.ID)
b = marshalString(b, p.Extension)
b = marshalString(b, p.Data)
return b, nil
}
// test that errors are sent back when we request an invalid extended packet operation
// this validates the following rfc draft is followed https://tools.ietf.org/html/draft-ietf-secsh-filexfer-extensions-00
func TestInvalidExtendedPacket(t *testing.T) {
client, server := clientServerPair(t)
defer client.Close()
defer server.Close()
badPacket := sshFxpTestBadExtendedPacket{client.nextID(), "thisDoesn'tExist", "foobar"}
typ, data, err := client.clientConn.sendPacket(badPacket)
if err != nil {
t.Fatalf("unexpected error from sendPacket: %s", err)
}
if typ != ssh_FXP_STATUS {
t.Fatalf("received non-FPX_STATUS packet: %v", typ)
}
err = unmarshalStatus(badPacket.id(), data)
statusErr, ok := err.(*StatusError)
if !ok {
t.Fatal("failed to convert error from unmarshalStatus to *StatusError")
}
if statusErr.Code != ssh_FX_OP_UNSUPPORTED {
t.Errorf("statusErr.Code => %d, wanted %d", statusErr.Code, ssh_FX_OP_UNSUPPORTED)
}
}
// test that server handles concurrent requests correctly
func TestConcurrentRequests(t *testing.T) {
client, server := clientServerPair(t)
defer client.Close()
defer server.Close()
concurrency := 2
var wg sync.WaitGroup
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
for j := 0; j < 1024; j++ {
f, err := client.Open("/etc/passwd")
if err != nil {
t.Errorf("failed to open file: %v", err)
}
if err := f.Close(); err != nil {
t.Errorf("failed t close file: %v", err)
}
}
}()
}
wg.Wait()
}
// Test error conversion
func TestStatusFromError(t *testing.T) {
type test struct {
err error
pkt sshFxpStatusPacket
}
tpkt := func(id, code uint32) sshFxpStatusPacket {
return sshFxpStatusPacket{
ID: id,
StatusError: StatusError{Code: code},
}
}
test_cases := []test{
test{syscall.ENOENT, tpkt(1, ssh_FX_NO_SUCH_FILE)},
test{&os.PathError{Err: syscall.ENOENT},
tpkt(2, ssh_FX_NO_SUCH_FILE)},
test{&os.PathError{Err: errors.New("foo")}, tpkt(3, ssh_FX_FAILURE)},
test{ErrSshFxEof, tpkt(4, ssh_FX_EOF)},
test{ErrSshFxOpUnsupported, tpkt(5, ssh_FX_OP_UNSUPPORTED)},
test{io.EOF, tpkt(6, ssh_FX_EOF)},
test{os.ErrNotExist, tpkt(7, ssh_FX_NO_SUCH_FILE)},
}
for _, tc := range test_cases {
tc.pkt.StatusError.msg = tc.err.Error()
assert.Equal(t, tc.pkt, statusFromError(tc.pkt, tc.err))
}
}