#' @keywords internal
smooth_series <- function(x, y, span = 0.25, degree = 1) {
  fit <- stats::loess(y ~ x, span = span, degree = degree,
                      control = stats::loess.control(surface = "direct"))
  as.numeric(stats::predict(fit, x))
}

#' @keywords internal
wvar_unbiased <- function(y, w) {
  keep <- is.finite(y) & is.finite(w) & w > 0
  y <- y[keep]; w <- w[keep]
  sw <- sum(w); sw2 <- sum(w^2)
  if (sw <= 0) return(NA_real_)
  mu <- sum(w * y) / sw
  num <- sum(w * (y - mu)^2)
  den <- sw - (sw2 / sw)
  if (den <= 0) return(NA_real_)
  num / den
}

#' @keywords internal
estimate_mrt_std_effect <- function(
    data, id, outcome, treatment, time,
    rand_prob, availability,
    covariates = NULL,
    smooth = TRUE,
    loess_span = 0.25,
    loess_degree = 1,
    do_bootstrap = TRUE,
    boot_replications = 1000,
    confidence_alpha = 0.05,
    ci_type = "perc"
) {
  all_ids   <- sort(unique(data[[id]]))
  all_t     <- sort(unique(data[[time]]))
  Tlen      <- length(all_t)
  rows_by_t <- split(seq_len(nrow(data)), data[[time]])

  Y_list <- X_list <- Wbase_list <- idchar_list <- vector("list", Tlen)
  W1_list <- W0_list <- A_list <- vector("list", Tlen)

  for (i in seq_len(Tlen)) {
    idx   <- rows_by_t[[as.character(all_t[i])]]
    Y     <- as.vector(data[idx, outcome])
    A     <- as.vector(data[idx, treatment])
    Z     <- if (is.null(covariates)) NULL else as.matrix(data[idx, covariates, drop = FALSE])
    prob  <- data[idx, rand_prob]
    avail <- data[idx, availability]

    Wb <- ifelse(A == 1, avail / pmax(prob, 1e-6), avail / pmax(1 - prob, 1e-6))
    X  <- if (is.null(Z)) cbind(`(Intercept)` = 1, A = A) else cbind(`(Intercept)` = 1, A = A, Z)

    Y_list[[i]]      <- Y
    X_list[[i]]      <- X
    Wbase_list[[i]]  <- Wb
    idchar_list[[i]] <- as.character(data[idx, id])
    W1_list[[i]]     <- ifelse(A == 1, avail / pmax(prob, 1e-6), 0)
    W0_list[[i]]     <- ifelse(A == 0, avail / pmax(1 - prob, 1e-6), 0)
    A_list[[i]]      <- A
  }

  coef_A_fast <- function(X, y, w) {
    if (all(!is.finite(w)) || sum(w) <= 0) return(NA_real_)
    WX   <- X * w
    XtWX <- crossprod(X, WX)
    XtWy <- crossprod(X, y * w)
    beta <- tryCatch(
      solve(XtWX, XtWy),
      error = function(e) tryCatch(qr.solve(XtWX, XtWy, tol = 1e-10), error = function(e2) NA_real_)
    )
    if (any(is.na(beta))) return(NA_real_)
    if (is.matrix(beta)) beta[2, 1] else beta[2]
  }

  estimate_core <- function(mult_vec_named = NULL) {
    beta_hat <- s_hat <- numeric(Tlen)
    for (i in seq_len(Tlen)) {
      mult <- if (is.null(mult_vec_named)) rep(1, length(idchar_list[[i]])) else {
        mv <- as.numeric(mult_vec_named[idchar_list[[i]]]); mv[is.na(mv)] <- 0; mv
      }
      W <- Wbase_list[[i]] * mult
      beta_hat[i] <- coef_A_fast(X_list[[i]], Y_list[[i]], W)

      W1 <- W1_list[[i]] * mult
      W0 <- W0_list[[i]] * mult
      v1 <- wvar_unbiased(Y_list[[i]][A_list[[i]] == 1], W1[A_list[[i]] == 1])
      v0 <- wvar_unbiased(Y_list[[i]][A_list[[i]] == 0], W0[A_list[[i]] == 0])
      m  <- mean(c(v1, v0), na.rm = TRUE)
      s_hat[i] <- if (is.finite(m)) sqrt(m) else NA_real_
    }

    if (smooth) {
      beta_sm <- smooth_series(all_t, beta_hat, span = loess_span, degree = loess_degree)
      s_sm    <- smooth_series(all_t, s_hat,    span = loess_span, degree = loess_degree)
    } else { beta_sm <- beta_hat; s_sm <- s_hat }

    std_est <- beta_sm / s_sm
    list(beta = beta_hat, s = s_hat, beta_sm = beta_sm, s_sm = s_sm, std = std_est)
  }

  base <- estimate_core()

  if (!do_bootstrap) {
    return(data.frame(
      time = all_t,
      beta_hat = base$beta, s_hat = base$s,
      beta_sm = base$beta_sm, s_sm = base$s_sm,
      std_estimate = base$std
    ))
  }

  ncpus <- as.integer(Sys.getenv("SLURM_CPUS_PER_TASK", "1"))
  if (is.na(ncpus) || ncpus < 1) ncpus <- 1
  par_type <- if (.Platform$OS.type == "windows") "no" else "multicore"

  stat_fun <- function(all_ids_arg, selected_subjects) {
    reps <- table(all_ids_arg[selected_subjects])
    res  <- estimate_core(mult_vec_named = reps)
    res$std
  }

  out_boot <- boot::boot(
    data = all_ids,
    statistic = stat_fun,
    R = boot_replications,
    parallel = par_type,
    ncpus = ncpus
  )

  std_hat <- base$std
  lo <- up <- numeric(Tlen)
  for (i in seq_len(Tlen)) {
    qs <- stats::quantile(
      out_boot$t[, i],
      probs = c(confidence_alpha/2, 1 - confidence_alpha/2),
      na.rm = TRUE, names = FALSE, type = 6
    )
    lo[i] <- qs[1]; up[i] <- qs[2]
  }

  data.frame(
    time = all_t,
    beta_hat = base$beta, s_hat = base$s,
    beta_sm = base$beta_sm, s_sm = base$s_sm,
    std_estimate = std_hat,
    std_lower = lo, std_upper = up
  )
}
