159 lines
3.7 KiB
Go
159 lines
3.7 KiB
Go
|
// Copyright (C) 2019 Storj Labs, Inc.
|
||
|
// See LICENSE for copying information.
|
||
|
|
||
|
package drpcwire
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"io"
|
||
|
"sync"
|
||
|
|
||
|
"storj.io/drpc"
|
||
|
)
|
||
|
|
||
|
//
|
||
|
// Writer
|
||
|
//
|
||
|
|
||
|
// Writer is a helper to buffer and write packets and frames to an io.Writer.
|
||
|
type Writer struct {
|
||
|
w io.Writer
|
||
|
size int
|
||
|
mu sync.Mutex
|
||
|
buf []byte
|
||
|
}
|
||
|
|
||
|
// NewWriter returns a Writer that will attempt to buffer size data before
|
||
|
// sending it to the io.Writer.
|
||
|
func NewWriter(w io.Writer, size int) *Writer {
|
||
|
if size == 0 {
|
||
|
size = 1024
|
||
|
}
|
||
|
|
||
|
return &Writer{
|
||
|
w: w,
|
||
|
size: size,
|
||
|
buf: make([]byte, 0, size),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WritePacket writes the packet as a single frame, ignoring any size
|
||
|
// constraints.
|
||
|
func (b *Writer) WritePacket(pkt Packet) (err error) {
|
||
|
return b.WriteFrame(Frame{
|
||
|
Data: pkt.Data,
|
||
|
ID: pkt.ID,
|
||
|
Kind: pkt.Kind,
|
||
|
Done: true,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// WriteFrame appends the frame into the buffer, and if the buffer is larger
|
||
|
// than the configured size, flushes it.
|
||
|
func (b *Writer) WriteFrame(fr Frame) (err error) {
|
||
|
b.mu.Lock()
|
||
|
b.buf = AppendFrame(b.buf, fr)
|
||
|
if len(b.buf) >= b.size {
|
||
|
_, err = b.w.Write(b.buf)
|
||
|
b.buf = b.buf[:0]
|
||
|
}
|
||
|
b.mu.Unlock()
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Flush forces a flush of any buffered data to the io.Writer. It is a no-op if
|
||
|
// there is no data in the buffer.
|
||
|
func (b *Writer) Flush() (err error) {
|
||
|
defer mon.Task()(nil)(&err)
|
||
|
|
||
|
b.mu.Lock()
|
||
|
if len(b.buf) > 0 {
|
||
|
_, err = b.w.Write(b.buf)
|
||
|
b.buf = b.buf[:0]
|
||
|
}
|
||
|
b.mu.Unlock()
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
//
|
||
|
// Reader
|
||
|
//
|
||
|
|
||
|
// SplitFrame is used by bufio.Scanner to split frames out of a stream of bytes.
|
||
|
func SplitFrame(data []byte, atEOF bool) (int, []byte, error) {
|
||
|
rem, _, ok, err := ParseFrame(data)
|
||
|
switch advance := len(data) - len(rem); {
|
||
|
case err != nil:
|
||
|
return 0, nil, err
|
||
|
case len(data) > 0 && !ok && atEOF:
|
||
|
return 0, nil, drpc.ProtocolError.New("truncated frame")
|
||
|
case !ok:
|
||
|
return 0, nil, nil
|
||
|
case advance < 0, len(data) < advance:
|
||
|
return 0, nil, drpc.InternalError.New("scanner issue with advance value")
|
||
|
default:
|
||
|
return advance, data[:advance], nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Reader reconstructs packets from frames read from an io.Reader.
|
||
|
type Reader struct {
|
||
|
buf *bufio.Scanner
|
||
|
id ID
|
||
|
}
|
||
|
|
||
|
// NewReader constructs a Reader to read Packets from the io.Reader.
|
||
|
func NewReader(r io.Reader) *Reader {
|
||
|
buf := bufio.NewScanner(r)
|
||
|
buf.Buffer(make([]byte, 4<<10), 1<<20)
|
||
|
buf.Split(SplitFrame)
|
||
|
return &Reader{buf: buf}
|
||
|
}
|
||
|
|
||
|
// ReadPacket reads a packet from the io.Reader. IDs read from frames
|
||
|
// must be monotonically increasing. When a new ID is read, the old
|
||
|
// data is discarded. This allows for easier asynchronous interrupts.
|
||
|
// If the amount of data in the Packet becomes too large, an error is
|
||
|
// returned.
|
||
|
func (s *Reader) ReadPacket() (pkt Packet, err error) {
|
||
|
defer mon.Task()(nil)(&err)
|
||
|
|
||
|
for s.buf.Scan() {
|
||
|
rem, fr, ok, err := ParseFrame(s.buf.Bytes())
|
||
|
switch {
|
||
|
case err != nil:
|
||
|
return Packet{}, drpc.ProtocolError.Wrap(err)
|
||
|
case !ok, len(rem) > 0:
|
||
|
return Packet{}, drpc.InternalError.New("problem with scanner")
|
||
|
case fr.Control:
|
||
|
// Ignore any frames with the control bit set so that we can
|
||
|
// use it in the future to mean things to people who understand
|
||
|
// it.
|
||
|
continue
|
||
|
case fr.ID.Less(s.id):
|
||
|
return Packet{}, drpc.ProtocolError.New("id monotonicity violation")
|
||
|
case s.id.Less(fr.ID):
|
||
|
s.id = fr.ID
|
||
|
pkt = Packet{
|
||
|
Data: pkt.Data[:0],
|
||
|
ID: fr.ID,
|
||
|
Kind: fr.Kind,
|
||
|
}
|
||
|
case fr.Kind != pkt.Kind:
|
||
|
return Packet{}, drpc.ProtocolError.New("packet kind change")
|
||
|
}
|
||
|
|
||
|
pkt.Data = append(pkt.Data, fr.Data...)
|
||
|
switch {
|
||
|
case len(pkt.Data) > 4<<20:
|
||
|
return Packet{}, drpc.ProtocolError.New("data overflow")
|
||
|
case fr.Done:
|
||
|
return pkt, nil
|
||
|
}
|
||
|
}
|
||
|
if err := s.buf.Err(); err != nil {
|
||
|
return Packet{}, err
|
||
|
}
|
||
|
return Packet{}, io.EOF
|
||
|
}
|