GLM di R: Model Linear Umum dengan Contoh

Daftar Isi:

Anonim

Apa itu regresi logistik?

Regresi logistik digunakan untuk memprediksi kelas, yaitu probabilitas. Regresi logistik dapat memprediksi hasil biner secara akurat.

Bayangkan Anda ingin memprediksi apakah suatu pinjaman ditolak / diterima berdasarkan banyak atribut. Regresi logistik berbentuk 0/1. y = 0 jika pinjaman ditolak, y = 1 jika diterima.

Model regresi logistik berbeda dari model regresi linier dalam dua cara.

  • Pertama-tama, regresi logistik hanya menerima input dikotomis (biner) sebagai variabel dependen (yaitu, vektor 0 dan 1).
  • Kedua, hasil diukur dengan fungsi tautan probabilistik berikut yang disebut sigmoid karena berbentuk S:

Output dari fungsi ini selalu antara 0 dan 1. Periksa Gambar di bawah

Fungsi sigmoid mengembalikan nilai dari 0 ke 1. Untuk tugas klasifikasi, kita membutuhkan keluaran diskrit 0 atau 1.

Untuk mengubah aliran kontinu menjadi nilai diskrit, kita dapat menetapkan batasan keputusan pada 0,5. Semua nilai di atas ambang batas ini diklasifikasikan sebagai 1

Dalam tutorial ini, Anda akan belajar

  • Apa itu regresi logistik?
  • Cara membuat Generalized Liner Model (GLM)
  • Langkah 1) Periksa variabel kontinu
  • Langkah 2) Periksa variabel faktor
  • Langkah 3) Rekayasa fitur
  • Langkah 4) Statistik Ringkasan
  • Langkah 5) Latih / set tes
  • Langkah 6) Bangun modelnya
  • Langkah 7) Nilai kinerja model

Cara membuat Generalized Liner Model (GLM)

Mari gunakan kumpulan data dewasa untuk menggambarkan regresi logistik. "Dewasa" adalah kumpulan data yang bagus untuk tugas klasifikasi. Tujuannya adalah untuk memprediksi apakah pendapatan tahunan dalam dolar seseorang akan melebihi 50.000. Dataset berisi 46.033 observasi dan sepuluh fitur:

  • usia: usia individu. Numerik
  • pendidikan: Tingkat pendidikan individu. Faktor.
  • marital.status: Status perkawinan individu. Faktor yaitu Belum kawin, Suami-istri-kawin,…
  • jenis kelamin: Jenis kelamin individu. Faktor, yaitu Pria atau Wanita
  • pendapatan: Variabel target. Penghasilan di atas atau di bawah 50K. Faktorkan yaitu> 50K, <= 50K

di antara yang lain

library(dplyr)data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")glimpse(data_adult)

Keluaran:

Observations: 48,842Variables: 10$ x  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,… $ age  25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26… $ workclass  Private, Private, Local-gov, Private, ?, Private,… $ education  11th, HS-grad, Assoc-acdm, Some-college, Some-col… $ educational.num  7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,… $ marital.status  Never-married, Married-civ-spouse, Married-civ-sp… $ race  Black, White, White, Black, White, White, Black,… $ gender  Male, Male, Male, Male, Female, Male, Male, Male,… $ hours.per.week  40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39… $ income  <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5… 

Kami akan melanjutkan sebagai berikut:

  • Langkah 1: Periksa variabel kontinu
  • Langkah 2: Periksa variabel faktor
  • Langkah 3: Rekayasa fitur
  • Langkah 4: Statistik ringkasan
  • Langkah 5: Latih / set pengujian
  • Langkah 6: Buat model
  • Langkah 7: Nilai performa model
  • langkah 8: Tingkatkan model

Tugas Anda adalah memprediksi individu mana yang akan memiliki pendapatan lebih tinggi dari 50K.

Dalam tutorial ini, setiap langkah akan dirinci untuk melakukan analisis pada dataset nyata.

Langkah 1) Periksa variabel kontinu

Pada langkah pertama, Anda dapat melihat distribusi variabel kontinu.

continuous <-select_if(data_adult, is.numeric)summary(continuous)

Penjelasan Kode

  • continuous <- select_if (data_adult, is.numeric): Gunakan fungsi select_if () dari pustaka dplyr untuk memilih hanya kolom numerik
  • ringkasan (berkelanjutan): Cetak statistik ringkasan

Keluaran:

## X age educational.num hours.per.week## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00## Median :23017 Median :37.00 Median :10.00 Median :40.00## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00

Dari tabel di atas, Anda dapat melihat bahwa datanya memiliki skala yang sangat berbeda dan jam.per.weeks memiliki pencilan yang besar (.ie lihat kuartil terakhir dan nilai maksimum).

Anda dapat mengatasinya dengan dua langkah berikut:

  • 1: Buat plot distribusi jam.per.week
  • 2: Standarisasi variabel kontinu
  1. Buat plot distribusinya

Mari kita lihat lebih dekat distribusi jam.per.week

# Histogram with kernel density curvelibrary(ggplot2)ggplot(continuous, aes(x = hours.per.week)) +geom_density(alpha = .2, fill = "#FF6666")

Keluaran:

Variabel memiliki banyak pencilan dan tidak terdistribusi dengan baik. Anda dapat mengatasi sebagian masalah ini dengan menghapus 0,01 persen jam kerja teratas per minggu.

Sintaks dasar kuantil:

quantile(variable, percentile)arguments:-variable: Select the variable in the data frame to compute the percentile-percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C,… )- `A`,`B`,`C` and `… ` are all integer from 0 to 1.

Kami menghitung persentil 2 persen teratas

top_one_percent <- quantile(data_adult$hours.per.week, .99)top_one_percent

Penjelasan Kode

  • quantile (data_adult $ hours.per.week, .99): Menghitung nilai 99 persen waktu kerja

Keluaran:

## 99%## 80 

98 persen populasi bekerja di bawah 80 jam per minggu.

Anda dapat menghentikan pengamatan di atas ambang batas ini. Anda menggunakan filter dari perpustakaan dplyr.

data_adult_drop <-data_adult %>%filter(hours.per.week

Keluaran:

## [1] 45537 10 
  1. Standarisasi variabel kontinu

Anda dapat membakukan setiap kolom untuk meningkatkan kinerja karena data Anda tidak memiliki skala yang sama. Anda dapat menggunakan fungsi mutate_if dari pustaka dplyr. Sintaks dasarnya adalah:

mutate_if(df, condition, funs(function))arguments:-`df`: Data frame used to compute the function- `condition`: Statement used. Do not use parenthesis- funs(function): Return the function to apply. Do not use parenthesis for the function

Anda dapat membakukan kolom numerik sebagai berikut:

data_adult_rescale <- data_adult_drop % > %mutate_if(is.numeric, funs(as.numeric(scale(.))))head(data_adult_rescale)

Penjelasan Kode

  • mutate_if (is.numeric, funs (scale)): Kondisinya hanya kolom numerik dan fungsinya adalah skala

Keluaran:

## X age workclass education educational.num## 1 -1.732680 -1.02325949 Private 11th -1.22106443## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494## 4 -1.732455 0.41426100 Private Some-college -0.04945081## 5 -1.732379 -0.34232873 Private 10th -1.61160231## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857## marital.status race gender hours.per.week income## 1 Never-married Black Male -0.03995944 <=50K## 2 Married-civ-spouse White Male 0.86863037 <=50K## 3 Married-civ-spouse White Male -0.03995944 >50K## 4 Married-civ-spouse Black Male -0.03995944 >50K## 5 Never-married White Male -0.94854924 <=50K## 6 Married-civ-spouse White Male -0.76683128 >50K

Langkah 2) Periksa variabel faktor

Langkah ini memiliki dua tujuan:

  • Periksa level di setiap kolom kategori
  • Tentukan level baru

Kami akan membagi langkah ini menjadi tiga bagian:

  • Pilih kolom kategorikal
  • Simpan diagram batang dari setiap kolom dalam daftar
  • Cetak grafiknya

Kita dapat memilih kolom faktor dengan kode di bawah ini:

# Select categorical columnfactor <- data.frame(select_if(data_adult_rescale, is.factor))ncol(factor)

Penjelasan Kode

  • data.frame (select_if (data_adult, is.factor)): Kami menyimpan kolom faktor dalam faktor dalam tipe bingkai data. Pustaka ggplot2 membutuhkan objek bingkai data.

Keluaran:

## [1] 6 

Dataset berisi 6 variabel kategori

Langkah kedua lebih terampil. Anda ingin memplot diagram batang untuk setiap kolom dalam faktor bingkai data. Akan lebih mudah untuk mengotomatiskan proses, terutama dalam situasi ada banyak kolom.

library(ggplot2)# Create graph for each columngraph <- lapply(names(factor),function(x)ggplot(factor, aes(get(x))) +geom_bar() +theme(axis.text.x = element_text(angle = 90)))

Penjelasan Kode

  • lapply (): Gunakan fungsi lapply () untuk melewatkan fungsi di semua kolom dataset. Anda menyimpan output dalam daftar
  • function (x): Fungsi akan diproses untuk setiap x. Di sini x adalah kolomnya
  • ggplot (factor, aes (get (x))) + geom_bar () + theme (axis.text.x = element_text (angle = 90)): Buat diagram bar char untuk setiap elemen x. Catatan, untuk mengembalikan x sebagai kolom, Anda perlu memasukkannya ke dalam get ()

Langkah terakhir relatif mudah. Anda ingin mencetak 6 grafik.

# Print the graphgraph

Keluaran:

## [[1]]

## ## [[2]]

## ## [[3]]

## ## [[4]]

## ## [[5]]

## ## [[6]]

Catatan: Gunakan tombol berikutnya untuk menavigasi ke grafik berikutnya

Langkah 3) Rekayasa fitur

Tata ulang pendidikan

Dari grafik di atas, terlihat bahwa variabel pendidikan memiliki 16 tingkatan. Ini substansial, dan beberapa level memiliki jumlah observasi yang relatif rendah. Jika Anda ingin meningkatkan jumlah informasi yang dapat Anda peroleh dari variabel ini, Anda dapat menyusunnya kembali ke tingkat yang lebih tinggi. Yakni, Anda membuat grup yang lebih besar dengan tingkat pendidikan yang sama. Misalnya, tingkat pendidikan yang rendah akan diubah menjadi putus sekolah. Tingkat pendidikan yang lebih tinggi akan diubah menjadi master.

Berikut detailnya:

Level lama

Level baru

Prasekolah

keluar

10

Keluar

11

Keluar

12

Keluar

1st-4th

Keluar

5-6

Keluar

7-8

Keluar

9

Keluar

HS-Grad

HighGrad

Beberapa perguruan tinggi

Masyarakat

Assoc-acdm

Masyarakat

Assoc-voc

Masyarakat

Sarjana

Sarjana

Master

Master

Prof-sekolah

Master

Gelar doktor

PhD

recast_data <- data_adult_rescale % > %select(-X) % > %mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",ifelse(education == "Bachelors", "Bachelors",ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

Penjelasan Kode

  • Kami menggunakan kata kerja mutate dari perpustakaan dplyr. Kami mengubah nilai-nilai pendidikan dengan pernyataan ifelse

Pada tabel di bawah, Anda membuat statistik ringkasan untuk melihat, rata-rata, berapa tahun pendidikan (nilai z) yang diperlukan untuk mencapai Sarjana, Magister atau PhD.

recast_data % > %group_by(education) % > %summarize(average_educ_year = mean(educational.num),count = n()) % > %arrange(average_educ_year)

Keluaran:

## # A tibble: 6 x 3## education average_educ_year count##   ## 1 dropout -1.76147258 5712## 2 HighGrad -0.43998868 14803## 3 Community 0.09561361 13407## 4 Bachelors 1.12216282 7720## 5 Master 1.60337381 3338## 6 PhD 2.29377644 557

Ubah status Perkawinan

Mungkin juga untuk membuat tingkat yang lebih rendah untuk status perkawinan. Dalam kode berikut Anda mengubah level sebagai berikut:

Level lama

Level baru

Tidak pernah menikah

Belum nikah

Menikah-pasangan-absen

Belum nikah

Menikah-AF-pasangan

Menikah

Suami-istri-suami-istri

Terpisah

Terpisah

Bercerai

Janda

Janda

# Change level marryrecast_data <- recast_data % > %mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))
Anda dapat memeriksa jumlah individu dalam setiap grup.
table(recast_data$marital.status)

Keluaran:

## ## Married Not_married Separated Widow## 21165 15359 7727 1286 

Langkah 4) Statistik Ringkasan

Saatnya untuk memeriksa beberapa statistik tentang variabel target kita. Pada grafik di bawah, Anda menghitung persentase individu yang berpenghasilan lebih dari 50.000 berdasarkan jenis kelamin mereka.

# Plot gender incomeggplot(recast_data, aes(x = gender, fill = income)) +geom_bar(position = "fill") +theme_classic()

Keluaran:

Selanjutnya, periksa apakah asal individu memengaruhi penghasilannya.

# Plot origin incomeggplot(recast_data, aes(x = race, fill = income)) +geom_bar(position = "fill") +theme_classic() +theme(axis.text.x = element_text(angle = 90))

Keluaran:

Jumlah jam kerja menurut gender.

# box plot gender working timeggplot(recast_data, aes(x = gender, y = hours.per.week)) +geom_boxplot() +stat_summary(fun.y = mean,geom = "point",size = 3,color = "steelblue") +theme_classic()

Keluaran:

Plot kotak menegaskan bahwa distribusi waktu kerja sesuai dengan kelompok yang berbeda. Dalam plot kotak, kedua jenis kelamin tidak memiliki pengamatan yang homogen.

Anda dapat memeriksa kepadatan waktu kerja mingguan berdasarkan jenis pendidikan. Distribusi memiliki banyak pilihan berbeda. Mungkin bisa dijelaskan dengan jenis kontrak di AS.

# Plot distribution working time by educationggplot(recast_data, aes(x = hours.per.week)) +geom_density(aes(color = education), alpha = 0.5) +theme_classic()

Penjelasan Kode

  • ggplot (recast_data, aes (x = hours.per.week)): Sebuah plot kepadatan hanya membutuhkan satu variabel
  • geom_density (aes (color = education), alpha = 0,5): Objek geometris untuk mengontrol kepadatan

Keluaran:

Untuk mengkonfirmasi pemikiran Anda, Anda dapat melakukan tes ANOVA satu arah:

anova <- aov(hours.per.week~education, recast_data)summary(anova)

Keluaran:

## Df Sum Sq Mean Sq F value Pr(>F)## education 5 1552 310.31 321.2 <2e-16 ***## Residuals 45531 43984 0.97## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Tes ANOVA mengkonfirmasi perbedaan rata-rata antar kelompok.

Non-linearitas

Sebelum Anda menjalankan model, Anda dapat melihat apakah jumlah jam kerja terkait dengan usia.

library(ggplot2)ggplot(recast_data, aes(x = age, y = hours.per.week)) +geom_point(aes(color = income),size = 0.5) +stat_smooth(method = 'lm',formula = y~poly(x, 2),se = TRUE,aes(color = income)) +theme_classic()

Penjelasan Kode

  • ggplot (recast_data, aes (x = age, y = hours.per.week)): Mengatur estetika grafik
  • geom_point (aes (color = income), size = 0,5): Buat plot titik
  • stat_smooth (): Tambahkan garis tren dengan argumen berikut:
    • metode = 'lm': Plot nilai pas jika regresi linier
    • formula = y ~ poly (x, 2): Cocokkan regresi polinomial
    • se = TRUE: Tambahkan kesalahan standar
    • aes (warna = pendapatan): Pisahkan model dengan pendapatan

Keluaran:

Singkatnya, Anda dapat menguji istilah interaksi dalam model untuk mengambil efek non-linearitas antara waktu kerja mingguan dan fitur lainnya. Penting untuk mendeteksi dalam kondisi apa waktu kerja berbeda.

Korelasi

Pemeriksaan selanjutnya adalah untuk memvisualisasikan korelasi antar variabel. Anda mengonversi tipe tingkat faktor menjadi numerik sehingga Anda dapat memplot peta panas yang berisi koefisien korelasi yang dihitung dengan metode Spearman.

library(GGally)# Convert data to numericcorr <- data.frame(lapply(recast_data, as.integer))# Plot the graphggcorr(corr,method = c("pairwise", "spearman"),nbreaks = 6,hjust = 0.8,label = TRUE,label_size = 3,color = "grey50")

Penjelasan Kode

  • data.frame (lapply (recast_data, as.integer)): Mengonversi data menjadi numerik
  • ggcorr () memplot peta panas dengan argumen berikut:
    • Metode: Metode untuk menghitung korelasi
    • nbreaks = 6: Jumlah break
    • hjust = 0.8: Posisi kontrol nama variabel di plot
    • label = TRUE: Tambahkan label di tengah jendela
    • label_size = 3: Label ukuran
    • color = "grey50"): Warna label

Keluaran:

Langkah 5) Latih / set tes

Setiap tugas machine learning yang diawasi perlu memisahkan data antara satu set rangkaian dan set pengujian. Anda dapat menggunakan "fungsi" yang Anda buat di tutorial pembelajaran yang diawasi lainnya untuk membuat set latihan / tes.

set.seed(1234)create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample <- 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}data_train <- create_train_test(recast_data, 0.8, train = TRUE)data_test <- create_train_test(recast_data, 0.8, train = FALSE)dim(data_train)

Keluaran:

## [1] 36429 9
dim(data_test)

Keluaran:

## [1] 9108 9 

Langkah 6) Bangun modelnya

Untuk melihat bagaimana algoritme bekerja, Anda menggunakan paket glm (). The Generalized Linear Model adalah kumpulan model. Sintaks dasarnya adalah:

glm(formula, data=data, family=linkfunction()Argument:- formula: Equation used to fit the model- data: dataset used- Family: - binomial: (link = "logit")- gaussian: (link = "identity")- Gamma: (link = "inverse")- inverse.gaussian: (link = "1/mu^2")- poisson: (link = "log")- quasi: (link = "identity", variance = "constant")- quasibinomial: (link = "logit")- quasipoisson: (link = "log")

Anda siap memperkirakan model logistik untuk membagi tingkat pendapatan di antara sekumpulan fitur.

formula <- income~.logit <- glm(formula, data = data_train, family = 'binomial')summary(logit)

Penjelasan Kode

  • formula <- income ~.: Buat model agar sesuai
  • logit <- glm (formula, data = data_train, family = 'binomial'): Sesuaikan model logistik (family = 'binomial') dengan data data_train.
  • ringkasan (logit): Cetak ringkasan model

Keluaran:

#### Call:## glm(formula = formula, family = "binomial", data = data_train)## ## Deviance Residuals:## Min 1Q Median 3Q Max## -2.6456 -0.5858 -0.2609 -0.0651 3.1982#### Coefficients:## Estimate Std. Error z value Pr(>|z|)## (Intercept) 0.07882 0.21726 0.363 0.71675## age 0.41119 0.01857 22.146 < 2e-16 ***## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 ***## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 ***## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 ***## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 ***## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 ***## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 ***## educationMaster 0.35651 0.06780 5.258 1.46e-07 ***## educationPhD 0.46995 0.15772 2.980 0.00289 **## educationdropout -1.04974 0.21280 -4.933 8.10e-07 ***## educational.num 0.56908 0.07063 8.057 7.84e-16 ***## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 ***## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 ***## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 ***## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117## raceBlack 0.07188 0.19330 0.372 0.71001## raceOther 0.01370 0.27695 0.049 0.96054## raceWhite 0.34830 0.18441 1.889 0.05894 .## genderMale 0.08596 0.04289 2.004 0.04506 *## hours.per.week 0.41942 0.01748 23.998 < 2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## (Dispersion parameter for binomial family taken to be 1)## ## Null deviance: 40601 on 36428 degrees of freedom## Residual deviance: 27041 on 36406 degrees of freedom## AIC: 27087#### Number of Fisher Scoring iterations: 6

Ringkasan model kami mengungkapkan informasi yang menarik. Kinerja regresi logistik dievaluasi dengan metrik kunci tertentu.

  • AIC (Akaike Information Criteria): Ini setara dengan R2 dalam regresi logistik. Ini mengukur kecocokan saat penalti diterapkan pada jumlah parameter. Nilai AIC yang lebih kecil menunjukkan model mendekati kebenaran.
  • Penyimpangan nol: Cocok untuk model hanya dengan intersep. Derajat kebebasan adalah n-1. Kita dapat menafsirkannya sebagai nilai Chi-square (nilai pas berbeda dari pengujian hipotesis nilai aktual).
  • Sisa Penyimpangan: Model dengan semua variabel. Ini juga diartikan sebagai pengujian hipotesis Chi-square.
  • Number of Fisher Scoring iterations: Jumlah iterasi sebelum konvergen.

Output dari fungsi glm () disimpan dalam sebuah daftar. Kode di bawah ini menunjukkan semua item yang tersedia di variabel logit yang kami buat untuk mengevaluasi regresi logistik.

# Daftarnya sangat panjang, cetak hanya tiga elemen pertama

lapply(logit, class)[1:3]

Keluaran:

## $coefficients## [1] "numeric"#### $residuals## [1] "numeric"#### $fitted.values## [1] "numeric"

Setiap nilai dapat diekstraksi dengan tanda $ diikuti dengan nama metrik. Misalnya, Anda menyimpan model sebagai logit. Untuk mengekstrak kriteria AIC, Anda menggunakan:

logit$aic

Keluaran:

## [1] 27086.65

Langkah 7) Nilai kinerja model

Confusion Matrix

The matriks kebingungan adalah pilihan yang lebih baik untuk mengevaluasi kinerja klasifikasi dibandingkan dengan metrik yang berbeda Anda lihat sebelumnya. Ide umumnya adalah menghitung berapa kali instance True diklasifikasikan sebagai Salah.

Untuk menghitung matriks konfusi, Anda harus terlebih dahulu memiliki sekumpulan prediksi agar dapat dibandingkan dengan target sebenarnya.

predict <- predict(logit, data_test, type = 'response')# confusion matrixtable_mat <- table(data_test$income, predict > 0.5)table_mat

Penjelasan Kode

  • prediksi (logit, data_test, type = 'response'): Hitung prediksi pada set pengujian. Set type = 'response' untuk menghitung probabilitas respons.
  • table (data_test $ income, predict> 0.5): Hitung matriks konfusi. predict> 0,5 berarti mengembalikan 1 jika probabilitas yang diprediksi di atas 0,5, jika tidak 0.

Keluaran:

#### FALSE TRUE## <=50K 6310 495## >50K 1074 1229

Setiap baris dalam matriks kebingungan mewakili target aktual, sedangkan setiap kolom mewakili target yang diprediksi. Baris pertama dari matriks ini menganggap pendapatan lebih rendah dari 50k (kelas Salah): 6241 diklasifikasikan dengan benar sebagai individu dengan pendapatan lebih rendah dari 50k ( Negatif benar ), sedangkan sisanya salah diklasifikasikan sebagai di atas 50k ( Positif palsu ). Baris kedua menganggap pendapatan di atas 50k, kelas positif adalah 1229 ( Benar positif ), sedangkan negatif Benar adalah 1074.

Anda dapat menghitung keakuratan model dengan menjumlahkan positif benar + negatif benar di atas total observasi

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test

Penjelasan Kode

  • sum (diag (table_mat)): Jumlah diagonal
  • sum (table_mat): Jumlah dari matriks.

Keluaran:

## [1] 0.8277339 

Model tersebut tampaknya mengalami satu masalah, yaitu melebih-lebihkan jumlah negatif palsu. Ini disebut paradoks uji akurasi . Kami menyatakan bahwa akurasi adalah rasio prediksi yang benar terhadap jumlah kasus. Kami dapat memiliki akurasi yang relatif tinggi tetapi model yang tidak berguna. Itu terjadi ketika ada kelas dominan. Jika Anda melihat kembali matriks kebingungan, Anda dapat melihat sebagian besar kasus diklasifikasikan sebagai negatif benar. Bayangkan sekarang, model mengklasifikasikan semua kelas sebagai negatif (yaitu lebih rendah dari 50k). Anda akan memiliki akurasi 75 persen (6718/6718 + 2257). Model Anda berkinerja lebih baik tetapi kesulitan membedakan positif benar dengan negatif sebenarnya.

Dalam situasi seperti itu, lebih disukai untuk memiliki metrik yang lebih ringkas. Kita bisa melihat:

  • Presisi = TP / (TP + FP)
  • Perolehan = TP / (TP + FN)

Presisi vs Perolehan

Presisi melihat keakuratan prediksi positif. Perolehan adalah rasio kejadian positif yang dideteksi dengan benar oleh pengklasifikasi;

Anda dapat membuat dua fungsi untuk menghitung dua metrik ini

  1. Bangun presisi
precision <- function(matrix) {# True positivetp <- matrix[2, 2]# false positivefp <- matrix[1, 2]return (tp / (tp + fp))}

Penjelasan Kode

  • mat [1,1]: Kembalikan sel pertama dari kolom pertama dari bingkai data, yaitu positif benar
  • tikar [1,2]; Kembalikan sel pertama dari kolom kedua dari bingkai data, yaitu positif palsu
recall <- function(matrix) {# true positivetp <- matrix[2, 2]# false positivefn <- matrix[2, 1]return (tp / (tp + fn))}

Penjelasan Kode

  • mat [1,1]: Kembalikan sel pertama dari kolom pertama dari bingkai data, yaitu positif benar
  • tikar [2,1]; Kembalikan sel kedua dari kolom pertama dari bingkai data, yaitu negatif palsu

Anda dapat menguji fungsi Anda

prec <- precision(table_mat)precrec <- recall(table_mat)rec

Keluaran:

## [1] 0.712877## [2] 0.5336518

Jika model mengatakan itu adalah individu di atas 50k, itu benar hanya dalam 54 persen kasus, dan dapat mengklaim individu di atas 50k dalam 72 persen kasus.

Anda dapat membuat adalah rata-rata harmonis dari kedua metrik ini, yang berarti memberi bobot lebih pada nilai yang lebih rendah.

f1 <- 2 * ((prec * rec) / (prec + rec))f1

Keluaran:

## [1] 0.6103799 

Pengorbanan Presisi vs Perolehan

Tidak mungkin memiliki presisi tinggi dan recall tinggi.

Jika kita meningkatkan presisi, individu yang benar akan lebih baik diprediksi, tetapi kita akan kehilangan banyak dari mereka (recall lebih rendah). Dalam beberapa situasi, kami lebih memilih presisi yang lebih tinggi daripada perolehan. Ada hubungan cekung antara presisi dan perolehan.

  • Bayangkan, Anda perlu memprediksi apakah seorang pasien mengidap suatu penyakit. Anda ingin menjadi setepat mungkin.
  • Jika Anda perlu mendeteksi potensi orang curang di jalan melalui pengenalan wajah, akan lebih baik jika Anda menangkap banyak orang yang dicap sebagai penipu meskipun tingkat ketelitiannya rendah. Polisi akan dapat membebaskan individu yang tidak curang tersebut.

Kurva KOP

The Receiver Operating Karakteristik kurva adalah alat lain yang umum digunakan dengan klasifikasi biner. Ini sangat mirip dengan kurva presisi / perolehan, tetapi alih-alih menggambarkan presisi versus perolehan, kurva KOP menunjukkan rasio positif benar (yaitu, perolehan) terhadap rasio positif palsu. Rasio positif palsu adalah rasio kejadian negatif yang salah diklasifikasikan sebagai positif. Ini sama dengan satu dikurangi tingkat negatif sebenarnya. Tingkat negatif sebenarnya juga disebut spesifisitas . Oleh karena itu kurva KOP memplot sensitivitas (recall) versus 1-spesifisitas

Untuk memplot kurva ROC, kita perlu menginstal perpustakaan yang disebut RORC. Kita bisa menemukannya di perpustakaan conda. Anda bisa mengetik kode:

conda install -cr r-rocr --ya

Kita bisa memplot ROC dengan fungsi prediction () dan performance ().

library(ROCR)ROCRpred <- prediction(predict, data_test$income)ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

Penjelasan Kode

  • prediction (predict, data_test $ income): Pustaka ROCR perlu membuat objek prediksi untuk mengubah data masukan
  • performance (ROCRpred, 'tpr', 'fpr'): Kembalikan dua kombinasi yang akan dihasilkan dalam grafik. Di sini, tpr dan fpr dibangun. Untuk presisi plot total dan recall bersama, gunakan "prec", "rec".

Keluaran:

Langkah 8) Tingkatkan model

Anda dapat mencoba menambahkan non-linearitas ke model dengan interaksi di antaranya

  • usia dan jam. per minggu
  • gender dan hours.per.week.

Anda perlu menggunakan tes skor untuk membandingkan kedua model

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .logit_2 <- glm(formula_2, data = data_train, family = 'binomial')predict_2 <- predict(logit_2, data_test, type = 'response')table_mat_2 <- table(data_test$income, predict_2 > 0.5)precision_2 <- precision(table_mat_2)recall_2 <- recall(table_mat_2)f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))f1_2

Keluaran:

## [1] 0.6109181 

Skor tersebut sedikit lebih tinggi dari yang sebelumnya. Anda dapat terus mengerjakan data untuk mencoba mengalahkan skor.

Ringkasan

Fungsi untuk melatih regresi logistik dapat kita rangkum dalam tabel di bawah ini:

Paket

Objektif

fungsi

argumen

-

Buat set data kereta / uji

create_train_set ()

data, ukuran, kereta

glm

Latih Model Linear Umum

glm ()

rumus, data, keluarga *

glm

Rangkum modelnya

ringkasan()

model pas

mendasarkan

Buat prediksi

meramalkan()

model pas, dataset, type = 'response'

mendasarkan

Buat matriks kebingungan

meja()

y, prediksi ()

mendasarkan

Buat skor akurasi

jumlah (diag (tabel ()) / jumlah (tabel ()

ROCR

Buat ROC: Langkah 1 Buat prediksi

ramalan()

prediksi (), y

ROCR

Buat ROC: Langkah 2 Buat kinerja

kinerja ()

prediksi (), 'tpr', 'fpr'

ROCR

Buat ROC: Langkah 3 Grafik plot

merencanakan()

kinerja ()

Jenis model GLM lainnya adalah:

- binomial: (link = "logit")

- gaussian: (link = "identity")

- Gamma: (link = "inverse")

- inverse.gaussian: (link = "1 / mu 2")

- poisson: (link = "log")

- quasi: (link = "identity", variance = "constant")

- quasibinomial: (link = "logit")

- quasipoisson: (link = "log")