|
| 1 | +#' @title Main function for CausalEGM to estimate causal effect in either binary or continuous treatment settings. |
| 2 | +#' |
| 3 | +#' @description This function takes observation data (x,y,v) as input, and estimate the ATE/ITE/ADRF. |
| 4 | +#' |
| 5 | +#' @param X is the treatment variable. |
| 6 | +#' @param Y is the potential outcome. |
| 7 | +#' @param V is the covariates. |
| 8 | +#' @param yaml_file is the deployment file for initializing a CausalEGM instance. |
| 9 | +#' |
| 10 | +#' |
| 11 | +#' @return NULL |
| 12 | +#' |
| 13 | +#' @examples causalegm(X=X,Y=Y,V=V,yaml_file='example.yaml') |
| 14 | +#' |
| 15 | +#' @export causalegm |
| 16 | +causalegm <- function(X, Y, V, |
| 17 | + output_dir = "./", |
| 18 | + dataset = "myData", |
| 19 | + z_dims = c(3,3,6,6), |
| 20 | + lr = 0.0002, |
| 21 | + alpha =1, |
| 22 | + beta = 1, |
| 23 | + gamma = 10, |
| 24 | + g_d_freq = 5, |
| 25 | + g_units = c(64,64,64,64,64), |
| 26 | + e_units = c(64,64,64,64,64), |
| 27 | + f_units = c(64,32,8), |
| 28 | + h_units = c(64,32,8), |
| 29 | + dv_units = c(64,32,8), |
| 30 | + dz_units = c(64,32,8), |
| 31 | + save_model = FALSE, |
| 32 | + binary_treatment = TRUE, |
| 33 | + use_z_rec = TRUE, |
| 34 | + use_v_gan = TRUE, |
| 35 | + random_seed = 123, |
| 36 | + n_iter = 20000) { |
| 37 | + |
| 38 | + #To ignore the warnings during usage |
| 39 | + options(warn=-1) |
| 40 | + options("getSymbols.warning4.0"=FALSE) |
| 41 | + if (!(py_module_available('CausalEGM'))){ |
| 42 | + py_install("CausalEGM", method="auto",pip=TRUE) |
| 43 | + } |
| 44 | + #get number of covariates |
| 45 | + v_dim <- dim(V)[2] |
| 46 | + params <- list(output_dir = output_dir, |
| 47 | + dataset = dataset, |
| 48 | + z_dims = as.integer(z_dims), |
| 49 | + v_dim = v_dim, |
| 50 | + lr = lr, |
| 51 | + alpha = alpha, |
| 52 | + beta = beta, |
| 53 | + gamma = gamma, |
| 54 | + g_d_freq = as.integer(g_d_freq), |
| 55 | + g_units = g_units, |
| 56 | + e_units = e_units, |
| 57 | + f_units = f_units, |
| 58 | + h_units = h_units, |
| 59 | + dv_units = dv_units, |
| 60 | + dz_units = dz_units, |
| 61 | + save_model = save_model, |
| 62 | + binary_treatment = binary_treatment, |
| 63 | + use_z_rec = use_z_rec, |
| 64 | + use_v_gan = use_v_gan) |
| 65 | + |
| 66 | + cegm <- import("CausalEGM") |
| 67 | + model <- cegm$CausalEGM(params=params,random_seed=as.integer(random_seed)) |
| 68 | + data <- list(X,Y,V) |
| 69 | + model$train(data,n_iter = as.integer(n_iter)) |
| 70 | + model |
| 71 | +} |
0 commit comments