Update Gopkg.toml to remove the constraint on zipkin-go-opentracing (#1231)

* Update vendor directory for latest changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Update Gopkg.toml to remove the constraint on zipkin-go-opentracing

As the issue on zipkin-go-opentracing has been fixed. See #1193
for details.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2017-11-13 11:54:46 -08:00 committed by Miek Gieben
parent 7c7a233b83
commit 9018451dd3
60 changed files with 3115 additions and 511 deletions

18
Gopkg.lock generated
View file

@ -16,19 +16,20 @@
[[projects]] [[projects]]
name = "github.com/Shopify/sarama" name = "github.com/Shopify/sarama"
packages = ["."] packages = ["."]
revision = "bbdbe644099b7fdc8327d5cc69c030945188b2e9" revision = "240fd146ce68bcafb034cc5dc977229ffbafa8ea"
version = "v1.13.0" version = "v1.14.0"
[[projects]] [[projects]]
branch = "master"
name = "github.com/apache/thrift" name = "github.com/apache/thrift"
packages = ["lib/go/thrift"] packages = ["lib/go/thrift"]
revision = "4f77ab8e296d64c57e6ea1c6e3f0f152bc7d6a3a" revision = "95d5fb3a1e38125b9eabcbe9cda1a6c7bbe3e93d"
[[projects]] [[projects]]
name = "github.com/asaskevich/govalidator" name = "github.com/asaskevich/govalidator"
packages = ["."] packages = ["."]
revision = "73945b6115bfbbcc57d89b7316e28109364124e1" revision = "521b25f4b05fd26bec69d9dedeb8f9c9a83939a8"
version = "v7" version = "v8"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -160,7 +161,7 @@
branch = "master" branch = "master"
name = "github.com/go-openapi/swag" name = "github.com/go-openapi/swag"
packages = ["."] packages = ["."]
revision = "f3f9494671f93fcff853e3c6e9e948b3eb71e590" revision = "cf0bdb963811675a4d7e74901cefc7411a1df939"
[[projects]] [[projects]]
name = "github.com/gogo/protobuf" name = "github.com/gogo/protobuf"
@ -262,6 +263,7 @@
name = "github.com/openzipkin/zipkin-go-opentracing" name = "github.com/openzipkin/zipkin-go-opentracing"
packages = [".","flag","thrift/gen-go/scribe","thrift/gen-go/zipkincore","types","wire"] packages = [".","flag","thrift/gen-go/scribe","thrift/gen-go/zipkincore","types","wire"]
revision = "45e90b00710a4c34a1a7d8a78d90f9b010b0bd4d" revision = "45e90b00710a4c34a1a7d8a78d90f9b010b0bd4d"
version = "v0.3.2"
[[projects]] [[projects]]
name = "github.com/pierrec/lz4" name = "github.com/pierrec/lz4"
@ -326,7 +328,7 @@
branch = "master" branch = "master"
name = "golang.org/x/sys" name = "golang.org/x/sys"
packages = ["unix","windows"] packages = ["unix","windows"]
revision = "75813c647272dd855bda156405bf844a5414f5bf" revision = "1e2299c37cc91a509f1b12369872d27be0ce98a6"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -378,6 +380,6 @@
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
analyzer-version = 1 analyzer-version = 1
inputs-digest = "c7279ef091bb11a42d1421f51e53d761113ea23d9e9b993823605883da0f80ff" inputs-digest = "be9300a30414c93aa44756868a7906a0a295b0910a662880741bcfac58b7b679"
solver-name = "gps-cdcl" solver-name = "gps-cdcl"
solver-version = 1 solver-version = 1

View file

@ -13,12 +13,9 @@ ignored = [
"golang.org/x/net/trace", "golang.org/x/net/trace",
] ]
[[constraint]]
name = "github.com/openzipkin/zipkin-go-opentracing"
revision = "45e90b00710a4c34a1a7d8a78d90f9b010b0bd4d"
[[override]] [[override]]
name = "github.com/apache/thrift" name = "github.com/apache/thrift"
revision = "4f77ab8e296d64c57e6ea1c6e3f0f152bc7d6a3a" branch = "master"
[[override]] [[override]]
name = "github.com/ugorji/go" name = "github.com/ugorji/go"

View file

@ -22,3 +22,5 @@ _cgo_export.*
_testmain.go _testmain.go
*.exe *.exe
coverage.txt

View file

@ -31,4 +31,7 @@ script:
- make errcheck - make errcheck
- make fmt - make fmt
after_success:
- bash <(curl -s https://codecov.io/bash)
sudo: false sudo: false

View file

@ -1,5 +1,22 @@
# Changelog # Changelog
#### Version 1.14.0 (2017-11-13)
New Features:
- Add support for the new Kafka 0.11 record-batch format, including the wire
protocol and the necessary behavioural changes in the producer and consumer.
Transactions and idempotency are not yet supported, but producing and
consuming should work with all the existing bells and whistles (batching,
compression, etc) as well as the new custom headers. Thanks to Vlad Hanciuta
of Arista Networks for this work. Part of
([#901](https://github.com/Shopify/sarama/issues/901)).
Bug Fixes:
- Fix encoding of ProduceResponse versions in test
([#970](https://github.com/Shopify/sarama/pull/970)).
- Return partial replicas list when we have it
([#975](https://github.com/Shopify/sarama/pull/975)).
#### Version 1.13.0 (2017-10-04) #### Version 1.13.0 (2017-10-04)
New Features: New Features:

View file

@ -1,7 +1,15 @@
default: fmt vet errcheck test default: fmt vet errcheck test
# Taken from https://github.com/codecov/example-go#caveat-multiple-files
test: test:
go test -v -timeout 60s -race ./... echo "" > coverage.txt
for d in `go list ./... | grep -v vendor`; do \
go test -v -timeout 60s -race -coverprofile=profile.out -covermode=atomic $$d; \
if [ -f profile.out ]; then \
cat profile.out >> coverage.txt; \
rm profile.out; \
fi \
done
vet: vet:
go vet ./... go vet ./...

View file

@ -3,6 +3,7 @@ sarama
[![GoDoc](https://godoc.org/github.com/Shopify/sarama?status.png)](https://godoc.org/github.com/Shopify/sarama) [![GoDoc](https://godoc.org/github.com/Shopify/sarama?status.png)](https://godoc.org/github.com/Shopify/sarama)
[![Build Status](https://travis-ci.org/Shopify/sarama.svg?branch=master)](https://travis-ci.org/Shopify/sarama) [![Build Status](https://travis-ci.org/Shopify/sarama.svg?branch=master)](https://travis-ci.org/Shopify/sarama)
[![Coverage](https://codecov.io/gh/Shopify/sarama/branch/master/graph/badge.svg)](https://codecov.io/gh/Shopify/sarama)
Sarama is an MIT-licensed Go client library for [Apache Kafka](https://kafka.apache.org/) version 0.8 (and later). Sarama is an MIT-licensed Go client library for [Apache Kafka](https://kafka.apache.org/) version 0.8 (and later).

View file

@ -1,6 +1,7 @@
package sarama package sarama
import ( import (
"encoding/binary"
"fmt" "fmt"
"sync" "sync"
"time" "time"
@ -119,6 +120,10 @@ type ProducerMessage struct {
// StringEncoder and ByteEncoder. // StringEncoder and ByteEncoder.
Value Encoder Value Encoder
// The headers are key-value pairs that are transparently passed
// by Kafka between producers and consumers.
Headers []RecordHeader
// This field is used to hold arbitrary data you wish to include so it // This field is used to hold arbitrary data you wish to include so it
// will be available when receiving on the Successes and Errors channels. // will be available when receiving on the Successes and Errors channels.
// Sarama completely ignores this field and is only to be used for // Sarama completely ignores this field and is only to be used for
@ -146,8 +151,16 @@ type ProducerMessage struct {
const producerMessageOverhead = 26 // the metadata overhead of CRC, flags, etc. const producerMessageOverhead = 26 // the metadata overhead of CRC, flags, etc.
func (m *ProducerMessage) byteSize() int { func (m *ProducerMessage) byteSize(version int) int {
size := producerMessageOverhead var size int
if version >= 2 {
size = maximumRecordOverhead
for _, h := range m.Headers {
size += len(h.Key) + len(h.Value) + 2*binary.MaxVarintLen32
}
} else {
size = producerMessageOverhead
}
if m.Key != nil { if m.Key != nil {
size += m.Key.Length() size += m.Key.Length()
} }
@ -254,7 +267,11 @@ func (p *asyncProducer) dispatcher() {
p.inFlight.Add(1) p.inFlight.Add(1)
} }
if msg.byteSize() > p.conf.Producer.MaxMessageBytes { version := 1
if p.conf.Version.IsAtLeast(V0_11_0_0) {
version = 2
}
if msg.byteSize(version) > p.conf.Producer.MaxMessageBytes {
p.returnError(msg, ErrMessageSizeTooLarge) p.returnError(msg, ErrMessageSizeTooLarge)
continue continue
} }

View file

@ -49,9 +49,9 @@ type Client interface {
RefreshMetadata(topics ...string) error RefreshMetadata(topics ...string) error
// GetOffset queries the cluster to get the most recent available offset at the // GetOffset queries the cluster to get the most recent available offset at the
// given time on the topic/partition combination. Time should be OffsetOldest for // given time (in milliseconds) on the topic/partition combination.
// the earliest available offset, OffsetNewest for the offset of the message that // Time should be OffsetOldest for the earliest available offset,
// will be produced next, or a time. // OffsetNewest for the offset of the message that will be produced next, or a time.
GetOffset(topic string, partitionID int32, time int64) (int64, error) GetOffset(topic string, partitionID int32, time int64) (int64, error)
// Coordinator returns the coordinating broker for a consumer group. It will // Coordinator returns the coordinating broker for a consumer group. It will
@ -297,7 +297,7 @@ func (client *client) Replicas(topic string, partitionID int32) ([]int32, error)
} }
if metadata.Err == ErrReplicaNotAvailable { if metadata.Err == ErrReplicaNotAvailable {
return nil, metadata.Err return dupInt32Slice(metadata.Replicas), metadata.Err
} }
return dupInt32Slice(metadata.Replicas), nil return dupInt32Slice(metadata.Replicas), nil
} }
@ -322,7 +322,7 @@ func (client *client) InSyncReplicas(topic string, partitionID int32) ([]int32,
} }
if metadata.Err == ErrReplicaNotAvailable { if metadata.Err == ErrReplicaNotAvailable {
return nil, metadata.Err return dupInt32Slice(metadata.Isr), metadata.Err
} }
return dupInt32Slice(metadata.Isr), nil return dupInt32Slice(metadata.Isr), nil
} }

View file

@ -33,6 +33,169 @@ func TestEmptyClientIDConfigValidates(t *testing.T) {
} }
} }
func TestNetConfigValidates(t *testing.T) {
tests := []struct {
name string
cfg func(*Config) // resorting to using a function as a param because of internal composite structs
err string
}{
{
"OpenRequests",
func(cfg *Config) {
cfg.Net.MaxOpenRequests = 0
},
"Net.MaxOpenRequests must be > 0"},
{"DialTimeout",
func(cfg *Config) {
cfg.Net.DialTimeout = 0
},
"Net.DialTimeout must be > 0"},
{"ReadTimeout",
func(cfg *Config) {
cfg.Net.ReadTimeout = 0
},
"Net.ReadTimeout must be > 0"},
{"WriteTimeout",
func(cfg *Config) {
cfg.Net.WriteTimeout = 0
},
"Net.WriteTimeout must be > 0"},
{"KeepAlive",
func(cfg *Config) {
cfg.Net.KeepAlive = -1
},
"Net.KeepAlive must be >= 0"},
{"SASL.User",
func(cfg *Config) {
cfg.Net.SASL.Enable = true
cfg.Net.SASL.User = ""
},
"Net.SASL.User must not be empty when SASL is enabled"},
{"SASL.Password",
func(cfg *Config) {
cfg.Net.SASL.Enable = true
cfg.Net.SASL.User = "user"
cfg.Net.SASL.Password = ""
},
"Net.SASL.Password must not be empty when SASL is enabled"},
}
for i, test := range tests {
c := NewConfig()
test.cfg(c)
if err := c.Validate(); string(err.(ConfigurationError)) != test.err {
t.Errorf("[%d]:[%s] Expected %s, Got %s\n", i, test.name, test.err, err)
}
}
}
func TestMetadataConfigValidates(t *testing.T) {
tests := []struct {
name string
cfg func(*Config) // resorting to using a function as a param because of internal composite structs
err string
}{
{
"Retry.Max",
func(cfg *Config) {
cfg.Metadata.Retry.Max = -1
},
"Metadata.Retry.Max must be >= 0"},
{"Retry.Backoff",
func(cfg *Config) {
cfg.Metadata.Retry.Backoff = -1
},
"Metadata.Retry.Backoff must be >= 0"},
{"RefreshFrequency",
func(cfg *Config) {
cfg.Metadata.RefreshFrequency = -1
},
"Metadata.RefreshFrequency must be >= 0"},
}
for i, test := range tests {
c := NewConfig()
test.cfg(c)
if err := c.Validate(); string(err.(ConfigurationError)) != test.err {
t.Errorf("[%d]:[%s] Expected %s, Got %s\n", i, test.name, test.err, err)
}
}
}
func TestProducerConfigValidates(t *testing.T) {
tests := []struct {
name string
cfg func(*Config) // resorting to using a function as a param because of internal composite structs
err string
}{
{
"MaxMessageBytes",
func(cfg *Config) {
cfg.Producer.MaxMessageBytes = 0
},
"Producer.MaxMessageBytes must be > 0"},
{"RequiredAcks",
func(cfg *Config) {
cfg.Producer.RequiredAcks = -2
},
"Producer.RequiredAcks must be >= -1"},
{"Timeout",
func(cfg *Config) {
cfg.Producer.Timeout = 0
},
"Producer.Timeout must be > 0"},
{"Partitioner",
func(cfg *Config) {
cfg.Producer.Partitioner = nil
},
"Producer.Partitioner must not be nil"},
{"Flush.Bytes",
func(cfg *Config) {
cfg.Producer.Flush.Bytes = -1
},
"Producer.Flush.Bytes must be >= 0"},
{"Flush.Messages",
func(cfg *Config) {
cfg.Producer.Flush.Messages = -1
},
"Producer.Flush.Messages must be >= 0"},
{"Flush.Frequency",
func(cfg *Config) {
cfg.Producer.Flush.Frequency = -1
},
"Producer.Flush.Frequency must be >= 0"},
{"Flush.MaxMessages",
func(cfg *Config) {
cfg.Producer.Flush.MaxMessages = -1
},
"Producer.Flush.MaxMessages must be >= 0"},
{"Flush.MaxMessages with Producer.Flush.Messages",
func(cfg *Config) {
cfg.Producer.Flush.MaxMessages = 1
cfg.Producer.Flush.Messages = 2
},
"Producer.Flush.MaxMessages must be >= Producer.Flush.Messages when set"},
{"Flush.Retry.Max",
func(cfg *Config) {
cfg.Producer.Retry.Max = -1
},
"Producer.Retry.Max must be >= 0"},
{"Flush.Retry.Backoff",
func(cfg *Config) {
cfg.Producer.Retry.Backoff = -1
},
"Producer.Retry.Backoff must be >= 0"},
}
for i, test := range tests {
c := NewConfig()
test.cfg(c)
if err := c.Validate(); string(err.(ConfigurationError)) != test.err {
t.Errorf("[%d]:[%s] Expected %s, Got %s\n", i, test.name, test.err, err)
}
}
}
func TestLZ4ConfigValidation(t *testing.T) { func TestLZ4ConfigValidation(t *testing.T) {
config := NewConfig() config := NewConfig()
config.Producer.Compression = CompressionLZ4 config.Producer.Compression = CompressionLZ4

View file

@ -16,6 +16,7 @@ type ConsumerMessage struct {
Offset int64 Offset int64
Timestamp time.Time // only set if kafka is version 0.10+, inner message timestamp Timestamp time.Time // only set if kafka is version 0.10+, inner message timestamp
BlockTimestamp time.Time // only set if kafka is version 0.10+, outer (compressed) block timestamp BlockTimestamp time.Time // only set if kafka is version 0.10+, outer (compressed) block timestamp
Headers []*RecordHeader // only set if kafka is version 0.11+
} }
// ConsumerError is what is provided to the user when an error occurs. // ConsumerError is what is provided to the user when an error occurs.
@ -478,44 +479,12 @@ feederLoop:
close(child.errors) close(child.errors)
} }
func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*ConsumerMessage, error) { func (child *partitionConsumer) parseMessages(msgSet *MessageSet) ([]*ConsumerMessage, error) {
block := response.GetBlock(child.topic, child.partition)
if block == nil {
return nil, ErrIncompleteResponse
}
if block.Err != ErrNoError {
return nil, block.Err
}
if len(block.MsgSet.Messages) == 0 {
// We got no messages. If we got a trailing one then we need to ask for more data.
// Otherwise we just poll again and wait for one to be produced...
if block.MsgSet.PartialTrailingMessage {
if child.conf.Consumer.Fetch.Max > 0 && child.fetchSize == child.conf.Consumer.Fetch.Max {
// we can't ask for more data, we've hit the configured limit
child.sendError(ErrMessageTooLarge)
child.offset++ // skip this one so we can keep processing future messages
} else {
child.fetchSize *= 2
if child.conf.Consumer.Fetch.Max > 0 && child.fetchSize > child.conf.Consumer.Fetch.Max {
child.fetchSize = child.conf.Consumer.Fetch.Max
}
}
}
return nil, nil
}
// we got messages, reset our fetch size in case it was increased for a previous request
child.fetchSize = child.conf.Consumer.Fetch.Default
atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset)
incomplete := false
prelude := true
var messages []*ConsumerMessage var messages []*ConsumerMessage
for _, msgBlock := range block.MsgSet.Messages { var incomplete bool
prelude := true
for _, msgBlock := range msgSet.Messages {
for _, msg := range msgBlock.Messages() { for _, msg := range msgBlock.Messages() {
offset := msg.Offset offset := msg.Offset
if msg.Msg.Version >= 1 { if msg.Msg.Version >= 1 {
@ -542,7 +511,6 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
incomplete = true incomplete = true
} }
} }
} }
if incomplete || len(messages) == 0 { if incomplete || len(messages) == 0 {
@ -551,6 +519,97 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
return messages, nil return messages, nil
} }
func (child *partitionConsumer) parseRecords(block *FetchResponseBlock) ([]*ConsumerMessage, error) {
var messages []*ConsumerMessage
var incomplete bool
prelude := true
batch := block.Records.recordBatch
for _, rec := range batch.Records {
offset := batch.FirstOffset + rec.OffsetDelta
if prelude && offset < child.offset {
continue
}
prelude = false
if offset >= child.offset {
messages = append(messages, &ConsumerMessage{
Topic: child.topic,
Partition: child.partition,
Key: rec.Key,
Value: rec.Value,
Offset: offset,
Timestamp: batch.FirstTimestamp.Add(rec.TimestampDelta),
Headers: rec.Headers,
})
child.offset = offset + 1
} else {
incomplete = true
}
if child.offset > block.LastStableOffset {
// We reached the end of closed transactions
break
}
}
if incomplete || len(messages) == 0 {
return nil, ErrIncompleteResponse
}
return messages, nil
}
func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*ConsumerMessage, error) {
block := response.GetBlock(child.topic, child.partition)
if block == nil {
return nil, ErrIncompleteResponse
}
if block.Err != ErrNoError {
return nil, block.Err
}
nRecs, err := block.Records.numRecords()
if err != nil {
return nil, err
}
if nRecs == 0 {
partialTrailingMessage, err := block.Records.isPartial()
if err != nil {
return nil, err
}
// We got no messages. If we got a trailing one then we need to ask for more data.
// Otherwise we just poll again and wait for one to be produced...
if partialTrailingMessage {
if child.conf.Consumer.Fetch.Max > 0 && child.fetchSize == child.conf.Consumer.Fetch.Max {
// we can't ask for more data, we've hit the configured limit
child.sendError(ErrMessageTooLarge)
child.offset++ // skip this one so we can keep processing future messages
} else {
child.fetchSize *= 2
if child.conf.Consumer.Fetch.Max > 0 && child.fetchSize > child.conf.Consumer.Fetch.Max {
child.fetchSize = child.conf.Consumer.Fetch.Max
}
}
}
return nil, nil
}
// we got messages, reset our fetch size in case it was increased for a previous request
child.fetchSize = child.conf.Consumer.Fetch.Default
atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset)
if control, err := block.Records.isControl(); err != nil || control {
return nil, err
}
if response.Version < 4 {
return child.parseMessages(block.Records.msgSet)
}
return child.parseRecords(block)
}
// brokerConsumer // brokerConsumer
type brokerConsumer struct { type brokerConsumer struct {
@ -740,6 +799,10 @@ func (bc *brokerConsumer) fetchNewMessages() (*FetchResponse, error) {
request.Version = 3 request.Version = 3
request.MaxBytes = MaxResponseSize request.MaxBytes = MaxResponseSize
} }
if bc.consumer.conf.Version.IsAtLeast(V0_11_0_0) {
request.Version = 4
request.Isolation = ReadUncommitted // We don't support yet transactions.
}
for child := range bc.subscriptions { for child := range bc.subscriptions {
request.AddBlock(child.topic, child.partition, child.offset, child.fetchSize) request.AddBlock(child.topic, child.partition, child.offset, child.fetchSize)

View file

@ -379,25 +379,41 @@ func TestConsumerShutsDownOutOfRange(t *testing.T) {
// requested, then such messages are ignored. // requested, then such messages are ignored.
func TestConsumerExtraOffsets(t *testing.T) { func TestConsumerExtraOffsets(t *testing.T) {
// Given // Given
legacyFetchResponse := &FetchResponse{}
legacyFetchResponse.AddMessage("my_topic", 0, nil, testMsg, 1)
legacyFetchResponse.AddMessage("my_topic", 0, nil, testMsg, 2)
legacyFetchResponse.AddMessage("my_topic", 0, nil, testMsg, 3)
legacyFetchResponse.AddMessage("my_topic", 0, nil, testMsg, 4)
newFetchResponse := &FetchResponse{Version: 4}
newFetchResponse.AddRecord("my_topic", 0, nil, testMsg, 1)
newFetchResponse.AddRecord("my_topic", 0, nil, testMsg, 2)
newFetchResponse.AddRecord("my_topic", 0, nil, testMsg, 3)
newFetchResponse.AddRecord("my_topic", 0, nil, testMsg, 4)
newFetchResponse.SetLastStableOffset("my_topic", 0, 4)
for _, fetchResponse1 := range []*FetchResponse{legacyFetchResponse, newFetchResponse} {
var offsetResponseVersion int16
cfg := NewConfig()
if fetchResponse1.Version >= 4 {
cfg.Version = V0_11_0_0
offsetResponseVersion = 1
}
broker0 := NewMockBroker(t, 0) broker0 := NewMockBroker(t, 0)
fetchResponse1 := &FetchResponse{}
fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 1)
fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 2)
fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 3)
fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 4)
fetchResponse2 := &FetchResponse{} fetchResponse2 := &FetchResponse{}
fetchResponse2.Version = fetchResponse1.Version
fetchResponse2.AddError("my_topic", 0, ErrNoError) fetchResponse2.AddError("my_topic", 0, ErrNoError)
broker0.SetHandlerByMap(map[string]MockResponse{ broker0.SetHandlerByMap(map[string]MockResponse{
"MetadataRequest": NewMockMetadataResponse(t). "MetadataRequest": NewMockMetadataResponse(t).
SetBroker(broker0.Addr(), broker0.BrokerID()). SetBroker(broker0.Addr(), broker0.BrokerID()).
SetLeader("my_topic", 0, broker0.BrokerID()), SetLeader("my_topic", 0, broker0.BrokerID()),
"OffsetRequest": NewMockOffsetResponse(t). "OffsetRequest": NewMockOffsetResponse(t).
SetVersion(offsetResponseVersion).
SetOffset("my_topic", 0, OffsetNewest, 1234). SetOffset("my_topic", 0, OffsetNewest, 1234).
SetOffset("my_topic", 0, OffsetOldest, 0), SetOffset("my_topic", 0, OffsetOldest, 0),
"FetchRequest": NewMockSequence(fetchResponse1, fetchResponse2), "FetchRequest": NewMockSequence(fetchResponse1, fetchResponse2),
}) })
master, err := NewConsumer([]string{broker0.Addr()}, nil) master, err := NewConsumer([]string{broker0.Addr()}, cfg)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -416,30 +432,45 @@ func TestConsumerExtraOffsets(t *testing.T) {
safeClose(t, consumer) safeClose(t, consumer)
safeClose(t, master) safeClose(t, master)
broker0.Close() broker0.Close()
}
} }
// It is fine if offsets of fetched messages are not sequential (although // It is fine if offsets of fetched messages are not sequential (although
// strictly increasing!). // strictly increasing!).
func TestConsumerNonSequentialOffsets(t *testing.T) { func TestConsumerNonSequentialOffsets(t *testing.T) {
// Given // Given
legacyFetchResponse := &FetchResponse{}
legacyFetchResponse.AddMessage("my_topic", 0, nil, testMsg, 5)
legacyFetchResponse.AddMessage("my_topic", 0, nil, testMsg, 7)
legacyFetchResponse.AddMessage("my_topic", 0, nil, testMsg, 11)
newFetchResponse := &FetchResponse{Version: 4}
newFetchResponse.AddRecord("my_topic", 0, nil, testMsg, 5)
newFetchResponse.AddRecord("my_topic", 0, nil, testMsg, 7)
newFetchResponse.AddRecord("my_topic", 0, nil, testMsg, 11)
newFetchResponse.SetLastStableOffset("my_topic", 0, 11)
for _, fetchResponse1 := range []*FetchResponse{legacyFetchResponse, newFetchResponse} {
var offsetResponseVersion int16
cfg := NewConfig()
if fetchResponse1.Version >= 4 {
cfg.Version = V0_11_0_0
offsetResponseVersion = 1
}
broker0 := NewMockBroker(t, 0) broker0 := NewMockBroker(t, 0)
fetchResponse1 := &FetchResponse{} fetchResponse2 := &FetchResponse{Version: fetchResponse1.Version}
fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 5)
fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 7)
fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 11)
fetchResponse2 := &FetchResponse{}
fetchResponse2.AddError("my_topic", 0, ErrNoError) fetchResponse2.AddError("my_topic", 0, ErrNoError)
broker0.SetHandlerByMap(map[string]MockResponse{ broker0.SetHandlerByMap(map[string]MockResponse{
"MetadataRequest": NewMockMetadataResponse(t). "MetadataRequest": NewMockMetadataResponse(t).
SetBroker(broker0.Addr(), broker0.BrokerID()). SetBroker(broker0.Addr(), broker0.BrokerID()).
SetLeader("my_topic", 0, broker0.BrokerID()), SetLeader("my_topic", 0, broker0.BrokerID()),
"OffsetRequest": NewMockOffsetResponse(t). "OffsetRequest": NewMockOffsetResponse(t).
SetVersion(offsetResponseVersion).
SetOffset("my_topic", 0, OffsetNewest, 1234). SetOffset("my_topic", 0, OffsetNewest, 1234).
SetOffset("my_topic", 0, OffsetOldest, 0), SetOffset("my_topic", 0, OffsetOldest, 0),
"FetchRequest": NewMockSequence(fetchResponse1, fetchResponse2), "FetchRequest": NewMockSequence(fetchResponse1, fetchResponse2),
}) })
master, err := NewConsumer([]string{broker0.Addr()}, nil) master, err := NewConsumer([]string{broker0.Addr()}, cfg)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -459,6 +490,7 @@ func TestConsumerNonSequentialOffsets(t *testing.T) {
safeClose(t, consumer) safeClose(t, consumer)
safeClose(t, master) safeClose(t, master)
broker0.Close() broker0.Close()
}
} }
// If leadership for a partition is changing then consumer resolves the new // If leadership for a partition is changing then consumer resolves the new

View file

@ -6,9 +6,19 @@ import (
"hash/crc32" "hash/crc32"
) )
type crcPolynomial int8
const (
crcIEEE crcPolynomial = iota
crcCastagnoli
)
var castagnoliTable = crc32.MakeTable(crc32.Castagnoli)
// crc32Field implements the pushEncoder and pushDecoder interfaces for calculating CRC32s. // crc32Field implements the pushEncoder and pushDecoder interfaces for calculating CRC32s.
type crc32Field struct { type crc32Field struct {
startOffset int startOffset int
polynomial crcPolynomial
} }
func (c *crc32Field) saveOffset(in int) { func (c *crc32Field) saveOffset(in int) {
@ -19,14 +29,24 @@ func (c *crc32Field) reserveLength() int {
return 4 return 4
} }
func newCRC32Field(polynomial crcPolynomial) *crc32Field {
return &crc32Field{polynomial: polynomial}
}
func (c *crc32Field) run(curOffset int, buf []byte) error { func (c *crc32Field) run(curOffset int, buf []byte) error {
crc := crc32.ChecksumIEEE(buf[c.startOffset+4 : curOffset]) crc, err := c.crc(curOffset, buf)
if err != nil {
return err
}
binary.BigEndian.PutUint32(buf[c.startOffset:], crc) binary.BigEndian.PutUint32(buf[c.startOffset:], crc)
return nil return nil
} }
func (c *crc32Field) check(curOffset int, buf []byte) error { func (c *crc32Field) check(curOffset int, buf []byte) error {
crc := crc32.ChecksumIEEE(buf[c.startOffset+4 : curOffset]) crc, err := c.crc(curOffset, buf)
if err != nil {
return err
}
expected := binary.BigEndian.Uint32(buf[c.startOffset:]) expected := binary.BigEndian.Uint32(buf[c.startOffset:])
if crc != expected { if crc != expected {
@ -35,3 +55,15 @@ func (c *crc32Field) check(curOffset int, buf []byte) error {
return nil return nil
} }
func (c *crc32Field) crc(curOffset int, buf []byte) (uint32, error) {
var tab *crc32.Table
switch c.polynomial {
case crcIEEE:
tab = crc32.IEEETable
case crcCastagnoli:
tab = castagnoliTable
default:
return 0, PacketDecodingError{"invalid CRC type"}
}
return crc32.Checksum(buf[c.startOffset+4:curOffset], tab), nil
}

View file

@ -29,16 +29,27 @@ type FetchRequest struct {
MinBytes int32 MinBytes int32
MaxBytes int32 MaxBytes int32
Version int16 Version int16
Isolation IsolationLevel
blocks map[string]map[int32]*fetchRequestBlock blocks map[string]map[int32]*fetchRequestBlock
} }
type IsolationLevel int8
const (
ReadUncommitted IsolationLevel = 0
ReadCommitted IsolationLevel = 1
)
func (r *FetchRequest) encode(pe packetEncoder) (err error) { func (r *FetchRequest) encode(pe packetEncoder) (err error) {
pe.putInt32(-1) // replica ID is always -1 for clients pe.putInt32(-1) // replica ID is always -1 for clients
pe.putInt32(r.MaxWaitTime) pe.putInt32(r.MaxWaitTime)
pe.putInt32(r.MinBytes) pe.putInt32(r.MinBytes)
if r.Version == 3 { if r.Version >= 3 {
pe.putInt32(r.MaxBytes) pe.putInt32(r.MaxBytes)
} }
if r.Version >= 4 {
pe.putInt8(int8(r.Isolation))
}
err = pe.putArrayLength(len(r.blocks)) err = pe.putArrayLength(len(r.blocks))
if err != nil { if err != nil {
return err return err
@ -74,11 +85,18 @@ func (r *FetchRequest) decode(pd packetDecoder, version int16) (err error) {
if r.MinBytes, err = pd.getInt32(); err != nil { if r.MinBytes, err = pd.getInt32(); err != nil {
return err return err
} }
if r.Version == 3 { if r.Version >= 3 {
if r.MaxBytes, err = pd.getInt32(); err != nil { if r.MaxBytes, err = pd.getInt32(); err != nil {
return err return err
} }
} }
if r.Version >= 4 {
isolation, err := pd.getInt8()
if err != nil {
return err
}
r.Isolation = IsolationLevel(isolation)
}
topicCount, err := pd.getArrayLength() topicCount, err := pd.getArrayLength()
if err != nil { if err != nil {
return err return err
@ -128,6 +146,8 @@ func (r *FetchRequest) requiredVersion() KafkaVersion {
return V0_10_0_0 return V0_10_0_0
case 3: case 3:
return V0_10_1_0 return V0_10_1_0
case 4:
return V0_11_0_0
default: default:
return minVersion return minVersion
} }

View file

@ -17,6 +17,15 @@ var (
0x00, 0x05, 't', 'o', 'p', 'i', 'c', 0x00, 0x05, 't', 'o', 'p', 'i', 'c',
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01,
0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x56} 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x56}
fetchRequestOneBlockV4 = []byte{
0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0xFF,
0x01,
0x00, 0x00, 0x00, 0x01,
0x00, 0x05, 't', 'o', 'p', 'i', 'c',
0x00, 0x00, 0x00, 0x01,
0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x56}
) )
func TestFetchRequest(t *testing.T) { func TestFetchRequest(t *testing.T) {
@ -31,4 +40,9 @@ func TestFetchRequest(t *testing.T) {
request.MinBytes = 0 request.MinBytes = 0
request.AddBlock("topic", 0x12, 0x34, 0x56) request.AddBlock("topic", 0x12, 0x34, 0x56)
testRequest(t, "one block", request, fetchRequestOneBlock) testRequest(t, "one block", request, fetchRequestOneBlock)
request.Version = 4
request.MaxBytes = 0xFF
request.Isolation = ReadCommitted
testRequest(t, "one block v4", request, fetchRequestOneBlockV4)
} }

View file

@ -2,13 +2,39 @@ package sarama
import "time" import "time"
type AbortedTransaction struct {
ProducerID int64
FirstOffset int64
}
func (t *AbortedTransaction) decode(pd packetDecoder) (err error) {
if t.ProducerID, err = pd.getInt64(); err != nil {
return err
}
if t.FirstOffset, err = pd.getInt64(); err != nil {
return err
}
return nil
}
func (t *AbortedTransaction) encode(pe packetEncoder) (err error) {
pe.putInt64(t.ProducerID)
pe.putInt64(t.FirstOffset)
return nil
}
type FetchResponseBlock struct { type FetchResponseBlock struct {
Err KError Err KError
HighWaterMarkOffset int64 HighWaterMarkOffset int64
MsgSet MessageSet LastStableOffset int64
AbortedTransactions []*AbortedTransaction
Records Records
} }
func (b *FetchResponseBlock) decode(pd packetDecoder) (err error) { func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error) {
tmp, err := pd.getInt16() tmp, err := pd.getInt16()
if err != nil { if err != nil {
return err return err
@ -20,27 +46,75 @@ func (b *FetchResponseBlock) decode(pd packetDecoder) (err error) {
return err return err
} }
msgSetSize, err := pd.getInt32() if version >= 4 {
b.LastStableOffset, err = pd.getInt64()
if err != nil { if err != nil {
return err return err
} }
msgSetDecoder, err := pd.getSubset(int(msgSetSize)) numTransact, err := pd.getArrayLength()
if err != nil { if err != nil {
return err return err
} }
err = (&b.MsgSet).decode(msgSetDecoder)
if numTransact >= 0 {
b.AbortedTransactions = make([]*AbortedTransaction, numTransact)
}
for i := 0; i < numTransact; i++ {
transact := new(AbortedTransaction)
if err = transact.decode(pd); err != nil {
return err return err
}
b.AbortedTransactions[i] = transact
}
}
recordsSize, err := pd.getInt32()
if err != nil {
return err
}
recordsDecoder, err := pd.getSubset(int(recordsSize))
if err != nil {
return err
}
var records Records
if version >= 4 {
records = newDefaultRecords(nil)
} else {
records = newLegacyRecords(nil)
}
if recordsSize > 0 {
if err = records.decode(recordsDecoder); err != nil {
return err
}
}
b.Records = records
return nil
} }
func (b *FetchResponseBlock) encode(pe packetEncoder) (err error) { func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error) {
pe.putInt16(int16(b.Err)) pe.putInt16(int16(b.Err))
pe.putInt64(b.HighWaterMarkOffset) pe.putInt64(b.HighWaterMarkOffset)
if version >= 4 {
pe.putInt64(b.LastStableOffset)
if err = pe.putArrayLength(len(b.AbortedTransactions)); err != nil {
return err
}
for _, transact := range b.AbortedTransactions {
if err = transact.encode(pe); err != nil {
return err
}
}
}
pe.push(&lengthField{}) pe.push(&lengthField{})
err = b.MsgSet.encode(pe) err = b.Records.encode(pe)
if err != nil { if err != nil {
return err return err
} }
@ -90,7 +164,7 @@ func (r *FetchResponse) decode(pd packetDecoder, version int16) (err error) {
} }
block := new(FetchResponseBlock) block := new(FetchResponseBlock)
err = block.decode(pd) err = block.decode(pd, version)
if err != nil { if err != nil {
return err return err
} }
@ -124,7 +198,7 @@ func (r *FetchResponse) encode(pe packetEncoder) (err error) {
for id, block := range partitions { for id, block := range partitions {
pe.putInt32(id) pe.putInt32(id)
err = block.encode(pe) err = block.encode(pe, r.Version)
if err != nil { if err != nil {
return err return err
} }
@ -148,6 +222,10 @@ func (r *FetchResponse) requiredVersion() KafkaVersion {
return V0_9_0_0 return V0_9_0_0
case 2: case 2:
return V0_10_0_0 return V0_10_0_0
case 3:
return V0_10_1_0
case 4:
return V0_11_0_0
default: default:
return minVersion return minVersion
} }
@ -182,7 +260,7 @@ func (r *FetchResponse) AddError(topic string, partition int32, err KError) {
frb.Err = err frb.Err = err
} }
func (r *FetchResponse) AddMessage(topic string, partition int32, key, value Encoder, offset int64) { func (r *FetchResponse) getOrCreateBlock(topic string, partition int32) *FetchResponseBlock {
if r.Blocks == nil { if r.Blocks == nil {
r.Blocks = make(map[string]map[int32]*FetchResponseBlock) r.Blocks = make(map[string]map[int32]*FetchResponseBlock)
} }
@ -196,6 +274,11 @@ func (r *FetchResponse) AddMessage(topic string, partition int32, key, value Enc
frb = new(FetchResponseBlock) frb = new(FetchResponseBlock)
partitions[partition] = frb partitions[partition] = frb
} }
return frb
}
func encodeKV(key, value Encoder) ([]byte, []byte) {
var kb []byte var kb []byte
var vb []byte var vb []byte
if key != nil { if key != nil {
@ -204,7 +287,36 @@ func (r *FetchResponse) AddMessage(topic string, partition int32, key, value Enc
if value != nil { if value != nil {
vb, _ = value.Encode() vb, _ = value.Encode()
} }
return kb, vb
}
func (r *FetchResponse) AddMessage(topic string, partition int32, key, value Encoder, offset int64) {
frb := r.getOrCreateBlock(topic, partition)
kb, vb := encodeKV(key, value)
msg := &Message{Key: kb, Value: vb} msg := &Message{Key: kb, Value: vb}
msgBlock := &MessageBlock{Msg: msg, Offset: offset} msgBlock := &MessageBlock{Msg: msg, Offset: offset}
frb.MsgSet.Messages = append(frb.MsgSet.Messages, msgBlock) set := frb.Records.msgSet
if set == nil {
set = &MessageSet{}
frb.Records = newLegacyRecords(set)
}
set.Messages = append(set.Messages, msgBlock)
}
func (r *FetchResponse) AddRecord(topic string, partition int32, key, value Encoder, offset int64) {
frb := r.getOrCreateBlock(topic, partition)
kb, vb := encodeKV(key, value)
rec := &Record{Key: kb, Value: vb, OffsetDelta: offset}
batch := frb.Records.recordBatch
if batch == nil {
batch = &RecordBatch{Version: 2}
frb.Records = newDefaultRecords(batch)
}
batch.addRecord(rec)
}
func (r *FetchResponse) SetLastStableOffset(topic string, partition int32, offset int64) {
frb := r.getOrCreateBlock(topic, partition)
frb.LastStableOffset = offset
} }

View file

@ -26,6 +26,43 @@ var (
0x00, 0x00,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0x00, 0x00, 0x00, 0x02, 0x00, 0xEE} 0x00, 0x00, 0x00, 0x02, 0x00, 0xEE}
oneRecordFetchResponse = []byte{
0x00, 0x00, 0x00, 0x00, // ThrottleTime
0x00, 0x00, 0x00, 0x01, // Number of Topics
0x00, 0x05, 't', 'o', 'p', 'i', 'c', // Topic
0x00, 0x00, 0x00, 0x01, // Number of Partitions
0x00, 0x00, 0x00, 0x05, // Partition
0x00, 0x01, // Error
0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x10, 0x10, // High Watermark Offset
0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x10, 0x10, // Last Stable Offset
0x00, 0x00, 0x00, 0x00, // Number of Aborted Transactions
0x00, 0x00, 0x00, 0x52, // Records length
// recordBatch
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x46,
0x00, 0x00, 0x00, 0x00,
0x02,
0xDB, 0x47, 0x14, 0xC9,
0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0A,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01,
// record
0x28,
0x00,
0x0A,
0x00,
0x08, 0x01, 0x02, 0x03, 0x04,
0x06, 0x05, 0x06, 0x07,
0x02,
0x06, 0x08, 0x09, 0x0A,
0x04, 0x0B, 0x0C,
}
) )
func TestEmptyFetchResponse(t *testing.T) { func TestEmptyFetchResponse(t *testing.T) {
@ -60,14 +97,22 @@ func TestOneMessageFetchResponse(t *testing.T) {
if block.HighWaterMarkOffset != 0x10101010 { if block.HighWaterMarkOffset != 0x10101010 {
t.Error("Decoding didn't produce correct high water mark offset.") t.Error("Decoding didn't produce correct high water mark offset.")
} }
if block.MsgSet.PartialTrailingMessage { partial, err := block.Records.isPartial()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if partial {
t.Error("Decoding detected a partial trailing message where there wasn't one.") t.Error("Decoding detected a partial trailing message where there wasn't one.")
} }
if len(block.MsgSet.Messages) != 1 { n, err := block.Records.numRecords()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n != 1 {
t.Fatal("Decoding produced incorrect number of messages.") t.Fatal("Decoding produced incorrect number of messages.")
} }
msgBlock := block.MsgSet.Messages[0] msgBlock := block.Records.msgSet.Messages[0]
if msgBlock.Offset != 0x550000 { if msgBlock.Offset != 0x550000 {
t.Error("Decoding produced incorrect message offset.") t.Error("Decoding produced incorrect message offset.")
} }
@ -82,3 +127,49 @@ func TestOneMessageFetchResponse(t *testing.T) {
t.Error("Decoding produced incorrect message value.") t.Error("Decoding produced incorrect message value.")
} }
} }
func TestOneRecordFetchResponse(t *testing.T) {
response := FetchResponse{}
testVersionDecodable(t, "one record", &response, oneRecordFetchResponse, 4)
if len(response.Blocks) != 1 {
t.Fatal("Decoding produced incorrect number of topic blocks.")
}
if len(response.Blocks["topic"]) != 1 {
t.Fatal("Decoding produced incorrect number of partition blocks for topic.")
}
block := response.GetBlock("topic", 5)
if block == nil {
t.Fatal("GetBlock didn't return block.")
}
if block.Err != ErrOffsetOutOfRange {
t.Error("Decoding didn't produce correct error code.")
}
if block.HighWaterMarkOffset != 0x10101010 {
t.Error("Decoding didn't produce correct high water mark offset.")
}
partial, err := block.Records.isPartial()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if partial {
t.Error("Decoding detected a partial trailing record where there wasn't one.")
}
n, err := block.Records.numRecords()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n != 1 {
t.Fatal("Decoding produced incorrect number of records.")
}
rec := block.Records.recordBatch.Records[0]
if !bytes.Equal(rec.Key, []byte{0x01, 0x02, 0x03, 0x04}) {
t.Error("Decoding produced incorrect record key.")
}
if !bytes.Equal(rec.Value, []byte{0x05, 0x06, 0x07}) {
t.Error("Decoding produced incorrect record value.")
}
}

View file

@ -27,3 +27,43 @@ func (l *lengthField) check(curOffset int, buf []byte) error {
return nil return nil
} }
type varintLengthField struct {
startOffset int
length int64
}
func (l *varintLengthField) decode(pd packetDecoder) error {
var err error
l.length, err = pd.getVarint()
return err
}
func (l *varintLengthField) saveOffset(in int) {
l.startOffset = in
}
func (l *varintLengthField) adjustLength(currOffset int) int {
oldFieldSize := l.reserveLength()
l.length = int64(currOffset - l.startOffset - oldFieldSize)
return l.reserveLength() - oldFieldSize
}
func (l *varintLengthField) reserveLength() int {
var tmp [binary.MaxVarintLen64]byte
return binary.PutVarint(tmp[:], l.length)
}
func (l *varintLengthField) run(curOffset int, buf []byte) error {
binary.PutVarint(buf[l.startOffset:], l.length)
return nil
}
func (l *varintLengthField) check(curOffset int, buf []byte) error {
if int64(curOffset-l.startOffset-l.reserveLength()) != l.length {
return PacketDecodingError{"length field invalid"}
}
return nil
}

View file

@ -37,7 +37,7 @@ type Message struct {
} }
func (m *Message) encode(pe packetEncoder) error { func (m *Message) encode(pe packetEncoder) error {
pe.push(&crc32Field{}) pe.push(newCRC32Field(crcIEEE))
pe.putInt8(m.Version) pe.putInt8(m.Version)
@ -45,15 +45,9 @@ func (m *Message) encode(pe packetEncoder) error {
pe.putInt8(attributes) pe.putInt8(attributes)
if m.Version >= 1 { if m.Version >= 1 {
timestamp := int64(-1) if err := (Timestamp{&m.Timestamp}).encode(pe); err != nil {
return err
if !m.Timestamp.Before(time.Unix(0, 0)) {
timestamp = m.Timestamp.UnixNano() / int64(time.Millisecond)
} else if !m.Timestamp.IsZero() {
return PacketEncodingError{fmt.Sprintf("invalid timestamp (%v)", m.Timestamp)}
} }
pe.putInt64(timestamp)
} }
err := pe.putBytes(m.Key) err := pe.putBytes(m.Key)
@ -112,7 +106,7 @@ func (m *Message) encode(pe packetEncoder) error {
} }
func (m *Message) decode(pd packetDecoder) (err error) { func (m *Message) decode(pd packetDecoder) (err error) {
err = pd.push(&crc32Field{}) err = pd.push(newCRC32Field(crcIEEE))
if err != nil { if err != nil {
return err return err
} }
@ -133,19 +127,9 @@ func (m *Message) decode(pd packetDecoder) (err error) {
m.Codec = CompressionCodec(attribute & compressionCodecMask) m.Codec = CompressionCodec(attribute & compressionCodecMask)
if m.Version == 1 { if m.Version == 1 {
millis, err := pd.getInt64() if err := (Timestamp{&m.Timestamp}).decode(pd); err != nil {
if err != nil {
return err return err
} }
// negative timestamps are invalid, in these cases we should return
// a zero time
timestamp := time.Time{}
if millis >= 0 {
timestamp = time.Unix(millis/1000, (millis%1000)*int64(time.Millisecond))
}
m.Timestamp = timestamp
} }
m.Key, err = pd.getBytes() m.Key, err = pd.getBytes()

View file

@ -122,6 +122,7 @@ func (mmr *MockMetadataResponse) For(reqBody versionedDecoder) encoder {
type MockOffsetResponse struct { type MockOffsetResponse struct {
offsets map[string]map[int32]map[int64]int64 offsets map[string]map[int32]map[int64]int64
t TestReporter t TestReporter
version int16
} }
func NewMockOffsetResponse(t TestReporter) *MockOffsetResponse { func NewMockOffsetResponse(t TestReporter) *MockOffsetResponse {
@ -131,6 +132,11 @@ func NewMockOffsetResponse(t TestReporter) *MockOffsetResponse {
} }
} }
func (mor *MockOffsetResponse) SetVersion(version int16) *MockOffsetResponse {
mor.version = version
return mor
}
func (mor *MockOffsetResponse) SetOffset(topic string, partition int32, time, offset int64) *MockOffsetResponse { func (mor *MockOffsetResponse) SetOffset(topic string, partition int32, time, offset int64) *MockOffsetResponse {
partitions := mor.offsets[topic] partitions := mor.offsets[topic]
if partitions == nil { if partitions == nil {
@ -148,7 +154,7 @@ func (mor *MockOffsetResponse) SetOffset(topic string, partition int32, time, of
func (mor *MockOffsetResponse) For(reqBody versionedDecoder) encoder { func (mor *MockOffsetResponse) For(reqBody versionedDecoder) encoder {
offsetRequest := reqBody.(*OffsetRequest) offsetRequest := reqBody.(*OffsetRequest)
offsetResponse := &OffsetResponse{} offsetResponse := &OffsetResponse{Version: mor.version}
for topic, partitions := range offsetRequest.blocks { for topic, partitions := range offsetRequest.blocks {
for partition, block := range partitions { for partition, block := range partitions {
offset := mor.getOffset(topic, partition, block.time) offset := mor.getOffset(topic, partition, block.time)
@ -402,7 +408,7 @@ func (mr *MockProduceResponse) SetError(topic string, partition int32, kerror KE
func (mr *MockProduceResponse) For(reqBody versionedDecoder) encoder { func (mr *MockProduceResponse) For(reqBody versionedDecoder) encoder {
req := reqBody.(*ProduceRequest) req := reqBody.(*ProduceRequest)
res := &ProduceResponse{} res := &ProduceResponse{}
for topic, partitions := range req.msgSets { for topic, partitions := range req.records {
for partition := range partitions { for partition := range partitions {
res.AddTopicPartition(topic, partition, mr.getError(topic, partition)) res.AddTopicPartition(topic, partition, mr.getError(topic, partition))
} }

View file

@ -9,11 +9,15 @@ type packetDecoder interface {
getInt16() (int16, error) getInt16() (int16, error)
getInt32() (int32, error) getInt32() (int32, error)
getInt64() (int64, error) getInt64() (int64, error)
getVarint() (int64, error)
getArrayLength() (int, error) getArrayLength() (int, error)
// Collections // Collections
getBytes() ([]byte, error) getBytes() ([]byte, error)
getVarintBytes() ([]byte, error)
getRawBytes(length int) ([]byte, error)
getString() (string, error) getString() (string, error)
getNullableString() (*string, error)
getInt32Array() ([]int32, error) getInt32Array() ([]int32, error)
getInt64Array() ([]int64, error) getInt64Array() ([]int64, error)
getStringArray() ([]string, error) getStringArray() ([]string, error)
@ -43,3 +47,12 @@ type pushDecoder interface {
// of data from the saved offset, and verify it based on the data between the saved offset and curOffset. // of data from the saved offset, and verify it based on the data between the saved offset and curOffset.
check(curOffset int, buf []byte) error check(curOffset int, buf []byte) error
} }
// dynamicPushDecoder extends the interface of pushDecoder for uses cases where the length of the
// fields itself is unknown until its value was decoded (for instance varint encoded length
// fields).
// During push, dynamicPushDecoder.decode() method will be called instead of reserveLength()
type dynamicPushDecoder interface {
pushDecoder
decoder
}

View file

@ -11,12 +11,15 @@ type packetEncoder interface {
putInt16(in int16) putInt16(in int16)
putInt32(in int32) putInt32(in int32)
putInt64(in int64) putInt64(in int64)
putVarint(in int64)
putArrayLength(in int) error putArrayLength(in int) error
// Collections // Collections
putBytes(in []byte) error putBytes(in []byte) error
putVarintBytes(in []byte) error
putRawBytes(in []byte) error putRawBytes(in []byte) error
putString(in string) error putString(in string) error
putNullableString(in *string) error
putStringArray(in []string) error putStringArray(in []string) error
putInt32Array(in []int32) error putInt32Array(in []int32) error
putInt64Array(in []int64) error putInt64Array(in []int64) error
@ -48,3 +51,14 @@ type pushEncoder interface {
// of data to the saved offset, based on the data between the saved offset and curOffset. // of data to the saved offset, based on the data between the saved offset and curOffset.
run(curOffset int, buf []byte) error run(curOffset int, buf []byte) error
} }
// dynamicPushEncoder extends the interface of pushEncoder for uses cases where the length of the
// fields itself is unknown until its value was computed (for instance varint encoded length
// fields).
type dynamicPushEncoder interface {
pushEncoder
// Called during pop() to adjust the length of the field.
// It should return the difference in bytes between the last computed length and current length.
adjustLength(currOffset int) int
}

View file

@ -1,6 +1,7 @@
package sarama package sarama
import ( import (
"encoding/binary"
"fmt" "fmt"
"math" "math"
@ -8,6 +9,7 @@ import (
) )
type prepEncoder struct { type prepEncoder struct {
stack []pushEncoder
length int length int
} }
@ -29,6 +31,11 @@ func (pe *prepEncoder) putInt64(in int64) {
pe.length += 8 pe.length += 8
} }
func (pe *prepEncoder) putVarint(in int64) {
var buf [binary.MaxVarintLen64]byte
pe.length += binary.PutVarint(buf[:], in)
}
func (pe *prepEncoder) putArrayLength(in int) error { func (pe *prepEncoder) putArrayLength(in int) error {
if in > math.MaxInt32 { if in > math.MaxInt32 {
return PacketEncodingError{fmt.Sprintf("array too long (%d)", in)} return PacketEncodingError{fmt.Sprintf("array too long (%d)", in)}
@ -44,11 +51,16 @@ func (pe *prepEncoder) putBytes(in []byte) error {
if in == nil { if in == nil {
return nil return nil
} }
if len(in) > math.MaxInt32 { return pe.putRawBytes(in)
return PacketEncodingError{fmt.Sprintf("byteslice too long (%d)", len(in))} }
}
pe.length += len(in) func (pe *prepEncoder) putVarintBytes(in []byte) error {
if in == nil {
pe.putVarint(-1)
return nil return nil
}
pe.putVarint(int64(len(in)))
return pe.putRawBytes(in)
} }
func (pe *prepEncoder) putRawBytes(in []byte) error { func (pe *prepEncoder) putRawBytes(in []byte) error {
@ -59,6 +71,14 @@ func (pe *prepEncoder) putRawBytes(in []byte) error {
return nil return nil
} }
func (pe *prepEncoder) putNullableString(in *string) error {
if in == nil {
pe.length += 2
return nil
}
return pe.putString(*in)
}
func (pe *prepEncoder) putString(in string) error { func (pe *prepEncoder) putString(in string) error {
pe.length += 2 pe.length += 2
if len(in) > math.MaxInt16 { if len(in) > math.MaxInt16 {
@ -108,10 +128,18 @@ func (pe *prepEncoder) offset() int {
// stackable // stackable
func (pe *prepEncoder) push(in pushEncoder) { func (pe *prepEncoder) push(in pushEncoder) {
in.saveOffset(pe.length)
pe.length += in.reserveLength() pe.length += in.reserveLength()
pe.stack = append(pe.stack, in)
} }
func (pe *prepEncoder) pop() error { func (pe *prepEncoder) pop() error {
in := pe.stack[len(pe.stack)-1]
pe.stack = pe.stack[:len(pe.stack)-1]
if dpe, ok := in.(dynamicPushEncoder); ok {
pe.length += dpe.adjustLength(pe.length)
}
return nil return nil
} }

View file

@ -21,55 +21,16 @@ const (
) )
type ProduceRequest struct { type ProduceRequest struct {
TransactionalID *string
RequiredAcks RequiredAcks RequiredAcks RequiredAcks
Timeout int32 Timeout int32
Version int16 // v1 requires Kafka 0.9, v2 requires Kafka 0.10 Version int16 // v1 requires Kafka 0.9, v2 requires Kafka 0.10, v3 requires Kafka 0.11
msgSets map[string]map[int32]*MessageSet records map[string]map[int32]Records
} }
func (r *ProduceRequest) encode(pe packetEncoder) error { func updateMsgSetMetrics(msgSet *MessageSet, compressionRatioMetric metrics.Histogram,
pe.putInt16(int16(r.RequiredAcks)) topicCompressionRatioMetric metrics.Histogram) int64 {
pe.putInt32(r.Timeout) var topicRecordCount int64
err := pe.putArrayLength(len(r.msgSets))
if err != nil {
return err
}
metricRegistry := pe.metricRegistry()
var batchSizeMetric metrics.Histogram
var compressionRatioMetric metrics.Histogram
if metricRegistry != nil {
batchSizeMetric = getOrRegisterHistogram("batch-size", metricRegistry)
compressionRatioMetric = getOrRegisterHistogram("compression-ratio", metricRegistry)
}
totalRecordCount := int64(0)
for topic, partitions := range r.msgSets {
err = pe.putString(topic)
if err != nil {
return err
}
err = pe.putArrayLength(len(partitions))
if err != nil {
return err
}
topicRecordCount := int64(0)
var topicCompressionRatioMetric metrics.Histogram
if metricRegistry != nil {
topicCompressionRatioMetric = getOrRegisterTopicHistogram("compression-ratio", topic, metricRegistry)
}
for id, msgSet := range partitions {
startOffset := pe.offset()
pe.putInt32(id)
pe.push(&lengthField{})
err = msgSet.encode(pe)
if err != nil {
return err
}
err = pe.pop()
if err != nil {
return err
}
if metricRegistry != nil {
for _, messageBlock := range msgSet.Messages { for _, messageBlock := range msgSet.Messages {
// Is this a fake "message" wrapping real messages? // Is this a fake "message" wrapping real messages?
if messageBlock.Msg.Set != nil { if messageBlock.Msg.Set != nil {
@ -88,6 +49,74 @@ func (r *ProduceRequest) encode(pe packetEncoder) error {
topicCompressionRatioMetric.Update(intCompressionRatio) topicCompressionRatioMetric.Update(intCompressionRatio)
} }
} }
return topicRecordCount
}
func updateBatchMetrics(recordBatch *RecordBatch, compressionRatioMetric metrics.Histogram,
topicCompressionRatioMetric metrics.Histogram) int64 {
if recordBatch.compressedRecords != nil {
compressionRatio := int64(float64(recordBatch.recordsLen) / float64(len(recordBatch.compressedRecords)) * 100)
compressionRatioMetric.Update(compressionRatio)
topicCompressionRatioMetric.Update(compressionRatio)
}
return int64(len(recordBatch.Records))
}
func (r *ProduceRequest) encode(pe packetEncoder) error {
if r.Version >= 3 {
if err := pe.putNullableString(r.TransactionalID); err != nil {
return err
}
}
pe.putInt16(int16(r.RequiredAcks))
pe.putInt32(r.Timeout)
metricRegistry := pe.metricRegistry()
var batchSizeMetric metrics.Histogram
var compressionRatioMetric metrics.Histogram
if metricRegistry != nil {
batchSizeMetric = getOrRegisterHistogram("batch-size", metricRegistry)
compressionRatioMetric = getOrRegisterHistogram("compression-ratio", metricRegistry)
}
totalRecordCount := int64(0)
err := pe.putArrayLength(len(r.records))
if err != nil {
return err
}
for topic, partitions := range r.records {
err = pe.putString(topic)
if err != nil {
return err
}
err = pe.putArrayLength(len(partitions))
if err != nil {
return err
}
topicRecordCount := int64(0)
var topicCompressionRatioMetric metrics.Histogram
if metricRegistry != nil {
topicCompressionRatioMetric = getOrRegisterTopicHistogram("compression-ratio", topic, metricRegistry)
}
for id, records := range partitions {
startOffset := pe.offset()
pe.putInt32(id)
pe.push(&lengthField{})
err = records.encode(pe)
if err != nil {
return err
}
err = pe.pop()
if err != nil {
return err
}
if metricRegistry != nil {
if r.Version >= 3 {
topicRecordCount += updateBatchMetrics(records.recordBatch, compressionRatioMetric, topicCompressionRatioMetric)
} else {
topicRecordCount += updateMsgSetMetrics(records.msgSet, compressionRatioMetric, topicCompressionRatioMetric)
}
batchSize := int64(pe.offset() - startOffset) batchSize := int64(pe.offset() - startOffset)
batchSizeMetric.Update(batchSize) batchSizeMetric.Update(batchSize)
getOrRegisterTopicHistogram("batch-size", topic, metricRegistry).Update(batchSize) getOrRegisterTopicHistogram("batch-size", topic, metricRegistry).Update(batchSize)
@ -108,6 +137,15 @@ func (r *ProduceRequest) encode(pe packetEncoder) error {
} }
func (r *ProduceRequest) decode(pd packetDecoder, version int16) error { func (r *ProduceRequest) decode(pd packetDecoder, version int16) error {
r.Version = version
if version >= 3 {
id, err := pd.getNullableString()
if err != nil {
return err
}
r.TransactionalID = id
}
requiredAcks, err := pd.getInt16() requiredAcks, err := pd.getInt16()
if err != nil { if err != nil {
return err return err
@ -123,7 +161,8 @@ func (r *ProduceRequest) decode(pd packetDecoder, version int16) error {
if topicCount == 0 { if topicCount == 0 {
return nil return nil
} }
r.msgSets = make(map[string]map[int32]*MessageSet)
r.records = make(map[string]map[int32]Records)
for i := 0; i < topicCount; i++ { for i := 0; i < topicCount; i++ {
topic, err := pd.getString() topic, err := pd.getString()
if err != nil { if err != nil {
@ -133,28 +172,34 @@ func (r *ProduceRequest) decode(pd packetDecoder, version int16) error {
if err != nil { if err != nil {
return err return err
} }
r.msgSets[topic] = make(map[int32]*MessageSet) r.records[topic] = make(map[int32]Records)
for j := 0; j < partitionCount; j++ { for j := 0; j < partitionCount; j++ {
partition, err := pd.getInt32() partition, err := pd.getInt32()
if err != nil { if err != nil {
return err return err
} }
messageSetSize, err := pd.getInt32() size, err := pd.getInt32()
if err != nil { if err != nil {
return err return err
} }
msgSetDecoder, err := pd.getSubset(int(messageSetSize)) recordsDecoder, err := pd.getSubset(int(size))
if err != nil { if err != nil {
return err return err
} }
msgSet := &MessageSet{} var records Records
err = msgSet.decode(msgSetDecoder) if version >= 3 {
if err != nil { records = newDefaultRecords(nil)
} else {
records = newLegacyRecords(nil)
}
if err := records.decode(recordsDecoder); err != nil {
return err return err
} }
r.msgSets[topic][partition] = msgSet r.records[topic][partition] = records
} }
} }
return nil return nil
} }
@ -172,38 +217,41 @@ func (r *ProduceRequest) requiredVersion() KafkaVersion {
return V0_9_0_0 return V0_9_0_0
case 2: case 2:
return V0_10_0_0 return V0_10_0_0
case 3:
return V0_11_0_0
default: default:
return minVersion return minVersion
} }
} }
func (r *ProduceRequest) ensureRecords(topic string, partition int32) {
if r.records == nil {
r.records = make(map[string]map[int32]Records)
}
if r.records[topic] == nil {
r.records[topic] = make(map[int32]Records)
}
}
func (r *ProduceRequest) AddMessage(topic string, partition int32, msg *Message) { func (r *ProduceRequest) AddMessage(topic string, partition int32, msg *Message) {
if r.msgSets == nil { r.ensureRecords(topic, partition)
r.msgSets = make(map[string]map[int32]*MessageSet) set := r.records[topic][partition].msgSet
}
if r.msgSets[topic] == nil {
r.msgSets[topic] = make(map[int32]*MessageSet)
}
set := r.msgSets[topic][partition]
if set == nil { if set == nil {
set = new(MessageSet) set = new(MessageSet)
r.msgSets[topic][partition] = set r.records[topic][partition] = newLegacyRecords(set)
} }
set.addMessage(msg) set.addMessage(msg)
} }
func (r *ProduceRequest) AddSet(topic string, partition int32, set *MessageSet) { func (r *ProduceRequest) AddSet(topic string, partition int32, set *MessageSet) {
if r.msgSets == nil { r.ensureRecords(topic, partition)
r.msgSets = make(map[string]map[int32]*MessageSet) r.records[topic][partition] = newLegacyRecords(set)
} }
if r.msgSets[topic] == nil { func (r *ProduceRequest) AddBatch(topic string, partition int32, batch *RecordBatch) {
r.msgSets[topic] = make(map[int32]*MessageSet) r.ensureRecords(topic, partition)
} r.records[topic][partition] = newDefaultRecords(batch)
r.msgSets[topic][partition] = set
} }

View file

@ -2,6 +2,7 @@ package sarama
import ( import (
"testing" "testing"
"time"
) )
var ( var (
@ -32,6 +33,41 @@ var (
0x00, 0x00,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0x00, 0x00, 0x00, 0x02, 0x00, 0xEE} 0x00, 0x00, 0x00, 0x02, 0x00, 0xEE}
produceRequestOneRecord = []byte{
0xFF, 0xFF, // Transaction ID
0x01, 0x23, // Required Acks
0x00, 0x00, 0x04, 0x44, // Timeout
0x00, 0x00, 0x00, 0x01, // Number of Topics
0x00, 0x05, 't', 'o', 'p', 'i', 'c', // Topic
0x00, 0x00, 0x00, 0x01, // Number of Partitions
0x00, 0x00, 0x00, 0xAD, // Partition
0x00, 0x00, 0x00, 0x52, // Records length
// recordBatch
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x46,
0x00, 0x00, 0x00, 0x00,
0x02,
0x54, 0x79, 0x61, 0xFD,
0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x01, 0x58, 0x8D, 0xCD, 0x59, 0x38,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01,
// record
0x28,
0x00,
0x0A,
0x00,
0x08, 0x01, 0x02, 0x03, 0x04,
0x06, 0x05, 0x06, 0x07,
0x02,
0x06, 0x08, 0x09, 0x0A,
0x04, 0x0B, 0x0C,
}
) )
func TestProduceRequest(t *testing.T) { func TestProduceRequest(t *testing.T) {
@ -44,4 +80,24 @@ func TestProduceRequest(t *testing.T) {
request.AddMessage("topic", 0xAD, &Message{Codec: CompressionNone, Key: nil, Value: []byte{0x00, 0xEE}}) request.AddMessage("topic", 0xAD, &Message{Codec: CompressionNone, Key: nil, Value: []byte{0x00, 0xEE}})
testRequest(t, "one message", request, produceRequestOneMessage) testRequest(t, "one message", request, produceRequestOneMessage)
request.Version = 3
batch := &RecordBatch{
Version: 2,
FirstTimestamp: time.Unix(1479847795, 0),
MaxTimestamp: time.Unix(0, 0),
Records: []*Record{{
TimestampDelta: 5 * time.Millisecond,
Key: []byte{0x01, 0x02, 0x03, 0x04},
Value: []byte{0x05, 0x06, 0x07},
Headers: []*RecordHeader{{
Key: []byte{0x08, 0x09, 0x0A},
Value: []byte{0x0B, 0x0C},
}},
}},
}
request.AddBatch("topic", 0xAD, batch)
packet := testRequestEncode(t, "one record", request, produceRequestOneRecord)
batch.Records[0].length.startOffset = 0
testRequestDecode(t, "one record", request, packet)
} }

View file

@ -1,6 +1,9 @@
package sarama package sarama
import "time" import (
"fmt"
"time"
)
type ProduceResponseBlock struct { type ProduceResponseBlock struct {
Err KError Err KError
@ -32,6 +35,23 @@ func (b *ProduceResponseBlock) decode(pd packetDecoder, version int16) (err erro
return nil return nil
} }
func (b *ProduceResponseBlock) encode(pe packetEncoder, version int16) (err error) {
pe.putInt16(int16(b.Err))
pe.putInt64(b.Offset)
if version >= 2 {
timestamp := int64(-1)
if !b.Timestamp.Before(time.Unix(0, 0)) {
timestamp = b.Timestamp.UnixNano() / int64(time.Millisecond)
} else if !b.Timestamp.IsZero() {
return PacketEncodingError{fmt.Sprintf("invalid timestamp (%v)", b.Timestamp)}
}
pe.putInt64(timestamp)
}
return nil
}
type ProduceResponse struct { type ProduceResponse struct {
Blocks map[string]map[int32]*ProduceResponseBlock Blocks map[string]map[int32]*ProduceResponseBlock
Version int16 Version int16
@ -103,8 +123,10 @@ func (r *ProduceResponse) encode(pe packetEncoder) error {
} }
for id, prb := range partitions { for id, prb := range partitions {
pe.putInt32(id) pe.putInt32(id)
pe.putInt16(int16(prb.Err)) err = prb.encode(pe, r.Version)
pe.putInt64(prb.Offset) if err != nil {
return err
}
} }
} }
if r.Version >= 1 { if r.Version >= 1 {
@ -127,6 +149,8 @@ func (r *ProduceResponse) requiredVersion() KafkaVersion {
return V0_9_0_0 return V0_9_0_0
case 2: case 2:
return V0_10_0_0 return V0_10_0_0
case 3:
return V0_11_0_0
default: default:
return minVersion return minVersion
} }

View file

@ -1,67 +1,128 @@
package sarama package sarama
import "testing" import (
"fmt"
var ( "testing"
produceResponseNoBlocks = []byte{ "time"
0x00, 0x00, 0x00, 0x00}
produceResponseManyBlocks = []byte{
0x00, 0x00, 0x00, 0x02,
0x00, 0x03, 'f', 'o', 'o',
0x00, 0x00, 0x00, 0x00,
0x00, 0x03, 'b', 'a', 'r',
0x00, 0x00, 0x00, 0x02,
0x00, 0x00, 0x00, 0x01,
0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF,
0x00, 0x00, 0x00, 0x02,
0x00, 0x02,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}
) )
func TestProduceResponse(t *testing.T) { var (
produceResponseNoBlocksV0 = []byte{
0x00, 0x00, 0x00, 0x00}
produceResponseManyBlocksVersions = [][]byte{
{
0x00, 0x00, 0x00, 0x01,
0x00, 0x03, 'f', 'o', 'o',
0x00, 0x00, 0x00, 0x01,
0x00, 0x00, 0x00, 0x01, // Partition 1
0x00, 0x02, // ErrInvalidMessage
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, // Offset 255
}, {
0x00, 0x00, 0x00, 0x01,
0x00, 0x03, 'f', 'o', 'o',
0x00, 0x00, 0x00, 0x01,
0x00, 0x00, 0x00, 0x01, // Partition 1
0x00, 0x02, // ErrInvalidMessage
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, // Offset 255
0x00, 0x00, 0x00, 0x64, // 100 ms throttle time
}, {
0x00, 0x00, 0x00, 0x01,
0x00, 0x03, 'f', 'o', 'o',
0x00, 0x00, 0x00, 0x01,
0x00, 0x00, 0x00, 0x01, // Partition 1
0x00, 0x02, // ErrInvalidMessage
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, // Offset 255
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xE8, // Timestamp January 1st 0001 at 00:00:01,000 UTC (LogAppendTime was used)
0x00, 0x00, 0x00, 0x64, // 100 ms throttle time
},
}
)
func TestProduceResponseDecode(t *testing.T) {
response := ProduceResponse{} response := ProduceResponse{}
testVersionDecodable(t, "no blocks", &response, produceResponseNoBlocks, 0) testVersionDecodable(t, "no blocks", &response, produceResponseNoBlocksV0, 0)
if len(response.Blocks) != 0 { if len(response.Blocks) != 0 {
t.Error("Decoding produced", len(response.Blocks), "topics where there were none") t.Error("Decoding produced", len(response.Blocks), "topics where there were none")
} }
testVersionDecodable(t, "many blocks", &response, produceResponseManyBlocks, 0) for v, produceResponseManyBlocks := range produceResponseManyBlocksVersions {
if len(response.Blocks) != 2 { t.Logf("Decoding produceResponseManyBlocks version %d", v)
t.Error("Decoding produced", len(response.Blocks), "topics where there were 2") testVersionDecodable(t, "many blocks", &response, produceResponseManyBlocks, int16(v))
if len(response.Blocks) != 1 {
t.Error("Decoding produced", len(response.Blocks), "topics where there was 1")
} }
if len(response.Blocks["foo"]) != 0 { if len(response.Blocks["foo"]) != 1 {
t.Error("Decoding produced", len(response.Blocks["foo"]), "partitions for 'foo' where there were none") t.Error("Decoding produced", len(response.Blocks["foo"]), "partitions for 'foo' where there was one")
} }
if len(response.Blocks["bar"]) != 2 { block := response.GetBlock("foo", 1)
t.Error("Decoding produced", len(response.Blocks["bar"]), "partitions for 'bar' where there were two")
}
block := response.GetBlock("bar", 1)
if block == nil { if block == nil {
t.Error("Decoding did not produce a block for bar/1") t.Error("Decoding did not produce a block for foo/1")
} else {
if block.Err != ErrNoError {
t.Error("Decoding failed for bar/1/Err, got:", int16(block.Err))
}
if block.Offset != 0xFF {
t.Error("Decoding failed for bar/1/Offset, got:", block.Offset)
}
}
block = response.GetBlock("bar", 2)
if block == nil {
t.Error("Decoding did not produce a block for bar/2")
} else { } else {
if block.Err != ErrInvalidMessage { if block.Err != ErrInvalidMessage {
t.Error("Decoding failed for bar/2/Err, got:", int16(block.Err)) t.Error("Decoding failed for foo/2/Err, got:", int16(block.Err))
}
if block.Offset != 255 {
t.Error("Decoding failed for foo/1/Offset, got:", block.Offset)
}
if v >= 2 {
if block.Timestamp != time.Unix(1, 0) {
t.Error("Decoding failed for foo/2/Timestamp, got:", block.Timestamp)
}
}
}
if v >= 1 {
if expected := 100 * time.Millisecond; response.ThrottleTime != expected {
t.Error("Failed decoding produced throttle time, expected:", expected, ", got:", response.ThrottleTime)
} }
if block.Offset != 0 {
t.Error("Decoding failed for bar/2/Offset, got:", block.Offset)
} }
} }
} }
func TestProduceResponseEncode(t *testing.T) {
response := ProduceResponse{}
response.Blocks = make(map[string]map[int32]*ProduceResponseBlock)
testEncodable(t, "empty", &response, produceResponseNoBlocksV0)
response.Blocks["foo"] = make(map[int32]*ProduceResponseBlock)
response.Blocks["foo"][1] = &ProduceResponseBlock{
Err: ErrInvalidMessage,
Offset: 255,
Timestamp: time.Unix(1, 0),
}
response.ThrottleTime = 100 * time.Millisecond
for v, produceResponseManyBlocks := range produceResponseManyBlocksVersions {
response.Version = int16(v)
testEncodable(t, fmt.Sprintf("many blocks version %d", v), &response, produceResponseManyBlocks)
}
}
func TestProduceResponseEncodeInvalidTimestamp(t *testing.T) {
response := ProduceResponse{}
response.Version = 2
response.Blocks = make(map[string]map[int32]*ProduceResponseBlock)
response.Blocks["t"] = make(map[int32]*ProduceResponseBlock)
response.Blocks["t"][0] = &ProduceResponseBlock{
Err: ErrNoError,
Offset: 0,
// Use a timestamp before Unix time
Timestamp: time.Unix(0, 0).Add(-1 * time.Millisecond),
}
response.ThrottleTime = 100 * time.Millisecond
_, err := encode(&response, nil)
if err == nil {
t.Error("Expecting error, got nil")
}
if _, ok := err.(PacketEncodingError); !ok {
t.Error("Expecting PacketEncodingError, got:", err)
}
}

View file

@ -1,10 +1,13 @@
package sarama package sarama
import "time" import (
"encoding/binary"
"time"
)
type partitionSet struct { type partitionSet struct {
msgs []*ProducerMessage msgs []*ProducerMessage
setToSend *MessageSet recordsToSend Records
bufferBytes int bufferBytes int
} }
@ -39,31 +42,64 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
} }
} }
timestamp := msg.Timestamp
if msg.Timestamp.IsZero() {
timestamp = time.Now()
}
partitions := ps.msgs[msg.Topic] partitions := ps.msgs[msg.Topic]
if partitions == nil { if partitions == nil {
partitions = make(map[int32]*partitionSet) partitions = make(map[int32]*partitionSet)
ps.msgs[msg.Topic] = partitions ps.msgs[msg.Topic] = partitions
} }
var size int
set := partitions[msg.Partition] set := partitions[msg.Partition]
if set == nil { if set == nil {
set = &partitionSet{setToSend: new(MessageSet)} if ps.parent.conf.Version.IsAtLeast(V0_11_0_0) {
batch := &RecordBatch{
FirstTimestamp: timestamp,
Version: 2,
ProducerID: -1, /* No producer id */
Codec: ps.parent.conf.Producer.Compression,
}
set = &partitionSet{recordsToSend: newDefaultRecords(batch)}
size = recordBatchOverhead
} else {
set = &partitionSet{recordsToSend: newLegacyRecords(new(MessageSet))}
}
partitions[msg.Partition] = set partitions[msg.Partition] = set
} }
set.msgs = append(set.msgs, msg) set.msgs = append(set.msgs, msg)
if ps.parent.conf.Version.IsAtLeast(V0_11_0_0) {
// We are being conservative here to avoid having to prep encode the record
size += maximumRecordOverhead
rec := &Record{
Key: key,
Value: val,
TimestampDelta: timestamp.Sub(set.recordsToSend.recordBatch.FirstTimestamp),
}
size += len(key) + len(val)
if len(msg.Headers) > 0 {
rec.Headers = make([]*RecordHeader, len(msg.Headers))
for i, h := range msg.Headers {
rec.Headers[i] = &h
size += len(h.Key) + len(h.Value) + 2*binary.MaxVarintLen32
}
}
set.recordsToSend.recordBatch.addRecord(rec)
} else {
msgToSend := &Message{Codec: CompressionNone, Key: key, Value: val} msgToSend := &Message{Codec: CompressionNone, Key: key, Value: val}
if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) { if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) {
if msg.Timestamp.IsZero() { msgToSend.Timestamp = timestamp
msgToSend.Timestamp = time.Now()
} else {
msgToSend.Timestamp = msg.Timestamp
}
msgToSend.Version = 1 msgToSend.Version = 1
} }
set.setToSend.addMessage(msgToSend) set.recordsToSend.msgSet.addMessage(msgToSend)
size = producerMessageOverhead + len(key) + len(val)
}
size := producerMessageOverhead + len(key) + len(val)
set.bufferBytes += size set.bufferBytes += size
ps.bufferBytes += size ps.bufferBytes += size
ps.bufferCount++ ps.bufferCount++
@ -79,17 +115,24 @@ func (ps *produceSet) buildRequest() *ProduceRequest {
if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) { if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) {
req.Version = 2 req.Version = 2
} }
if ps.parent.conf.Version.IsAtLeast(V0_11_0_0) {
req.Version = 3
}
for topic, partitionSet := range ps.msgs { for topic, partitionSet := range ps.msgs {
for partition, set := range partitionSet { for partition, set := range partitionSet {
if req.Version >= 3 {
req.AddBatch(topic, partition, set.recordsToSend.recordBatch)
continue
}
if ps.parent.conf.Producer.Compression == CompressionNone { if ps.parent.conf.Producer.Compression == CompressionNone {
req.AddSet(topic, partition, set.setToSend) req.AddSet(topic, partition, set.recordsToSend.msgSet)
} else { } else {
// When compression is enabled, the entire set for each partition is compressed // When compression is enabled, the entire set for each partition is compressed
// and sent as the payload of a single fake "message" with the appropriate codec // and sent as the payload of a single fake "message" with the appropriate codec
// set and no key. When the server sees a message with a compression codec, it // set and no key. When the server sees a message with a compression codec, it
// decompresses the payload and treats the result as its message set. // decompresses the payload and treats the result as its message set.
payload, err := encode(set.setToSend, ps.parent.conf.MetricRegistry) payload, err := encode(set.recordsToSend.msgSet, ps.parent.conf.MetricRegistry)
if err != nil { if err != nil {
Logger.Println(err) // if this happens, it's basically our fault. Logger.Println(err) // if this happens, it's basically our fault.
panic(err) panic(err)
@ -98,11 +141,11 @@ func (ps *produceSet) buildRequest() *ProduceRequest {
Codec: ps.parent.conf.Producer.Compression, Codec: ps.parent.conf.Producer.Compression,
Key: nil, Key: nil,
Value: payload, Value: payload,
Set: set.setToSend, // Provide the underlying message set for accurate metrics Set: set.recordsToSend.msgSet, // Provide the underlying message set for accurate metrics
} }
if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) { if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) {
compMsg.Version = 1 compMsg.Version = 1
compMsg.Timestamp = set.setToSend.Messages[0].Msg.Timestamp compMsg.Timestamp = set.recordsToSend.msgSet.Messages[0].Msg.Timestamp
} }
req.AddMessage(topic, partition, compMsg) req.AddMessage(topic, partition, compMsg)
} }
@ -135,14 +178,19 @@ func (ps *produceSet) dropPartition(topic string, partition int32) []*ProducerMe
} }
func (ps *produceSet) wouldOverflow(msg *ProducerMessage) bool { func (ps *produceSet) wouldOverflow(msg *ProducerMessage) bool {
version := 1
if ps.parent.conf.Version.IsAtLeast(V0_11_0_0) {
version = 2
}
switch { switch {
// Would we overflow our maximum possible size-on-the-wire? 10KiB is arbitrary overhead for safety. // Would we overflow our maximum possible size-on-the-wire? 10KiB is arbitrary overhead for safety.
case ps.bufferBytes+msg.byteSize() >= int(MaxRequestSize-(10*1024)): case ps.bufferBytes+msg.byteSize(version) >= int(MaxRequestSize-(10*1024)):
return true return true
// Would we overflow the size-limit of a compressed message-batch for this partition? // Would we overflow the size-limit of a compressed message-batch for this partition?
case ps.parent.conf.Producer.Compression != CompressionNone && case ps.parent.conf.Producer.Compression != CompressionNone &&
ps.msgs[msg.Topic] != nil && ps.msgs[msg.Topic][msg.Partition] != nil && ps.msgs[msg.Topic] != nil && ps.msgs[msg.Topic][msg.Partition] != nil &&
ps.msgs[msg.Topic][msg.Partition].bufferBytes+msg.byteSize() >= ps.parent.conf.Producer.MaxMessageBytes: ps.msgs[msg.Topic][msg.Partition].bufferBytes+msg.byteSize(version) >= ps.parent.conf.Producer.MaxMessageBytes:
return true return true
// Would we overflow simply in number of messages? // Would we overflow simply in number of messages?
case ps.parent.conf.Producer.Flush.MaxMessages > 0 && ps.bufferCount >= ps.parent.conf.Producer.Flush.MaxMessages: case ps.parent.conf.Producer.Flush.MaxMessages > 0 && ps.bufferCount >= ps.parent.conf.Producer.Flush.MaxMessages:

View file

@ -137,7 +137,7 @@ func TestProduceSetRequestBuilding(t *testing.T) {
t.Error("Timeout not set properly") t.Error("Timeout not set properly")
} }
if len(req.msgSets) != 2 { if len(req.records) != 2 {
t.Error("Wrong number of topics in request") t.Error("Wrong number of topics in request")
} }
} }
@ -166,7 +166,7 @@ func TestProduceSetCompressedRequestBuilding(t *testing.T) {
t.Error("Wrong request version") t.Error("Wrong request version")
} }
for _, msgBlock := range req.msgSets["t1"][0].Messages { for _, msgBlock := range req.records["t1"][0].msgSet.Messages {
msg := msgBlock.Msg msg := msgBlock.Msg
err := msg.decodeSet() err := msg.decodeSet()
if err != nil { if err != nil {
@ -183,3 +183,40 @@ func TestProduceSetCompressedRequestBuilding(t *testing.T) {
} }
} }
} }
func TestProduceSetV3RequestBuilding(t *testing.T) {
parent, ps := makeProduceSet()
parent.conf.Producer.RequiredAcks = WaitForAll
parent.conf.Producer.Timeout = 10 * time.Second
parent.conf.Version = V0_11_0_0
now := time.Now()
msg := &ProducerMessage{
Topic: "t1",
Partition: 0,
Key: StringEncoder(TestMessage),
Value: StringEncoder(TestMessage),
Timestamp: now,
}
for i := 0; i < 10; i++ {
safeAddMessage(t, ps, msg)
msg.Timestamp = msg.Timestamp.Add(time.Second)
}
req := ps.buildRequest()
if req.Version != 3 {
t.Error("Wrong request version")
}
batch := req.records["t1"][0].recordBatch
if batch.FirstTimestamp != now {
t.Errorf("Wrong first timestamp: %v", batch.FirstTimestamp)
}
for i := 0; i < 10; i++ {
rec := batch.Records[i]
if rec.TimestampDelta != time.Duration(i)*time.Second {
t.Errorf("Wrong timestamp delta: %v", rec.TimestampDelta)
}
}
}

View file

@ -7,8 +7,10 @@ import (
var errInvalidArrayLength = PacketDecodingError{"invalid array length"} var errInvalidArrayLength = PacketDecodingError{"invalid array length"}
var errInvalidByteSliceLength = PacketDecodingError{"invalid byteslice length"} var errInvalidByteSliceLength = PacketDecodingError{"invalid byteslice length"}
var errInvalidByteSliceLengthType = PacketDecodingError{"invalid byteslice length type"}
var errInvalidStringLength = PacketDecodingError{"invalid string length"} var errInvalidStringLength = PacketDecodingError{"invalid string length"}
var errInvalidSubsetSize = PacketDecodingError{"invalid subset size"} var errInvalidSubsetSize = PacketDecodingError{"invalid subset size"}
var errVarintOverflow = PacketDecodingError{"varint overflow"}
type realDecoder struct { type realDecoder struct {
raw []byte raw []byte
@ -58,12 +60,26 @@ func (rd *realDecoder) getInt64() (int64, error) {
return tmp, nil return tmp, nil
} }
func (rd *realDecoder) getVarint() (int64, error) {
tmp, n := binary.Varint(rd.raw[rd.off:])
if n == 0 {
rd.off = len(rd.raw)
return -1, ErrInsufficientData
}
if n < 0 {
rd.off -= n
return -1, errVarintOverflow
}
rd.off += n
return tmp, nil
}
func (rd *realDecoder) getArrayLength() (int, error) { func (rd *realDecoder) getArrayLength() (int, error) {
if rd.remaining() < 4 { if rd.remaining() < 4 {
rd.off = len(rd.raw) rd.off = len(rd.raw)
return -1, ErrInsufficientData return -1, ErrInsufficientData
} }
tmp := int(binary.BigEndian.Uint32(rd.raw[rd.off:])) tmp := int(int32(binary.BigEndian.Uint32(rd.raw[rd.off:])))
rd.off += 4 rd.off += 4
if tmp > rd.remaining() { if tmp > rd.remaining() {
rd.off = len(rd.raw) rd.off = len(rd.raw)
@ -78,28 +94,26 @@ func (rd *realDecoder) getArrayLength() (int, error) {
func (rd *realDecoder) getBytes() ([]byte, error) { func (rd *realDecoder) getBytes() ([]byte, error) {
tmp, err := rd.getInt32() tmp, err := rd.getInt32()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if tmp == -1 {
n := int(tmp)
switch {
case n < -1:
return nil, errInvalidByteSliceLength
case n == -1:
return nil, nil return nil, nil
case n == 0:
return make([]byte, 0), nil
case n > rd.remaining():
rd.off = len(rd.raw)
return nil, ErrInsufficientData
} }
tmpStr := rd.raw[rd.off : rd.off+n] return rd.getRawBytes(int(tmp))
rd.off += n }
return tmpStr, nil
func (rd *realDecoder) getVarintBytes() ([]byte, error) {
tmp, err := rd.getVarint()
if err != nil {
return nil, err
}
if tmp == -1 {
return nil, nil
}
return rd.getRawBytes(int(tmp))
} }
func (rd *realDecoder) getString() (string, error) { func (rd *realDecoder) getString() (string, error) {
@ -128,6 +142,15 @@ func (rd *realDecoder) getString() (string, error) {
return tmpStr, nil return tmpStr, nil
} }
func (rd *realDecoder) getNullableString() (*string, error) {
tmp, err := rd.getInt16()
if err != nil || tmp == -1 {
return nil, err
}
str, err := rd.getString()
return &str, err
}
func (rd *realDecoder) getInt32Array() ([]int32, error) { func (rd *realDecoder) getInt32Array() ([]int32, error) {
if rd.remaining() < 4 { if rd.remaining() < 4 {
rd.off = len(rd.raw) rd.off = len(rd.raw)
@ -221,8 +244,16 @@ func (rd *realDecoder) remaining() int {
} }
func (rd *realDecoder) getSubset(length int) (packetDecoder, error) { func (rd *realDecoder) getSubset(length int) (packetDecoder, error) {
buf, err := rd.getRawBytes(length)
if err != nil {
return nil, err
}
return &realDecoder{raw: buf}, nil
}
func (rd *realDecoder) getRawBytes(length int) ([]byte, error) {
if length < 0 { if length < 0 {
return nil, errInvalidSubsetSize return nil, errInvalidByteSliceLength
} else if length > rd.remaining() { } else if length > rd.remaining() {
rd.off = len(rd.raw) rd.off = len(rd.raw)
return nil, ErrInsufficientData return nil, ErrInsufficientData
@ -230,7 +261,7 @@ func (rd *realDecoder) getSubset(length int) (packetDecoder, error) {
start := rd.off start := rd.off
rd.off += length rd.off += length
return &realDecoder{raw: rd.raw[start:rd.off]}, nil return rd.raw[start:rd.off], nil
} }
// stacks // stacks
@ -238,11 +269,18 @@ func (rd *realDecoder) getSubset(length int) (packetDecoder, error) {
func (rd *realDecoder) push(in pushDecoder) error { func (rd *realDecoder) push(in pushDecoder) error {
in.saveOffset(rd.off) in.saveOffset(rd.off)
reserve := in.reserveLength() var reserve int
if dpd, ok := in.(dynamicPushDecoder); ok {
if err := dpd.decode(rd); err != nil {
return err
}
} else {
reserve = in.reserveLength()
if rd.remaining() < reserve { if rd.remaining() < reserve {
rd.off = len(rd.raw) rd.off = len(rd.raw)
return ErrInsufficientData return ErrInsufficientData
} }
}
rd.stack = append(rd.stack, in) rd.stack = append(rd.stack, in)

View file

@ -35,6 +35,10 @@ func (re *realEncoder) putInt64(in int64) {
re.off += 8 re.off += 8
} }
func (re *realEncoder) putVarint(in int64) {
re.off += binary.PutVarint(re.raw[re.off:], in)
}
func (re *realEncoder) putArrayLength(in int) error { func (re *realEncoder) putArrayLength(in int) error {
re.putInt32(int32(in)) re.putInt32(int32(in))
return nil return nil
@ -54,9 +58,16 @@ func (re *realEncoder) putBytes(in []byte) error {
return nil return nil
} }
re.putInt32(int32(len(in))) re.putInt32(int32(len(in)))
copy(re.raw[re.off:], in) return re.putRawBytes(in)
re.off += len(in) }
func (re *realEncoder) putVarintBytes(in []byte) error {
if in == nil {
re.putVarint(-1)
return nil return nil
}
re.putVarint(int64(len(in)))
return re.putRawBytes(in)
} }
func (re *realEncoder) putString(in string) error { func (re *realEncoder) putString(in string) error {
@ -66,6 +77,14 @@ func (re *realEncoder) putString(in string) error {
return nil return nil
} }
func (re *realEncoder) putNullableString(in *string) error {
if in == nil {
re.putInt16(-1)
return nil
}
return re.putString(*in)
}
func (re *realEncoder) putStringArray(in []string) error { func (re *realEncoder) putStringArray(in []string) error {
err := re.putArrayLength(len(in)) err := re.putArrayLength(len(in))
if err != nil { if err != nil {

113
vendor/github.com/Shopify/sarama/record.go generated vendored Normal file
View file

@ -0,0 +1,113 @@
package sarama
import (
"encoding/binary"
"time"
)
const (
controlMask = 0x20
maximumRecordOverhead = 5*binary.MaxVarintLen32 + binary.MaxVarintLen64 + 1
)
type RecordHeader struct {
Key []byte
Value []byte
}
func (h *RecordHeader) encode(pe packetEncoder) error {
if err := pe.putVarintBytes(h.Key); err != nil {
return err
}
return pe.putVarintBytes(h.Value)
}
func (h *RecordHeader) decode(pd packetDecoder) (err error) {
if h.Key, err = pd.getVarintBytes(); err != nil {
return err
}
if h.Value, err = pd.getVarintBytes(); err != nil {
return err
}
return nil
}
type Record struct {
Attributes int8
TimestampDelta time.Duration
OffsetDelta int64
Key []byte
Value []byte
Headers []*RecordHeader
length varintLengthField
}
func (r *Record) encode(pe packetEncoder) error {
pe.push(&r.length)
pe.putInt8(r.Attributes)
pe.putVarint(int64(r.TimestampDelta / time.Millisecond))
pe.putVarint(r.OffsetDelta)
if err := pe.putVarintBytes(r.Key); err != nil {
return err
}
if err := pe.putVarintBytes(r.Value); err != nil {
return err
}
pe.putVarint(int64(len(r.Headers)))
for _, h := range r.Headers {
if err := h.encode(pe); err != nil {
return err
}
}
return pe.pop()
}
func (r *Record) decode(pd packetDecoder) (err error) {
if err = pd.push(&r.length); err != nil {
return err
}
if r.Attributes, err = pd.getInt8(); err != nil {
return err
}
timestamp, err := pd.getVarint()
if err != nil {
return err
}
r.TimestampDelta = time.Duration(timestamp) * time.Millisecond
if r.OffsetDelta, err = pd.getVarint(); err != nil {
return err
}
if r.Key, err = pd.getVarintBytes(); err != nil {
return err
}
if r.Value, err = pd.getVarintBytes(); err != nil {
return err
}
numHeaders, err := pd.getVarint()
if err != nil {
return err
}
if numHeaders >= 0 {
r.Headers = make([]*RecordHeader, numHeaders)
}
for i := int64(0); i < numHeaders; i++ {
hdr := new(RecordHeader)
if err := hdr.decode(pd); err != nil {
return err
}
r.Headers[i] = hdr
}
return pd.pop()
}

260
vendor/github.com/Shopify/sarama/record_batch.go generated vendored Normal file
View file

@ -0,0 +1,260 @@
package sarama
import (
"bytes"
"compress/gzip"
"fmt"
"io/ioutil"
"time"
"github.com/eapache/go-xerial-snappy"
"github.com/pierrec/lz4"
)
const recordBatchOverhead = 49
type recordsArray []*Record
func (e recordsArray) encode(pe packetEncoder) error {
for _, r := range e {
if err := r.encode(pe); err != nil {
return err
}
}
return nil
}
func (e recordsArray) decode(pd packetDecoder) error {
for i := range e {
rec := &Record{}
if err := rec.decode(pd); err != nil {
return err
}
e[i] = rec
}
return nil
}
type RecordBatch struct {
FirstOffset int64
PartitionLeaderEpoch int32
Version int8
Codec CompressionCodec
Control bool
LastOffsetDelta int32
FirstTimestamp time.Time
MaxTimestamp time.Time
ProducerID int64
ProducerEpoch int16
FirstSequence int32
Records []*Record
PartialTrailingRecord bool
compressedRecords []byte
recordsLen int // uncompressed records size
}
func (b *RecordBatch) encode(pe packetEncoder) error {
if b.Version != 2 {
return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
}
pe.putInt64(b.FirstOffset)
pe.push(&lengthField{})
pe.putInt32(b.PartitionLeaderEpoch)
pe.putInt8(b.Version)
pe.push(newCRC32Field(crcCastagnoli))
pe.putInt16(b.computeAttributes())
pe.putInt32(b.LastOffsetDelta)
if err := (Timestamp{&b.FirstTimestamp}).encode(pe); err != nil {
return err
}
if err := (Timestamp{&b.MaxTimestamp}).encode(pe); err != nil {
return err
}
pe.putInt64(b.ProducerID)
pe.putInt16(b.ProducerEpoch)
pe.putInt32(b.FirstSequence)
if err := pe.putArrayLength(len(b.Records)); err != nil {
return err
}
if b.compressedRecords == nil {
if err := b.encodeRecords(pe); err != nil {
return err
}
}
if err := pe.putRawBytes(b.compressedRecords); err != nil {
return err
}
if err := pe.pop(); err != nil {
return err
}
return pe.pop()
}
func (b *RecordBatch) decode(pd packetDecoder) (err error) {
if b.FirstOffset, err = pd.getInt64(); err != nil {
return err
}
batchLen, err := pd.getInt32()
if err != nil {
return err
}
if b.PartitionLeaderEpoch, err = pd.getInt32(); err != nil {
return err
}
if b.Version, err = pd.getInt8(); err != nil {
return err
}
if err = pd.push(&crc32Field{polynomial: crcCastagnoli}); err != nil {
return err
}
attributes, err := pd.getInt16()
if err != nil {
return err
}
b.Codec = CompressionCodec(int8(attributes) & compressionCodecMask)
b.Control = attributes&controlMask == controlMask
if b.LastOffsetDelta, err = pd.getInt32(); err != nil {
return err
}
if err = (Timestamp{&b.FirstTimestamp}).decode(pd); err != nil {
return err
}
if err = (Timestamp{&b.MaxTimestamp}).decode(pd); err != nil {
return err
}
if b.ProducerID, err = pd.getInt64(); err != nil {
return err
}
if b.ProducerEpoch, err = pd.getInt16(); err != nil {
return err
}
if b.FirstSequence, err = pd.getInt32(); err != nil {
return err
}
numRecs, err := pd.getArrayLength()
if err != nil {
return err
}
if numRecs >= 0 {
b.Records = make([]*Record, numRecs)
}
bufSize := int(batchLen) - recordBatchOverhead
recBuffer, err := pd.getRawBytes(bufSize)
if err != nil {
return err
}
if err = pd.pop(); err != nil {
return err
}
switch b.Codec {
case CompressionNone:
case CompressionGZIP:
reader, err := gzip.NewReader(bytes.NewReader(recBuffer))
if err != nil {
return err
}
if recBuffer, err = ioutil.ReadAll(reader); err != nil {
return err
}
case CompressionSnappy:
if recBuffer, err = snappy.Decode(recBuffer); err != nil {
return err
}
case CompressionLZ4:
reader := lz4.NewReader(bytes.NewReader(recBuffer))
if recBuffer, err = ioutil.ReadAll(reader); err != nil {
return err
}
default:
return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)}
}
b.recordsLen = len(recBuffer)
err = decode(recBuffer, recordsArray(b.Records))
if err == ErrInsufficientData {
b.PartialTrailingRecord = true
b.Records = nil
return nil
}
return err
}
func (b *RecordBatch) encodeRecords(pe packetEncoder) error {
var raw []byte
if b.Codec != CompressionNone {
var err error
if raw, err = encode(recordsArray(b.Records), nil); err != nil {
return err
}
b.recordsLen = len(raw)
}
switch b.Codec {
case CompressionNone:
offset := pe.offset()
if err := recordsArray(b.Records).encode(pe); err != nil {
return err
}
b.recordsLen = pe.offset() - offset
case CompressionGZIP:
var buf bytes.Buffer
writer := gzip.NewWriter(&buf)
if _, err := writer.Write(raw); err != nil {
return err
}
if err := writer.Close(); err != nil {
return err
}
b.compressedRecords = buf.Bytes()
case CompressionSnappy:
b.compressedRecords = snappy.Encode(raw)
case CompressionLZ4:
var buf bytes.Buffer
writer := lz4.NewWriter(&buf)
if _, err := writer.Write(raw); err != nil {
return err
}
if err := writer.Close(); err != nil {
return err
}
b.compressedRecords = buf.Bytes()
default:
return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
}
return nil
}
func (b *RecordBatch) computeAttributes() int16 {
attr := int16(b.Codec) & int16(compressionCodecMask)
if b.Control {
attr |= controlMask
}
return attr
}
func (b *RecordBatch) addRecord(r *Record) {
b.Records = append(b.Records, r)
}

284
vendor/github.com/Shopify/sarama/record_test.go generated vendored Normal file
View file

@ -0,0 +1,284 @@
package sarama
import (
"reflect"
"runtime"
"strconv"
"strings"
"testing"
"time"
"github.com/davecgh/go-spew/spew"
)
var recordBatchTestCases = []struct {
name string
batch RecordBatch
encoded []byte
oldGoEncoded []byte // used in case of gzipped content for go versions prior to 1.8
}{
{
name: "empty record",
batch: RecordBatch{
Version: 2,
FirstTimestamp: time.Unix(0, 0),
MaxTimestamp: time.Unix(0, 0),
Records: []*Record{},
},
encoded: []byte{
0, 0, 0, 0, 0, 0, 0, 0, // First Offset
0, 0, 0, 49, // Length
0, 0, 0, 0, // Partition Leader Epoch
2, // Version
89, 95, 183, 221, // CRC
0, 0, // Attributes
0, 0, 0, 0, // Last Offset Delta
0, 0, 0, 0, 0, 0, 0, 0, // First Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
0, 0, // Producer Epoch
0, 0, 0, 0, // First Sequence
0, 0, 0, 0, // Number of Records
},
},
{
name: "control batch",
batch: RecordBatch{
Version: 2,
Control: true,
FirstTimestamp: time.Unix(0, 0),
MaxTimestamp: time.Unix(0, 0),
Records: []*Record{},
},
encoded: []byte{
0, 0, 0, 0, 0, 0, 0, 0, // First Offset
0, 0, 0, 49, // Length
0, 0, 0, 0, // Partition Leader Epoch
2, // Version
81, 46, 67, 217, // CRC
0, 32, // Attributes
0, 0, 0, 0, // Last Offset Delta
0, 0, 0, 0, 0, 0, 0, 0, // First Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
0, 0, // Producer Epoch
0, 0, 0, 0, // First Sequence
0, 0, 0, 0, // Number of Records
},
},
{
name: "uncompressed record",
batch: RecordBatch{
Version: 2,
FirstTimestamp: time.Unix(1479847795, 0),
MaxTimestamp: time.Unix(0, 0),
Records: []*Record{{
TimestampDelta: 5 * time.Millisecond,
Key: []byte{1, 2, 3, 4},
Value: []byte{5, 6, 7},
Headers: []*RecordHeader{{
Key: []byte{8, 9, 10},
Value: []byte{11, 12},
}},
}},
recordsLen: 21,
},
encoded: []byte{
0, 0, 0, 0, 0, 0, 0, 0, // First Offset
0, 0, 0, 70, // Length
0, 0, 0, 0, // Partition Leader Epoch
2, // Version
84, 121, 97, 253, // CRC
0, 0, // Attributes
0, 0, 0, 0, // Last Offset Delta
0, 0, 1, 88, 141, 205, 89, 56, // First Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
0, 0, // Producer Epoch
0, 0, 0, 0, // First Sequence
0, 0, 0, 1, // Number of Records
40, // Record Length
0, // Attributes
10, // Timestamp Delta
0, // Offset Delta
8, // Key Length
1, 2, 3, 4,
6, // Value Length
5, 6, 7,
2, // Number of Headers
6, // Header Key Length
8, 9, 10, // Header Key
4, // Header Value Length
11, 12, // Header Value
},
},
{
name: "gzipped record",
batch: RecordBatch{
Version: 2,
Codec: CompressionGZIP,
FirstTimestamp: time.Unix(1479847795, 0),
MaxTimestamp: time.Unix(0, 0),
Records: []*Record{{
TimestampDelta: 5 * time.Millisecond,
Key: []byte{1, 2, 3, 4},
Value: []byte{5, 6, 7},
Headers: []*RecordHeader{{
Key: []byte{8, 9, 10},
Value: []byte{11, 12},
}},
}},
recordsLen: 21,
},
encoded: []byte{
0, 0, 0, 0, 0, 0, 0, 0, // First Offset
0, 0, 0, 94, // Length
0, 0, 0, 0, // Partition Leader Epoch
2, // Version
159, 236, 182, 189, // CRC
0, 1, // Attributes
0, 0, 0, 0, // Last Offset Delta
0, 0, 1, 88, 141, 205, 89, 56, // First Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
0, 0, // Producer Epoch
0, 0, 0, 0, // First Sequence
0, 0, 0, 1, // Number of Records
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 210, 96, 224, 98, 224, 96, 100, 98, 102, 97, 99, 101,
99, 103, 98, 227, 224, 228, 98, 225, 230, 1, 4, 0, 0, 255, 255, 173, 201, 88, 103, 21, 0, 0, 0,
},
oldGoEncoded: []byte{
0, 0, 0, 0, 0, 0, 0, 0, // First Offset
0, 0, 0, 94, // Length
0, 0, 0, 0, // Partition Leader Epoch
2, // Version
0, 216, 14, 210, // CRC
0, 1, // Attributes
0, 0, 0, 0, // Last Offset Delta
0, 0, 1, 88, 141, 205, 89, 56, // First Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
0, 0, // Producer Epoch
0, 0, 0, 0, // First Sequence
0, 0, 0, 1, // Number of Records
31, 139, 8, 0, 0, 9, 110, 136, 0, 255, 210, 96, 224, 98, 224, 96, 100, 98, 102, 97, 99, 101,
99, 103, 98, 227, 224, 228, 98, 225, 230, 1, 4, 0, 0, 255, 255, 173, 201, 88, 103, 21, 0, 0, 0,
},
},
{
name: "snappy compressed record",
batch: RecordBatch{
Version: 2,
Codec: CompressionSnappy,
FirstTimestamp: time.Unix(1479847795, 0),
MaxTimestamp: time.Unix(0, 0),
Records: []*Record{{
TimestampDelta: 5 * time.Millisecond,
Key: []byte{1, 2, 3, 4},
Value: []byte{5, 6, 7},
Headers: []*RecordHeader{{
Key: []byte{8, 9, 10},
Value: []byte{11, 12},
}},
}},
recordsLen: 21,
},
encoded: []byte{
0, 0, 0, 0, 0, 0, 0, 0, // First Offset
0, 0, 0, 72, // Length
0, 0, 0, 0, // Partition Leader Epoch
2, // Version
21, 0, 159, 97, // CRC
0, 2, // Attributes
0, 0, 0, 0, // Last Offset Delta
0, 0, 1, 88, 141, 205, 89, 56, // First Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
0, 0, // Producer Epoch
0, 0, 0, 0, // First Sequence
0, 0, 0, 1, // Number of Records
21, 80, 40, 0, 10, 0, 8, 1, 2, 3, 4, 6, 5, 6, 7, 2, 6, 8, 9, 10, 4, 11, 12,
},
},
{
name: "lz4 compressed record",
batch: RecordBatch{
Version: 2,
Codec: CompressionLZ4,
FirstTimestamp: time.Unix(1479847795, 0),
MaxTimestamp: time.Unix(0, 0),
Records: []*Record{{
TimestampDelta: 5 * time.Millisecond,
Key: []byte{1, 2, 3, 4},
Value: []byte{5, 6, 7},
Headers: []*RecordHeader{{
Key: []byte{8, 9, 10},
Value: []byte{11, 12},
}},
}},
recordsLen: 21,
},
encoded: []byte{
0, 0, 0, 0, 0, 0, 0, 0, // First Offset
0, 0, 0, 89, // Length
0, 0, 0, 0, // Partition Leader Epoch
2, // Version
169, 74, 119, 197, // CRC
0, 3, // Attributes
0, 0, 0, 0, // Last Offset Delta
0, 0, 1, 88, 141, 205, 89, 56, // First Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
0, 0, // Producer Epoch
0, 0, 0, 0, // First Sequence
0, 0, 0, 1, // Number of Records
4, 34, 77, 24, 100, 112, 185, 21, 0, 0, 128, 40, 0, 10, 0, 8, 1, 2, 3, 4, 6, 5, 6, 7, 2,
6, 8, 9, 10, 4, 11, 12, 0, 0, 0, 0, 12, 59, 239, 146,
},
},
}
func isOldGo(t *testing.T) bool {
v := strings.Split(runtime.Version()[2:], ".")
if len(v) < 2 {
t.Logf("Can't parse version: %s", runtime.Version())
return false
}
maj, err := strconv.Atoi(v[0])
if err != nil {
t.Logf("Can't parse version: %s", runtime.Version())
return false
}
min, err := strconv.Atoi(v[1])
if err != nil {
t.Logf("Can't parse version: %s", runtime.Version())
return false
}
return maj < 1 || (maj == 1 && min < 8)
}
func TestRecordBatchEncoding(t *testing.T) {
for _, tc := range recordBatchTestCases {
if tc.oldGoEncoded != nil && isOldGo(t) {
testEncodable(t, tc.name, &tc.batch, tc.oldGoEncoded)
} else {
testEncodable(t, tc.name, &tc.batch, tc.encoded)
}
}
}
func TestRecordBatchDecoding(t *testing.T) {
for _, tc := range recordBatchTestCases {
batch := RecordBatch{}
testDecodable(t, tc.name, &batch, tc.encoded)
for _, r := range batch.Records {
r.length = varintLengthField{}
}
for _, r := range tc.batch.Records {
r.length = varintLengthField{}
}
if !reflect.DeepEqual(batch, tc.batch) {
t.Errorf(spew.Sprintf("invalid decode of %s\ngot %+v\nwanted %+v", tc.name, batch, tc.batch))
}
}
}

96
vendor/github.com/Shopify/sarama/records.go generated vendored Normal file
View file

@ -0,0 +1,96 @@
package sarama
import "fmt"
const (
legacyRecords = iota
defaultRecords
)
// Records implements a union type containing either a RecordBatch or a legacy MessageSet.
type Records struct {
recordsType int
msgSet *MessageSet
recordBatch *RecordBatch
}
func newLegacyRecords(msgSet *MessageSet) Records {
return Records{recordsType: legacyRecords, msgSet: msgSet}
}
func newDefaultRecords(batch *RecordBatch) Records {
return Records{recordsType: defaultRecords, recordBatch: batch}
}
func (r *Records) encode(pe packetEncoder) error {
switch r.recordsType {
case legacyRecords:
if r.msgSet == nil {
return nil
}
return r.msgSet.encode(pe)
case defaultRecords:
if r.recordBatch == nil {
return nil
}
return r.recordBatch.encode(pe)
}
return fmt.Errorf("unknown records type: %v", r.recordsType)
}
func (r *Records) decode(pd packetDecoder) error {
switch r.recordsType {
case legacyRecords:
r.msgSet = &MessageSet{}
return r.msgSet.decode(pd)
case defaultRecords:
r.recordBatch = &RecordBatch{}
return r.recordBatch.decode(pd)
}
return fmt.Errorf("unknown records type: %v", r.recordsType)
}
func (r *Records) numRecords() (int, error) {
switch r.recordsType {
case legacyRecords:
if r.msgSet == nil {
return 0, nil
}
return len(r.msgSet.Messages), nil
case defaultRecords:
if r.recordBatch == nil {
return 0, nil
}
return len(r.recordBatch.Records), nil
}
return 0, fmt.Errorf("unknown records type: %v", r.recordsType)
}
func (r *Records) isPartial() (bool, error) {
switch r.recordsType {
case legacyRecords:
if r.msgSet == nil {
return false, nil
}
return r.msgSet.PartialTrailingMessage, nil
case defaultRecords:
if r.recordBatch == nil {
return false, nil
}
return r.recordBatch.PartialTrailingRecord, nil
}
return false, fmt.Errorf("unknown records type: %v", r.recordsType)
}
func (r *Records) isControl() (bool, error) {
switch r.recordsType {
case legacyRecords:
return false, nil
case defaultRecords:
if r.recordBatch == nil {
return false, nil
}
return r.recordBatch.Control, nil
}
return false, fmt.Errorf("unknown records type: %v", r.recordsType)
}

137
vendor/github.com/Shopify/sarama/records_test.go generated vendored Normal file
View file

@ -0,0 +1,137 @@
package sarama
import (
"bytes"
"reflect"
"testing"
)
func TestLegacyRecords(t *testing.T) {
set := &MessageSet{
Messages: []*MessageBlock{
{
Msg: &Message{
Version: 1,
},
},
},
}
r := newLegacyRecords(set)
exp, err := encode(set, nil)
if err != nil {
t.Fatal(err)
}
buf, err := encode(&r, nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf, exp) {
t.Errorf("Wrong encoding for legacy records, wanted %v, got %v", exp, buf)
}
set = &MessageSet{}
r = newLegacyRecords(nil)
err = decode(exp, set)
if err != nil {
t.Fatal(err)
}
err = decode(buf, &r)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(set, r.msgSet) {
t.Errorf("Wrong decoding for legacy records, wanted %#+v, got %#+v", set, r.msgSet)
}
n, err := r.numRecords()
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Errorf("Wrong number of records, wanted 1, got %d", n)
}
p, err := r.isPartial()
if err != nil {
t.Fatal(err)
}
if p {
t.Errorf("MessageSet shouldn't have a partial trailing message")
}
c, err := r.isControl()
if err != nil {
t.Fatal(err)
}
if c {
t.Errorf("MessageSet can't be a control batch")
}
}
func TestDefaultRecords(t *testing.T) {
batch := &RecordBatch{
Version: 2,
Records: []*Record{
{
Value: []byte{1},
},
},
}
r := newDefaultRecords(batch)
exp, err := encode(batch, nil)
if err != nil {
t.Fatal(err)
}
buf, err := encode(&r, nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf, exp) {
t.Errorf("Wrong encoding for default records, wanted %v, got %v", exp, buf)
}
batch = &RecordBatch{}
r = newDefaultRecords(nil)
err = decode(exp, batch)
if err != nil {
t.Fatal(err)
}
err = decode(buf, &r)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(batch, r.recordBatch) {
t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch)
}
n, err := r.numRecords()
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Errorf("Wrong number of records, wanted 1, got %d", n)
}
p, err := r.isPartial()
if err != nil {
t.Fatal(err)
}
if p {
t.Errorf("RecordBatch shouldn't have a partial trailing record")
}
c, err := r.isControl()
if err != nil {
t.Fatal(err)
}
if c {
t.Errorf("RecordBatch shouldn't be a control batch")
}
}

40
vendor/github.com/Shopify/sarama/timestamp.go generated vendored Normal file
View file

@ -0,0 +1,40 @@
package sarama
import (
"fmt"
"time"
)
type Timestamp struct {
*time.Time
}
func (t Timestamp) encode(pe packetEncoder) error {
timestamp := int64(-1)
if !t.Before(time.Unix(0, 0)) {
timestamp = t.UnixNano() / int64(time.Millisecond)
} else if !t.IsZero() {
return PacketEncodingError{fmt.Sprintf("invalid timestamp (%v)", t)}
}
pe.putInt64(timestamp)
return nil
}
func (t Timestamp) decode(pd packetDecoder) error {
millis, err := pd.getInt64()
if err != nil {
return err
}
// negative timestamps are invalid, in these cases we should return
// a zero time
timestamp := time.Time{}
if millis >= 0 {
timestamp = time.Unix(millis/1000, (millis%1000)*int64(time.Millisecond))
}
*t.Time = timestamp
return nil
}

View file

@ -146,5 +146,6 @@ var (
V0_10_0_1 = newKafkaVersion(0, 10, 0, 1) V0_10_0_1 = newKafkaVersion(0, 10, 0, 1)
V0_10_1_0 = newKafkaVersion(0, 10, 1, 0) V0_10_1_0 = newKafkaVersion(0, 10, 1, 0)
V0_10_2_0 = newKafkaVersion(0, 10, 2, 0) V0_10_2_0 = newKafkaVersion(0, 10, 2, 0)
V0_11_0_0 = newKafkaVersion(0, 11, 0, 0)
minVersion = V0_8_2_0 minVersion = V0_8_2_0
) )

View file

@ -6,6 +6,14 @@ Thank you for your interest in contributing to the Apache Thrift project! Infor
* [Get involved!](http://www.apache.org/foundation/getinvolved.html) * [Get involved!](http://www.apache.org/foundation/getinvolved.html)
* [Legal aspects on Submission of Contributions (Patches)](http://www.apache.org/licenses/LICENSE-2.0.html#contributions) * [Legal aspects on Submission of Contributions (Patches)](http://www.apache.org/licenses/LICENSE-2.0.html#contributions)
## If you want to build the project locally ##
For Windows systems, see our detailed instructions on the [CMake README](/build/cmake/README.md).
For Windows Native C++ builds, see our detailed instructions on the [WinCPP README](/build/wincpp/README.md).
For unix systems, see our detailed instructions on the [Docker README](/build/docker/README.md).
## If you want to review open issues... ## ## If you want to review open issues... ##
1. Review the [GitHub Pull Request Backlog](https://github.com/apache/thrift/pulls). Code reviews are open to all. 1. Review the [GitHub Pull Request Backlog](https://github.com/apache/thrift/pulls). Code reviews are open to all.

View file

@ -1,7 +1,7 @@
Apache Thrift Apache Thrift
============= =============
Last Modified: 2014-03-16 Last Modified: 2017--10
License License
======= =======
@ -171,3 +171,8 @@ To run the cross-language test suite, please run:
This will run a set of tests that use different language clients and This will run a set of tests that use different language clients and
servers. servers.
Development
===========
To build the same way Travis CI builds the project you should use docker.
We have [comprehensive building instructions for docker](build/docker/README.md).

View file

@ -0,0 +1,26 @@
#### Support
If you do have a contribution to the package, feel free to create a Pull Request or an Issue.
#### What to contribute
If you don't know what to do, there are some features and functions that need to be done
- [ ] Refactor code
- [ ] Edit docs and [README](https://github.com/asaskevich/govalidator/README.md): spellcheck, grammar and typo check
- [ ] Create actual list of contributors and projects that currently using this package
- [ ] Resolve [issues and bugs](https://github.com/asaskevich/govalidator/issues)
- [ ] Update actual [list of functions](https://github.com/asaskevich/govalidator#list-of-functions)
- [ ] Update [list of validators](https://github.com/asaskevich/govalidator#validatestruct-2) that available for `ValidateStruct` and add new
- [ ] Implement new validators: `IsFQDN`, `IsIMEI`, `IsPostalCode`, `IsISIN`, `IsISRC` etc
- [ ] Implement [validation by maps](https://github.com/asaskevich/govalidator/issues/224)
- [ ] Implement fuzzing testing
- [ ] Implement some struct/map/array utilities
- [ ] Implement map/array validation
- [ ] Implement benchmarking
- [ ] Implement batch of examples
- [ ] Look at forks for new features and fixes
#### Advice
Feel free to create what you want, but keep in mind when you implement new features:
- Code must be clear and readable, names of variables/constants clearly describes what they are doing
- Public functions must be documented and described in source file and added to README.md to the list of available functions
- There are must be unit-tests for any new functions and improvements

View file

@ -156,6 +156,7 @@ func IsPort(str string) bool
func IsPositive(value float64) bool func IsPositive(value float64) bool
func IsPrintableASCII(str string) bool func IsPrintableASCII(str string) bool
func IsRFC3339(str string) bool func IsRFC3339(str string) bool
func IsRFC3339WithoutZone(str string) bool
func IsRGBcolor(str string) bool func IsRGBcolor(str string) bool
func IsRequestURI(rawurl string) bool func IsRequestURI(rawurl string) bool
func IsRequestURL(rawurl string) bool func IsRequestURL(rawurl string) bool
@ -317,6 +318,7 @@ Here is a list of available validators for struct fields (validator - used funct
"ssn": IsSSN, "ssn": IsSSN,
"semver": IsSemver, "semver": IsSemver,
"rfc3339": IsRFC3339, "rfc3339": IsRFC3339,
"rfc3339WithoutZone": IsRFC3339WithoutZone,
"ISO3166Alpha2": IsISO3166Alpha2, "ISO3166Alpha2": IsISO3166Alpha2,
"ISO3166Alpha3": IsISO3166Alpha3, "ISO3166Alpha3": IsISO3166Alpha3,
``` ```
@ -409,7 +411,31 @@ Documentation is available here: [godoc.org](https://godoc.org/github.com/asaske
Full information about code coverage is also available here: [govalidator on gocover.io](http://gocover.io/github.com/asaskevich/govalidator). Full information about code coverage is also available here: [govalidator on gocover.io](http://gocover.io/github.com/asaskevich/govalidator).
#### Support #### Support
If you do have a contribution for the package feel free to put up a Pull Request or open Issue. If you do have a contribution to the package, feel free to create a Pull Request or an Issue.
#### What to contribute
If you don't know what to do, there are some features and functions that need to be done
- [ ] Refactor code
- [ ] Edit docs and [README](https://github.com/asaskevich/govalidator/README.md): spellcheck, grammar and typo check
- [ ] Create actual list of contributors and projects that currently using this package
- [ ] Resolve [issues and bugs](https://github.com/asaskevich/govalidator/issues)
- [ ] Update actual [list of functions](https://github.com/asaskevich/govalidator#list-of-functions)
- [ ] Update [list of validators](https://github.com/asaskevich/govalidator#validatestruct-2) that available for `ValidateStruct` and add new
- [ ] Implement new validators: `IsFQDN`, `IsIMEI`, `IsPostalCode`, `IsISIN`, `IsISRC` etc
- [ ] Implement [validation by maps](https://github.com/asaskevich/govalidator/issues/224)
- [ ] Implement fuzzing testing
- [ ] Implement some struct/map/array utilities
- [ ] Implement map/array validation
- [ ] Implement benchmarking
- [ ] Implement batch of examples
- [ ] Look at forks for new features and fixes
#### Advice
Feel free to create what you want, but keep in mind when you implement new features:
- Code must be clear and readable, names of variables/constants clearly describes what they are doing
- Public functions must be documented and described in source file and added to README.md to the list of available functions
- There are must be unit-tests for any new functions and improvements
#### Special thanks to [contributors](https://github.com/asaskevich/govalidator/graphs/contributors) #### Special thanks to [contributors](https://github.com/asaskevich/govalidator/graphs/contributors)
* [Daniel Lohse](https://github.com/annismckenzie) * [Daniel Lohse](https://github.com/annismckenzie)

View file

@ -1,5 +1,7 @@
package govalidator package govalidator
import "strings"
// Errors is an array of multiple errors and conforms to the error interface. // Errors is an array of multiple errors and conforms to the error interface.
type Errors []error type Errors []error
@ -9,11 +11,11 @@ func (es Errors) Errors() []error {
} }
func (es Errors) Error() string { func (es Errors) Error() string {
var err string var errs []string
for _, e := range es { for _, e := range es {
err += e.Error() + ";" errs = append(errs, e.Error())
} }
return err return strings.Join(errs, ";")
} }
// Error encapsulates a name, an error and whether there's a custom error message or not. // Error encapsulates a name, an error and whether there's a custom error message or not.
@ -21,6 +23,9 @@ type Error struct {
Name string Name string
Err error Err error
CustomErrorMessageExists bool CustomErrorMessageExists bool
// Validator indicates the name of the validator that failed
Validator string
} }
func (e Error) Error() string { func (e Error) Error() string {

View file

@ -15,10 +15,10 @@ func TestErrorsToString(t *testing.T) {
expected string expected string
}{ }{
{Errors{}, ""}, {Errors{}, ""},
{Errors{fmt.Errorf("Error 1")}, "Error 1;"}, {Errors{fmt.Errorf("Error 1")}, "Error 1"},
{Errors{fmt.Errorf("Error 1"), fmt.Errorf("Error 2")}, "Error 1;Error 2;"}, {Errors{fmt.Errorf("Error 1"), fmt.Errorf("Error 2")}, "Error 1;Error 2"},
{Errors{customErr, fmt.Errorf("Error 2")}, "Custom Error Name: stdlib error;Error 2;"}, {Errors{customErr, fmt.Errorf("Error 2")}, "Custom Error Name: stdlib error;Error 2"},
{Errors{fmt.Errorf("Error 123"), customErrWithCustomErrorMessage}, "Error 123;Bad stuff happened;"}, {Errors{fmt.Errorf("Error 123"), customErrWithCustomErrorMessage}, "Error 123;Bad stuff happened"},
} }
for _, test := range tests { for _, test := range tests {
actual := test.param1.Error() actual := test.param1.Error()

View file

@ -1,6 +1,9 @@
package govalidator package govalidator
import "math" import (
"math"
"reflect"
)
// Abs returns absolute value of number // Abs returns absolute value of number
func Abs(value float64) float64 { func Abs(value float64) float64 {
@ -39,13 +42,47 @@ func IsNonPositive(value float64) bool {
} }
// InRange returns true if value lies between left and right border // InRange returns true if value lies between left and right border
func InRange(value, left, right float64) bool { func InRangeInt(value, left, right int) bool {
if left > right { if left > right {
left, right = right, left left, right = right, left
} }
return value >= left && value <= right return value >= left && value <= right
} }
// InRange returns true if value lies between left and right border
func InRangeFloat32(value, left, right float32) bool {
if left > right {
left, right = right, left
}
return value >= left && value <= right
}
// InRange returns true if value lies between left and right border
func InRangeFloat64(value, left, right float64) bool {
if left > right {
left, right = right, left
}
return value >= left && value <= right
}
// InRange returns true if value lies between left and right border, generic type to handle int, float32 or float64, all types must the same type
func InRange(value interface{}, left interface{}, right interface{}) bool {
reflectValue := reflect.TypeOf(value).Kind()
reflectLeft := reflect.TypeOf(left).Kind()
reflectRight := reflect.TypeOf(right).Kind()
if reflectValue == reflect.Int && reflectLeft == reflect.Int && reflectRight == reflect.Int {
return InRangeInt(value.(int), left.(int), right.(int))
} else if reflectValue == reflect.Float32 && reflectLeft == reflect.Float32 && reflectRight == reflect.Float32 {
return InRangeFloat32(value.(float32), left.(float32), right.(float32))
} else if reflectValue == reflect.Float64 && reflectLeft == reflect.Float64 && reflectRight == reflect.Float64 {
return InRangeFloat64(value.(float64), left.(float64), right.(float64))
} else {
return false
}
}
// IsWhole returns true if value is whole number // IsWhole returns true if value is whole number
func IsWhole(value float64) bool { func IsWhole(value float64) bool {
return math.Remainder(value, 1) == 0 return math.Remainder(value, 1) == 0

View file

@ -177,7 +177,60 @@ func TestIsNatural(t *testing.T) {
} }
} }
} }
func TestInRange(t *testing.T) {
func TestInRangeInt(t *testing.T) {
t.Parallel()
var tests = []struct {
param int
left int
right int
expected bool
}{
{0, 0, 0, true},
{1, 0, 0, false},
{-1, 0, 0, false},
{0, -1, 1, true},
{0, 0, 1, true},
{0, -1, 0, true},
{0, 0, -1, true},
{0, 10, 5, false},
}
for _, test := range tests {
actual := InRangeInt(test.param, test.left, test.right)
if actual != test.expected {
t.Errorf("Expected InRangeInt(%v, %v, %v) to be %v, got %v", test.param, test.left, test.right, test.expected, actual)
}
}
}
func TestInRangeFloat32(t *testing.T) {
t.Parallel()
var tests = []struct {
param float32
left float32
right float32
expected bool
}{
{0, 0, 0, true},
{1, 0, 0, false},
{-1, 0, 0, false},
{0, -1, 1, true},
{0, 0, 1, true},
{0, -1, 0, true},
{0, 0, -1, true},
{0, 10, 5, false},
}
for _, test := range tests {
actual := InRangeFloat32(test.param, test.left, test.right)
if actual != test.expected {
t.Errorf("Expected InRangeFloat32(%v, %v, %v) to be %v, got %v", test.param, test.left, test.right, test.expected, actual)
}
}
}
func TestInRangeFloat64(t *testing.T) {
t.Parallel() t.Parallel()
var tests = []struct { var tests = []struct {
@ -196,6 +249,98 @@ func TestInRange(t *testing.T) {
{0, 10, 5, false}, {0, 10, 5, false},
} }
for _, test := range tests { for _, test := range tests {
actual := InRangeFloat64(test.param, test.left, test.right)
if actual != test.expected {
t.Errorf("Expected InRangeFloat64(%v, %v, %v) to be %v, got %v", test.param, test.left, test.right, test.expected, actual)
}
}
}
func TestInRange(t *testing.T) {
t.Parallel()
var testsInt = []struct {
param int
left int
right int
expected bool
}{
{0, 0, 0, true},
{1, 0, 0, false},
{-1, 0, 0, false},
{0, -1, 1, true},
{0, 0, 1, true},
{0, -1, 0, true},
{0, 0, -1, true},
{0, 10, 5, false},
}
for _, test := range testsInt {
actual := InRange(test.param, test.left, test.right)
if actual != test.expected {
t.Errorf("Expected InRange(%v, %v, %v) to be %v, got %v", test.param, test.left, test.right, test.expected, actual)
}
}
var testsFloat32 = []struct {
param float32
left float32
right float32
expected bool
}{
{0, 0, 0, true},
{1, 0, 0, false},
{-1, 0, 0, false},
{0, -1, 1, true},
{0, 0, 1, true},
{0, -1, 0, true},
{0, 0, -1, true},
{0, 10, 5, false},
}
for _, test := range testsFloat32 {
actual := InRange(test.param, test.left, test.right)
if actual != test.expected {
t.Errorf("Expected InRange(%v, %v, %v) to be %v, got %v", test.param, test.left, test.right, test.expected, actual)
}
}
var testsFloat64 = []struct {
param float64
left float64
right float64
expected bool
}{
{0, 0, 0, true},
{1, 0, 0, false},
{-1, 0, 0, false},
{0, -1, 1, true},
{0, 0, 1, true},
{0, -1, 0, true},
{0, 0, -1, true},
{0, 10, 5, false},
}
for _, test := range testsFloat64 {
actual := InRange(test.param, test.left, test.right)
if actual != test.expected {
t.Errorf("Expected InRange(%v, %v, %v) to be %v, got %v", test.param, test.left, test.right, test.expected, actual)
}
}
var testsTypeMix = []struct {
param int
left float64
right float64
expected bool
}{
{0, 0, 0, false},
{1, 0, 0, false},
{-1, 0, 0, false},
{0, -1, 1, false},
{0, 0, 1, false},
{0, -1, 0, false},
{0, 0, -1, false},
{0, 10, 5, false},
}
for _, test := range testsTypeMix {
actual := InRange(test.param, test.left, test.right) actual := InRange(test.param, test.left, test.right)
if actual != test.expected { if actual != test.expected {
t.Errorf("Expected InRange(%v, %v, %v) to be %v, got %v", test.param, test.left, test.right, test.expected, actual) t.Errorf("Expected InRange(%v, %v, %v) to be %v, got %v", test.param, test.left, test.right, test.expected, actual)

View file

@ -33,7 +33,6 @@ const (
IP string = `(([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,7}:|([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|:((:[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))` IP string = `(([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,7}:|([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|:((:[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))`
URLSchema string = `((ftp|tcp|udp|wss?|https?):\/\/)` URLSchema string = `((ftp|tcp|udp|wss?|https?):\/\/)`
URLUsername string = `(\S+(:\S*)?@)` URLUsername string = `(\S+(:\S*)?@)`
Hostname string = ``
URLPath string = `((\/|\?|#)[^\s]*)` URLPath string = `((\/|\?|#)[^\s]*)`
URLPort string = `(:(\d{1,5}))` URLPort string = `(:(\d{1,5}))`
URLIP string = `([1-9]\d?|1\d\d|2[01]\d|22[0-3])(\.(1?\d{1,2}|2[0-4]\d|25[0-5])){2}(?:\.([0-9]\d?|1\d\d|2[0-4]\d|25[0-4]))` URLIP string = `([1-9]\d?|1\d\d|2[01]\d|22[0-3])(\.(1?\d{1,2}|2[0-4]\d|25[0-5])){2}(?:\.([0-9]\d?|1\d\d|2[0-4]\d|25[0-4]))`

View file

@ -34,6 +34,7 @@ var ParamTagMap = map[string]ParamValidator{
"stringlength": StringLength, "stringlength": StringLength,
"matches": StringMatches, "matches": StringMatches,
"in": isInRaw, "in": isInRaw,
"rsapub": IsRsaPub,
} }
// ParamTagRegexMap maps param tags to their respective regexes. // ParamTagRegexMap maps param tags to their respective regexes.
@ -44,6 +45,7 @@ var ParamTagRegexMap = map[string]*regexp.Regexp{
"stringlength": regexp.MustCompile("^stringlength\\((\\d+)\\|(\\d+)\\)$"), "stringlength": regexp.MustCompile("^stringlength\\((\\d+)\\|(\\d+)\\)$"),
"in": regexp.MustCompile(`^in\((.*)\)`), "in": regexp.MustCompile(`^in\((.*)\)`),
"matches": regexp.MustCompile(`^matches\((.+)\)$`), "matches": regexp.MustCompile(`^matches\((.+)\)$`),
"rsapub": regexp.MustCompile("^rsapub\\((\\d+)\\)$"),
} }
type customTypeTagMap struct { type customTypeTagMap struct {
@ -120,6 +122,7 @@ var TagMap = map[string]Validator{
"ssn": IsSSN, "ssn": IsSSN,
"semver": IsSemver, "semver": IsSemver,
"rfc3339": IsRFC3339, "rfc3339": IsRFC3339,
"rfc3339WithoutZone": IsRFC3339WithoutZone,
"ISO3166Alpha2": IsISO3166Alpha2, "ISO3166Alpha2": IsISO3166Alpha2,
"ISO3166Alpha3": IsISO3166Alpha3, "ISO3166Alpha3": IsISO3166Alpha3,
"ISO4217": IsISO4217, "ISO4217": IsISO4217,

View file

@ -108,7 +108,7 @@ func CamelCaseToUnderscore(str string) string {
var output []rune var output []rune
var segment []rune var segment []rune
for _, r := range str { for _, r := range str {
if !unicode.IsLower(r) { if !unicode.IsLower(r) && string(r) != "_" {
output = addSegment(output, segment) output = addSegment(output, segment)
segment = nil segment = nil
} }

View file

@ -269,6 +269,7 @@ func TestCamelCaseToUnderscore(t *testing.T) {
{"MyFunc", "my_func"}, {"MyFunc", "my_func"},
{"ABC", "a_b_c"}, {"ABC", "a_b_c"},
{"1B", "1_b"}, {"1B", "1_b"},
{"foo_bar", "foo_bar"},
} }
for _, test := range tests { for _, test := range tests {
actual := CamelCaseToUnderscore(test.param) actual := CamelCaseToUnderscore(test.param)

View file

@ -2,8 +2,14 @@
package govalidator package govalidator
import ( import (
"bytes"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/url" "net/url"
"reflect" "reflect"
@ -20,10 +26,12 @@ var (
fieldsRequiredByDefault bool fieldsRequiredByDefault bool
notNumberRegexp = regexp.MustCompile("[^0-9]+") notNumberRegexp = regexp.MustCompile("[^0-9]+")
whiteSpacesAndMinus = regexp.MustCompile("[\\s-]+") whiteSpacesAndMinus = regexp.MustCompile("[\\s-]+")
paramsRegexp = regexp.MustCompile("\\(.*\\)$")
) )
const maxURLRuneCount = 2083 const maxURLRuneCount = 2083
const minURLRuneCount = 3 const minURLRuneCount = 3
const RF3339WithoutZone = "2006-01-02T15:04:05"
// SetFieldsRequiredByDefault causes validation to fail when struct fields // SetFieldsRequiredByDefault causes validation to fail when struct fields
// do not include validations or are not explicitly marked as exempt (using `valid:"-"` or `valid:"email,optional"`). // do not include validations or are not explicitly marked as exempt (using `valid:"-"` or `valid:"email,optional"`).
@ -54,7 +62,13 @@ func IsURL(str string) bool {
if str == "" || utf8.RuneCountInString(str) >= maxURLRuneCount || len(str) <= minURLRuneCount || strings.HasPrefix(str, ".") { if str == "" || utf8.RuneCountInString(str) >= maxURLRuneCount || len(str) <= minURLRuneCount || strings.HasPrefix(str, ".") {
return false return false
} }
u, err := url.Parse(str) strTemp := str
if strings.Index(str, ":") >= 0 && strings.Index(str, "://") == -1 {
// support no indicated urlscheme but with colon for port number
// http:// is appended so url.Parse will succeed, strTemp used so it does not impact rxURL.MatchString
strTemp = "http://" + str
}
u, err := url.Parse(strTemp)
if err != nil { if err != nil {
return false return false
} }
@ -65,7 +79,6 @@ func IsURL(str string) bool {
return false return false
} }
return rxURL.MatchString(str) return rxURL.MatchString(str)
} }
// IsRequestURL check if the string rawurl, assuming // IsRequestURL check if the string rawurl, assuming
@ -486,6 +499,33 @@ func IsDNSName(str string) bool {
return !IsIP(str) && rxDNSName.MatchString(str) return !IsIP(str) && rxDNSName.MatchString(str)
} }
// IsHash checks if a string is a hash of type algorithm.
// Algorithm is one of ['md4', 'md5', 'sha1', 'sha256', 'sha384', 'sha512', 'ripemd128', 'ripemd160', 'tiger128', 'tiger160', 'tiger192', 'crc32', 'crc32b']
func IsHash(str string, algorithm string) bool {
len := "0"
algo := strings.ToLower(algorithm)
if algo == "crc32" || algo == "crc32b" {
len = "8"
} else if algo == "md5" || algo == "md4" || algo == "ripemd128" || algo == "tiger128" {
len = "32"
} else if algo == "sha1" || algo == "ripemd160" || algo == "tiger160" {
len = "40"
} else if algo == "tiger192" {
len = "48"
} else if algo == "sha256" {
len = "64"
} else if algo == "sha384" {
len = "96"
} else if algo == "sha512" {
len = "128"
} else {
return false
}
return Matches(str, "^[a-f0-9]{" + len + "}$")
}
// IsDialString validates the given string for usage with the various Dial() functions // IsDialString validates the given string for usage with the various Dial() functions
func IsDialString(str string) bool { func IsDialString(str string) bool {
@ -560,6 +600,40 @@ func IsLongitude(str string) bool {
return rxLongitude.MatchString(str) return rxLongitude.MatchString(str)
} }
// IsRsaPublicKey check if a string is valid public key with provided length
func IsRsaPublicKey(str string, keylen int) bool {
bb := bytes.NewBufferString(str)
pemBytes, err := ioutil.ReadAll(bb)
if err != nil {
return false
}
block, _ := pem.Decode(pemBytes)
if block != nil && block.Type != "PUBLIC KEY" {
return false
}
var der []byte
if block != nil {
der = block.Bytes
} else {
der, err = base64.StdEncoding.DecodeString(str)
if err != nil {
return false
}
}
key, err := x509.ParsePKIXPublicKey(der)
if err != nil {
return false
}
pubkey, ok := key.(*rsa.PublicKey)
if !ok {
return false
}
bitlen := len(pubkey.N.Bytes()) * 8
return bitlen == int(keylen)
}
func toJSONName(tag string) string { func toJSONName(tag string) string {
if tag == "" { if tag == "" {
return "" return ""
@ -568,7 +642,16 @@ func toJSONName(tag string) string {
// JSON name always comes first. If there's no options then split[0] is // JSON name always comes first. If there's no options then split[0] is
// JSON name, if JSON name is not set, then split[0] is an empty string. // JSON name, if JSON name is not set, then split[0] is an empty string.
split := strings.SplitN(tag, ",", 2) split := strings.SplitN(tag, ",", 2)
return split[0]
name := split[0]
// However it is possible that the field is skipped when
// (de-)serializing from/to JSON, in which case assume that there is no
// tag name to use
if name == "-" {
return ""
}
return name
} }
// ValidateStruct use tags for fields. // ValidateStruct use tags for fields.
@ -613,6 +696,14 @@ func ValidateStruct(s interface{}) (bool, error) {
jsonError.Name = jsonTag jsonError.Name = jsonTag
err2 = jsonError err2 = jsonError
case Errors: case Errors:
for i2, err3 := range jsonError {
switch customErr := err3.(type) {
case Error:
customErr.Name = jsonTag
jsonError[i2] = customErr
}
}
err2 = jsonError err2 = jsonError
} }
} }
@ -630,8 +721,11 @@ func ValidateStruct(s interface{}) (bool, error) {
// parseTagIntoMap parses a struct tag `valid:required~Some error message,length(2|3)` into map[string]string{"required": "Some error message", "length(2|3)": ""} // parseTagIntoMap parses a struct tag `valid:required~Some error message,length(2|3)` into map[string]string{"required": "Some error message", "length(2|3)": ""}
func parseTagIntoMap(tag string) tagOptionsMap { func parseTagIntoMap(tag string) tagOptionsMap {
optionsMap := make(tagOptionsMap) optionsMap := make(tagOptionsMap)
options := strings.SplitN(tag, ",", -1) options := strings.Split(tag, ",")
for _, option := range options { for _, option := range options {
option = strings.TrimSpace(option)
validationOptions := strings.Split(option, "~") validationOptions := strings.Split(option, "~")
if !isValidTag(validationOptions[0]) { if !isValidTag(validationOptions[0]) {
continue continue
@ -688,6 +782,11 @@ func IsRFC3339(str string) bool {
return IsTime(str, time.RFC3339) return IsTime(str, time.RFC3339)
} }
// IsRFC3339WithoutZone check if string is valid timestamp value according to RFC3339 which excludes the timezone.
func IsRFC3339WithoutZone(str string) bool {
return IsTime(str, RF3339WithoutZone)
}
// IsISO4217 check if string is valid ISO currency code // IsISO4217 check if string is valid ISO currency code
func IsISO4217(str string) bool { func IsISO4217(str string) bool {
for _, currency := range ISO4217List { for _, currency := range ISO4217List {
@ -716,6 +815,17 @@ func RuneLength(str string, params ...string) bool {
return StringLength(str, params...) return StringLength(str, params...)
} }
// IsRsaPub check whether string is valid RSA key
// Alias for IsRsaPublicKey
func IsRsaPub(str string, params ...string) bool {
if len(params) == 1 {
len, _ := ToInt(params[0])
return IsRsaPublicKey(str, int(len))
}
return false
}
// StringMatches checks if a string matches a given pattern. // StringMatches checks if a string matches a given pattern.
func StringMatches(s string, params ...string) bool { func StringMatches(s string, params ...string) bool {
if len(params) == 1 { if len(params) == 1 {
@ -776,11 +886,11 @@ func IsIn(str string, params ...string) bool {
func checkRequired(v reflect.Value, t reflect.StructField, options tagOptionsMap) (bool, error) { func checkRequired(v reflect.Value, t reflect.StructField, options tagOptionsMap) (bool, error) {
if requiredOption, isRequired := options["required"]; isRequired { if requiredOption, isRequired := options["required"]; isRequired {
if len(requiredOption) > 0 { if len(requiredOption) > 0 {
return false, Error{t.Name, fmt.Errorf(requiredOption), true} return false, Error{t.Name, fmt.Errorf(requiredOption), true, "required"}
} }
return false, Error{t.Name, fmt.Errorf("non zero value required"), false} return false, Error{t.Name, fmt.Errorf("non zero value required"), false, "required"}
} else if _, isOptional := options["optional"]; fieldsRequiredByDefault && !isOptional { } else if _, isOptional := options["optional"]; fieldsRequiredByDefault && !isOptional {
return false, Error{t.Name, fmt.Errorf("All fields are required to at least have one validation defined"), false} return false, Error{t.Name, fmt.Errorf("All fields are required to at least have one validation defined"), false, "required"}
} }
// not required and empty is valid // not required and empty is valid
return true, nil return true, nil
@ -799,7 +909,7 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value, options
if !fieldsRequiredByDefault { if !fieldsRequiredByDefault {
return true, nil return true, nil
} }
return false, Error{t.Name, fmt.Errorf("All fields are required to at least have one validation defined"), false} return false, Error{t.Name, fmt.Errorf("All fields are required to at least have one validation defined"), false, "required"}
case "-": case "-":
return true, nil return true, nil
} }
@ -822,10 +932,10 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value, options
if result := validatefunc(v.Interface(), o.Interface()); !result { if result := validatefunc(v.Interface(), o.Interface()); !result {
if len(customErrorMessage) > 0 { if len(customErrorMessage) > 0 {
customTypeErrors = append(customTypeErrors, Error{Name: t.Name, Err: fmt.Errorf(customErrorMessage), CustomErrorMessageExists: true}) customTypeErrors = append(customTypeErrors, Error{Name: t.Name, Err: fmt.Errorf(customErrorMessage), CustomErrorMessageExists: true, Validator: stripParams(validatorName)})
continue continue
} }
customTypeErrors = append(customTypeErrors, Error{Name: t.Name, Err: fmt.Errorf("%s does not validate as %s", fmt.Sprint(v), validatorName), CustomErrorMessageExists: false}) customTypeErrors = append(customTypeErrors, Error{Name: t.Name, Err: fmt.Errorf("%s does not validate as %s", fmt.Sprint(v), validatorName), CustomErrorMessageExists: false, Validator: stripParams(validatorName)})
} }
} }
} }
@ -844,7 +954,7 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value, options
for validator := range options { for validator := range options {
isValid = false isValid = false
resultErr = Error{t.Name, fmt.Errorf( resultErr = Error{t.Name, fmt.Errorf(
"The following validator is invalid or can't be applied to the field: %q", validator), false} "The following validator is invalid or can't be applied to the field: %q", validator), false, stripParams(validator)}
return return
} }
} }
@ -888,16 +998,16 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value, options
field := fmt.Sprint(v) // make value into string, then validate with regex field := fmt.Sprint(v) // make value into string, then validate with regex
if result := validatefunc(field, ps[1:]...); (!result && !negate) || (result && negate) { if result := validatefunc(field, ps[1:]...); (!result && !negate) || (result && negate) {
if customMsgExists { if customMsgExists {
return false, Error{t.Name, fmt.Errorf(customErrorMessage), customMsgExists} return false, Error{t.Name, fmt.Errorf(customErrorMessage), customMsgExists, stripParams(validatorSpec)}
} }
if negate { if negate {
return false, Error{t.Name, fmt.Errorf("%s does validate as %s", field, validator), customMsgExists} return false, Error{t.Name, fmt.Errorf("%s does validate as %s", field, validator), customMsgExists, stripParams(validatorSpec)}
} }
return false, Error{t.Name, fmt.Errorf("%s does not validate as %s", field, validator), customMsgExists} return false, Error{t.Name, fmt.Errorf("%s does not validate as %s", field, validator), customMsgExists, stripParams(validatorSpec)}
} }
default: default:
// type not yet supported, fail // type not yet supported, fail
return false, Error{t.Name, fmt.Errorf("Validator %s doesn't support kind %s", validator, v.Kind()), false} return false, Error{t.Name, fmt.Errorf("Validator %s doesn't support kind %s", validator, v.Kind()), false, stripParams(validatorSpec)}
} }
} }
@ -909,17 +1019,17 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value, options
field := fmt.Sprint(v) // make value into string, then validate with regex field := fmt.Sprint(v) // make value into string, then validate with regex
if result := validatefunc(field); !result && !negate || result && negate { if result := validatefunc(field); !result && !negate || result && negate {
if customMsgExists { if customMsgExists {
return false, Error{t.Name, fmt.Errorf(customErrorMessage), customMsgExists} return false, Error{t.Name, fmt.Errorf(customErrorMessage), customMsgExists, stripParams(validatorSpec)}
} }
if negate { if negate {
return false, Error{t.Name, fmt.Errorf("%s does validate as %s", field, validator), customMsgExists} return false, Error{t.Name, fmt.Errorf("%s does validate as %s", field, validator), customMsgExists, stripParams(validatorSpec)}
} }
return false, Error{t.Name, fmt.Errorf("%s does not validate as %s", field, validator), customMsgExists} return false, Error{t.Name, fmt.Errorf("%s does not validate as %s", field, validator), customMsgExists, stripParams(validatorSpec)}
} }
default: default:
//Not Yet Supported Types (Fail here!) //Not Yet Supported Types (Fail here!)
err := fmt.Errorf("Validator %s doesn't support kind %s for value %v", validator, v.Kind(), v) err := fmt.Errorf("Validator %s doesn't support kind %s for value %v", validator, v.Kind(), v)
return false, Error{t.Name, err, false} return false, Error{t.Name, err, false, stripParams(validatorSpec)}
} }
} }
} }
@ -933,10 +1043,19 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value, options
sort.Sort(sv) sort.Sort(sv)
result := true result := true
for _, k := range sv { for _, k := range sv {
resultItem, err := ValidateStruct(v.MapIndex(k).Interface()) var resultItem bool
var err error
if v.MapIndex(k).Kind() != reflect.Struct {
resultItem, err = typeCheck(v.MapIndex(k), t, o, options)
if err != nil { if err != nil {
return false, err return false, err
} }
} else {
resultItem, err = ValidateStruct(v.MapIndex(k).Interface())
if err != nil {
return false, err
}
}
result = result && resultItem result = result && resultItem
} }
return result, nil return result, nil
@ -978,6 +1097,10 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value, options
} }
} }
func stripParams(validatorString string) string {
return paramsRegexp.ReplaceAllString(validatorString, "")
}
func isEmptyValue(v reflect.Value) bool { func isEmptyValue(v reflect.Value) bool {
switch v.Kind() { switch v.Kind() {
case reflect.String, reflect.Array: case reflect.String, reflect.Array:

View file

@ -536,6 +536,37 @@ func TestIsInt(t *testing.T) {
} }
} }
func TestIsHash(t *testing.T) {
t.Parallel()
var tests = []struct {
param string
algo string
expected bool
}{
{"3ca25ae354e192b26879f651a51d92aa8a34d8d3", "sha1", true},
{"3ca25ae354e192b26879f651a51d34d8d3", "sha1", false},
{"3ca25ae354e192b26879f651a51d92aa8a34d8d3", "Tiger160", true},
{"3ca25ae354e192b26879f651a51d34d8d3", "ripemd160", false},
{"579282cfb65ca1f109b78536effaf621b853c9f7079664a3fbe2b519f435898c", "sha256", true},
{"579282cfb65ca1f109b78536effaf621b853c9f7079664a3fbe2b519f435898casfdsafsadfsdf", "sha256", false},
{"bf547c3fc5841a377eb1519c2890344dbab15c40ae4150b4b34443d2212e5b04aa9d58865bf03d8ae27840fef430b891", "sha384", true},
{"579282cfb65ca1f109b78536effaf621b853c9f7079664a3fbe2b519f435898casfdsafsadfsdf", "sha384", false},
{"45bc5fa8cb45ee408c04b6269e9f1e1c17090c5ce26ffeeda2af097735b29953ce547e40ff3ad0d120e5361cc5f9cee35ea91ecd4077f3f589b4d439168f91b9", "sha512", true},
{"579282cfb65ca1f109b78536effaf621b853c9f7079664a3fbe2b519f435898casfdsafsadfsdf", "sha512", false},
{"46fc0125a148788a3ac1d649566fc04eb84a746f1a6e4fa7", "tiger192", true},
{"46fc0125a148788a3ac1d649566fc04eb84a746f1a6$$%@^", "TIGER192", false},
{"46fc0125a148788a3ac1d649566fc04eb84a746f1a6$$%@^", "SOMEHASH", false},
}
for _, test := range tests {
actual := IsHash(test.param, test.algo)
if actual != test.expected {
t.Errorf("Expected IsHash(%q, %q) to be %v, got %v", test.param, test.algo, test.expected, actual)
}
}
}
func TestIsEmail(t *testing.T) { func TestIsEmail(t *testing.T) {
t.Parallel() t.Parallel()
@ -633,6 +664,7 @@ func TestIsURL(t *testing.T) {
{"https://pbs.twimg.com/profile_images/560826135676588032/j8fWrmYY_normal.jpeg", true}, {"https://pbs.twimg.com/profile_images/560826135676588032/j8fWrmYY_normal.jpeg", true},
// according to #125 // according to #125
{"http://prometheus-alertmanager.service.q:9093", true}, {"http://prometheus-alertmanager.service.q:9093", true},
{"aio1_alertmanager_container-63376c45:9093", true},
{"https://www.logn-123-123.url.with.sigle.letter.d:12345/url/path/foo?bar=zzz#user", true}, {"https://www.logn-123-123.url.with.sigle.letter.d:12345/url/path/foo?bar=zzz#user", true},
{"http://me.example.com", true}, {"http://me.example.com", true},
{"http://www.me.example.com", true}, {"http://www.me.example.com", true},
@ -661,6 +693,10 @@ func TestIsURL(t *testing.T) {
{"foo_bar.example.com", true}, {"foo_bar.example.com", true},
{"foo_bar_fizz_buzz.example.com", true}, {"foo_bar_fizz_buzz.example.com", true},
{"http://hello_world.example.com", true}, {"http://hello_world.example.com", true},
// According to #212
{"foo_bar-fizz-buzz:1313", true},
{"foo_bar-fizz-buzz:13:13", false},
{"foo_bar-fizz-buzz://1313", false},
} }
for _, test := range tests { for _, test := range tests {
actual := IsURL(test.param) actual := IsURL(test.param)
@ -980,6 +1016,7 @@ func TestIsMultibyte(t *testing.T) {
{"testexample.com", true}, {"testexample.com", true},
{"1234abcDE", true}, {"1234abcDE", true},
{"カタカナ", true}, {"カタカナ", true},
{"", true},
} }
for _, test := range tests { for _, test := range tests {
actual := IsMultibyte(test.param) actual := IsMultibyte(test.param)
@ -1850,6 +1887,13 @@ func TestIsTime(t *testing.T) {
{"2016-12-31T11:00:00.05Z", time.RFC3339, true}, {"2016-12-31T11:00:00.05Z", time.RFC3339, true},
{"2016-12-31T11:00:00.05-01:00", time.RFC3339, true}, {"2016-12-31T11:00:00.05-01:00", time.RFC3339, true},
{"2016-12-31T11:00:00.05+01:00", time.RFC3339, true}, {"2016-12-31T11:00:00.05+01:00", time.RFC3339, true},
{"2016-12-31T11:00:00", RF3339WithoutZone, true},
{"2016-12-31T11:00:00Z", RF3339WithoutZone, false},
{"2016-12-31T11:00:00+01:00", RF3339WithoutZone, false},
{"2016-12-31T11:00:00-01:00", RF3339WithoutZone, false},
{"2016-12-31T11:00:00.05Z", RF3339WithoutZone, false},
{"2016-12-31T11:00:00.05-01:00", RF3339WithoutZone, false},
{"2016-12-31T11:00:00.05+01:00", RF3339WithoutZone, false},
} }
for _, test := range tests { for _, test := range tests {
actual := IsTime(test.param, test.format) actual := IsTime(test.param, test.format)
@ -2162,7 +2206,7 @@ func TestInvalidValidator(t *testing.T) {
invalidStruct := InvalidStruct{1} invalidStruct := InvalidStruct{1}
if valid, err := ValidateStruct(&invalidStruct); valid || err == nil || if valid, err := ValidateStruct(&invalidStruct); valid || err == nil ||
err.Error() != `Field: The following validator is invalid or can't be applied to the field: "someInvalidValidator";` { err.Error() != `Field: The following validator is invalid or can't be applied to the field: "someInvalidValidator"` {
t.Errorf("Got an unexpected result for struct with invalid validator: %t %s", valid, err) t.Errorf("Got an unexpected result for struct with invalid validator: %t %s", valid, err)
} }
} }
@ -2184,12 +2228,12 @@ func TestCustomValidator(t *testing.T) {
t.Errorf("Got an unexpected result for struct with custom always true validator: %t %s", valid, err) t.Errorf("Got an unexpected result for struct with custom always true validator: %t %s", valid, err)
} }
if valid, err := ValidateStruct(&InvalidStruct{Field: 1}); valid || err == nil || err.Error() != "Custom validator error;;" { if valid, err := ValidateStruct(&InvalidStruct{Field: 1}); valid || err == nil || err.Error() != "Custom validator error" {
t.Errorf("Got an unexpected result for struct with custom always false validator: %t %s", valid, err) t.Errorf("Got an unexpected result for struct with custom always false validator: %t %s", valid, err)
} }
mixedStruct := StructWithCustomAndBuiltinValidator{} mixedStruct := StructWithCustomAndBuiltinValidator{}
if valid, err := ValidateStruct(&mixedStruct); valid || err == nil || err.Error() != "Field: non zero value required;" { if valid, err := ValidateStruct(&mixedStruct); valid || err == nil || err.Error() != "Field: non zero value required" {
t.Errorf("Got an unexpected result for invalid struct with custom and built-in validators: %t %s", valid, err) t.Errorf("Got an unexpected result for invalid struct with custom and built-in validators: %t %s", valid, err)
} }
@ -2522,6 +2566,8 @@ func TestValidateStruct(t *testing.T) {
type testByteArray [8]byte type testByteArray [8]byte
type testByteMap map[byte]byte type testByteMap map[byte]byte
type testByteSlice []byte type testByteSlice []byte
type testStringStringMap map[string]string
type testStringIntMap map[string]int
func TestRequired(t *testing.T) { func TestRequired(t *testing.T) {
@ -2606,6 +2652,22 @@ func TestRequired(t *testing.T) {
}{}, }{},
false, false,
}, },
{
struct {
TestStringStringMap testStringStringMap `valid:"required"`
}{
testStringStringMap{"test": "test"},
},
true,
},
{
struct {
TestIntMap testStringIntMap `valid:"required"`
}{
testStringIntMap{"test": 42},
},
true,
},
} }
for _, test := range tests { for _, test := range tests {
actual, err := ValidateStruct(test.param) actual, err := ValidateStruct(test.param)
@ -2693,7 +2755,7 @@ func TestErrorsByField(t *testing.T) {
{"CustomField", "An error occurred"}, {"CustomField", "An error occurred"},
} }
err = Error{"CustomField", fmt.Errorf("An error occurred"), false} err = Error{"CustomField", fmt.Errorf("An error occurred"), false, "hello"}
errs = ErrorsByField(err) errs = ErrorsByField(err)
if len(errs) != 1 { if len(errs) != 1 {
@ -2842,6 +2904,7 @@ func TestJSONValidator(t *testing.T) {
WithoutJSONName string `valid:"-,required"` WithoutJSONName string `valid:"-,required"`
WithJSONOmit string `json:"with_other_json_name,omitempty" valid:"-,required"` WithJSONOmit string `json:"with_other_json_name,omitempty" valid:"-,required"`
WithJSONOption string `json:",omitempty" valid:"-,required"` WithJSONOption string `json:",omitempty" valid:"-,required"`
WithEmptyJSONName string `json:"-" valid:"-,required"`
} }
_, err := ValidateStruct(val) _, err := ValidateStruct(val)
@ -2861,4 +2924,117 @@ func TestJSONValidator(t *testing.T) {
if Contains(err.Error(), "omitempty") { if Contains(err.Error(), "omitempty") {
t.Errorf("Expected error message to not contain ',omitempty' but actual error is: %s", err.Error()) t.Errorf("Expected error message to not contain ',omitempty' but actual error is: %s", err.Error())
} }
if !Contains(err.Error(), "WithEmptyJSONName") {
t.Errorf("Expected error message to contain WithEmptyJSONName but actual error is: %s", err.Error())
}
}
func TestValidatorIncludedInError(t *testing.T) {
post := Post{
Title: "",
Message: "👍",
AuthorIP: "xyz",
}
validatorMap := map[string]string{
"Title": "required",
"Message": "ascii",
"AuthorIP": "ipv4",
}
ok, errors := ValidateStruct(post)
if ok {
t.Errorf("expected validation to fail %v", ok)
}
for _, e := range errors.(Errors) {
casted := e.(Error)
if validatorMap[casted.Name] != casted.Validator {
t.Errorf("expected validator for %s to be %s, but was %s", casted.Name, validatorMap[casted.Name], casted.Validator)
}
}
// check to make sure that validators with arguments (like length(1|10)) don't include the arguments
// in the validator name
message := MessageWithSeveralFieldsStruct{
Title: "",
Body: "asdfasdfasdfasdfasdf",
}
validatorMap = map[string]string{
"Title": "length",
"Body": "length",
}
ok, errors = ValidateStruct(message)
if ok {
t.Errorf("expected validation to fail, %v", ok)
}
for _, e := range errors.(Errors) {
casted := e.(Error)
if validatorMap[casted.Name] != casted.Validator {
t.Errorf("expected validator for %s to be %s, but was %s", casted.Name, validatorMap[casted.Name], casted.Validator)
}
}
// make sure validators with custom messages don't show up in the validator string
type CustomMessage struct {
Text string `valid:"length(1|10)~Custom message"`
}
cs := CustomMessage{Text: "asdfasdfasdfasdf"}
ok, errors = ValidateStruct(&cs)
if ok {
t.Errorf("expected validation to fail, %v", ok)
}
validator := errors.(Errors)[0].(Error).Validator
if validator != "length" {
t.Errorf("expected validator for Text to be length, but was %s", validator)
}
}
func TestIsRsaPublicKey(t *testing.T) {
var tests = []struct {
rsastr string
keylen int
expected bool
}{
{`fubar`, 2048, false},
{`MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAvncDCeibmEkabJLmFec7x9y86RP6dIvkVxxbQoOJo06E+p7tH6vCmiGHKnuu
XwKYLq0DKUE3t/HHsNdowfD9+NH8caLzmXqGBx45/Dzxnwqz0qYq7idK+Qff34qrk/YFoU7498U1Ee7PkKb7/VE9BmMEcI3uoKbeXCbJRI
HoTp8bUXOpNTSUfwUNwJzbm2nsHo2xu6virKtAZLTsJFzTUmRd11MrWCvj59lWzt1/eIMN+ekjH8aXeLOOl54CL+kWp48C+V9BchyKCShZ
B7ucimFvjHTtuxziXZQRO7HlcsBOa0WwvDJnRnskdyoD31s4F4jpKEYBJNWTo63v6lUvbQIDAQAB`, 2048, true},
{`MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAvncDCeibmEkabJLmFec7x9y86RP6dIvkVxxbQoOJo06E+p7tH6vCmiGHKnuu
XwKYLq0DKUE3t/HHsNdowfD9+NH8caLzmXqGBx45/Dzxnwqz0qYq7idK+Qff34qrk/YFoU7498U1Ee7PkKb7/VE9BmMEcI3uoKbeXCbJRI
HoTp8bUXOpNTSUfwUNwJzbm2nsHo2xu6virKtAZLTsJFzTUmRd11MrWCvj59lWzt1/eIMN+ekjH8aXeLOOl54CL+kWp48C+V9BchyKCShZ
B7ucimFvjHTtuxziXZQRO7HlcsBOa0WwvDJnRnskdyoD31s4F4jpKEYBJNWTo63v6lUvbQIDAQAB`, 1024, false},
{`-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAvncDCeibmEkabJLmFec7
x9y86RP6dIvkVxxbQoOJo06E+p7tH6vCmiGHKnuuXwKYLq0DKUE3t/HHsNdowfD9
+NH8caLzmXqGBx45/Dzxnwqz0qYq7idK+Qff34qrk/YFoU7498U1Ee7PkKb7/VE9
BmMEcI3uoKbeXCbJRIHoTp8bUXOpNTSUfwUNwJzbm2nsHo2xu6virKtAZLTsJFzT
UmRd11MrWCvj59lWzt1/eIMN+ekjH8aXeLOOl54CL+kWp48C+V9BchyKCShZB7uc
imFvjHTtuxziXZQRO7HlcsBOa0WwvDJnRnskdyoD31s4F4jpKEYBJNWTo63v6lUv
bQIDAQAB
-----END PUBLIC KEY-----`, 2048, true},
{`-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAvncDCeibmEkabJLmFec7
x9y86RP6dIvkVxxbQoOJo06E+p7tH6vCmiGHKnuuXwKYLq0DKUE3t/HHsNdowfD9
+NH8caLzmXqGBx45/Dzxnwqz0qYq7idK+Qff34qrk/YFoU7498U1Ee7PkKb7/VE9
BmMEcI3uoKbeXCbJRIHoTp8bUXOpNTSUfwUNwJzbm2nsHo2xu6virKtAZLTsJFzT
UmRd11MrWCvj59lWzt1/eIMN+ekjH8aXeLOOl54CL+kWp48C+V9BchyKCShZB7uc
imFvjHTtuxziXZQRO7HlcsBOa0WwvDJnRnskdyoD31s4F4jpKEYBJNWTo63v6lUv
bQIDAQAB
-----END PUBLIC KEY-----`, 4096, false},
}
for i, test := range tests {
actual := IsRsaPublicKey(test.rsastr, test.keylen)
if actual != test.expected {
t.Errorf("Expected TestIsRsaPublicKey(%d, %d) to be %v, got %v", i, test.keylen, test.expected, actual)
}
}
} }

View file

@ -1 +1,3 @@
secrets.yml secrets.yml
vendor
Godeps

View file

@ -17,8 +17,8 @@ package swag
import ( import (
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath"
"path" "path"
"path/filepath"
"runtime" "runtime"
"testing" "testing"

View file

@ -40,6 +40,7 @@ var commonInitialisms = map[string]bool{
"IP": true, "IP": true,
"JSON": true, "JSON": true,
"LHS": true, "LHS": true,
"OAI": true,
"QPS": true, "QPS": true,
"RAM": true, "RAM": true,
"RHS": true, "RHS": true,
@ -163,8 +164,8 @@ func split(str string) (words []string) {
// Split when uppercase is found (needed for Snake) // Split when uppercase is found (needed for Snake)
str = rex1.ReplaceAllString(str, " $1") str = rex1.ReplaceAllString(str, " $1")
// check if consecutive single char things make up an initialism
// check if consecutive single char things make up an initialism
for _, k := range initialisms { for _, k := range initialisms {
str = strings.Replace(str, rex1.ReplaceAllString(k, " $1"), " "+k, -1) str = strings.Replace(str, rex1.ReplaceAllString(k, " $1"), " "+k, -1)
} }
@ -189,10 +190,47 @@ func lower(str string) string {
return strings.ToLower(trim(str)) return strings.ToLower(trim(str))
} }
// Camelize an uppercased word
func Camelize(word string) (camelized string) {
for pos, ru := range word {
if pos > 0 {
camelized += string(unicode.ToLower(ru))
} else {
camelized += string(unicode.ToUpper(ru))
}
}
return
}
// ToFileName lowercases and underscores a go type name // ToFileName lowercases and underscores a go type name
func ToFileName(name string) string { func ToFileName(name string) string {
var out []string var out []string
for _, w := range split(name) { cml := trim(name)
// Camelize any capital word preceding a reserved keyword ("initialism")
// thus, upper-cased words preceding a common initialism will get separated
// e.g: ELBHTTPLoadBalancer becomes elb_http_load_balancer
rexPrevious := regexp.MustCompile(`(?P<word>\p{Lu}{2,})(?:HTTP|OAI)`)
cml = rexPrevious.ReplaceAllStringFunc(cml, func(match string) (replaceInMatch string) {
for _, m := range rexPrevious.FindAllStringSubmatch(match, -1) { // [ match submatch ]
if m[1] != "" {
replaceInMatch = strings.Replace(m[0], m[1], Camelize(m[1]), -1)
}
}
return
})
// Pre-camelize reserved keywords ("initialisms") to avoid unnecessary hyphenization
for _, k := range initialisms {
cml = strings.Replace(cml, k, Camelize(k), -1)
}
// Camelize other capital words to avoid unnecessary hyphenization
rexCase := regexp.MustCompile(`(\p{Lu}{2,})`)
cml = rexCase.ReplaceAllStringFunc(cml, Camelize)
// Final split with hyphens
for _, w := range split(cml) {
out = append(out, lower(w)) out = append(out, lower(w))
} }
return strings.Join(out, "_") return strings.Join(out, "_")

View file

@ -39,6 +39,7 @@ func TestToGoName(t *testing.T) {
{"findThingById", "FindThingByID"}, {"findThingById", "FindThingByID"},
{"日本語sample 2 Text", "X日本語sample2Text"}, {"日本語sample 2 Text", "X日本語sample2Text"},
{"日本語findThingById", "X日本語findThingByID"}, {"日本語findThingById", "X日本語findThingByID"},
{"findTHINGSbyID", "FindTHINGSbyID"},
} }
for k := range commonInitialisms { for k := range commonInitialisms {
@ -122,8 +123,16 @@ func TestToFileName(t *testing.T) {
samples := []translationSample{ samples := []translationSample{
{"SampleText", "sample_text"}, {"SampleText", "sample_text"},
{"FindThingByID", "find_thing_by_id"}, {"FindThingByID", "find_thing_by_id"},
{"CAPWD.folwdBYlc", "capwd_folwd_bylc"},
{"CAPWDfolwdBYlc", "capwdfolwd_bylc"},
{"CAP_WD_folwdBYlc", "cap_wd_folwd_bylc"},
{"TypeOAI_alias", "type_oai_alias"},
{"Type_OAI_alias", "type_oai_alias"},
{"Type_OAIAlias", "type_oai_alias"},
{"ELB.HTTPLoadBalancer", "elb_http_load_balancer"},
{"elbHTTPLoadBalancer", "elb_http_load_balancer"},
{"ELBHTTPLoadBalancer", "elb_http_load_balancer"},
} }
for k := range commonInitialisms { for k := range commonInitialisms {
samples = append(samples, samples = append(samples,
translationSample{"Sample" + k + "Text", "sample_" + lower(k) + "_text"}, translationSample{"Sample" + k + "Text", "sample_" + lower(k) + "_text"},

View file

@ -1,27 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unix
import (
"os"
"syscall"
)
// FIXME: unexported function from os
// syscallMode returns the syscall-specific mode bits from Go's portable mode bits.
func syscallMode(i os.FileMode) (o uint32) {
o |= uint32(i.Perm())
if i&os.ModeSetuid != 0 {
o |= syscall.S_ISUID
}
if i&os.ModeSetgid != 0 {
o |= syscall.S_ISGID
}
if i&os.ModeSticky != 0 {
o |= syscall.S_ISVTX
}
// No mapping for Go's ModeTemporary (plan9 only).
return
}

View file

@ -1125,6 +1125,10 @@ func PtracePokeData(pid int, addr uintptr, data []byte) (count int, err error) {
return ptracePoke(PTRACE_POKEDATA, PTRACE_PEEKDATA, pid, addr, data) return ptracePoke(PTRACE_POKEDATA, PTRACE_PEEKDATA, pid, addr, data)
} }
func PtracePokeUser(pid int, addr uintptr, data []byte) (count int, err error) {
return ptracePoke(PTRACE_POKEUSR, PTRACE_PEEKUSR, pid, addr, data)
}
func PtraceGetRegs(pid int, regsout *PtraceRegs) (err error) { func PtraceGetRegs(pid int, regsout *PtraceRegs) (err error) {
return ptrace(PTRACE_GETREGS, pid, 0, uintptr(unsafe.Pointer(regsout))) return ptrace(PTRACE_GETREGS, pid, 0, uintptr(unsafe.Pointer(regsout)))
} }

View file

@ -796,6 +796,75 @@ func ConnectEx(fd Handle, sa Sockaddr, sendBuf *byte, sendDataLen uint32, bytesS
return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped) return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
} }
var sendRecvMsgFunc struct {
once sync.Once
sendAddr uintptr
recvAddr uintptr
err error
}
func loadWSASendRecvMsg() error {
sendRecvMsgFunc.once.Do(func() {
var s Handle
s, sendRecvMsgFunc.err = Socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)
if sendRecvMsgFunc.err != nil {
return
}
defer CloseHandle(s)
var n uint32
sendRecvMsgFunc.err = WSAIoctl(s,
SIO_GET_EXTENSION_FUNCTION_POINTER,
(*byte)(unsafe.Pointer(&WSAID_WSARECVMSG)),
uint32(unsafe.Sizeof(WSAID_WSARECVMSG)),
(*byte)(unsafe.Pointer(&sendRecvMsgFunc.recvAddr)),
uint32(unsafe.Sizeof(sendRecvMsgFunc.recvAddr)),
&n, nil, 0)
if sendRecvMsgFunc.err != nil {
return
}
sendRecvMsgFunc.err = WSAIoctl(s,
SIO_GET_EXTENSION_FUNCTION_POINTER,
(*byte)(unsafe.Pointer(&WSAID_WSASENDMSG)),
uint32(unsafe.Sizeof(WSAID_WSASENDMSG)),
(*byte)(unsafe.Pointer(&sendRecvMsgFunc.sendAddr)),
uint32(unsafe.Sizeof(sendRecvMsgFunc.sendAddr)),
&n, nil, 0)
})
return sendRecvMsgFunc.err
}
func WSASendMsg(fd Handle, msg *WSAMsg, flags uint32, bytesSent *uint32, overlapped *Overlapped, croutine *byte) error {
err := loadWSASendRecvMsg()
if err != nil {
return err
}
r1, _, e1 := syscall.Syscall6(sendRecvMsgFunc.sendAddr, 6, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(croutine)))
if r1 == socket_error {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return err
}
func WSARecvMsg(fd Handle, msg *WSAMsg, bytesReceived *uint32, overlapped *Overlapped, croutine *byte) error {
err := loadWSASendRecvMsg()
if err != nil {
return err
}
r1, _, e1 := syscall.Syscall6(sendRecvMsgFunc.recvAddr, 5, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(unsafe.Pointer(bytesReceived)), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(croutine)), 0)
if r1 == socket_error {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return err
}
// Invented structures to support what package os expects. // Invented structures to support what package os expects.
type Rusage struct { type Rusage struct {
CreationTime Filetime CreationTime Filetime

View file

@ -29,6 +29,7 @@ const (
ERROR_NOT_FOUND syscall.Errno = 1168 ERROR_NOT_FOUND syscall.Errno = 1168
ERROR_PRIVILEGE_NOT_HELD syscall.Errno = 1314 ERROR_PRIVILEGE_NOT_HELD syscall.Errno = 1314
WSAEACCES syscall.Errno = 10013 WSAEACCES syscall.Errno = 10013
WSAEMSGSIZE syscall.Errno = 10040
WSAECONNRESET syscall.Errno = 10054 WSAECONNRESET syscall.Errno = 10054
) )
@ -567,6 +568,16 @@ const (
IPV6_JOIN_GROUP = 0xc IPV6_JOIN_GROUP = 0xc
IPV6_LEAVE_GROUP = 0xd IPV6_LEAVE_GROUP = 0xd
MSG_OOB = 0x1
MSG_PEEK = 0x2
MSG_DONTROUTE = 0x4
MSG_WAITALL = 0x8
MSG_TRUNC = 0x0100
MSG_CTRUNC = 0x0200
MSG_BCAST = 0x0400
MSG_MCAST = 0x0800
SOMAXCONN = 0x7fffffff SOMAXCONN = 0x7fffffff
TCP_NODELAY = 1 TCP_NODELAY = 1
@ -584,6 +595,15 @@ type WSABuf struct {
Buf *byte Buf *byte
} }
type WSAMsg struct {
Name *syscall.RawSockaddrAny
Namelen int32
Buffers *WSABuf
BufferCount uint32
Control WSABuf
Flags uint32
}
// Invented values to support what package os expects. // Invented values to support what package os expects.
const ( const (
S_IFMT = 0x1f000 S_IFMT = 0x1f000
@ -1011,6 +1031,20 @@ var WSAID_CONNECTEX = GUID{
[8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e}, [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
} }
var WSAID_WSASENDMSG = GUID{
0xa441e712,
0x754f,
0x43ca,
[8]byte{0x84, 0xa7, 0x0d, 0xee, 0x44, 0xcf, 0x60, 0x6d},
}
var WSAID_WSARECVMSG = GUID{
0xf689d7c8,
0x6f1f,
0x436b,
[8]byte{0x8a, 0x53, 0xe5, 0x4f, 0xe3, 0x51, 0xc3, 0x22},
}
const ( const (
FILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1 FILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1
FILE_SKIP_SET_EVENT_ON_HANDLE = 2 FILE_SKIP_SET_EVENT_ON_HANDLE = 2