library(grf)

Visualize trees in a random forest

All GRF forest objects works with the tree extractor function get_tree which you can use to visually inspect trees with either print or plot. The example below demonstrates this for a regression forest.

n <- 100
p <- 5
X <- matrix(rnorm(n * p), n, p)
Y <- X[, 1] * rnorm(n)
r.forest <- regression_forest(X, Y, num.trees = 100)

# Extract the first tree from the fitted forest.
tree <- get_tree(r.forest, 1)
# Print the first tree.
print(tree)
#> GRF tree object 
#> Number of training samples: 50 
#> Variable splits: 
#> (1) split_variable: X.2  split_value: 1.84386 
#>   (2) split_variable: X.3  split_value: 0.0471544 
#>     (4) split_variable: X.1  split_value: 1.25381 
#>       (6) split_variable: X.2  split_value: -0.895363 
#>         (10) * num_samples: 3  avg_Y: 0.52 
#>         (11) split_variable: X.2  split_value: -0.849704 
#>           (14) * num_samples: 1  avg_Y: 0.32 
#>           (15) split_variable: X.5  split_value: -0.718466 
#>             (16) * num_samples: 1  avg_Y: 0.11 
#>             (17) * num_samples: 4  avg_Y: 0.22 
#>       (7) * num_samples: 2  avg_Y: -1.57 
#>     (5) split_variable: X.5  split_value: 0.659903 
#>       (8) split_variable: X.2  split_value: -0.741336 
#>         (12) * num_samples: 1  avg_Y: -1.91 
#>         (13) * num_samples: 9  avg_Y: 0.49 
#>       (9) * num_samples: 2  avg_Y: -0.2 
#>   (3) * num_samples: 2  avg_Y: -0.02

# Plot the first tree.
plot(tree)

To find the leaf node a given sample falls into, you can use the get_leaf_node function.

n.test <- 4
X.test <- matrix(rnorm(n.test * p), n.test, p)
print(X.test)
#>            [,1]       [,2]       [,3]        [,4]       [,5]
#> [1,]  1.2127367 0.10594824 0.84944344 -1.74628449  0.2242183
#> [2,]  0.5771822 3.05712530 1.61090482 -0.75187304 -1.7156039
#> [3,]  2.2804284 0.01974265 0.05529995  0.07762225  0.5352278
#> [4,] -1.0041101 0.89149575 1.20539645 -0.59010514 -0.4925395
# Get a vector of node numbers for each sample.
get_leaf_node(tree, X.test)
#> [1] 13  3 13 13
# Get a list of samples per node.
get_leaf_node(tree, X.test, node.id = FALSE)
#> $`3`
#> [1] 2
#> 
#> $`13`
#> [1] 1 3 4

For a tutorial on visualizing tree-based treatment assignment rules, see the Policy learning vignette.