Skip to content

Commit 09c07a8

Browse files
committed
reworked hyperparameters
1 parent a480221 commit 09c07a8

5 files changed

Lines changed: 67 additions & 58 deletions

File tree

R/PipeOpSplines.R

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#'
1010
#' Depending on the type parameter, constructs polynomial B-splines [`splines::bs()`] or natural cubic splines [`splines::ns()`] for the respective column.
1111
#'
12-
#'
1312
#' @section Construction:
1413
#' ```
1514
#' po("splines", param_vals = list())
@@ -31,32 +30,31 @@
3130
#'
3231
#' @section Parameters:
3332
#' The parameters are the parameters inherited from [`PipeOpTaskPreproc`], as well as:
34-
#' * `type` :: `character(1)` \cr
35-
#' `polynomial` when polynomial splines are applied [`splines::bs`] or
36-
#' `natural` when natural splines are applied [`splines::ns`].
37-
#' Default is `natural`.
33+
#' * `type` :: `character(1)` \cr
34+
#' Controls the type of splines that are to be created. Can be either `polynomial` ([`splines::bs`])
35+
#' or `natural` ([`splines::ns`]). Initializied to `natural`.
3836
#' * `df` :: `integer(1)` \cr
39-
#' Number of degrees of freedom for calculation of splines basis matrix.
40-
#' Default is `NULL`.
41-
#' For further information look up [`splines::bs()`] or [`splines::ns()`].
37+
#' Number of degrees of freedom for calculation of the spline basis matrix. Initialized to `NULL`.
38+
#' Depending on `type`, see either [`splines::bs()`] or [`splines::ns()`].
4239
#' * `knots` :: named `list` \cr
43-
#' The internal breakpoints that define the spline. Parameter has to be passed as a named list.
44-
#' Default is `NULL`. For further information consult [`splines::bs()`] or [`splines::ns()`].
40+
#' Internal breakpoints that define the spline, given as a named list of numeric vectors,
41+
#' where each name corresponds to a feature and its value specifies the knots for that feature.
42+
#' Initialized to `NULL`. Depending on `type`, see either [`splines::bs()`] or [`splines::ns()`].
4543
#' * `intercept` :: `logical(1)` \cr
4644
#' If `TRUE`, an intercept is included in the basis. Default is `FALSE`.
47-
#' For further information look up [`splines::bs()`] or [`splines::ns()`].
45+
#' Depending on `type`, see either [`splines::bs()`] or [`splines::ns()`].
4846
#' * `degree` :: `integer(1)` \cr
49-
#' This parameter depends on type = "polynomial". Degree of the polynomial used to compute B-splines.
50-
#' Default is `3`. For further information look up [`splines::bs()`].
47+
#' Degree of the polynomial used to compute polynomial splines. Only used if `type` is `"polynomial"`.
48+
#' Default is `3`. See [`splines::bs()`].
5149
#' * `Boundary.knots` :: named `list` \cr
52-
#' Boundary points at which to anchor the spline basis. Parameter has to be passed as a named list.
53-
#' Default is `NULL`.
54-
#' For further information look up [`splines::bs()`] or [`splines::ns()`].
50+
#' Boundary points at which to anchor the spline basis, given as a named list of numeric vectors,
51+
#' where each name corresponds to a feature and its value specifies the boundary points for that feature.
52+
#' Initialized to `NULL`. Depending on `type`, see either [`splines::bs()`] or [`splines::ns()`].
5553
#'
5654
#' @section Internals:
57-
#' Creates spline basis via [`splines::bs`]/[`splines::ns`] function depending on `type`.
58-
#' After training, the `Boundary.knots` that are either defined in the Parameter Set
59-
#' or have been calculated during training will be passed to the `$state` of the PipeOp.
55+
#' Creates a spline basis using either [`splines::bs`] or [`splines::ns`] depending on the hyperparameter `type`.
56+
#' After training, the `Boundary.knots` that were either provided by the user or calculated during training are
57+
#' stored in the `PipeOp`'s `$state`.
6058
#'
6159
#' @section Fields:
6260
#' Only fields inherited from [`PipeOp`].
@@ -72,25 +70,28 @@
7270
#'
7371
#' pop$train(list(task))[[1]]$data()
7472
#'
75-
#' pobk = po("splines", param_vals = list(Boundary.knots = list("Petal.Length" = c(0, 4), "Petal.Width" = c(4, 7), "Sepal.Length" = c(1, 5), "Sepal.Width" = c(3, 6))))
73+
#' pobk = po("splines", Boundary.knots = list(
74+
#' Petal.Length = c(0, 4), Petal.Width = c(4, 7), Sepal.Length = c(1, 5), Sepal.Width = c(3, 6))
75+
#' )
7676
#' pobk$train(list(task))[[1]]$data()
7777
#'
7878
#' @family PipeOps
7979
#' @template seealso_pipeopslist
8080
#' @include PipeOpTaskPreproc.R
8181
#' @export
82-
8382
PipeOpSplines = R6Class("PipeOpSplines",
8483
inherit = PipeOpTaskPreproc,
8584
public = list(
8685
initialize = function(id = "splines", param_vals = list()) {
8786
ps = ps(
8887
type = p_fct(levels = c("polynomial", "natural"), init = "natural", tags = c("train", "splines", "required")),
89-
df = p_int(lower = 1, upper = Inf, special_vals = list(NULL), init = NULL, tags = c("train", "splines")),
90-
knots = p_uty(special_vals = list(NULL), init = NULL, tags = c("train", "splines")),
91-
degree = p_int(lower = 1, upper = Inf, default = 3, depends = type == "polynomial", tags = c("train", "splines")),
92-
intercept = p_lgl(init = FALSE, tags = c("train", "splines")),
93-
Boundary.knots = p_uty(tags = c("train", "splines"))
88+
df = p_int(lower = 1, upper = Inf, special_vals = list(NULL), default = NULL, tags = c("train", "splines")),
89+
knots = p_uty(special_vals = list(NULL), init = NULL, custom_check = function(x) check_list(x, any.missing = FALSE, null.ok = TRUE, names = "named"),
90+
tags = c("train", "splines")),
91+
degree = p_int(lower = 1, upper = Inf, default = 3, depends = quote(type == "polynomial"), tags = c("train", "splines")),
92+
intercept = p_lgl(default = FALSE, tags = c("train", "splines")),
93+
Boundary.knots = p_uty(special_vals = list(NULL), init = NULL, custom_check = function(x) check_list(x, any.missing = FALSE, null.ok = TRUE, names = "named"),
94+
tags = c("train", "splines"), )
9495
)
9596
super$initialize(id = id, param_set = ps, param_vals = param_vals, packages = c("splines", "stats"))
9697
}
@@ -100,25 +101,27 @@ PipeOpSplines = R6Class("PipeOpSplines",
100101
result = list()
101102
bk = list()
102103
pv = self$param_set$get_values(tags = "splines")
103-
for (i in colnames(dt)) {
104-
args = pv
105-
args$type = NULL
106-
args$knots = pv$knots[[i]]
107-
args$Boundary.knots = pv$Boundary.knots[[i]]
108-
if (pv$type == "polynomial") {
109-
result[[i]] = invoke(splines::bs, .args = args, x = dt[[i]], warn.outside = FALSE)
110-
} else {
111-
result[[i]] = invoke(splines::ns, .args = args, x = dt[[i]])
112-
}
113-
colnames(result[[i]]) = paste0("splines.", seq_len(ncol(result[[i]])))
114-
bk[[i]] = attributes(result[[i]])$Boundary.knots
104+
105+
for (i in colnames(dt)) {
106+
args = pv
107+
args$type = NULL
108+
args$knots = pv$knots[[i]]
109+
args$Boundary.knots = pv$Boundary.knots[[i]]
110+
if (pv$type == "polynomial") {
111+
result[[i]] = invoke(splines::bs, .args = args, x = dt[[i]], warn.outside = FALSE)
112+
} else {
113+
result[[i]] = invoke(splines::ns, .args = args, x = dt[[i]])
115114
}
116-
self$state$Boundary.knots = bk
115+
colnames(result[[i]]) = paste0("splines.", seq_len(ncol(result[[i]])))
116+
bk[[i]] = attributes(result[[i]])$Boundary.knots
117+
}
118+
self$state$Boundary.knots = bk
117119
result
118120
},
119121
.predict_dt = function(dt, levels) {
120122
result = list()
121123
pv = self$param_set$get_values(tags = "splines")
124+
122125
for (i in colnames(dt)) {
123126
args = pv
124127
args$type = NULL

man/mlr3pipelines-package.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_pipeops_info.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_pipeops_isomap.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_pipeops_splines.Rd

Lines changed: 22 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)