Apa itu Pohon Keputusan?
Pohon Keputusan adalah algoritma Pembelajaran Mesin serbaguna yang dapat melakukan tugas klasifikasi dan regresi. Mereka adalah algoritme yang sangat kuat, yang mampu menyesuaikan dengan kumpulan data yang kompleks. Selain itu, pohon keputusan adalah komponen fundamental dari hutan acak, yang merupakan salah satu algoritme Pembelajaran Mesin paling ampuh yang tersedia saat ini.
Pelatihan dan Visualisasi pohon keputusan
Untuk membangun pohon keputusan pertama Anda dalam contoh R, kita akan melanjutkan seperti berikut dalam tutorial Pohon Keputusan ini:
- Langkah 1: Impor data
- Langkah 2: Bersihkan kumpulan data
- Langkah 3: Buat set latihan / pengujian
- Langkah 4: Buat model
- Langkah 5: Buat prediksi
- Langkah 6: Ukur kinerja
- Langkah 7: Sesuaikan hyper-parameter
Langkah 1) Impor data
Jika Anda penasaran dengan nasib titanic, Anda dapat menonton video ini di Youtube. Tujuan dari kumpulan data ini adalah untuk memprediksi orang mana yang lebih mungkin untuk bertahan hidup setelah tabrakan dengan gunung es. Dataset berisi 13 variabel dan 1309 observasi. Dataset diurutkan oleh variabel X.
set.seed(678)path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'titanic <-read.csv(path)head(titanic)
Keluaran:
## X pclass survived name sex## 1 1 1 1 Allen, Miss. Elisabeth Walton female## 2 2 1 1 Allison, Master. Hudson Trevor male## 3 3 1 0 Allison, Miss. Helen Loraine female## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female## 6 6 1 1 Anderson, Mr. Harry male## age sibsp parch ticket fare cabin embarked## 1 29.0000 0 0 24160 211.3375 B5 S## 2 0.9167 1 2 113781 151.5500 C22 C26 S## 3 2.0000 1 2 113781 151.5500 C22 C26 S## 4 30.0000 1 2 113781 151.5500 C22 C26 S## 5 25.0000 1 2 113781 151.5500 C22 C26 S## 6 48.0000 0 0 19952 26.5500 E12 S## home.dest## 1 St Louis, MO## 2 Montreal, PQ / Chesterville, ON## 3 Montreal, PQ / Chesterville, ON## 4 Montreal, PQ / Chesterville, ON## 5 Montreal, PQ / Chesterville, ON## 6 New York, NY
tail(titanic)
Keluaran:
## X pclass survived name sex age sibsp## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0## parch ticket fare cabin embarked home.dest## 1304 0 2627 14.4583 C## 1305 0 2665 14.4542 C## 1306 0 2665 14.4542 C## 1307 0 2656 7.2250 C## 1308 0 2670 7.2250 C## 1309 0 315082 7.8750 S
Dari head and tail output, Anda dapat melihat bahwa data tidak diacak. Ini masalah besar! Saat Anda akan membagi data antara set kereta dan set pengujian, Anda hanya akan memilih penumpang dari kelas 1 dan 2 (Tidak ada penumpang dari kelas 3 yang termasuk dalam 80 persen pengamatan teratas), yang berarti algoritme tidak akan pernah melihat fitur penumpang kelas 3. Kesalahan ini akan menyebabkan prediksi yang buruk.
Untuk mengatasi masalah ini, Anda bisa menggunakan function sample ().
shuffle_index <- sample(1:nrow(titanic))head(shuffle_index)
Penjelasan kode R pohon keputusan
- sample (1: nrow (titanic)): Menghasilkan daftar indeks secara acak dari 1 hingga 1309 (yaitu jumlah baris maksimum).
Keluaran:
## [1] 288 874 1078 633 887 992
Anda akan menggunakan indeks ini untuk mengacak kumpulan data titanic.
titanic <- titanic[shuffle_index, ]head(titanic)
Keluaran:
## X pclass survived## 288 288 1 0## 874 874 3 0## 1078 1078 3 1## 633 633 3 0## 887 887 3 1## 992 992 3 1## name sex age## 288 Sutton, Mr. Frederick male 61## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42## 1078 O'Driscoll, Miss. Bridget female NA## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39## 887 Jermyn, Miss. Annie female NA## 992 Mamee, Mr. Hanna male NA## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ## 874 0 0 348121 7.6500 F G63 S## 1078 0 0 14311 7.7500 Q## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN## 887 0 0 14313 7.7500 Q## 992 0 0 2677 7.2292 C
Langkah 2) Bersihkan dataset
Struktur datanya menunjukkan beberapa variabel memiliki NA. Pembersihan data dilakukan sebagai berikut
- Jatuhkan variabel home.dest, cabin, name, X dan ticket
- Buat variabel faktor untuk pclass dan selamat
- Jatuhkan NA
library(dplyr)# Drop variablesclean_titanic <- titanic % > %select(-c(home.dest, cabin, name, X, ticket)) % > %#Convert to factor levelmutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %na.omit()glimpse(clean_titanic)
Penjelasan Kode
- pilih (-c (home.dest, cabin, name, X, ticket)): Jatuhkan variabel yang tidak perlu
- pclass = faktor (pclass, level = c (1,2,3), label = c ('Atas', 'Tengah', 'Bawah')): Tambahkan label ke variabel pclass. 1 menjadi Atas, 2 menjadi MIddle dan 3 menjadi lebih rendah
- faktor (bertahan, level = c (0,1), label = c ('Tidak', 'Ya')): Tambahkan label ke variabel bertahan. 1 Menjadi Tidak dan 2 menjadi Ya
- na.omit (): Hapus pengamatan NA
Keluaran:
## Observations: 1,045## Variables: 8## $ pclassUpper, Lower, Lower, Upper, Middle, Upper, Middle, U… ## $ survived No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y… ## $ sex male, male, female, female, male, male, female, male… ## $ age 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0,… ## $ sibsp 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,… ## $ parch 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,… ## $ fare 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542,… ## $ embarked S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C…
Langkah 3) Buat set latihan / pengujian
Sebelum melatih model Anda, Anda perlu melakukan dua langkah:
- Membuat kereta dan set pengujian: Anda melatih model di set kereta dan menguji prediksi di set pengujian (yaitu data yang tidak terlihat)
- Instal rpart.plot dari konsol
Praktik yang umum dilakukan adalah membagi data 80/20, 80 persen data berfungsi untuk melatih model, dan 20 persen untuk membuat prediksi. Anda perlu membuat dua bingkai data terpisah. Anda tidak ingin menyentuh set pengujian sampai Anda selesai membuat model Anda. Anda bisa membuat nama fungsi create_train_test () yang membutuhkan tiga argumen.
create_train_test(df, size = 0.8, train = TRUE)arguments:-df: Dataset used to train the model.-size: Size of the split. By default, 0.8. Numerical value-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample < - 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}
Penjelasan Kode
- function (data, size = 0.8, train = TRUE): Tambahkan argumen ke dalam fungsi
- n_row = nrow (data): Hitung jumlah baris dalam dataset
- total_row = size * n_row: Kembalikan baris ke-n untuk membuat set kereta
- train_sample <- 1: total_row: Pilih baris pertama hingga baris ke-n
- if (kereta == TRUE) {} else {}: Jika kondisi disetel ke true, kembalikan set rangkaian, kalau tidak set pengujian.
Anda dapat menguji fungsi Anda dan memeriksa dimensinya.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)dim(data_train)
Keluaran:
## [1] 836 8
dim(data_test)
Keluaran:
## [1] 209 8
Dataset kereta memiliki 1046 baris sedangkan set data uji memiliki 262 baris.
Anda menggunakan fungsi prop.table () dikombinasikan dengan tabel () untuk memverifikasi apakah proses pengacakan sudah benar.
prop.table(table(data_train$survived))
Keluaran:
#### No Yes## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Keluaran:
#### No Yes## 0.5789474 0.4210526
Di kedua dataset, jumlah survivor sama, sekitar 40 persen.
Pasang rpart.plot
rpart.plot tidak tersedia dari perpustakaan conda. Anda dapat menginstalnya dari konsol:
install.packages("rpart.plot")
Langkah 4) Bangun modelnya
Anda siap membuat model. Sintaks untuk fungsi pohon keputusan Rpart adalah:
rpart(formula, data=, method='')arguments:- formula: The function to predict- data: Specifies the data frame- method:- "class" for a classification tree- "anova" for a regression tree
Anda menggunakan metode kelas karena Anda memprediksi kelas.
library(rpart)library(rpart.plot)fit <- rpart(survived~., data = data_train, method = 'class')rpart.plot(fit, extra = 106
Penjelasan Kode
- rpart (): Berfungsi untuk menyesuaikan model. Argumennya adalah:
- selamat ~ .: Formula Pohon Keputusan
- data = data_train: Set data
- method = 'class': Cocok dengan model biner
- rpart.plot (fit, extra = 106): Plot pohonnya. Fitur tambahan disetel ke 101 untuk menampilkan probabilitas kelas ke-2 (berguna untuk respons biner). Anda dapat merujuk ke sketsa untuk informasi lebih lanjut tentang pilihan lain.
Keluaran:
Anda mulai di simpul akar (kedalaman 0 di atas 3, bagian atas grafik):
- Di atas, itu adalah kemungkinan bertahan hidup secara keseluruhan. Ini menunjukkan proporsi penumpang yang selamat dari kecelakaan itu. 41 persen penumpang selamat.
- Node ini menanyakan apakah jenis kelamin penumpang adalah laki-laki. Jika ya, maka Anda turun ke simpul anak kiri root (kedalaman 2). 63 persen adalah laki-laki dengan kemungkinan bertahan hidup sebesar 21 persen.
- Di simpul kedua, Anda menanyakan apakah penumpang pria berusia di atas 3,5 tahun. Jika ya, maka peluang bertahan hidup adalah 19 persen.
- Anda terus melakukannya untuk memahami fitur apa yang memengaruhi kemungkinan bertahan hidup.
Perhatikan bahwa, salah satu dari banyak kualitas Pohon Keputusan adalah bahwa mereka memerlukan sedikit persiapan data. Secara khusus, mereka tidak memerlukan penskalaan atau pemusatan fitur.
Secara default, fungsi rpart () menggunakan ukuran ketidakmurnian Gini untuk membagi catatan. Semakin tinggi koefisien Gini, semakin banyak contoh berbeda di dalam node.
Langkah 5) Buat prediksi
Anda dapat memprediksi set data pengujian Anda. Untuk membuat prediksi, Anda bisa menggunakan fungsi predict (). Sintaks dasar untuk memprediksi pohon keputusan R adalah:
predict(fitted_model, df, type = 'class')arguments:- fitted_model: This is the object stored after model estimation.- df: Data frame used to make the prediction- type: Type of prediction- 'class': for classification- 'prob': to compute the probability of each class- 'vector': Predict the mean response at the node level
Anda ingin memprediksi penumpang mana yang lebih mungkin bertahan setelah tabrakan dari set pengujian. Artinya, di antara 209 penumpang itu Anda akan tahu, mana yang akan selamat atau tidak.
predict_unseen <-predict(fit, data_test, type = 'class')
Penjelasan Kode
- memprediksi (fit, data_test, type = 'class'): Memprediksi kelas (0/1) set pengujian
Menguji penumpang yang tidak berhasil dan yang berhasil.
table_mat <- table(data_test$survived, predict_unseen)table_mat
Penjelasan Kode
- tabel (data_test $ survived, predict_unseen): Buat tabel untuk menghitung berapa banyak penumpang yang diklasifikasikan sebagai survivor dan meninggal dunia dibandingkan dengan klasifikasi pohon keputusan yang benar di R
Keluaran:
## predict_unseen## No Yes## No 106 15## Yes 30 58
Model tersebut dengan tepat memprediksi 106 penumpang tewas tetapi mengklasifikasikan 15 korban tewas. Dengan analogi, model tersebut salah mengklasifikasikan 30 penumpang sebagai korban yang ternyata tewas.
Langkah 6) Ukur kinerja
Anda dapat menghitung ukuran akurasi untuk tugas klasifikasi dengan matriks kebingungan :
The matriks kebingungan adalah pilihan yang lebih baik untuk mengevaluasi kinerja klasifikasi. Ide umumnya adalah menghitung berapa kali instance True diklasifikasikan sebagai Salah.
Setiap baris dalam matriks kebingungan mewakili target aktual, sedangkan setiap kolom mewakili target yang diprediksi. Baris pertama dari matriks ini menganggap penumpang yang meninggal (kelas Salah): 106 diklasifikasikan dengan benar sebagai meninggal ( True negative ), sedangkan sisanya salah diklasifikasikan sebagai penumpang yang selamat ( False positive ). Baris kedua menghitung survivor, kelas positif 58 ( True positive ), sedangkan True negative 30.
Anda dapat menghitung uji akurasi dari matriks kebingungan:
Ini adalah proporsi positif benar dan negatif benar atas jumlah matriks. Dengan R, Anda dapat membuat kode sebagai berikut:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Penjelasan Kode
- sum (diag (table_mat)): Jumlah diagonal
- sum (table_mat): Jumlah dari matriks.
Anda dapat mencetak keakuratan set pengujian:
print(paste('Accuracy for test', accuracy_Test))
Keluaran:
## [1] "Accuracy for test 0.784688995215311"
Anda memiliki skor 78 persen untuk set tes. Anda dapat mereplikasi latihan yang sama dengan set data pelatihan.
Langkah 7) Sesuaikan hyper-parameternya
Pohon keputusan di R memiliki berbagai parameter yang mengontrol aspek fit. Di pustaka pohon keputusan rpart, Anda dapat mengontrol parameter menggunakan fungsi rpart.control (). Dalam kode berikut, Anda memperkenalkan parameter yang akan Anda setel. Anda dapat merujuk ke sketsa untuk parameter lain.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)Arguments:-minsplit: Set the minimum number of observations in the node before the algorithm perform a split-minbucket: Set the minimum number of observations in the final note i.e. the leaf-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Kami akan melanjutkan sebagai berikut:
- Membangun fungsi untuk mengembalikan akurasi
- Sesuaikan kedalaman maksimum
- Sesuaikan jumlah sampel minimum yang harus dimiliki node sebelum dapat dipisahkan
- Menyesuaikan jumlah sampel minimum yang harus dimiliki simpul daun
Anda dapat menulis fungsi untuk menampilkan akurasi. Anda cukup membungkus kode yang Anda gunakan sebelumnya:
- prediksi: predict_unseen <- prediksi (fit, data_test, type = 'class')
- Menghasilkan tabel: table_mat <- table (data_test $ survived, predict_unseen)
- Akurasi komputasi: akurasi_Test <- jumlah (diag (mat_tabel)) / jumlah (mat_tabel)
accuracy_tune <- function(fit) {predict_unseen <- predict(fit, data_test, type = 'class')table_mat <- table(data_test$survived, predict_unseen)accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test}
Anda dapat mencoba menyesuaikan parameter dan melihat apakah Anda dapat meningkatkan model di atas nilai default. Sebagai pengingat, Anda perlu mendapatkan akurasi yang lebih tinggi dari 0,78
control <- rpart.control(minsplit = 4,minbucket = round(5 / 3),maxdepth = 3,cp = 0)tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)accuracy_tune(tune_fit)
Keluaran:
## [1] 0.7990431
Dengan parameter berikut:
minsplit = 4minbucket= round(5/3)maxdepth = 3cp=0
Anda mendapatkan performa yang lebih tinggi dari model sebelumnya. Selamat!
Ringkasan
Kita dapat meringkas fungsi untuk melatih algoritma pohon keputusan di R
Perpustakaan |
Objektif |
fungsi |
kelas |
parameter |
detailnya |
---|---|---|---|---|---|
rpart |
Pohon klasifikasi kereta api di R |
rpart () |
kelas |
rumus, df, metode | |
rpart |
Latih pohon regresi |
rpart () |
anova |
rumus, df, metode | |
rpart |
Plot pohonnya |
rpart.plot () |
model pas | ||
mendasarkan |
meramalkan |
meramalkan() |
kelas |
model pas, tipe | |
mendasarkan |
meramalkan |
meramalkan() |
masalah |
model pas, tipe | |
mendasarkan |
meramalkan |
meramalkan() |
vektor |
model pas, tipe | |
rpart |
Parameter kontrol |
rpart.control () |
menitpisah |
Tetapkan jumlah minimum observasi di node sebelum algoritme melakukan pemisahan |
|
minbucket.dll |
Tetapkan jumlah minimum pengamatan di catatan akhir yaitu daun |
||||
maxdepth |
Tetapkan kedalaman maksimum simpul mana pun dari pohon terakhir. Node root diperlakukan sebagai kedalaman 0 |
||||
rpart |
Latih model dengan parameter kontrol |
rpart () |
rumus, df, metode, kontrol |
Catatan: Latih model pada data pelatihan dan uji performa pada set data yang tidak terlihat, yaitu set pengujian.