// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package firestore

import (
	"testing"

	"golang.org/x/net/context"
	"google.golang.org/grpc/status"

	pb "google.golang.org/genproto/googleapis/firestore/v1beta1"

	"github.com/golang/protobuf/ptypes/empty"
	"google.golang.org/api/iterator"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
)

func TestRunTransaction(t *testing.T) {
	ctx := context.Background()
	const db = "projects/projectID/databases/(default)"
	tid := []byte{1}
	c, srv := newMock(t)
	beginReq := &pb.BeginTransactionRequest{Database: db}
	beginRes := &pb.BeginTransactionResponse{Transaction: tid}
	commitReq := &pb.CommitRequest{Database: db, Transaction: tid}
	// Empty transaction.
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
	err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
	if err != nil {
		t.Fatal(err)
	}

	// Transaction with read and write.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	aDoc := &pb.Document{
		Name:       db + "/documents/C/a",
		CreateTime: aTimestamp,
		UpdateTime: aTimestamp2,
		Fields:     map[string]*pb.Value{"count": intval(1)},
	}
	srv.addRPC(
		&pb.BatchGetDocumentsRequest{
			Database:            c.path(),
			Documents:           []string{db + "/documents/C/a"},
			ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
		}, []interface{}{
			&pb.BatchGetDocumentsResponse{
				Result:   &pb.BatchGetDocumentsResponse_Found{aDoc},
				ReadTime: aTimestamp2,
			},
		})
	aDoc2 := &pb.Document{
		Name:   aDoc.Name,
		Fields: map[string]*pb.Value{"count": intval(2)},
	}
	srv.addRPC(
		&pb.CommitRequest{
			Database:    db,
			Transaction: tid,
			Writes: []*pb.Write{{
				Operation:  &pb.Write_Update{aDoc2},
				UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}},
				CurrentDocument: &pb.Precondition{
					ConditionType: &pb.Precondition_Exists{true},
				},
			}},
		},
		&pb.CommitResponse{CommitTime: aTimestamp3},
	)
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		docref := c.Collection("C").Doc("a")
		doc, err := tx.Get(docref)
		if err != nil {
			return err
		}
		count, err := doc.DataAt("count")
		if err != nil {
			return err
		}
		return tx.Update(docref, []Update{{Path: "count", Value: count.(int64) + 1}})
	})
	if err != nil {
		t.Fatal(err)
	}

	// Query
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(
		&pb.RunQueryRequest{
			Parent: db,
			QueryType: &pb.RunQueryRequest_StructuredQuery{
				&pb.StructuredQuery{
					From: []*pb.StructuredQuery_CollectionSelector{{CollectionId: "C"}},
				},
			},
			ConsistencySelector: &pb.RunQueryRequest_Transaction{tid},
		},
		[]interface{}{},
	)
	srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp3})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		it := tx.Documents(c.Collection("C"))
		defer it.Stop()
		_, err := it.Next()
		if err != iterator.Done {
			return err
		}
		return nil
	})
	if err != nil {
		t.Fatal(err)
	}

	// Retry entire transaction.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
	srv.addRPC(
		&pb.BeginTransactionRequest{
			Database: db,
			Options: &pb.TransactionOptions{
				Mode: &pb.TransactionOptions_ReadWrite_{
					&pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
				},
			},
		},
		beginRes,
	)
	srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { return nil })
	if err != nil {
		t.Fatal(err)
	}
}

func TestTransactionErrors(t *testing.T) {
	ctx := context.Background()
	const db = "projects/projectID/databases/(default)"
	c, srv := newMock(t)
	var (
		tid         = []byte{1}
		internalErr = status.Errorf(codes.Internal, "so sad")
		beginReq    = &pb.BeginTransactionRequest{
			Database: db,
		}
		beginRes = &pb.BeginTransactionResponse{Transaction: tid}
		getReq   = &pb.BatchGetDocumentsRequest{
			Database:            c.path(),
			Documents:           []string{db + "/documents/C/a"},
			ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
		}
		rollbackReq = &pb.RollbackRequest{Database: db, Transaction: tid}
		commitReq   = &pb.CommitRequest{Database: db, Transaction: tid}
	)

	// BeginTransaction has a permanent error.
	srv.addRPC(beginReq, internalErr)
	err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
	if grpc.Code(err) != codes.Internal {
		t.Errorf("got <%v>, want Internal", err)
	}

	// Get has a permanent error.
	get := func(_ context.Context, tx *Transaction) error {
		_, err := tx.Get(c.Doc("C/a"))
		return err
	}
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(getReq, internalErr)
	srv.addRPC(rollbackReq, &empty.Empty{})
	err = c.RunTransaction(ctx, get)
	if grpc.Code(err) != codes.Internal {
		t.Errorf("got <%v>, want Internal", err)
	}

	// Get has a permanent error, but the rollback fails. We still
	// return Get's error.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(getReq, internalErr)
	srv.addRPC(rollbackReq, status.Errorf(codes.FailedPrecondition, ""))
	err = c.RunTransaction(ctx, get)
	if grpc.Code(err) != codes.Internal {
		t.Errorf("got <%v>, want Internal", err)
	}

	// Commit has a permanent error.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(getReq, []interface{}{
		&pb.BatchGetDocumentsResponse{
			Result: &pb.BatchGetDocumentsResponse_Found{&pb.Document{
				Name:       "projects/projectID/databases/(default)/documents/C/a",
				CreateTime: aTimestamp,
				UpdateTime: aTimestamp2,
			}},
			ReadTime: aTimestamp2,
		},
	})
	srv.addRPC(commitReq, internalErr)
	err = c.RunTransaction(ctx, get)
	if grpc.Code(err) != codes.Internal {
		t.Errorf("got <%v>, want Internal", err)
	}

	// Read after write.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(rollbackReq, &empty.Empty{})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		if err := tx.Delete(c.Doc("C/a")); err != nil {
			return err
		}
		if _, err := tx.Get(c.Doc("C/a")); err != nil {
			return err
		}
		return nil
	})
	if err != errReadAfterWrite {
		t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
	}

	// Read after write, with query.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(rollbackReq, &empty.Empty{})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		if err := tx.Delete(c.Doc("C/a")); err != nil {
			return err
		}
		it := tx.Documents(c.Collection("C").Select("x"))
		defer it.Stop()
		if _, err := it.Next(); err != iterator.Done {
			return err
		}
		return nil
	})
	if err != errReadAfterWrite {
		t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
	}

	// Read after write fails even if the user ignores the read's error.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(rollbackReq, &empty.Empty{})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		if err := tx.Delete(c.Doc("C/a")); err != nil {
			return err
		}
		if _, err := tx.Get(c.Doc("C/a")); err != nil {
			return err
		}
		return nil
	})
	if err != errReadAfterWrite {
		t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
	}

	// Write in read-only transaction.
	srv.reset()
	srv.addRPC(
		&pb.BeginTransactionRequest{
			Database: db,
			Options: &pb.TransactionOptions{
				Mode: &pb.TransactionOptions_ReadOnly_{&pb.TransactionOptions_ReadOnly{}},
			},
		},
		beginRes,
	)
	srv.addRPC(rollbackReq, &empty.Empty{})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		return tx.Delete(c.Doc("C/a"))
	}, ReadOnly)
	if err != errWriteReadOnly {
		t.Errorf("got <%v>, want <%v>", err, errWriteReadOnly)
	}

	// Too many retries.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
	srv.addRPC(
		&pb.BeginTransactionRequest{
			Database: db,
			Options: &pb.TransactionOptions{
				Mode: &pb.TransactionOptions_ReadWrite_{
					&pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
				},
			},
		},
		beginRes,
	)
	srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
	srv.addRPC(rollbackReq, &empty.Empty{})
	err = c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil },
		MaxAttempts(2))
	if grpc.Code(err) != codes.Aborted {
		t.Errorf("got <%v>, want Aborted", err)
	}

	// Nested transaction.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(rollbackReq, &empty.Empty{})
	err = c.RunTransaction(ctx, func(ctx context.Context, tx *Transaction) error {
		return c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
	})
	if got, want := err, errNestedTransaction; got != want {
		t.Errorf("got <%v>, want <%v>", got, want)
	}

	// Non-transactional operation.
	dr := c.Doc("C/d")

	for i, op := range []func(ctx context.Context) error{
		func(ctx context.Context) error { _, err := c.GetAll(ctx, []*DocumentRef{dr}); return err },
		func(ctx context.Context) error { _, _, err := c.Collection("C").Add(ctx, testData); return err },
		func(ctx context.Context) error { _, err := dr.Get(ctx); return err },
		func(ctx context.Context) error { _, err := dr.Create(ctx, testData); return err },
		func(ctx context.Context) error { _, err := dr.Set(ctx, testData); return err },
		func(ctx context.Context) error { _, err := dr.Delete(ctx); return err },
		func(ctx context.Context) error {
			_, err := dr.Update(ctx, []Update{{FieldPath: []string{"*"}, Value: 1}})
			return err
		},
		func(ctx context.Context) error { it := c.Collections(ctx); _, err := it.Next(); return err },
		func(ctx context.Context) error { it := dr.Collections(ctx); _, err := it.Next(); return err },
		func(ctx context.Context) error {
			_, err := c.Batch().Set(dr, testData).Commit(ctx)
			return err
		},
		func(ctx context.Context) error {
			it := c.Collection("C").Documents(ctx)
			defer it.Stop()
			_, err := it.Next()
			return err
		},
	} {
		srv.reset()
		srv.addRPC(beginReq, beginRes)
		srv.addRPC(rollbackReq, &empty.Empty{})
		err = c.RunTransaction(ctx, func(ctx context.Context, _ *Transaction) error {
			return op(ctx)
		})
		if got, want := err, errNonTransactionalOp; got != want {
			t.Errorf("#%d: got <%v>, want <%v>", i, got, want)
		}
	}
}

func TestTransactionGetAll(t *testing.T) {
	c, srv := newMock(t)
	defer c.Close()
	const dbPath = "projects/projectID/databases/(default)"
	tid := []byte{1}
	beginReq := &pb.BeginTransactionRequest{Database: dbPath}
	beginRes := &pb.BeginTransactionResponse{Transaction: tid}
	srv.addRPC(beginReq, beginRes)
	req := &pb.BatchGetDocumentsRequest{
		Database: dbPath,
		Documents: []string{
			dbPath + "/documents/C/a",
			dbPath + "/documents/C/b",
			dbPath + "/documents/C/c",
		},
		ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
	}
	err := c.RunTransaction(context.Background(), func(_ context.Context, tx *Transaction) error {
		testGetAll(t, c, srv, dbPath,
			func(drs []*DocumentRef) ([]*DocumentSnapshot, error) { return tx.GetAll(drs) },
			req)
		commitReq := &pb.CommitRequest{Database: dbPath, Transaction: tid}
		srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
		return nil
	})
	if err != nil {
		t.Fatal(err)
	}
}