package main

import (
	"context"
	"fmt"
	"log"
	"net"
	"sync"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-observability/tracing"
	srv "git.frostfs.info/TrueCloudLab/frostfs-observability/tracing/examples/grpc/server"
	tracing_grpc "git.frostfs.info/TrueCloudLab/frostfs-observability/tracing/grpc"
	"go.opentelemetry.io/otel/trace"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
)

type server struct {
	srv.UnimplementedServerServer
}

func (s *server) Echo(ctx context.Context, req *srv.Request) (*srv.Response, error) {
	sc := trace.SpanFromContext(ctx).SpanContext()
	if !sc.TraceID().IsValid() || !sc.SpanID().IsValid() {
		return nil, fmt.Errorf("no trace id or span id on server side")
	}
	log.Printf("server trace id: %v", sc.TraceID())
	log.Printf("server span id: %v", sc.SpanID())
	return &srv.Response{
		Value: req.GetValue(),
	}, nil
}

func verifyClientTraceID(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
	sc := trace.SpanFromContext(ctx).SpanContext()
	if !sc.TraceID().IsValid() || !sc.SpanID().IsValid() {
		return fmt.Errorf("no trace id or span id on client side")
	}
	log.Printf("client trace id: %v", sc.TraceID())
	log.Printf("client span id: %v", sc.SpanID())
	return invoker(ctx, method, req, reply, cc, opts...)
}

func main() {
	ctx := context.Background()

	tracingCfg := tracing.Config{
		Enabled:  true,
		Exporter: tracing.NoOpExporter,
		Service:  "example-grpc",
	}
	enabled, err := tracing.Setup(ctx, tracingCfg)
	if err != nil {
		log.Fatalf("failed to setup tracing: %v", err)
	}
	if !enabled {
		log.Fatalf("failed to enable tracing")
	}

	lis, err := net.Listen("tcp", ":7000")
	if err != nil {
		log.Fatalf("failed to listen: %v", err)
	}
	s := grpc.NewServer(
		grpc.ChainStreamInterceptor(tracing_grpc.NewStreamServerInterceptor()),
		grpc.ChainUnaryInterceptor(tracing_grpc.NewUnaryServerInterceptor()),
	)
	srv.RegisterServerServer(s, &server{})

	wg := &sync.WaitGroup{}
	wg.Add(1)
	go func() {
		defer wg.Done()
		if err := s.Serve(lis); err != nil {
			log.Fatalf("failed to serve: %v", err)
		}
	}()

	time.Sleep(1 * time.Second)

	cc, err := grpc.DialContext(ctx,
		":7000",
		grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithChainUnaryInterceptor(
			tracing_grpc.NewUnaryClientInteceptor(),
			verifyClientTraceID,
		),
		grpc.WithChainStreamInterceptor(
			tracing_grpc.NewStreamClientInterceptor(),
		),
	)
	if err != nil {
		log.Fatalf("failed to dial: %v", err)
	}
	client := srv.NewServerClient(cc)
	resp, err := client.Echo(ctx, &srv.Request{
		Value: "Hello!",
	})

	if err != nil {
		log.Fatalf("failed to get response: %v", err)
	}
	log.Printf("response received: %s", resp.GetValue())

	s.GracefulStop()

	wg.Wait()
}