Seal of Approval: mlr3

seal of approval
application package
Author

Maximilian Mücke

Published

October 1, 2024

mlr3

Author(s): Michel Lang, Bernd Bischl, Jakob Richter, Patrick Schratz, Martin Binder, Florian Pfisterer, Raphael Sonabend, Marc Becker, Sebastian Fischer

Maintainer: Marc Becker (marcbecker@posteo.de)

Seal of Approval

mlr3 hex sticker

A modern object-oriented machine learning framework. Successor of mlr.

Relationship with data.table

mlr3 was designed to integrate closely with data.table for efficient data handling in machine learning workflows. There are two main ways mlr3 is related to data.table:

  1. Data Backend: mlr3 uses data.table as the core data backend for all Task objects. This means that when you work with tasks in mlr3, the underlying data is stored and managed using data.table. Moreover, users can leverage data.table syntax directly within mlr3 workflows. Accessing task data via task$data() returns a data.table, enabling you to apply data.table operations for data preprocessing, feature engineering, and subsetting without any additional conversion or overhead.
  2. Result Storage: mlr3 stores various results such as predictions, resampling outcomes, and benchmarking results as data.table objects.

Overview

Excerpted from the mlr3 book

The mlr3 universe includes a wide range of tools taking you from basic ML to complex experiments. To get started, here is an example of the simplest functionality – training a model and making predictions.

library(mlr3)

task = tsk("penguins")
split = partition(task)
learner = lrn("classif.rpart")

learner$train(task, row_ids = split$train)
learner$model
n= 230 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 230 124 Adelie (0.460869565 0.191304348 0.347826087)  
  2) flipper_length< 207 148  44 Adelie (0.702702703 0.290540541 0.006756757)  
    4) bill_length< 43.35 105   4 Adelie (0.961904762 0.038095238 0.000000000) *
    5) bill_length>=43.35 43   4 Chinstrap (0.069767442 0.906976744 0.023255814) *
  3) flipper_length>=207 82   3 Gentoo (0.024390244 0.012195122 0.963414634) *
prediction = learner$predict(task, row_ids = split$test)
prediction
<PredictionClassif> for 114 observations:
    row_ids     truth  response
          4    Adelie    Adelie
          5    Adelie    Adelie
         10    Adelie    Adelie
---                            
        340 Chinstrap    Gentoo
        343 Chinstrap    Gentoo
        344 Chinstrap Chinstrap
prediction$score(msr("classif.acc"))
classif.acc 
  0.9473684 

In this example, we trained a decision tree on a subset of the penguins dataset, made predictions on the rest of the data and then evaluated these with the accuracy measure.

The mlr3 interface also lets you run more complicated experiments in just a few lines of code:

library(mlr3verse)

tasks = tsks(c("german_credit", "sonar"))

glrn_rf_tuned = as_learner(ppl("robustify") %>>% auto_tuner(
    tnr("grid_search", resolution = 5),
    lrn("classif.ranger", num.trees = to_tune(200, 500)),
    rsmp("holdout")
))
glrn_rf_tuned$id = "RF"

glrn_stack = as_learner(ppl("robustify") %>>% ppl("stacking",
    lrns(c("classif.rpart", "classif.kknn")),
    lrn("classif.log_reg")
))
glrn_stack$id = "Stack"

learners = c(glrn_rf_tuned, glrn_stack)
bmr = benchmark(benchmark_grid(tasks, learners, rsmp("cv", folds = 3)))

bmr$aggregate(msr("classif.acc"))
      nr       task_id learner_id resampling_id iters classif.acc
   <int>        <char>     <char>        <char> <int>       <num>
1:     1 german_credit         RF            cv     3   0.7749966
2:     2 german_credit      Stack            cv     3   0.7450175
3:     3         sonar         RF            cv     3   0.8077295
4:     4         sonar      Stack            cv     3   0.7121463
Hidden columns: resample_result

In this more complex example, we selected two tasks and two learners, used automated tuning to optimize the number of trees in the random forest learner, and employed a machine learning pipeline that imputes missing data, consolidates factor levels, and stacks models. We also showed basic features like loading learners and choosing resampling strategies for benchmarking. Finally, we compared the performance of the models using the mean accuracy with three-fold cross-validation.

No matching items