Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# mlr3pipelines 0.7.2-9000

- Added missing error for predicting with untrained `PipeOp`s / `Graph`s.

# mlr3pipelines 0.7.2

Expand Down
7 changes: 6 additions & 1 deletion R/Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ Graph = R6Class("Graph",
self$edges[, c("src_id", "dst_id") := list(map_values(src_id, old, new), map_values(dst_id, old, new))]
invisible(self)
},

update_ids = function(prefix = "", postfix = "") {
ids = names2(self$pipeops)
self$set_names(ids, sprintf("%s%s%s", assert_string(prefix), ids, assert_string(postfix)))
Expand All @@ -456,9 +457,13 @@ Graph = R6Class("Graph",
},

predict = function(input, single_input = TRUE) {
if (!self$is_trained) {
stop("Cannot predict, Graph has not been trained yet")
}
graph_load_namespaces(self, "predict")
graph_reduce(self, input, "predict", single_input)
},

help = function(help_type = getOption("help_type")) {
parts = strsplit(self$man, split = "::", fixed = TRUE)[[1]]
match.fun("help")(parts[[2]], package = parts[[1]], help_type = help_type)
Expand Down Expand Up @@ -717,7 +722,7 @@ graph_load_namespaces = function(self, info) {
#' @export
predict.Graph = function(object, newdata, ...) {
if (!object$is_trained) {
stop("Graph is not trained.")
stop("Cannot predict, Graph has not been trained yet")
}
output = object$output
if (nrow(output) != 1) {
Expand Down
6 changes: 5 additions & 1 deletion R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
#' should not modify the `PipeOp` in any way.\cr
#'
#' @section Inheriting:
#' To create your own `PipeOp`, you need to overload the `private$.train()` and `private$.test()` functions.
#' To create your own `PipeOp`, you need to overload the `private$.train()` and `private$.predict()` functions.
#' It is most likely also necessary to overload the `$initialize()` function to do additional initialization.
#' The `$initialize()` method should have at least the arguments `id` and `param_vals`, which should be passed on to `super$initialize()` unchanged.
#' `id` should have a useful default value, and `param_vals` should have the default value `list()`, meaning no initialization of hyperparameters.
Expand Down Expand Up @@ -325,6 +325,10 @@ PipeOp = R6Class("PipeOp",
predict = function(input) {
assert_list(input, .var.name = sprintf("input to PipeOp %s's $predict()", self$id))

if (!self$is_trained) {
stopf("Cannot predict, PipeOp '%s' has not been trained yet", self$id)
}

# need to load packages in train *and* predict, because they might run in different R instances
require_namespaces(self$packages)

Expand Down
2 changes: 1 addition & 1 deletion man/PipeOp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions tests/testthat/test_Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,10 @@ test_that("Graph with vararg input", {
list(debugvararg.output = 1006, debugvararg2.output = 2006))

})

test_that("Error when predicting with untrained Graph, #893", {
g = Graph$new()$
add_pipeop(PipeOpDebugBasic$new())

expect_error(g$predict(tsk("iris")), "Graph has not been trained yet")
})
9 changes: 6 additions & 3 deletions tests/testthat/test_PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ test_that("PipeOp - General functions", {
expect_equal(po_1$packages, "mlr3pipelines")
expect_null(po_1$state)
assert_subset(po_1$tags, mlr_reflections$pipeops$valid_tags)
expect_error(po_1$predict(list(tsk("iris"))), "has not been trained yet")

expect_output(expect_equal(po_1$train(list(1)), list(output = 1)), "Training debug.basic")
expect_equal(po_1$state, list(input = 1))
Expand Down Expand Up @@ -95,14 +96,16 @@ test_that("Informative error and warning messages", {

expect_no_warning(suppressWarnings(gr$predict(tsk("iris"))))


gr$param_set$values$classif.debug.warning_train = 0
gr$param_set$values$classif.debug.warning_predict = 0

gr$param_set$values$classif.debug.error_train = 1
gr$param_set$values$classif.debug.error_predict = 1

expect_error(gr$train(tsk("iris")), "This happened PipeOp classif.debug's \\$train\\(\\)$")

gr$param_set$values$classif.debug.error_train = 0
gr$param_set$values$classif.debug.error_predict = 1
# Need to first train the Graph for predict to work
gr$train(tsk("iris"))
expect_error(gr$predict(tsk("iris")), "This happened PipeOp classif.debug's \\$predict\\(\\)$")

potest = R6::R6Class("potest", inherit = PipeOp,
Expand Down
4 changes: 4 additions & 0 deletions tests/testthat/test_pipeop_ovr.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ test_that("PipeOpOVRUnite- train and predict", {
lrn$train(task)
lrn$predict(task)
})
# Need to first train before predicting
expect_list(po$train(list(as.Multiplicity(NULL))), len = 1, types = "null")
pout = po$predict(list(as.Multiplicity(tin)))
expect_prediction_classif(pout[[1]])

Expand Down Expand Up @@ -112,6 +114,8 @@ test_that("PipeOpOVRSplit and PipeOpOVRUnite - train and predict", {
lrn$train(task)
lrn$predict(task)
})
# Need to first train before predicting
po$train(list(as.Multiplicity(NULL)))
pout_ref = po$predict(list(as.Multiplicity(tin)))

gr = PipeOpOVRSplit$new() %>>% LearnerClassifRpart$new() %>>% PipeOpOVRUnite$new()
Expand Down