package daemon

import (
	"context"
	"log"
	"log/slog"
	"net"
	"os"
	"path/filepath"

	"github.com/itchio/butler/butlerd/horror"

	"crawshaw.io/sqlite/sqlitex"
	"github.com/google/gops/agent"
	"github.com/google/uuid"
	"github.com/itchio/butler/butlerd"
	"github.com/itchio/butler/database"
	"github.com/itchio/butler/database/models"
	"github.com/itchio/headway/state"

	"github.com/itchio/butler/comm"
	"github.com/itchio/butler/mansion"
	"github.com/pkg/errors"
)

var args = struct {
	destinyPids []int64
	transport   string
	keepAlive   bool
	log         bool
}{}

// origStdout holds the real stdout before redirecting it for stdio transport.
var origStdout *os.File

func Register(ctx *mansion.Context) {
	cmd := ctx.App.Command("daemon", "Start a butlerd instance").Hidden()
	cmd.Flag("destiny-pid", "The daemon will shutdown whenever any of its destiny PIDs shuts down").Int64ListVar(&args.destinyPids)
	cmd.Flag("transport", "Which transport to use").Default("tcp").EnumVar(&args.transport, "http", "tcp", "stdio")
	cmd.Flag("keep-alive", "Accept multiple TCP connections, stay up until killed or a destiny PID shuts down").BoolVar(&args.keepAlive)
	cmd.Flag("log", "Log all requests to stderr").BoolVar(&args.log)
	ctx.Register(cmd, do)
}

func do(ctx *mansion.Context) {
	if !comm.JsonEnabled() {
		comm.Notice("Hello from butler daemon", []string{"We can't do anything interesting without --json, bailing out", "", "Learn more: https://docs.itch.zone/butlerd/master/"})
		os.Exit(1)
	}

	// Configure hades SQL logging before any DB work (including Prepare/AutoMigrate).
	if models.LogSql {
		models.SetHadesLogger(slog.New(comm.NewSlogHandler(slog.LevelDebug)))
	}

	// Configure HTTP debug logging for go-itchio clients.
	if mansion.LogHttp {
		ctx.SetClientLogger(slog.New(comm.NewSlogHandler(slog.LevelDebug)))
	}

	// For stdio transport, redirect os.Stdout to stderr early so that
	// comm package log messages don't corrupt the JSON-RPC transport.
	// Also redirect Go's log package, which main.go pointed at stdout.
	if args.transport == "stdio" {
		origStdout = os.Stdout
		os.Stdout = os.Stderr
		log.SetOutput(os.Stderr)
	}

	ctx.EnsureDBPath()

	err := agent.Listen(agent.Options{
		Addr:            "localhost:0",
		ShutdownCleanup: true,
	})
	if err != nil {
		comm.Warnf("butlerd: Could not start gops agent: %+v", err)
	}

	for _, destinyPid := range args.destinyPids {
		go tieDestiny(destinyPid)
	}

	generateSecret := func() string {
		var res string
		for rounds := 4; rounds > 0; rounds-- {
			res += uuid.New().String()
		}
		return res
	}
	secret := generateSecret()

	err = os.MkdirAll(filepath.Dir(ctx.DBPath), 0o755)
	if err != nil {
		ctx.Must(errors.WithMessage(err, "creating DB directory if necessary"))
	}

	justCreated := false
	_, statErr := os.Stat(ctx.DBPath)
	if statErr != nil {
		comm.Logf("butlerd: creating new DB at %s", ctx.DBPath)
		justCreated = true
	}

	dbPool, err := sqlitex.Open(ctx.DBPath, 0, 100)
	if err != nil {
		ctx.Must(errors.WithMessage(err, "opening DB for the first time"))
	}
	defer dbPool.Close()

	dbPrepareLogger := slog.New(comm.NewSlogHandler(slog.LevelDebug)).With("source", "db_prepare")

	err = func() (retErr error) {
		defer horror.RecoverInto(&retErr)

		conn := dbPool.Get(context.Background())
		defer dbPool.Put(conn)
		return database.Prepare(&state.Consumer{
			OnMessage: func(lvl string, msg string) {
				dbPrepareLogger.Log(context.Background(), stateLevelToSlogLevel(lvl), msg)
			},
		}, conn, justCreated)
	}()
	if err != nil {
		ctx.Must(errors.WithMessage(err, "preparing DB"))
	}

	ctx.Must(Do(ctx, context.Background(), dbPool, secret))
}

func Do(mansionContext *mansion.Context, ctx context.Context, dbPool *sqlitex.Pool, secret string) error {
	s := butlerd.NewServer(secret)
	router := GetRouter(dbPool, mansionContext)

	switch args.transport {
	case "tcp":
		listener, err := net.Listen("tcp", "127.0.0.1:")
		if err != nil {
			return err
		}

		comm.Object("butlerd/listen-notification", map[string]interface{}{
			"secret": secret,
			"tcp": map[string]interface{}{
				"address": listener.Addr().String(),
			},
		})

		err = s.ServeTCP(ctx, butlerd.ServeTCPParams{
			Handler:   router,
			Listener:  listener,
			Secret:    secret,
			Log:       args.log,
			KeepAlive: args.keepAlive,

			ShutdownChan: router.ShutdownChan,
		})
		if err != nil {
			return err
		}
	case "stdio":
		rwc := &stdioReadWriteCloser{
			in:  os.Stdin,
			out: origStdout,
		}

		err := s.ServeStdio(ctx, butlerd.ServeStdioParams{
			Handler:      router,
			ShutdownChan: router.ShutdownChan,
		}, rwc)
		if err != nil {
			return err
		}
	case "http":
		comm.Dief("The HTTP transport is deprecated. Use TCP instead.")
	}

	return nil
}

type stdioReadWriteCloser struct {
	in  *os.File
	out *os.File
}

func (s *stdioReadWriteCloser) Read(p []byte) (int, error) {
	return s.in.Read(p)
}

func (s *stdioReadWriteCloser) Write(p []byte) (int, error) {
	return s.out.Write(p)
}

func (s *stdioReadWriteCloser) Close() error {
	return s.in.Close()
}

func stateLevelToSlogLevel(level string) slog.Level {
	switch level {
	case "debug":
		return slog.LevelDebug
	case "info":
		return slog.LevelInfo
	case "warning":
		return slog.LevelWarn
	case "error":
		return slog.LevelError
	default:
		return slog.LevelInfo
	}
}
