diff --git a/pkg/session/session.go b/pkg/session/session.go index ecdb93b2..4d9780ca 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -135,6 +135,50 @@ func (t *Token) Signature() *pkg.Signature { ) } +// SetContext sets context of the Token. +// +// Supported contexts: +// - *ContainerContext. +// +// Resets context if it is not supported. +func (t *Token) SetContext(v interface{}) { + var cV2 session.SessionTokenContext + + switch c := v.(type) { + case *ContainerContext: + cV2 = c.ToV2() + } + + t.setBodyField(func(body *session.SessionTokenBody) { + body.SetContext(cV2) + }) +} + +// Context returns context of the Token. +// +// Supports same contexts as SetContext. +// +// Returns nil if context is not supported. +func (t *Token) Context() interface{} { + switch v := (*session.SessionToken)(t). + GetBody(). + GetContext(); c := v.(type) { + default: + return nil + case *session.ContainerSessionContext: + return ContainerContextFromV2(c) + } +} + +// GetContainerContext is a helper function that casts +// Token context to ContainerContext. +// +// Returns nil if context is not a ContainerContext. +func GetContainerContext(t *Token) *ContainerContext { + c, _ := t.Context().(*ContainerContext) + return c +} + // Marshal marshals Token into a protobuf binary form. // // Buffer is allocated when the argument is empty. diff --git a/pkg/session/session_test.go b/pkg/session/session_test.go index 99645410..5476f153 100644 --- a/pkg/session/session_test.go +++ b/pkg/session/session_test.go @@ -86,3 +86,55 @@ func TestToken_VerifySignature(t *testing.T) { require.True(t, tok.VerifySignature()) }) } + +var unsupportedContexts = []interface{}{ + 123, + true, + session.NewToken(), +} + +var nonContainerContexts = unsupportedContexts + +func TestToken_Context(t *testing.T) { + tok := session.NewToken() + + for _, item := range []struct { + ctx interface{} + v2assert func(interface{}) + }{ + { + ctx: sessiontest.ContainerContext(), + v2assert: func(c interface{}) { + require.Equal(t, c.(*session.ContainerContext).ToV2(), tok.ToV2().GetBody().GetContext()) + }, + }, + } { + tok.SetContext(item.ctx) + + require.Equal(t, item.ctx, tok.Context()) + + item.v2assert(item.ctx) + } + + for _, c := range unsupportedContexts { + tok.SetContext(c) + + require.Nil(t, tok.Context()) + } +} + +func TestGetContainerContext(t *testing.T) { + tok := session.NewToken() + + c := sessiontest.ContainerContext() + + tok.SetContext(c) + + require.Equal(t, c, session.GetContainerContext(tok)) + + for _, c := range nonContainerContexts { + tok.SetContext(c) + + require.Nil(t, session.GetContainerContext(tok)) + } +}