Skip to content

Commit be39a01

Browse files
authored
Merge pull request #897 from mlr-org/error_untrained_predict
Add error message for predicting with untrained PipeOp or Graph
2 parents d579ba4 + b90ee86 commit be39a01

7 files changed

Lines changed: 30 additions & 7 deletions

File tree

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# mlr3pipelines 0.7.2-9000
2-
2+
- Added missing error for predicting with untrained `PipeOp`s / `Graph`s.
33

44
# mlr3pipelines 0.7.2
55

R/Graph.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ Graph = R6Class("Graph",
444444
self$edges[, c("src_id", "dst_id") := list(map_values(src_id, old, new), map_values(dst_id, old, new))]
445445
invisible(self)
446446
},
447+
447448
update_ids = function(prefix = "", postfix = "") {
448449
ids = names2(self$pipeops)
449450
self$set_names(ids, sprintf("%s%s%s", assert_string(prefix), ids, assert_string(postfix)))
@@ -456,9 +457,13 @@ Graph = R6Class("Graph",
456457
},
457458

458459
predict = function(input, single_input = TRUE) {
460+
if (!self$is_trained) {
461+
stop("Cannot predict, Graph has not been trained yet")
462+
}
459463
graph_load_namespaces(self, "predict")
460464
graph_reduce(self, input, "predict", single_input)
461465
},
466+
462467
help = function(help_type = getOption("help_type")) {
463468
parts = strsplit(self$man, split = "::", fixed = TRUE)[[1]]
464469
match.fun("help")(parts[[2]], package = parts[[1]], help_type = help_type)
@@ -717,7 +722,7 @@ graph_load_namespaces = function(self, info) {
717722
#' @export
718723
predict.Graph = function(object, newdata, ...) {
719724
if (!object$is_trained) {
720-
stop("Graph is not trained.")
725+
stop("Cannot predict, Graph has not been trained yet")
721726
}
722727
output = object$output
723728
if (nrow(output) != 1) {

R/PipeOp.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@
188188
#' should not modify the `PipeOp` in any way.\cr
189189
#'
190190
#' @section Inheriting:
191-
#' To create your own `PipeOp`, you need to overload the `private$.train()` and `private$.test()` functions.
191+
#' To create your own `PipeOp`, you need to overload the `private$.train()` and `private$.predict()` functions.
192192
#' It is most likely also necessary to overload the `$initialize()` function to do additional initialization.
193193
#' The `$initialize()` method should have at least the arguments `id` and `param_vals`, which should be passed on to `super$initialize()` unchanged.
194194
#' `id` should have a useful default value, and `param_vals` should have the default value `list()`, meaning no initialization of hyperparameters.
@@ -328,6 +328,10 @@ PipeOp = R6Class("PipeOp",
328328
predict = function(input) {
329329
assert_list(input, .var.name = sprintf("input to PipeOp %s's $predict()", self$id))
330330

331+
if (!self$is_trained) {
332+
stopf("Cannot predict, PipeOp '%s' has not been trained yet", self$id)
333+
}
334+
331335
# need to load packages in train *and* predict, because they might run in different R instances
332336
require_namespaces(self$packages)
333337

man/PipeOp.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_Graph.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,3 +666,10 @@ test_that("Graph with vararg input", {
666666
list(debugvararg.output = 1006, debugvararg2.output = 2006))
667667

668668
})
669+
670+
test_that("Error when predicting with untrained Graph, #893", {
671+
g = Graph$new()$
672+
add_pipeop(PipeOpDebugBasic$new())
673+
674+
expect_error(g$predict(tsk("iris")), "Graph has not been trained yet")
675+
})

tests/testthat/test_PipeOp.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ test_that("PipeOp - General functions", {
1313
expect_equal(po_1$packages, "mlr3pipelines")
1414
expect_null(po_1$state)
1515
assert_subset(po_1$tags, mlr_reflections$pipeops$valid_tags)
16+
expect_error(po_1$predict(list(tsk("iris"))), "has not been trained yet")
1617

1718
expect_output(expect_equal(po_1$train(list(1)), list(output = 1)), "Training debug.basic")
1819
expect_equal(po_1$state, list(input = 1))
@@ -95,14 +96,16 @@ test_that("Informative error and warning messages", {
9596

9697
expect_no_warning(suppressWarnings(gr$predict(tsk("iris"))))
9798

98-
9999
gr$param_set$values$classif.debug.warning_train = 0
100100
gr$param_set$values$classif.debug.warning_predict = 0
101+
101102
gr$param_set$values$classif.debug.error_train = 1
102-
gr$param_set$values$classif.debug.error_predict = 1
103-
104103
expect_error(gr$train(tsk("iris")), "This happened PipeOp classif.debug's \\$train\\(\\)$")
105104

105+
gr$param_set$values$classif.debug.error_train = 0
106+
gr$param_set$values$classif.debug.error_predict = 1
107+
# Need to first train the Graph for predict to work
108+
gr$train(tsk("iris"))
106109
expect_error(gr$predict(tsk("iris")), "This happened PipeOp classif.debug's \\$predict\\(\\)$")
107110

108111
potest = R6::R6Class("potest", inherit = PipeOp,

tests/testthat/test_pipeop_ovr.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ test_that("PipeOpOVRUnite- train and predict", {
6767
lrn$train(task)
6868
lrn$predict(task)
6969
})
70+
# Need to first train before predicting
71+
expect_list(po$train(list(as.Multiplicity(NULL))), len = 1, types = "null")
7072
pout = po$predict(list(as.Multiplicity(tin)))
7173
expect_prediction_classif(pout[[1]])
7274

@@ -112,6 +114,8 @@ test_that("PipeOpOVRSplit and PipeOpOVRUnite - train and predict", {
112114
lrn$train(task)
113115
lrn$predict(task)
114116
})
117+
# Need to first train before predicting
118+
po$train(list(as.Multiplicity(NULL)))
115119
pout_ref = po$predict(list(as.Multiplicity(tin)))
116120

117121
gr = PipeOpOVRSplit$new() %>>% LearnerClassifRpart$new() %>>% PipeOpOVRUnite$new()

0 commit comments

Comments
 (0)