restic/vendor/github.com/pkg/sftp/packet-manager.go

204 lines
5.2 KiB
Go
Raw Normal View History

2017-07-23 12:24:45 +00:00
package sftp
import (
"encoding"
"sort"
2017-07-23 12:24:45 +00:00
"sync"
)
// The goal of the packetManager is to keep the outgoing packets in the same
2018-09-03 18:23:56 +00:00
// order as the incoming as is requires by section 7 of the RFC.
2017-07-23 12:24:45 +00:00
2018-09-03 18:23:56 +00:00
type packetManager struct {
requests chan orderedPacket
responses chan orderedPacket
fini chan struct{}
incoming orderedPackets
outgoing orderedPackets
sender packetSender // connection object
working *sync.WaitGroup
packetCount uint32
2017-07-23 12:24:45 +00:00
}
2018-09-03 18:23:56 +00:00
type packetSender interface {
sendPacket(encoding.BinaryMarshaler) error
2017-07-23 12:24:45 +00:00
}
2017-09-13 12:09:48 +00:00
func newPktMgr(sender packetSender) *packetManager {
s := &packetManager{
2018-09-03 18:23:56 +00:00
requests: make(chan orderedPacket, SftpServerWorkerCount),
responses: make(chan orderedPacket, SftpServerWorkerCount),
2017-07-23 12:24:45 +00:00
fini: make(chan struct{}),
2018-09-03 18:23:56 +00:00
incoming: make([]orderedPacket, 0, SftpServerWorkerCount),
outgoing: make([]orderedPacket, 0, SftpServerWorkerCount),
2017-07-23 12:24:45 +00:00
sender: sender,
working: &sync.WaitGroup{},
}
go s.controller()
return s
}
2018-09-03 18:23:56 +00:00
//// packet ordering
func (s *packetManager) newOrderId() uint32 {
s.packetCount++
return s.packetCount
}
2018-09-03 18:23:56 +00:00
type orderedRequest struct {
requestPacket
orderid uint32
}
2018-09-03 18:23:56 +00:00
func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
return orderedRequest{requestPacket: p, orderid: s.newOrderId()}
}
func (p orderedRequest) orderId() uint32 { return p.orderid }
func (p orderedRequest) setOrderId(oid uint32) { p.orderid = oid }
2018-09-03 18:23:56 +00:00
type orderedResponse struct {
responsePacket
orderid uint32
}
func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
) orderedResponse {
return orderedResponse{responsePacket: p, orderid: id}
}
func (p orderedResponse) orderId() uint32 { return p.orderid }
func (p orderedResponse) setOrderId(oid uint32) { p.orderid = oid }
type orderedPacket interface {
id() uint32
orderId() uint32
}
type orderedPackets []orderedPacket
func (o orderedPackets) Sort() {
sort.Slice(o, func(i, j int) bool {
return o[i].orderId() < o[j].orderId()
})
}
2018-09-03 18:23:56 +00:00
//// packet registry
2017-07-23 12:24:45 +00:00
// register incoming packets to be handled
2018-09-03 18:23:56 +00:00
func (s *packetManager) incomingPacket(pkt orderedRequest) {
2017-07-23 12:24:45 +00:00
s.working.Add(1)
2018-09-03 18:23:56 +00:00
s.requests <- pkt
2017-07-23 12:24:45 +00:00
}
// register outgoing packets as being ready
2018-09-03 18:23:56 +00:00
func (s *packetManager) readyPacket(pkt orderedResponse) {
2017-07-23 12:24:45 +00:00
s.responses <- pkt
s.working.Done()
}
// shut down packetManager controller
2017-09-13 12:09:48 +00:00
func (s *packetManager) close() {
2017-07-23 12:24:45 +00:00
// pause until current packets are processed
s.working.Wait()
close(s.fini)
}
// Passed a worker function, returns a channel for incoming packets.
2018-09-03 18:23:56 +00:00
// Keep process packet responses in the order they are received while
// maximizing throughput of file transfers.
func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
) chan orderedRequest {
2017-07-23 12:24:45 +00:00
2018-09-03 18:23:56 +00:00
// multiple workers for faster read/writes
rwChan := make(chan orderedRequest, SftpServerWorkerCount)
2017-09-13 12:09:48 +00:00
for i := 0; i < SftpServerWorkerCount; i++ {
2017-07-23 12:24:45 +00:00
runWorker(rwChan)
}
2018-09-03 18:23:56 +00:00
// single worker to enforce sequential processing of everything else
cmdChan := make(chan orderedRequest)
2017-07-23 12:24:45 +00:00
runWorker(cmdChan)
2018-09-03 18:23:56 +00:00
pktChan := make(chan orderedRequest, SftpServerWorkerCount)
2017-07-23 12:24:45 +00:00
go func() {
for pkt := range pktChan {
2018-09-03 18:23:56 +00:00
switch pkt.requestPacket.(type) {
case *sshFxpReadPacket, *sshFxpWritePacket:
s.incomingPacket(pkt)
rwChan <- pkt
continue
2017-07-23 12:24:45 +00:00
case *sshFxpClosePacket:
2018-09-03 18:23:56 +00:00
// wait for reads/writes to finish when file is closed
// incomingPacket() call must occur after this
2017-07-23 12:24:45 +00:00
s.working.Wait()
}
s.incomingPacket(pkt)
2018-09-03 18:23:56 +00:00
// all non-RW use sequential cmdChan
cmdChan <- pkt
2017-07-23 12:24:45 +00:00
}
close(rwChan)
close(cmdChan)
s.close()
}()
return pktChan
}
// process packets
func (s *packetManager) controller() {
for {
select {
case pkt := <-s.requests:
2018-09-03 18:23:56 +00:00
debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderId())
s.incoming = append(s.incoming, pkt)
s.incoming.Sort()
2017-07-23 12:24:45 +00:00
case pkt := <-s.responses:
2018-09-03 18:23:56 +00:00
debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderId())
2017-07-23 12:24:45 +00:00
s.outgoing = append(s.outgoing, pkt)
2018-09-03 18:23:56 +00:00
s.outgoing.Sort()
2017-07-23 12:24:45 +00:00
case <-s.fini:
return
}
s.maybeSendPackets()
}
}
// send as many packets as are ready
func (s *packetManager) maybeSendPackets() {
for {
if len(s.outgoing) == 0 || len(s.incoming) == 0 {
debug("break! -- outgoing: %v; incoming: %v",
len(s.outgoing), len(s.incoming))
break
}
out := s.outgoing[0]
in := s.incoming[0]
2018-09-03 18:23:56 +00:00
// debug("incoming: %v", ids(s.incoming))
// debug("outgoing: %v", ids(s.outgoing))
if in.orderId() == out.orderId() {
debug("Sending packet: %v", out.id())
s.sender.sendPacket(out.(encoding.BinaryMarshaler))
2017-07-23 12:24:45 +00:00
// pop off heads
copy(s.incoming, s.incoming[1:]) // shift left
2019-01-27 20:07:57 +00:00
s.incoming[len(s.incoming)-1] = nil // clear last
2017-07-23 12:24:45 +00:00
s.incoming = s.incoming[:len(s.incoming)-1] // remove last
copy(s.outgoing, s.outgoing[1:]) // shift left
2019-01-27 20:07:57 +00:00
s.outgoing[len(s.outgoing)-1] = nil // clear last
2017-07-23 12:24:45 +00:00
s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
} else {
break
}
}
}
2018-09-03 18:23:56 +00:00
// func oids(o []orderedPacket) []uint32 {
// res := make([]uint32, 0, len(o))
// for _, v := range o {
// res = append(res, v.orderId())
// }
// return res
// }
// func ids(o []orderedPacket) []uint32 {
// res := make([]uint32, 0, len(o))
// for _, v := range o {
// res = append(res, v.id())
// }
// return res
// }