In this example, we will illustrate several ways of dealing with categorical variables when using grf.
One of the approaches below relies on grf’s sister package sufrep. Let’s install and load it first.
library(grf) install.packages("https://github.com/grf-labs/sufrep/blob/master/sufrep_0.1.0.tar.gz?raw=true", repos = NULL, type = "source") #> Installing package into '/tmp/RtmpMA3g0E/temp_libpath7d252b4b86f1' #> (as 'lib' is unspecified) library(sufrep)
Let’s pretend we would like to estimate mileage per gallon (mpg) from number of cylinders (cyl), quarter-mile time (qsec), and car brand name (brand, created below).
# Create a categorical column with brand name df <- within(mtcars, { # E.g. 'Mazda RX4' --> 'Mazda' brand <- factor(sapply(rownames(mtcars), function(x) strsplit(x, " ")[[1]][1])) }) x <- c("cyl", "qsec") # Continuous variables g <- c("brand") # Categorical variable head(df[c(x, g)]) #> cyl qsec brand #> Mazda RX4 6 16.46 Mazda #> Mazda RX4 Wag 6 17.02 Mazda #> Datsun 710 4 18.61 Datsun #> Hornet 4 Drive 6 19.44 Hornet #> Hornet Sportabout 8 17.02 Hornet #> Valiant 6 20.22 Valiant
This code would raise an error, because data is not numerical.
# rf <- regression_forest(X=df[c(x, g)], Y=df$mpg)
We can consider three approaches here.
means method from the sufrep package.The last method involves substituting the brand column by averages of the continuous columns cyl and qsec, grouped by category. If you are curious about why that works, or would like to know more about sufficient representations, please check out our sufrep paper (ArXiv).
# Solution 1: Transform variable into numbers X1 <- within(df[c(x, g)], brand <- as.numeric(brand)) rf1 <- regression_forest(X1, df$mpg) # Solution 2: One-hot encoding X2 <- model.matrix(~ 0 + ., df[c(x, g)]) rf2 <- regression_forest(X2, df$mpg) # Solution 3: 'Means' encoding using the 'sufrep' package encoder <- make_encoder(df[x], df$brand, method="means") X3 <- encoder(df[x], df$brand) #> [1] 22 2 rf3 <- regression_forest(X3, df$mpg)
Different approaches can yield different forest performance.
mse1 <- mean(rf1$debiased.error) mse2 <- mean(rf2$debiased.error) mse3 <- mean(rf3$debiased.error) print("MSE when representing categorical variables as...") #> [1] "MSE when representing categorical variables as..." print(paste0("Integers: ", mse1)) #> [1] "Integers: 15.6558414382405" print(paste0("One-hot vectors: ", mse2)) #> [1] "One-hot vectors: 14.6681216279153" print(paste0("'Means' encoding [sufrep]: ", mse3)) #> [1] "'Means' encoding [sufrep]: 14.4303582052086"