When building a model for your dataset, there are a wide variety of models and model packages that each have their own unique parameters and specifications. Parsnip provides a unified, simple interface for model creation so that you can easily swap between models without worrying about minor syntax changes.


Initializing a Model

A parsnip model is built by defining an initial model object, setting its parameters/engine, and then fitting the model to data. Each type of supported model has its own initializing function.

For example, suppose we wanted to create a linear regression model. To initiate this model, call the linear_reg() function.

lr <- linear_reg()

Alternatively, to initialize a random forest model use the rand_forest() function.

rf <- rand_forest()

Other models supported include logistic regression, boosted trees, k nearest neighbors, and more. For a full list of supported models see the official Function Reference.


Setting Parameters

After initializing a model, set its specifications/parameters.

# setting mode and parameters for both of our example models 
# note that you don't need to set every model parameter if you're happy with the defaults
lr <- lr |>
  set_mode("regression") |> 
  set_args(penalty = NULL) 

rf <- rf |>
  set_mode("classification") |>
  set_args(trees = 200)
# what engines are available for random forest?
show_engines("rand_forest")
## # A tibble: 6 × 2
##   engine       mode          
##   <chr>        <chr>         
## 1 ranger       classification
## 2 ranger       regression    
## 3 randomForest classification
## 4 randomForest regression    
## 5 spark        classification
## 6 spark        regression
# setting the engine to randomForest 
# note the engine-specific argument used in this call 
rf <- rf |>
  set_engine("ranger", verbose = TRUE)

# to change the engine, call set_engine() again
rf <- rf |>
  set_engine("randomForest")

# using translate() to view the model and how parsnip arguments transfer to the randomForest package
translate(rf)
## Random Forest Model Specification (classification)
## 
## Main Arguments:
##   trees = 200
## 
## Computational engine: randomForest 
## 
## Model fit template:
## randomForest::randomForest(x = missing_arg(), y = missing_arg(), 
##     ntree = 200)

Note: Mode, engine, and model arguments can be specified within the initial model call. However, setting them with separate functions enhances readability and creates a more flexible interface.

# defining parameters within initial call
rand_forest(mode = "classification", 
            engine = "randomForest", 
            trees = 500)
## Random Forest Model Specification (classification)
## 
## Main Arguments:
##   trees = 500
## 
## Computational engine: randomForest


Fitting

Once a model has been initialized and defined, execute the model on our prepped training data using the fit() function. fit() takes in a formula and a model object and returns a fitted model.

We will use the iris_recipe_prepped, prepped_training and prepped_testing datasets defined in the Recipes Tutorial.

# extracting model formula from prepped recipe object
rf_formula <- formula(iris_recipe_prepped)

# fitting our random forest model to our training data
rf_fit <- fit(rf, rf_formula, data = prepped_training)

# taking a look at our fitted model
rf_fit 
## parsnip model object
## 
## 
## Call:
##  randomForest(x = maybe_data_frame(x), y = y, ntree = ~200) 
##                Type of random forest: classification
##                      Number of trees: 200
## No. of variables tried at each split: 2
## 
##         OOB estimate of  error rate: 0%
## Confusion matrix:
##            setosa versicolor virginica class.error
## setosa         39          0         0           0
## versicolor      0         36         0           0
## virginica       0          0        37           0

Note: If a model uses a specialized formula that specifies model structure as well as terms (eg. GAMs), pass it to fit().

Note: If the recipe used to preprocess data modified any roles from your original formula, be sure to extract your model formula from the recipe object to account for those changes.


Predicting

In order to generate predictions from new data, pass the fitted model and a dataset to either the predict() or augment() methods predict() will return a tibble of predictions, and augment() will bind prediction columns to the input data.

# generating predictions
predict(rf_fit, prepped_testing) |>
  head()
## # A tibble: 6 × 1
##   .pred_class
##   <fct>      
## 1 setosa     
## 2 setosa     
## 3 setosa     
## 4 setosa     
## 5 setosa     
## 6 setosa
# note the use of type = prob to retrieve probabilities for each class
predict(rf_fit, prepped_testing, type = "prob") |>
  head()
## # A tibble: 6 × 3
##   .pred_setosa .pred_versicolor .pred_virginica
##          <dbl>            <dbl>           <dbl>
## 1        1                0                   0
## 2        1                0                   0
## 3        1                0                   0
## 4        0.945            0.055               0
## 5        0.99             0.01                0
## 6        1                0                   0
# augment binds predictions to original data
# augment does not have a type argument
iris_preds <- augment(rf_fit, prepped_testing) 

head(iris_preds)
## # A tibble: 6 × 9
##   Sepal.Length Sepal.Width Petal.Width   Row Species .pred_class .pred_setosa .pred_versicolor .pred_virginica
##          <dbl>       <dbl>       <dbl> <int> <fct>   <fct>              <dbl>            <dbl>           <dbl>
## 1       -0.945      0.789        -1.27     8 setosa  setosa             1                0                   0
## 2       -1.06       0.0986       -1.40    10 setosa  setosa             1                0                   0
## 3       -1.17       0.789        -1.27    12 setosa  setosa             1                0                   0
## 4       -0.147      1.71         -1.14    19 setosa  setosa             0.945            0.055               0
## 5       -0.489      0.789        -1.27    21 setosa  setosa             0.99             0.01                0
## 6       -0.831      1.48         -1.02    22 setosa  setosa             1                0                   0

Methods to fit resampling sets are covered in the Tune Tutorial.


Further Resources