Skip to content

Commit 3b8d431

Browse files
committed
added error for predicting with untrained PipeOp + test
1 parent 2c80f4e commit 3b8d431

3 files changed

Lines changed: 7 additions & 2 deletions

File tree

R/PipeOp.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@
188188
#' should not modify the `PipeOp` in any way.\cr
189189
#'
190190
#' @section Inheriting:
191-
#' To create your own `PipeOp`, you need to overload the `private$.train()` and `private$.test()` functions.
191+
#' To create your own `PipeOp`, you need to overload the `private$.train()` and `private$.predict()` functions.
192192
#' It is most likely also necessary to overload the `$initialize()` function to do additional initialization.
193193
#' The `$initialize()` method should have at least the arguments `id` and `param_vals`, which should be passed on to `super$initialize()` unchanged.
194194
#' `id` should have a useful default value, and `param_vals` should have the default value `list()`, meaning no initialization of hyperparameters.
@@ -325,6 +325,10 @@ PipeOp = R6Class("PipeOp",
325325
predict = function(input) {
326326
assert_list(input, .var.name = sprintf("input to PipeOp %s's $predict()", self$id))
327327

328+
if (!self$is_trained) {
329+
stopf("Cannot predict, PipeOp '%s' has not been trained yet", self$id)
330+
}
331+
328332
# need to load packages in train *and* predict, because they might run in different R instances
329333
require_namespaces(self$packages)
330334

man/PipeOp.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_PipeOp.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ test_that("PipeOp - General functions", {
1313
expect_equal(po_1$packages, "mlr3pipelines")
1414
expect_null(po_1$state)
1515
assert_subset(po_1$tags, mlr_reflections$pipeops$valid_tags)
16+
expect_error(po_1$predict(list(tsk("iris"))), "has not been trained yet")
1617

1718
expect_output(expect_equal(po_1$train(list(1)), list(output = 1)), "Training debug.basic")
1819
expect_equal(po_1$state, list(input = 1))

0 commit comments

Comments
 (0)