สอนปลูกต้นไม้ในภาษา R (ภาค 1): วิธีสร้าง tree-based models ใน 3 ขั้นตอนด้วย rpart และ randomForest packages — ตัวอย่างการทำนายประเภทเกียร์รถใน mtcars dataset

ในบทความนี้ เราจะมาทำความรู้จักและสร้าง tree-based models ในภาษา R กัน:

  1. Classification tree
  2. Random forest

ถ้าพร้อมแล้ว ไปเริ่มกันเลย


🌳 Tree-Based Algorithms คืออะไร?

Tree-based algorithm เป็น machine learning algorithm ประเภท supervised learning ที่ทำนายข้อมูลโดยใช้ decision tree

Decision tree เป็นเครื่องมือช่วยตัดสินใจที่ดูคล้ายกับต้นไม้กลับหัว เพราะประกอบด้วยจุดตัดสินใจ (node) ที่แตกยอด (ทางเลือก) ออกไปเรื่อย ๆ เพื่อใช้ทำนายผลลัพธ์ที่ต้องการ

ยกตัวอย่างเช่น เราอยากรู้ว่า เราควรจะสมัครงานกับบริษัทแห่งหนึ่งไหม เราอาจจะสร้าง decision tree จาก 3 ปัจจัย คือ:

  1. เงินเดือน
  2. ได้ใช้ทักษะที่มี
  3. การทำงานแบบ hybrid หรือ remote

ได้แบบนี้:

จาก decision tree ถ้าเราเห็นงานที่ได้เงินเดือนไม่ตรงใจ เราจะไม่สมัครงานนั้น (เส้นการตัดสินใจซ้ายสุด)

ในทางกลับกัน ถ้างานมีเงินเดือนที่น่าสนใจ และได้ใช้ทักษะที่มีอยู่ เราจะสมัครงานนั้น (เส้นการตัดสินใจขวาสุด) เป็นต้น


💻 Tree-Based Models ในภาษา R

ในภาษา R เราสามารถสร้าง tree-based model ได้ด้วย rpart() จาก rpart package:

# Install
install.packages("rpart")
# Load
library(rpart)

rpart() ต้องการ 4 arguments ดังนี้:

rpart(formula, data, method, control)
  • formula = สูตรในการวิเคราะห์ (ตัวแปรตาม ~ ตัวแปรต้น)
  • data = dataset ที่ใช้สร้าง model
  • method = ประเภท algorithm ("anova" สำหรับ regression และ "class" สำหรับ classification)
  • control (optional) = เงื่อนไขควบคุม “การเติบโต” ของ decision tree เช่น ระดับชั้นที่มีได้ เป็นต้น

(Note: ศึกษาการใช้งาน rpart() เพิ่มเติมได้ที่ rpart: Recursive Partitioning and Regression Trees)

เราไปดูตัวอย่างการใช้งาน rpart() กัน

.

🚗 Dataset: mtcars

ในบทความนี้ เราจะลองสร้าง tree-based model ประเภท classification หรือ classification tree เพื่อทำนายประเภทเกียร์รถใน mtcars dataset

mtcars เป็นชุดข้อมูลรถจาก ปี ค.ศ. 1974 ซึ่งประกอบไปด้วยข้อมูล เช่น รุ่นรถ น้ำหนัก ระดับการกินน้ำมัน แรงม้า เป็นต้น

เราสามารถโหลด mtcars มาใช้งานได้ด้วย data() และ preview ด้วย head():

# Load
data(mtcars)
# Preview
head(mtcars)

ตัวอย่างข้อมูล:

                   mpg cyl disp  hp drat    wt  qsec vs        am gear carb
Mazda RX4         21.0   6  160 110 3.90 2.620 16.46  0    manual    4    4
Mazda RX4 Wag     21.0   6  160 110 3.90 2.875 17.02  0    manual    4    4
Datsun 710        22.8   4  108  93 3.85 2.320 18.61  1    manual    4    1
Hornet 4 Drive    21.4   6  258 110 3.08 3.215 19.44  1 automatic    3    1
Hornet Sportabout 18.7   8  360 175 3.15 3.440 17.02  0 automatic    3    2
Valiant           18.1   6  225 105 2.76 3.460 20.22  1 automatic    3    1

.

🔧 Prepare the Data

ก่อนนำ mtcars ไปใช้สร้าง classification tree เราจะต้องทำ 2 อย่างก่อน:

อย่างที่ #1. ปรับ column am ให้เป็น factor เพราะสิ่งที่เราต้องการทำนายเป็น categorical data:

# Convert `am` to factor
mtcars$am <- factor(mtcars$am,
                    levels = c(0, 1),
                    labels = c("automatic", "manual"))
# Check the result
class(mtcars$am)

ผลลัพธ์:

[1] "factor"

อย่างที่ #2. Split ข้อมูลเป็น 2 ชุด:

  1. Training set สำหรับสร้าง model
  2. Test set สำหรับประเมิน model
# Set seed for reproducibility
set.seed(500)
# Get training index
train_index <- sample(nrow(mtcars),
                      nrow(mtcars) * 0.7)
# Split the data
train_set <- mtcars[train_index, ]
test_set <- mtcars[-train_index, ]

.

🪴 Train the Model

ตอนนี้ เราพร้อมที่จะสร้าง classification tree ด้วย rpart() แล้ว

สำหรับ classification tree ในบทความนี้ เราจะลองตั้งเงื่อนไขในการปลูกต้นไม้ (control) ดังนี้:

ArgumentExplanation
cp = 0คะแนนประสิทธิภาพขั้นต่ำ ก่อนจะแตกกิ่งใหม่ได้ = 0
minsplit = 1จำนวนกิ่งย่อยขั้นต่ำที่ต้องมี ก่อนจะแตกกิ่งใหม่ได้ = 1
maxdepth = 5จำนวนชั้นที่ decision tree มีได้สูงสุด = 5
# Classification tree
ct <- rpart(am ~ .,
            data = train_set,
            method = "class",
            control = rpart.control(cp = 0, minsplit = 1, maxdepth = 5))

เราสามารถดู classification tree ของเราได้ด้วย rpart.plot():

# Plot classification tree
rpart.plot(ct,
           type = 3,
           extra = 101,
           under = TRUE,
           digits = 3,
           tweak = 1.2)

ผลลัพธ์:

Note:

  • ก่อนใช้งาน rpart.plot() เราต้องจะติดตั้งและเรียกใช้งานด้วย install.packages() และ library() ตามลำดับ
  • ศึกษาการใช้งาน rpart.plot() เพิ่มเติมได้ที่ rpart.plot: Plot an rpart model. A simplified interface to the prp function.

.

📏 Evaluate the Model

เมื่อได้ classification tree มาแล้ว เราลองมาประเมินความสามารถของ model ด้วย accuracy หรือสัดส่วนการทำนายที่ถูกต้องกัน

เราเริ่มจากใช้ predict() เพื่อใช้ model ทำนายประเภทเกียร์:

# Predict the outcome
test_set$pred_ct <- predict(ct,
                            newdata = test_set,
                            type = "class")

จากนั้น สร้าง confusion matrix หรือ matrix เปรียบเทียบคำทำนายกับข้อมูลจริง:

# Create a confusion matrix
cm_ct <- table(Predicted = test_set$pred_ct,
               Actual = test_set$am)
# Print confusion matrix
print(cm_ct)

ผลลัพธ์:

           Actual
Predicted   automatic manual
  automatic         5      1
  manual            0      4

สุดท้าย เราหา accuracy ด้วยการนำจำนวนคำทำนายที่ถูกต้องมาหารด้วยจำนวนคำทำนายทั้งหมด:

# Get accuracy
acc_ct <- sum(diag(cm_ct)) / sum(cm_ct)
# Print accuracy
cat("Accuracy (classification tree):", acc_ct)

ผลลัพธ์:

Accuracy (classification tree): 0.9

จะเห็นได้ว่า model ของเรามีความแม่นยำถึง 90%


🍄 Random Forest

Random forest เป็น tree-based algorithm ที่ช่วยเพิ่มความแม่นยำในการทำนาย โดยสุ่มสร้าง decision trees ต้นเล็กขึ้นมาเป็นกลุ่ม (forest) แทนการปลูก decision tree ต้นเดียว

Decision tree แต่ละต้นใน random forest มีความสามารถในการทำนายแตกต่างกัน ซึ่งบางต้นอาจมีความสามารถที่น้อยมาก

แต่จุดแข็งของ random forest อยู่ที่จำนวน โดย random forest ทำนายผลลัพธ์โดยดูจากผลลัพธ์ในภาพรวม ดังนี้:

TaskPredict by
Regressionค่าเฉลี่ยของผลลัพธ์การทำนายของทุกต้น
Classificationเสียงส่วนมาก (majority vote)

ดังนั้น แม้ว่า decision tree บางต้นอาจทำนายผิดพลาด แต่โดยรวมแล้ว random forest มีโอกาสที่จะทำนายได้ดีกว่า decision tree ต้นเดียว

ในภาษา R เราสามารถสร้าง random forest ได้ด้วย randomForest() จาก randomForest package ซึ่งต้องการ 3 arguments:

randomFrest(formula, data, ntree)
  • formula = สูตรในการวิเคราะห์ (ตัวแปรตาม ~ ตัวแปรต้น)
  • data = dataset ที่ใช้สร้าง model
  • ntree = จำนวน decision trees ที่ต้องการสร้าง

Note:

  • เราไม่ต้องกำหนดว่า จะทำ classification หรือ regression model เพราะ randomForest() จะเลือก model ให้อัตโนมัติตามข้อมูลที่เราใส่เข้าไป
  • ศึกษาการใช้งาน randomForest() เพิ่มเติมได้ที่ randomForest: Classification and Regression with Random Forest

ก่อนใช้ randomForest() เราต้องเตรียมข้อมูลแบบเดียวกันกับ rpart() ได้แก่:

  1. เปลี่ยน am ให้เป็น factor
  2. Split the data

สมมุติว่า เราเตรียมข้อมูลแล้ว เราสามารถเรียกใช้ randomForest() ได้เลย โดยเราจะลองสร้าง random forest ที่ประกอบด้วย decision trees 100 ต้น:

# Random forest
rf <- randomForest(am ~ .,
                   data = train_set,
                   ntree = 100)

แล้วลองประเมินความสามารถของ model ด้วย accuracy

เริ่มจากทำนายประเภทเกียร์:

# Predict the outcome
test_set$pred_rf <- predict(rf,
                            newdata = test_set,
                            type = "class")

สร้าง confusion matrix:

# Create a confusion matrix
cm_rf <- table(Predicted = test_set$pred_rf,
               Actual = test_set$am)
# Print confusion matrix
print(cm_rf)

ผลลัพธ์:

           Actual
Predicted   automatic manual
  automatic         5      0
  manual            0      5

และสุดท้าย คำนวณ accuracy:

# Get accuracy
acc_rf <- sum(diag(cm_rf)) / sum(cm_rf)
# Print accuracy
cat("Accuracy (random forest):", acc_rf)

ผลลัพธ์:

Accuracy (random forest): 1

จะเห็นว่า random forest (100%) มีความแม่นยำในการทำนายมากกว่า classification tree ต้นเดียว (90%)


🐱 GitHub

ดู code ทั้งหมดในบทความนี้ได้ที่ GitHub:


📃 References


✅ R Book for Psychologists: หนังสือภาษา R สำหรับนักจิตวิทยา

📕 ขอฝากหนังสือเล่มแรกในชีวิตด้วยนะครับ 😆

🙋 ใครที่กำลังเรียนจิตวิทยาหรือทำงานสายจิตวิทยา และเบื่อที่ต้องใช้ software ราคาแพงอย่าง SPSS และ Excel เพื่อทำข้อมูล

💪 ผมขอแนะนำ R Book for Psychologists หนังสือสอนใช้ภาษา R เพื่อการวิเคราะห์ข้อมูลทางจิตวิทยา ที่เขียนมาเพื่อนักจิตวิทยาที่ไม่เคยมีประสบการณ์เขียน code มาก่อน

ในหนังสือ เราจะปูพื้นฐานภาษา R และพาไปดูวิธีวิเคราะห์สถิติที่ใช้บ่อยกัน เช่น:

  • Correlation
  • t-tests
  • ANOVA
  • Reliability
  • Factor analysis

🚀 เมื่ออ่านและทำตามตัวอย่างใน R Book for Psychologists ทุกคนจะไม่ต้องพึง SPSS และ Excel ในการทำงานอีกต่อไป และสามารถวิเคราะห์ข้อมูลด้วยตัวเองได้ด้วยความมั่นใจ

แล้วทุกคนจะแปลกใจว่า ทำไมภาษา R ง่ายขนาดนี้ 🙂‍↕️

👉 สนใจดูรายละเอียดหนังสือได้ที่ meb:

Comments

3 responses to “สอนปลูกต้นไม้ในภาษา R (ภาค 1): วิธีสร้าง tree-based models ใน 3 ขั้นตอนด้วย rpart และ randomForest packages — ตัวอย่างการทำนายประเภทเกียร์รถใน mtcars dataset”

  1.  Avatar

    […] ในบทความก่อน เราดูวิธีการใช้ randomForest แล้ว […]

    Like

Leave a reply to Machine Learning in R: รวบรวม 13 บทความสอนสร้าง machine learning ในภาษา R – Shi no Shigoto Cancel reply