package cache

import (
	"context"
	"errors"
	"testing"

	"github.com/distribution/distribution/v3"
	digest "github.com/opencontainers/go-digest"
)

func TestCacheSet(t *testing.T) {
	cache := newTestStatter()
	backend := newTestStatter()
	st := NewCachedBlobStatter(cache, backend)
	ctx := context.Background()

	dgst := digest.Digest("dontvalidate")
	_, err := st.Stat(ctx, dgst)
	if err != distribution.ErrBlobUnknown {
		t.Fatalf("Unexpected error %v, expected %v", err, distribution.ErrBlobUnknown)
	}

	desc := distribution.Descriptor{
		Digest: dgst,
	}
	if err := backend.SetDescriptor(ctx, dgst, desc); err != nil {
		t.Fatal(err)
	}

	actual, err := st.Stat(ctx, dgst)
	if err != nil {
		t.Fatal(err)
	}
	if actual.Digest != desc.Digest {
		t.Fatalf("Unexpected descriptor %v, expected %v", actual, desc)
	}

	if len(cache.sets) != 1 || len(cache.sets[dgst]) == 0 {
		t.Fatalf("Expected cache set")
	}
	if cache.sets[dgst][0].Digest != desc.Digest {
		t.Fatalf("Unexpected descriptor %v, expected %v", cache.sets[dgst][0], desc)
	}

	desc2 := distribution.Descriptor{
		Digest: digest.Digest("dontvalidate 2"),
	}
	cache.sets[dgst] = append(cache.sets[dgst], desc2)

	actual, err = st.Stat(ctx, dgst)
	if err != nil {
		t.Fatal(err)
	}
	if actual.Digest != desc2.Digest {
		t.Fatalf("Unexpected descriptor %v, expected %v", actual, desc)
	}
}

func TestCacheError(t *testing.T) {
	cache := newErrTestStatter(errors.New("cache error"))
	backend := newTestStatter()
	st := NewCachedBlobStatter(cache, backend)
	ctx := context.Background()

	dgst := digest.Digest("dontvalidate")
	_, err := st.Stat(ctx, dgst)
	if err != distribution.ErrBlobUnknown {
		t.Fatalf("Unexpected error %v, expected %v", err, distribution.ErrBlobUnknown)
	}

	desc := distribution.Descriptor{
		Digest: dgst,
	}
	if err := backend.SetDescriptor(ctx, dgst, desc); err != nil {
		t.Fatal(err)
	}

	actual, err := st.Stat(ctx, dgst)
	if err != nil {
		t.Fatal(err)
	}
	if actual.Digest != desc.Digest {
		t.Fatalf("Unexpected descriptor %v, expected %v", actual, desc)
	}

	if len(cache.sets) > 0 {
		t.Fatalf("Set should not be called after stat error")
	}
}

func newTestStatter() *testStatter {
	return &testStatter{
		stats: []digest.Digest{},
		sets:  map[digest.Digest][]distribution.Descriptor{},
	}
}

func newErrTestStatter(err error) *testStatter {
	return &testStatter{
		sets: map[digest.Digest][]distribution.Descriptor{},
		err:  err,
	}
}

type testStatter struct {
	stats []digest.Digest
	sets  map[digest.Digest][]distribution.Descriptor
	err   error
}

func (s *testStatter) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
	if s.err != nil {
		return distribution.Descriptor{}, s.err
	}

	if set := s.sets[dgst]; len(set) > 0 {
		return set[len(set)-1], nil
	}

	return distribution.Descriptor{}, distribution.ErrBlobUnknown
}

func (s *testStatter) SetDescriptor(ctx context.Context, dgst digest.Digest, desc distribution.Descriptor) error {
	s.sets[dgst] = append(s.sets[dgst], desc)
	return s.err
}

func (s *testStatter) Clear(ctx context.Context, dgst digest.Digest) error {
	return s.err
}