Skip to content

Commit 64d6ae1

Browse files
authored
Merge pull request #976 from mlr-org/filterensemble_trafos
Add score and result trafos in `FilterEnsemble`
2 parents 49ca931 + 1c2ee53 commit 64d6ae1

6 files changed

Lines changed: 214 additions & 31 deletions

File tree

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* Fix: `PipeOpTextVectorizer` now uses coercion to `TsparseMatrix` instead of deprecated `dgTMatrix` to avoid `Matrix` deprecation warnings.
77
* New method `$predict_newdata_fast()` for `GraphLearner`. Note that currently this is only a thin wrapper around `$predict_newdata()` to maintain compatibility, but in the future it may get optimized to enable faster predictions on new data.
88
* feat: `PipeOpRenameColumns`'s hyperparameter `renaming` can now also take a function transforming old column names to new column names.
9+
* feat: Added new hyperparameters `filter_score_transform`, `result_score_transform`, and `aggregator` to `FilterEnsemble`. BREAKING CHANGE: The default behavior for handling NA scores in the aggregation has changed. Previously, NA scores were simply ignored and weights were not changed. Now, `weighted.mean` is used, which normalizes the weights for all non-NA scores.
910

1011
# mlr3pipelines 0.10.0
1112

R/FilterEnsemble.R

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
2-
31
#' @title Filter Ensemble
42
#'
53
#' @usage NULL
@@ -30,8 +28,16 @@
3028
#' Required non-negative weights, one for each wrapped filter, with at least one strictly positive value.
3129
#' Values are used as given when calculating the weighted mean. If named, names must match the wrapped filter ids.
3230
#' * `rank_transform` :: `logical(1)`\cr
33-
#' If `TRUE`, ranks of individual filter scores are used instead of the raw scores before
34-
#' averaging. Initialized to `FALSE`.
31+
#' If `TRUE`, ranks of individual filter scores are used instead of the raw scores. Initialized to `FALSE`.
32+
#' * `filter_score_transform` :: `function`\cr
33+
#' Function to be applied to the vector of individual filter scores after they were potentially transformed by
34+
#' `rank_transform` but before weighting and aggregation. Initialized to `identity`.
35+
#' * `aggregator` :: `function`\cr
36+
#' Function to aggregate the (potentially transformed) and weighted filter scores across filters. Must take
37+
#' arguments `w` for weights and `na.rm`, the latter of which is always set to `TRUE`. Defaults to [`stats::weighted.mean`].
38+
#' * `result_score_transform` :: `function`\cr
39+
#' Function to be applied to the vector of aggregated scores after they were potentially transformed by `rank_transform` and/or
40+
#' `filter_score_transform`. Initialized to `identity`.
3541
#'
3642
#' Parameters of wrapped filters are available via `$param_set` and can be referenced using
3743
#' the wrapped filter id followed by `"."`, e.g. `"variance.na.rm"`.
@@ -54,9 +60,17 @@
5460
#'
5561
#' @section Internals:
5662
#' All wrapped filters are called with `nfeat` equal to the number of features to ensure that
57-
#' complete score vectors are available for aggregation. Scores are combined per feature by
58-
#' computing the weighted (optionally rank-based) mean.
59-
#'
63+
#' complete score vectors are available for aggregation.
64+
#' Scores are combined per feature by computing a weighted aggregation of transformed (default: `identity`)
65+
#' scores or ranks. Additionally, the final scores may also be transformed (default: `identity`).
66+
#'
67+
#' The order of transformations is as follows:
68+
#' 1. `$calculate` the filter's scores for all features;
69+
#' 2. If `rank_transform` is `TRUE`, convert filter scores to ranks;
70+
#' 3. Apply `filter_score_transform` to the scores / ranks;
71+
#' 4. Calculate the weighted aggregation across all filters using `aggregator`;
72+
#' 6. Potentially apply `result_score_transform` to the vector of scores for each feature aggreagted across filters.
73+
#'
6074
#' @section References:
6175
#' `r format_bib("binder_2020")`
6276
#'
@@ -66,11 +80,30 @@
6680
#'
6781
#' task = tsk("sonar")
6882
#'
69-
#' flt = mlr_filters$get("ensemble",
83+
#' filter = flt("ensemble",
7084
#' filters = list(FilterVariance$new(), FilterAUC$new()))
71-
#' flt$param_set$values$weights = c(variance = 0.5, auc = 0.5)
72-
#' flt$calculate(task)
73-
#' head(as.data.table(flt))
85+
#' filter$param_set$values$weights = c(variance = 0.5, auc = 0.5)
86+
#' filter$calculate(task)
87+
#' head(as.data.table(filter))
88+
#'
89+
#' # Weighted median as aggregator
90+
#' filter$param_set$set_values(aggregator = function(x, w, na.rm) {
91+
#' if (na.rm) x <- x[!is.na(x)]
92+
#' o <- order(x)
93+
#' x <- x[o]
94+
#' w <- w[o]
95+
#' x[match(TRUE, which(cumsum(w) >= sum(w) / 2))]
96+
#' })
97+
#' filter$calculate(task)
98+
#' head(as.data.table(filter))
99+
#'
100+
#' # Aggregate reciprocal ranking
101+
#' filter$param_set$set_values(rank_transform = TRUE,
102+
#' filter_score_transform = function(x) 1 / x,
103+
#' result_score_transform = function(x) rank(1 / x, ties.method = "average"))
104+
#' filter$calculate(task)
105+
#' head(as.data.table(filter))
106+
#'
74107
#' @export
75108
FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter,
76109
public = list(
@@ -96,7 +129,10 @@ FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter,
96129
}, fnames),
97130
tags = "required"
98131
),
99-
rank_transform = p_lgl(init = FALSE, tags = "required")
132+
rank_transform = p_lgl(init = FALSE, tags = "required"),
133+
filter_score_transform = p_uty(init = identity, tags = "required", custom_check = check_function),
134+
result_score_transform = p_uty(init = identity, tags = "required", custom_check = check_function),
135+
aggregator = p_uty(init = stats::weighted.mean, tags = "required", custom_check = crate(function(x) check_function(x, args = "w")))
100136
)
101137

102138
super$initialize(
@@ -162,22 +198,35 @@ FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter,
162198
nfeat = length(fn) # need to rank all features in an ensemble
163199
weights = pv$weights
164200
wnames = names(private$.wrapped)
201+
165202
if (!is.null(names(weights))) {
166203
weights = weights[wnames]
167204
}
168205
if (!any(weights > 0)) {
169206
stop("At least one weight must be > 0.")
170207
}
171-
scores = pmap(list(private$.wrapped, weights), function(x, w) {
208+
209+
# Calculate filter scores, apply rank and filter score trafo
210+
scores = map(private$.wrapped, function(x) {
172211
x$calculate(task, nfeat)
173212
s = x$scores[fn]
174213
if (pv$rank_transform) s = rank(s, na.last = "keep", ties.method = "average")
175-
s * w
214+
s = pv$filter_score_transform(s)
215+
if (!isTRUE(check_numeric(s, len = nfeat))) stopf("Filter score transformation did not return a numeric vector of the same length as there are features.")
216+
s
176217
})
177-
scores_df = as.data.frame(scores)
178-
combined = rowSums(scores_df, na.rm = TRUE)
179-
all_missing = rowSums(!is.na(scores_df)) == 0L
218+
scores_dt = as.data.table(scores)
219+
220+
# Aggregate across features
221+
combined = apply(scores_dt, 1, pv$aggregator, w = weights, na.rm = TRUE) # weighted.mean normalizes weights in case of NAs
222+
if (!isTRUE(check_numeric(combined, len = nfeat))) stopf("Aggregator did not return a numeric vector of the same length as there are scored features.")
223+
# Apply result score trafo
224+
combined = pv$result_score_transform(combined)
225+
if (!isTRUE(check_numeric(combined, len = nfeat))) stopf("Result score transformation did not return a numeric vector of the same length as there are features.")
226+
227+
all_missing = rowSums(!is.na(scores_dt)) == 0L
180228
combined[all_missing] = NA_real_
229+
181230
structure(combined, names = fn)
182231
},
183232
deep_clone = function(name, value) {
@@ -212,5 +261,4 @@ FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter,
212261
private$.param_set
213262
}
214263
)
215-
216264
)

R/PipeOpFilter.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ PipeOpFilter = R6Class("PipeOpFilter",
136136
filtercrit = c("nfeat", "frac", "cutoff", "permuted")
137137
filtercrit = Filter(function(name) !is.null(private$.outer_param_set$values[[name]]), filtercrit)
138138
if (length(filtercrit) != 1) {
139-
stopf("Exactly one of 'nfeat', 'frac', 'cutoff', or 'permuted' must be given. Instead given: %s",
139+
stopf("Exactly one hyperparameter of 'filter.nfeat', 'filter.frac', 'filter.cutoff', or 'filter.permuted' must be given. Instead given: %s",
140140
if (length(filtercrit) == 0) "none" else str_collapse(filtercrit))
141141
}
142142
critvalue = private$.outer_param_set$values[[filtercrit]]

man/mlr_filters_ensemble.Rd

Lines changed: 44 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_filter_ensemble.R

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ test_that("FilterEnsemble ignores NA scores from wrapped filters", {
317317
flt_ensemble$calculate(task)
318318

319319
combined_scores = flt_ensemble$scores[task$feature_names]
320-
expect_equal(combined_scores, variance_scores * weights[["variance"]])
320+
expect_equal(combined_scores, variance_scores)
321321
})
322322

323323
test_that("FilterEnsemble rank transform ignores NA scores", {
@@ -362,7 +362,7 @@ test_that("FilterEnsemble rank transform ignores NA scores", {
362362

363363
combined_scores = flt_ensemble$scores[task$feature_names]
364364
expected_rank = rank(variance_scores, na.last = "keep", ties.method = "average")
365-
expect_equal(combined_scores, expected_rank * weights[["variance"]])
365+
expect_equal(combined_scores, expected_rank)
366366
})
367367

368368
test_that("FilterEnsemble weight helper normalization works", {
@@ -478,3 +478,101 @@ test_that("FilterEnsemble weight search space works with bbotk", {
478478
expect_true(nrow(instance$archive$data) >= 2)
479479
expect_true(all(instance$archive$data$classif.acc <= 1))
480480
})
481+
482+
test_that("FilterEnsemble - trafos", {
483+
skip_if_not_installed("mlr3filters")
484+
task = tsk("sonar")
485+
weights = c(0.7, 0.3)
486+
487+
filters = list(
488+
mlr3filters::FilterVariance$new(),
489+
mlr3filters::FilterAUC$new()
490+
)
491+
ensemble = FilterEnsemble$new(filters)
492+
493+
ensemble$param_set$set_values(
494+
weights = weights,
495+
rank_transform = TRUE,
496+
filter_score_transform = function (x) 1 / x,
497+
result_score_transform = function (x) rank(1 / x, ties.method = "average")
498+
)
499+
500+
actual = ensemble$calculate(task)$scores
501+
502+
individual_scores = as.data.table(lapply(filters, function(flt) {
503+
flt$calculate(task)
504+
rank(flt$scores[task$feature_names], ties.method = "average")
505+
}))
506+
expected_scores = apply(individual_scores, 1, function(row) 1 / sum(1 / row * weights))
507+
expected = rank(expected_scores, ties.method = "average")
508+
expected = sort(setNames(expected, task$feature_names), decreasing = TRUE)
509+
510+
expect_equal(actual, expected)
511+
})
512+
513+
test_that("FilterEnsemble - aggregator", {
514+
skip_if_not_installed("mlr3filters")
515+
516+
task = mlr_tasks$get("sonar")
517+
filters = list(
518+
mlr3filters::FilterVariance$new(),
519+
mlr3filters::FilterAUC$new()
520+
)
521+
flt_ensemble = FilterEnsemble$new(filters)
522+
523+
flt_ensemble$param_set$set_values(
524+
weights = c(0.5, 0.5),
525+
aggregator = function(x, w, na.rm) median(x, na.rm = na.rm)
526+
)
527+
528+
flt_ensemble$calculate(task)
529+
combined_scores = flt_ensemble$scores
530+
individual_scores = as.data.table(lapply(filters, function(flt) {
531+
flt$calculate(task)
532+
flt$scores[task$feature_names]
533+
}))
534+
expected_scores = apply(individual_scores, 1, function(row) median(row, na.rm = TRUE))
535+
expected = sort(setNames(expected_scores, task$feature_names), decreasing = TRUE)
536+
537+
expect_equal(combined_scores, expected)
538+
})
539+
540+
test_that("FilterEnsemble - Error messages", {
541+
skip_if_not_installed("mlr3filters")
542+
543+
task = mlr_tasks$get("sonar")
544+
filters = list(
545+
mlr3filters::FilterVariance$new(),
546+
mlr3filters::FilterAUC$new()
547+
)
548+
flt_ensemble = FilterEnsemble$new(filters)
549+
550+
flt_ensemble$param_set$set_values(
551+
weights = c(0.5, 0.5)
552+
)
553+
554+
# Error if formal args are inccorect
555+
expect_error(flt_ensemble$param_set$set_values(
556+
aggregator = function(x) mean(x)
557+
), "Must have formal arguments: w")
558+
559+
# Error if filter_score_transform output has wrong length
560+
flt_ensemble$param_set$set_values(
561+
filter_score_transform = function(x) rep(1, length(x) + 1)
562+
)
563+
expect_error(flt_ensemble$calculate(task), "Filter score transformation.*length.*")
564+
565+
# Error if aggregator output has wrong length
566+
flt_ensemble$param_set$set_values(
567+
filter_score_transform = identity,
568+
aggregator = function(x, w, na.rm) rep(1, length(x) + 1)
569+
)
570+
expect_error(flt_ensemble$calculate(task), "Aggregator.*length.*")
571+
572+
# Error if result_score_transform output has wrong length
573+
flt_ensemble$param_set$set_values(
574+
aggregator = weighted.mean,
575+
result_score_transform = function(x) rep(1, length(x) + 1)
576+
)
577+
expect_error(flt_ensemble$calculate(task), "Result score transformation.*length.*")
578+
})

tests/testthat/test_pipeop_filter.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ test_that("PipeOpFilter", {
1616

1717
expect_equal(po$id, mlr3filters::FilterVariance$new()$id)
1818

19-
expect_error(po$train(list(task)), "Exactly one of 'nfeat', 'frac', 'cutoff', or 'permuted' must be given.*none")
19+
expect_error(po$train(list(task)), "Exactly one hyperparameter of 'filter.nfeat', 'filter.frac', 'filter.cutoff', or 'filter.permuted' must be given.*none")
2020

2121
po$param_set$values = list(filter.nfeat = 1, filter.frac = 1, na.rm = TRUE)
22-
expect_error(po$train(list(task)), "Exactly one of 'nfeat', 'frac', 'cutoff', or 'permuted' must be given.*nfeat, frac")
22+
expect_error(po$train(list(task)), "Exactly one hyperparameter of 'filter.nfeat', 'filter.frac', 'filter.cutoff', or 'filter.permuted' must be given.*nfeat, frac")
2323

2424
po$param_set$values = list(filter.nfeat = 1, na.rm = TRUE)
2525

0 commit comments

Comments
 (0)