diff --git a/pkg/chain/chain.go b/pkg/chain/chain.go index 2da29ff..aa857c9 100644 --- a/pkg/chain/chain.go +++ b/pkg/chain/chain.go @@ -10,7 +10,7 @@ import ( ) // ID is the ID of rule chain. -type ID string +type ID []byte // MatchType is the match type for chain rules. type MatchType uint8 diff --git a/pkg/chain/chain_easyjson.go b/pkg/chain/chain_easyjson.go index c744878..fd7ff8c 100644 Binary files a/pkg/chain/chain_easyjson.go and b/pkg/chain/chain_easyjson.go differ diff --git a/pkg/chain/marshal_binary.go b/pkg/chain/marshal_binary.go index 343eca4..83e6380 100644 --- a/pkg/chain/marshal_binary.go +++ b/pkg/chain/marshal_binary.go @@ -19,7 +19,7 @@ var ( func (c *Chain) MarshalBinary() ([]byte, error) { s := marshal.UInt8Size // Marshaller version s += marshal.UInt8Size // Chain version - s += marshal.StringSize(string(c.ID)) + s += marshal.SliceSize(c.ID, func(byte) int { return marshal.ByteSize }) s += marshal.SliceSize(c.Rules, ruleSize) s += marshal.UInt8Size // MatchType @@ -34,7 +34,7 @@ func (c *Chain) MarshalBinary() ([]byte, error) { if err != nil { return nil, err } - offset, err = marshal.StringMarshal(buf, offset, string(c.ID)) + offset, err = marshal.SliceMarshal(buf, offset, c.ID, marshal.ByteMarshal) if err != nil { return nil, err } @@ -72,11 +72,11 @@ func (c *Chain) UnmarshalBinary(data []byte) error { return fmt.Errorf("unsupported chain version %d", chainVersion) } - idStr, offset, err := marshal.StringUnmarshal(data, offset) + idBytes, offset, err := marshal.SliceUnmarshal(data, offset, marshal.ByteUnmarshal) if err != nil { return err } - c.ID = ID(idStr) + c.ID = ID(idBytes) c.Rules, offset, err = marshal.SliceUnmarshal(data, offset, unmarshalRule) if err != nil { diff --git a/pkg/chain/marshal_json.go b/pkg/chain/marshal_json.go index 6039081..8dec214 100644 --- a/pkg/chain/marshal_json.go +++ b/pkg/chain/marshal_json.go @@ -143,11 +143,3 @@ func (ct *ConditionType) UnmarshalEasyJSON(l *jlexer.Lexer) { } *ct = ConditionType(v) } - -func (id ID) MarshalEasyJSON(w *jwriter.Writer) { - w.Base64Bytes([]byte(id)) -} - -func (id *ID) UnmarshalEasyJSON(l *jlexer.Lexer) { - *id = ID(l.Bytes()) -} diff --git a/pkg/chain/marshal_json_test.go b/pkg/chain/marshal_json_test.go index 75b1bc6..6c15ae9 100644 --- a/pkg/chain/marshal_json_test.go +++ b/pkg/chain/marshal_json_test.go @@ -38,11 +38,11 @@ func TestMatchTypeJson(t *testing.T) { data, err := chain.MarshalJSON() require.NoError(t, err) if mt == MatchTypeDenyPriority { - require.Equal(t, []byte("{\"ID\":\"\",\"Rules\":null,\"MatchType\":\"DenyPriority\"}"), data) + require.Equal(t, []byte("{\"ID\":null,\"Rules\":null,\"MatchType\":\"DenyPriority\"}"), data) } else if mt == MatchTypeFirstMatch { - require.Equal(t, []byte("{\"ID\":\"\",\"Rules\":null,\"MatchType\":\"FirstMatch\"}"), data) + require.Equal(t, []byte("{\"ID\":null,\"Rules\":null,\"MatchType\":\"FirstMatch\"}"), data) } else { - require.Equal(t, []byte(fmt.Sprintf("{\"ID\":\"\",\"Rules\":null,\"MatchType\":\"%d\"}", mt)), data) + require.Equal(t, []byte(fmt.Sprintf("{\"ID\":null,\"Rules\":null,\"MatchType\":\"%d\"}", mt)), data) } var parsed Chain @@ -55,7 +55,7 @@ func TestMatchTypeJson(t *testing.T) { func TestJsonEnums(t *testing.T) { chain := Chain{ - ID: "2cca5ae7-cee8-428d-b45f-567fb1d03f01", // will be encoded to base64 + ID: []byte("2cca5ae7-cee8-428d-b45f-567fb1d03f01"), // will be encoded to base64 MatchType: MatchTypeFirstMatch, Rules: []Rule{ { diff --git a/pkg/engine/inmemory/local_storage.go b/pkg/engine/inmemory/local_storage.go index 21d3d55..30553bd 100644 --- a/pkg/engine/inmemory/local_storage.go +++ b/pkg/engine/inmemory/local_storage.go @@ -1,6 +1,7 @@ package inmemory import ( + "bytes" "fmt" "math/rand" "strings" @@ -14,14 +15,14 @@ import ( type targetToChain map[engine.Target][]*chain.Chain type inmemoryLocalStorage struct { - usedChainID map[chain.ID]struct{} + usedChainID map[string]struct{} nameToResourceChains map[chain.Name]targetToChain guard *sync.RWMutex } func NewInmemoryLocalStorage() engine.LocalOverrideStorage { return &inmemoryLocalStorage{ - usedChainID: map[chain.ID]struct{}{}, + usedChainID: map[string]struct{}{}, nameToResourceChains: make(map[chain.Name]targetToChain), guard: &sync.RWMutex{}, } @@ -35,12 +36,13 @@ func (s *inmemoryLocalStorage) generateChainID(name chain.Name, target engine.Ta sid = strings.ReplaceAll(sid, "*", "") sid = strings.ReplaceAll(sid, "/", ":") sid = strings.ReplaceAll(sid, "::", ":") - id = chain.ID(sid) - _, ok := s.usedChainID[id] + _, ok := s.usedChainID[sid] if ok { continue } - s.usedChainID[id] = struct{}{} + s.usedChainID[sid] = struct{}{} + + id = chain.ID(sid) break } return id @@ -51,7 +53,7 @@ func (s *inmemoryLocalStorage) AddOverride(name chain.Name, target engine.Target defer s.guard.Unlock() // AddOverride assigns generated chain ID if it has not been assigned. - if c.ID == "" { + if len(c.ID) == 0 { c.ID = s.generateChainID(name, target) } if s.nameToResourceChains[name] == nil { @@ -59,7 +61,7 @@ func (s *inmemoryLocalStorage) AddOverride(name chain.Name, target engine.Target } rc := s.nameToResourceChains[name] for i := range rc[target] { - if rc[target][i].ID == c.ID { + if bytes.Equal(rc[target][i].ID, c.ID) { rc[target][i] = c return c.ID, nil } @@ -80,7 +82,7 @@ func (s *inmemoryLocalStorage) GetOverride(name chain.Name, target engine.Target return nil, engine.ErrResourceNotFound } for _, c := range chains { - if c.ID == chainID { + if bytes.Equal(c.ID, chainID) { return c, nil } } @@ -99,7 +101,7 @@ func (s *inmemoryLocalStorage) RemoveOverride(name chain.Name, target engine.Tar return engine.ErrResourceNotFound } for i, c := range chains { - if c.ID == chainID { + if bytes.Equal(c.ID, chainID) { s.nameToResourceChains[name][target] = append(chains[:i], chains[i+1:]...) return nil } diff --git a/pkg/engine/inmemory/local_storage_test.go b/pkg/engine/inmemory/local_storage_test.go index 3609070..5c37879 100644 --- a/pkg/engine/inmemory/local_storage_test.go +++ b/pkg/engine/inmemory/local_storage_test.go @@ -14,9 +14,7 @@ const ( nonExistChainId = "ingress:LxGyWyL" ) -var ( - resrc = engine.ContainerTarget(container) -) +var resrc = engine.ContainerTarget(container) func testInmemLocalStorage() *inmemoryLocalStorage { return NewInmemoryLocalStorage().(*inmemoryLocalStorage) @@ -210,12 +208,12 @@ func TestGenerateID(t *testing.T) { } func hasDuplicates(ids []chain.ID) bool { - seen := make(map[chain.ID]bool) + seen := make(map[string]bool) for _, id := range ids { - if seen[id] { + if seen[string(id)] { return true } - seen[id] = true + seen[string(id)] = true } return false } diff --git a/pkg/morph/policy/policy_contract_storage.go b/pkg/morph/policy/policy_contract_storage.go index 2e0a756..2eaa099 100644 --- a/pkg/morph/policy/policy_contract_storage.go +++ b/pkg/morph/policy/policy_contract_storage.go @@ -44,7 +44,7 @@ func NewContractStorageWithSimpleActor(rpcActor actor.RPCActor, acc *wallet.Acco } func (s *ContractStorage) AddMorphRuleChain(name chain.Name, target engine.Target, c *chain.Chain) (txHash util.Uint256, vub uint32, err error) { - if c.ID == "" { + if len(c.ID) == 0 { err = ErrEmptyChainID return } @@ -61,7 +61,7 @@ func (s *ContractStorage) AddMorphRuleChain(name chain.Name, target engine.Targe } func (s *ContractStorage) RemoveMorphRuleChain(name chain.Name, target engine.Target, chainID chain.ID) (txHash util.Uint256, vub uint32, err error) { - if chainID == "" { + if len(chainID) == 0 { err = ErrEmptyChainID return }