Skip to content

Commit 1c2ee53

Browse files
committed
Merge branch 'filterensemble_trafos' of https://github.com/mlr-org/mlr3pipelines into filterensemble_trafos
2 parents ed9551c + 752bbff commit 1c2ee53

5 files changed

Lines changed: 77 additions & 25 deletions

File tree

.github/workflows/pkgdown.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444

4545
- name: Deploy
4646
if: github.event_name != 'pull_request'
47-
uses: JamesIves/github-pages-deploy-action@v4.7.4
47+
uses: JamesIves/github-pages-deploy-action@v4.8.0
4848
with:
4949
clean: false
5050
branch: gh-pages

NEWS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
* feat: All imputation PipeOps now support feature types `Date` and `POSIXct`.
66
* Fix: `PipeOpTextVectorizer` now uses coercion to `TsparseMatrix` instead of deprecated `dgTMatrix` to avoid `Matrix` deprecation warnings.
77
* New method `$predict_newdata_fast()` for `GraphLearner`. Note that currently this is only a thin wrapper around `$predict_newdata()` to maintain compatibility, but in the future it may get optimized to enable faster predictions on new data.
8-
- feat: Added new hyperparameters `filter_score_transform`, `result_score_transform`, and `aggregator` to `FilterEnsemble`. BREAKING CHANGE: The default behavior for handling NA scores in the aggregation has changed. Previously, NA scores were simply ignored and weights were not changed. Now, `weighted.mean` is used, which normalizes the weights for all non-NA scores.
8+
* feat: `PipeOpRenameColumns`'s hyperparameter `renaming` can now also take a function transforming old column names to new column names.
9+
* feat: Added new hyperparameters `filter_score_transform`, `result_score_transform`, and `aggregator` to `FilterEnsemble`. BREAKING CHANGE: The default behavior for handling NA scores in the aggregation has changed. Previously, NA scores were simply ignored and weights were not changed. Now, `weighted.mean` is used, which normalizes the weights for all non-NA scores.
910

1011
# mlr3pipelines 0.10.0
1112

R/PipeOpRenameColumns.R

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,21 @@
2828
#'
2929
#' @section Parameters:
3030
#' The parameters are the parameters inherited from [`PipeOpTaskPreproc`], as well as:
31-
#' * `renaming` :: named `character`\cr
32-
#' Named `character` vector. The names of the vector specify the old column names that should be
33-
#' changed to the new column names as given by the elements of the vector. Initialized to the empty
34-
#' character vector.
31+
#' * `renaming` :: named `character` | `function`\cr
32+
#' Takes the form of either a named `character` or a `function`.
33+
#' For a named `character` vector, the names of the vector elements specify the
34+
#' old column names and the corresponding element values give the new column names.
35+
#' A `function` specifies how the old column names should be changed to the new column names.
36+
#' The function must return a `character` vector with one entry per input column name so that each selected column receives a new name.
37+
#' To choose columns use the `affect_columns` parameter.
38+
#' Initialized to `character(0)`.
3539
#' * `ignore_missing` :: `logical(1)`\cr
3640
#' Ignore if columns named in `renaming` are not found in the input [`Task`][mlr3::Task]. If this is
3741
#' `FALSE`, then names found in `renaming` not found in the [`Task`][mlr3::Task] cause an error.
3842
#' Initialized to `FALSE`.
3943
#'
4044
#' @section Internals:
41-
#' Uses the `$rename()` mutator of the [`Task`][mlr3::Task] to set the new column names.
45+
#' Uses the `$rename()` mutator of the [`Task`][mlr3::Task] to set new column names.
4246
#'
4347
#' @section Fields:
4448
#' Only fields inherited from [`PipeOp`].
@@ -56,36 +60,53 @@
5660
#' task = tsk("iris")
5761
#' pop = po("renamecolumns", param_vals = list(renaming = c("Petal.Length" = "PL")))
5862
#' pop$train(list(task))
63+
#'
64+
#' pof = po("renamecolumns", param_vals = list(renaming = function(colnames) {
65+
#' sub("Petal", "P", colnames)
66+
#' }))
67+
#' pof$train(list(task))
68+
#'
5969
PipeOpRenameColumns = R6Class("PipeOpRenameColumns",
6070
inherit = PipeOpTaskPreprocSimple,
6171
public = list(
6272
initialize = function(id = "renamecolumns", param_vals = list()) {
6373
ps = ps(
6474
renaming = p_uty(
65-
custom_check = crate(function(x) check_character(x, any.missing = FALSE, names = "strict") %check&&% check_names(x, type = "strict"),
66-
.parent = topenv()),
75+
custom_check = crate(function(x) (check_character(x, any.missing = FALSE, names = "strict") %check&&% check_names(x, type = "strict")) %check||% check_function(x)),
6776
tags = c("train", "predict", "required")
6877
),
6978
ignore_missing = p_lgl(tags = c("train", "predict", "required"))
7079
)
7180
ps$values = list(renaming = character(0), ignore_missing = FALSE)
72-
super$initialize(id, ps, param_vals = param_vals, can_subset_cols = FALSE)
81+
super$initialize(id, ps, param_vals = param_vals, can_subset_cols = TRUE)
7382
}
7483
),
7584
private = list(
85+
.get_state = function(task) {
86+
if (is.function(self$param_set$values$renaming)) {
87+
new_names = self$param_set$values$renaming(task$feature_names)
88+
assert_character(new_names, any.missing = FALSE, len = length(task$feature_names), .var.name = "the value returned by `renaming` function")
89+
names(new_names) = task$feature_names
90+
list(old_names = task$feature_names, new_names = new_names)
91+
} else {
92+
pv = self$param_set$get_values(tags = "train")
93+
new_names = pv$renaming
94+
innames = names(new_names)
95+
nontargets = task$col_roles
96+
nontargets$target = NULL
97+
takenames = innames %in% unlist(nontargets)
98+
if (!pv$ignore_missing && !all(takenames)) {
99+
# we can't rely on task$rename because it could also change the target name, which we don't want.
100+
stopf("The names %s from `renaming` parameter were not found in the Task.", str_collapse(innames[!takenames]))
101+
}
102+
list(old_names = innames[takenames], new_names = new_names[takenames])
103+
}
104+
},
76105
.transform = function(task) {
77-
if (!length(self$param_set$values$renaming)) {
106+
if (!length(self$state$new_names)) {
78107
return(task) # early exit
79108
}
80-
innames = names(self$param_set$values$renaming)
81-
nontargets = task$col_roles
82-
nontargets$target = NULL
83-
takenames = innames %in% unlist(nontargets)
84-
if (!self$param_set$values$ignore_missing && !all(takenames)) {
85-
# we can't rely on task$rename because it could also change the target name, which we don't want.
86-
stopf("The names %s from `renaming` parameter were not found in the Task.", str_collapse(innames[!takenames]))
87-
}
88-
task$rename(old = innames[takenames], new = self$param_set$values$renaming[takenames])
109+
task$rename(old = self$state$old_names, new = self$state$new_names)
89110
}
90111
)
91112
)

man/mlr_pipeops_renamecolumns.Rd

Lines changed: 15 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_pipeop_renamecolumns.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,23 @@ test_that("error handling", {
3939
op$param_set$values$ignore_missing = TRUE
4040
expect_equal(task$data(), op$train(list(task))[[1]]$data())
4141
})
42+
43+
test_that("PipeOpRenameColumns - errors for renaming function", {
44+
task = mlr_tasks$get("iris")
45+
expect_error(po("renamecolumns", param_vals = list(renaming = 1 + 1)))
46+
47+
op = po("renamecolumns", param_vals = list(renaming = function(x) "a"))
48+
expect_error(op$train(list(task)), "value returned by `renaming` function.*length")
49+
})
50+
51+
test_that("PipeOpRenameColumns - renamimg by function", {
52+
task = mlr_tasks$get("iris")
53+
54+
op = po("renamecolumns", param_vals = list(renaming = function(colnames) sub("Petal", "P", colnames)))
55+
result = op$train(list(task))
56+
expect_equal(result[[1]]$feature_names, c("P.Length", "P.Width", "Sepal.Length", "Sepal.Width"))
57+
58+
op$param_set$set_values(affect_columns = selector_name("Petal.Length"))
59+
result = op$train(list(task))
60+
expect_equal(result[[1]]$feature_names, c("P.Length", "Petal.Width", "Sepal.Length", "Sepal.Width"))
61+
})

0 commit comments

Comments
 (0)