Skip to content

What does booster.refit actually do? #6838

@mlondschien

Description

@mlondschien

I am looking for a reference on the inner workings of refit that goes beyond "refit uses new data to update tree leaf values, keeping the tree's structure intact". How are the tree leaf values updated? Is this documented somewhere in detail?

I am aware of #3003, #1473, #1529.

My understanding, from here and here is the following (pseudo-code):

class Booster():
    def init(self, trees: List[tree], params):
        self.trees = trees
        self.params = params

    def get_grad(self, y, f):
        if self.params["objective"] == "regression":
            return y - f
        elif self.params["objective"] == "classification":
            return 1 / (1 + np.exp(-f)) - y

    def refit(self, X, y):
        f = np.zeros_like(y)  # or some init_model
        decay_rate = self.params["decay_rate"]

        for tree in self.trees:
            grad = self.get_grad(y, f)

            leaf_indices = tree.get_leaf_indices(X)

            for leaf_index in tree.leaf_indices:
                old_leaf_value = tree.get_leaf_value(leaf_index)
                new_leaf_value = np.mean(grad[leaf_indices == leaf_index])
                tree.set_leaf_value(
                    leaf_index,
                    decay_rate * old_leaf_value + (1 - decay_rate) * new_leaf_value
                )

            f += self.params["learning_rate"] * tree.predict(X)

Is this correct? To me it seems that this would be in contradiction with #5609 (comment):

but the refit method updates all the trees in one go.

It would be nice if this mechanic was documented somewhere in more details, to be referenced when using the refit method.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions