Last updated: 2020-06-23

Checks: 7 0

Knit directory: MSTPsummerstatistics/

This reproducible R Markdown analysis was created with workflowr (version 1.5.0). The Checks tab describes the reproducibility checks that were applied when the results were created. The Past versions tab lists the development history.


Great! Since the R Markdown file has been committed to the Git repository, you know the exact version of the code that produced these results.

Great job! The global environment was empty. Objects defined in the global environment can affect the analysis in your R Markdown file in unknown ways. For reproduciblity it’s best to always run the code in an empty environment.

The command set.seed(20180927) was run prior to running the code in the R Markdown file. Setting a seed ensures that any results that rely on randomness, e.g. subsampling or permutations, are reproducible.

Great job! Recording the operating system, R version, and package versions is critical for reproducibility.

Nice! There were no cached chunks for this analysis, so you can be confident that you successfully produced the results during this run.

Great job! Using relative paths to the files within your workflowr project makes it easier to run your code on other machines.

Great! You are using Git for version control. Tracking code development and connecting the code version to the results is critical for reproducibility. The version displayed above was the version of the Git repository at the time these results were generated.

Note that you need to be careful to ensure that all relevant files for the analysis have been committed to Git prior to generating the results (you can use wflow_publish or wflow_git_commit). workflowr only checks the R Markdown file, but you know if there are other scripts or data files that it depends on. Below is the status of the Git repository when the results were generated:


Ignored files:
    Ignored:    .DS_Store
    Ignored:    .RData
    Ignored:    .Rhistory
    Ignored:    .Rproj.user/
    Ignored:    analysis/.DS_Store
    Ignored:    analysis/.RData
    Ignored:    analysis/.Rhistory
    Ignored:    data/.DS_Store

Note that any generated files, e.g. HTML, png, CSS, etc., are not included in this status report because it is ok for generated content to have uncommitted changes.


These are the previous versions of the R Markdown and HTML files. If you’ve configured a remote Git repository (see ?wflow_git_remote), click on the hyperlinks in the table below to view them.

File Version Author Date Message
html 61e3892 Anthony Hung 2020-06-23 add cheatsheet links
html 2a37983 Anthony Hung 2020-06-23 Build site.
html a1a0bd4 Anthony Hung 2020-06-23 Build site.
Rmd 2ac74f5 Anthony Hung 2020-06-22 edits to cross-validation snippet
html 14c0094 Anthony Hung 2020-06-12 Build site.
html 1378fca Anthony Hung 2020-06-12 Build site.
Rmd 5663923 Anthony Hung 2020-05-12 knit files
html 5663923 Anthony Hung 2020-05-12 knit files
html 2114e6c Anthony Hung 2020-05-10 Build site.
html 29c91df Anthony Hung 2020-05-10 Build site.
Rmd 8e5d9b0 Anthony Hung 2020-05-09 add exercises
Rmd ab1bee8 Anthony Hung 2020-05-09 add description of PR vs ROC
html ab1bee8 Anthony Hung 2020-05-09 add description of PR vs ROC
Rmd ecd1380 Anthony Hung 2020-05-09 add ROC
html ecd1380 Anthony Hung 2020-05-09 add ROC
html e18c369 Anthony Hung 2020-05-02 Build site.
html 0e6b6d0 Anthony Hung 2020-04-30 Build site.
html 5cbe42c Anthony Hung 2020-04-23 Build site.
html 4e08935 Anthony Hung 2020-03-30 Build site.
html f15db48 Anthony Hung 2020-03-30 Build site.
html 310d040 Anthony Hung 2020-02-20 Build site.
Rmd e6a84ff Anthony Hung 2020-02-14 correct typos
html 96722bd Anthony Hung 2019-08-07 Build site.
Rmd 7c27654 Anthony Hung 2019-08-07 add NaiveBayes
html 7c27654 Anthony Hung 2019-08-07 add NaiveBayes

Introduction

The Naive Bayes classifiers are a class of supervised machine learning algorithms that use Bayes rule to solve classification problems. Today we will introduce the concept of classification and implement a Naive Bayes classifier to build a classifier to identify poisonous mushrooms.

What is Machine learning?

Machine learning is a field that leverages generic algorithms that learn pattterns from data rather than having to write code that specifically instructs the algorithm what patterns to focus on.

What is classification?

Classification is essentially the act of arranging items into categories according to shared characteristics amongst the items in the same category. An example is classifying a banana and apple as fruits and lettuce and spinach as vegetables. In machine learning, classification is typically an example of supervised learning, in which an algorithm learns from a user-supplied gold standard example of data and assigned category labels in order to determine patterns and shared characteristics that define each category. For example, consider this table of data that contains information about individual mushroom samples. The original data comes from UCI’s Machine learning department via kaggle: https://www.kaggle.com/uciml/mushroom-classification.

Other broad types of machine learning are depicted here: https://d2h0cx97tjks2p.cloudfront.net/blogs/wp-content/uploads/sites/2/2019/08/Types-of-Machine-Learning-algorithms.jpg

Motivating example: Is this Mushroom poisonous or not?

mush_data <- read.csv("data/mushrooms.csv")
mush_data[1:6,1:6]
  class cap.shape cap.surface cap.color bruises odor
1     p         x           s         n       t    p
2     e         x           s         y       t    a
3     e         b           s         w       t    l
4     p         x           y         w       t    p
5     e         x           s         g       f    n
6     e         x           y         y       t    a

We can break down the above table into two different elements. The first, containing all the columns except for the first one, is also known as the “feature matrix”. This matrix contains descriptions of each of the features that describe each of the individual mushrooms (i.e. mushroom 1 has a cap shape that can be described as conveX, a cap surface that is Smooth, a cap color of browN, has bruises, and a Pungent odor). The second element is the first column, also known as the “response vector”. The response vector contains the class of each of the mushrooms (the response variable in our case).

Notice that the response variable is a categorical variable that can take on one of two values: {e=edible, p=poisonous}. In classification problems, the y variable is always categorical. The analogous situation for when you have a continouous response variable is regression, which will be covered in a later lecture.

Using this labeled dataset, we can teach a classification algorithm how to predict if a new mushroom is poisonous or not, given we have information about the other features of that mushroom. In order to do this, the algorithm must search for patterns amongst the other features that are shared between poisonous mushrooms and distinguish these mushrooms from edible ones. One way to do this is through using a Naive bayes classifier.

As a side note, a third task that is often performed with machine learning is clustering, in which clusters are learned from the input data themselves rather than being specified by the user. This is a case of unsupervised learning, meaning the algorithm is not supplied with gold standard examples of what each cluster looks like but rather defines them on its own.

Review of Bayes rule

Before we jump into the applied example, we must first review the components of Bayes rule. Bayes rule is a result that comes from probability and describes the relationship between conditional probabilities.

Let us define A and B as two separate types of events. P(A|B), or “the probability of A given B” denotes the conditional probability of A occurring if we know that B has occurred. Likewise, P(B|A) denotes “the probability of B given A”. Bayes theorem relates P(A|B) and P(B|A) in a deceptively simple equation.

Derivation of Bayes rule

https://oracleaide.files.wordpress.com/2012/12/ovals_pink_and_blue1.png

From our definition of conditional probability, we know that P(A|B) can be defined as the probability that A occurs given that B has occured. This can be written mathematically as:

\[P(A|B) = \frac{P(A \cap B )}{P(B)}\]

Here, \(\cap\) denotes the intersection between A and B (i.e. “A AND B occur together”). To calculate the probability of A conditional on B, we first need to find the probability that B has occured. Then, we need to figure out out of the situations where B has occured, how often does A also occur?

In a similar way, we can write P(B|A) mathematically:

\[P(B|A) = \frac{P(B \cap A )}{P(A)}\]

Since \(P(B \cap A )=P(A \cap B)\) (does this make sense?), we can combine the two equations:

\[P(A|B)P(B) = P(B \cap A ) = P(B|A)P(A)\]

If we divide both sides by P(B):

\[P(A|B) = \frac{P(B|A )P(A)}{P(B)}\]

This is Bayes theorem! Notice that using this equation, we can connect the two conditional probabilities. Oftentimes, knowing this relationship is extremely useful because we will know P(B|A) but want to compute P(A|B). Let’s explore an example.

Applying Bayes theorem: Example of screening test

Let us assume that a patient named John goes to a see a doctor to undergo a screening test for an infectious disease. The test that is performed has been previously researched, and it is known to have a 99% reliability when administered to patients like John. In other words, 99% of sick people test positive in the test and 99% of healthy people test negative. The doctor has prior knowledge that 1% of people in general will have the disease in question. If the patient tests positive, what are the chances that he is sick?

50%

Exercise:

In a particular pain clinic, 10% of patients are prescribed narcotic pain killers. Overall, five percent of the clinic’s patients are addicted to narcotics (including pain killers and illegal substances). Out of all the people prescribed pain pills, 8% are addicts. If a patient is an addict, what is the probability that they will be prescribed pain pills? 16%

Components of Bayes rule

\[P(A|B) = \frac{P(B|A )P(A)}{P(B)}\]

If we define B to be our observed data (i.e. features of a mushroom), then Bayes theorem becomes:

\[P(A|Data) = \frac{P(Data|A )P(A)}{P(Data)}\]

Notice the \(P(Data|A)\) term is something we’ve talked about earlier, the likelihood. “The likelihood for a model is the probability of the data under the model.” With that in mind, we can now attach names to each of the terms in the equation.

  • P(A|Data) is known as the a posteriori probability, or the probability of a model A given some observations.

  • P(Data|A) is known as the likelihood.

  • P(A) is known as the prior probability, or the probability of A before we have the observations.

  • P(Data) is the prior probability that the data themselves are true.

Let us say we have two possible models \(A_p\) and \(A_e\) that could have generated our data that we would like to pick between.

\[P(A_p|Data) = \frac{P(Data|A_p )P(A_p)}{P(Data)}\]

and

\[P(A_e|Data) = \frac{P(Data|A_e )P(A_e)}{P(Data)}\]

If we would like to compare the probability that a certain set of data (or features of a mushroom) came from model \(A_e\) with the probability that the data came from model \(A_p\), we can simply compare their a posteriori probabilities \(P(A_e|Data)\) and \(P(A_p|Data)\). Additionally, we can notice that both of these a posteriori probabilities contain the same denominators, which means that the denominator can be ignored when comparing between them to see which is larger. Therefore, we can compare \(P(Data|A_e)P(A_e)\) and \(P(Data|A_p)P(A_p)\) to determine the model that has the higher probability of generating the observed data. Being able to ignore the denominator when using Bayes rule to compare between two (or more) models is an extremely convenient trick and simplifies things mathematically a great deal.

A Naive Bayes Classifier uses data that is fed into it to estimate prior probabilities and likelihoods to be used in Bayes rule. The classifier then uses Bayes rule to compute the posterior probability that a new observation belongs to each possible defined class given its other features. It then assigns the observation to the class that has the largest posterior probability (“Maximum A Posteriori Probability”, or MAP).

Exercise: Naive Bayes classifier to classify mushrooms

The reason why this is called a “Naive” Bayes classifier is that this algorithm assumes that each feature in the feature matrix is independent of all others and equal to all others in weight. For example, in our mushroom example a Naive Bayes classifier assumes that the cap shape of a mushroom does not depend on the cap color. Also, each feature carries equal weight in determining which class a row belongs to: there are no features that are irrelevant and each features contributes equally to the classification. Even though these conditions are almost never completely true in practice, the classifier still works very well in situations where these assumptions are violated.

Let’s walk through conceptually what the Naive Bayes algorithm does, and then look at how R can help perform each step.

Training a Naive Bayes algorithm using training data

As mentioned previously, training of a Naive Bayes algorithm essentially boils down to using training data to estimate values for prior probabilities (\(P(A_p)\) and \(P(A_e)\)) and likelihoods (\(P(Data|A_p)\) and \(P(Data|A_e)\)) to be inputted into Bayes rule.

1. Estimating Prior probabilities from training data

Question: What would be a good way to come up with a prior probability for a mushroom being poisonous if you have access to a labeled training dataset? Recall that a prior probability is the probability of a certain model before you have access to any data.

Simply calculate the proportion of observations in the training data that belong that are poisonous. Alternatively, one could also argue that we treat the prior probabilities as 1/2 and 1/2, treating each class as equally likely.

Pretty straightforward.

2. Estimating likelihoods from training data

Now, we need to calculate the other missing part of our Bayes rule equation: the likelihoods. A likelihood is similar to, but not equivalent to, a probability.

Likelihood vs probability

Probability

Recall from our previous class on probability distributions that the definition of probability can be visualized as the area under the curve of a probability distribution. For example, let’s say that we have a fair coin (P(heads) = 0.5) and we flip it 30 times:

library(ggplot2)
library(cowplot)

********************************************************
Note: As of version 1.0.0, cowplot does not change the
  default ggplot2 theme anymore. To recover the previous
  behavior, execute:
  theme_set(theme_cowplot())
********************************************************
library(grid)

x1  <- 5:25
df <- data.frame(x = x1, y = dbinom(x1, 30, 0.5))

ggplot(df, aes(x = x, y = y)) +
  geom_bar(stat = "identity", col = "red", fill = c("white")) +
  scale_y_continuous(expand = c(0.01, 0)) + xlab("number of heads") + ylab("Density")

Version Author Date
61e3892 Anthony Hung 2020-06-23
5cbe42c Anthony Hung 2020-04-23
310d040 Anthony Hung 2020-02-20
7c27654 Anthony Hung 2019-08-07

If we would like to find the probability that we would get more than 20 heads in 30 flips, we could calculate the area represented by bars that are greater than 18 on the x axis:

ggplot(df, aes(x = x, y = y)) +
  geom_bar(stat = "identity", col = "red", fill = c(rep("white", 14), rep("red", 7))) +
  scale_y_continuous(expand = c(0.01, 0)) + xlab("number of heads") + ylab("Density")

Version Author Date
61e3892 Anthony Hung 2020-06-23
5cbe42c Anthony Hung 2020-04-23
310d040 Anthony Hung 2020-02-20
7c27654 Anthony Hung 2019-08-07

Similarly, we could calculate the probability that we get between 9 and 13 heads:

ggplot(df, aes(x = x, y = y)) +
  geom_bar(stat = "identity", col = "red", fill = c(rep("white", 5), rep("red", 4), rep("white", 12))) +
  scale_y_continuous(expand = c(0.01, 0)) + xlab("number of heads") + ylab("Density")

Version Author Date
61e3892 Anthony Hung 2020-06-23
5cbe42c Anthony Hung 2020-04-23
310d040 Anthony Hung 2020-02-20
7c27654 Anthony Hung 2019-08-07

In each case, notice that the shape of the distribution does not change. The only thing that changes is the area that we shade in. In mathematical terms, in the first case we are calculating:

\[P(num\_heads > 20 | Binom(n=30, p=0.5))\]

and in the second:

\[P(9< num\_heads < 13 | Binom(n=30, p=0.5))\]

What is changing is the left side of the | . The shape of the distribution stays the same. When we discuss probabilities, we are talking about the areas under a fixed distribution (model).

Likelihood

So what about likelihood? Before we look at it graphically, let’s define what we mean by the term. “The likelihood for a model is the probability of the data under the model.” Mathematically,

\[L(Model;Data) = P(Data|Model)\]

This may look the same as what we did before, but in this case our data are fixed, not the distribution. Instead of asking, “If I keep my distribution constant, what is the probability of observing something?” with likelihood we are asking “Given that I have collected some data, how well does a certain distribution fit the data?”

Let’s assume the same situation we did for probability with the coin. In this case, we do not know if the coin is actually fair (P(heads = 0.5), or if it is rigged (e.g. P(heads = 0.6). We flip the coin 30 times and observe 20 heads.

What is the likelihood for our fair model (\(Binom(n=30, p=0.5)\)) given that we observe these data? In other words, how well does the model as paramterized fit our observations?

\[L(Model;Data) = P(num\_heads = 20|Binom(n=30, p=0.5))\]

Let’s look at this graphically.

ggplot(df, aes(x = x, y = y)) +
  geom_bar(stat = "identity", col = "red", fill = c(rep("white", 15), rep("red", 1), rep("white", 5))) +
  scale_y_continuous(expand = c(0.01, 0)) + xlab("number of heads") + ylab("Density")

Version Author Date
61e3892 Anthony Hung 2020-06-23
5cbe42c Anthony Hung 2020-04-23
310d040 Anthony Hung 2020-02-20
7c27654 Anthony Hung 2019-08-07

We can also compute the exact probability using the “dbinom” function in R.

dbinom(x = 20, size = 30, prob = 0.5)
[1] 0.0279816

Okay. How well does our data fit a rigged coin model, where the P(heads = 0.6)? What is the likelihood for the rigged coin model given our data?

\[L(Model;Data) = P(num\_heads = 25|Binom(n=30, p=0.6))\]

Let’s look at this graphically.

x1  <- 5:25
df_rigged <- data.frame(x = x1, y = dbinom(x1, 30, 0.6))

ggplot(df_rigged, aes(x = x, y = y)) +
  geom_bar(stat = "identity", col = "red", fill = c(rep("white", 15), rep("red", 1), rep("white", 5))) +
  scale_y_continuous(expand = c(0.01, 0)) + xlab("number of heads") + ylab("Density")

Version Author Date
61e3892 Anthony Hung 2020-06-23
5cbe42c Anthony Hung 2020-04-23
310d040 Anthony Hung 2020-02-20
7c27654 Anthony Hung 2019-08-07

We can also compute the exact probability using the “dbinom” function in R.

dbinom(x = 20, size = 30, prob = 0.6)
[1] 0.1151854

It looks like the likelihood for the rigged coin model is higher!

Determining likelihood equations

Now, let’s approach computing likelihoods for our mushroom example. The first thing we need to do is choose an appropriate distribution to model our data. Given that each of our features is a categorical variable that has either 2 or more than 2 levels, what is the best distribution to choose for the expression \(P(Data|A_p)\) for each individual feature (e.g. cap shape)?

For the features that are categorical variables with 2 levels, a Bernoulli likelihood is most appropriate (this is also equivalent to a binomial distribution with only 1 trial). For the features that are categorical variables with more than 2 levels, a multinomial likelihood is appropriate (or more specifically, the categorical distribution).

The multinomial distribution is a generalization of the binomial to cases where the number of categories is greater than 2. Recall that a binomial distribution describes the probability of observing k successes in n trials, with the result of each trial being either a success or a failure (2 possible categories of results). The multinomial distribution expands the number of possible categories of results to beyond 2. In other words, while you can use the binomial distribution to model the probability of observing a number of heads in n coin filps, you can use the multinomial distribution to model the probability of landing on any one of the sides of a 6-sided dice a certain number of times after n throws of the dice.

We can get the PMFs for the bernoulli and multinomial distributions from memory or wikipedia: https://en.wikipedia.org/wiki/Bernoulli_distribution, https://en.wikipedia.org/wiki/Multinomial_distribution. These are our likelihood equations.

Estimating likelihoods from training data

Now that we have determined our likelihood equations, we can move on to estimating the parameters that will fit into these equations from our training data.

In both cases (bernoulli and multinomial), the number of trials n will be 1. Why is this? Each mushroom will only have one value for each feature. For example, a mushroom cannot have a bell shaped and conical shaped cap at the same time, and cannot have a doubly-bell shaped cap.

All that is left to estimate, therefore, are the event probabilities p for the bernoulli or \(p_1, p_2, p_3, ... p_k\) for the multinomial. What is an intuitve way to estimate these probabilities from our training data? Much like what we did above for our prior probabiliites, we can simply find the proportion of mushrooms that are classified as poisonous in our training data that have bell shaped caps, conical shaped caps, etc. and do the same for the mushrooms that are classified as edible. These probabilities should add up to 1 within each feature within each class of mushroom. If you look hard enough at the PMFs for either the bernoulli or multinomial distributions where n = 1, you’ll be able to see that it’s actually very very easy to compute likelihoods for each individual feature.

3. Putting it all together: computing a posteriori probabilities using Bayes rule

Cool, now that we have all the pieces, we can use Bayes rule to calculate our posterior probabilities. Recall that since the denominators will always be the same across categories when using Bayes rule, we can omit them and instead use a simplified version of the equation:

\[Posterior \propto Prior * Likelihood\]

Recall also that the Naive in Naive Bayes means we are assuming that each of our features is independent. This means that to calculate the joint probabilities across all of features for an individual mushroom, we can simply take the product of the likelihoods for each individual feature (multiply many multinomial/bernoulli likelihoods together).

\[Posterior_{edible} \propto Prior_{edible} * Likelihood_{cap.shape\_edible} * Likelihood_{cap.surface\_edible} * Likelihood_{cap.color\_edible} * ... * Likelihood_{habitat\_edible}\]

or

\[P(A_{edible}|Data) \propto P(A_{edible}) * P(Data_{cap.shape}|A_{edible}) * P(Data_{cap.surface}|A_{edible}) * P(Data_{cap.color}|A_{edible}) * ... * P(Data_{habitat}|A_{edible})\]

Using R, we can easily obtain the numbers needed to estimate the prior probabilties and the event probabilities from our mushroom data using a few commands. We will not walk through individually calculating all the likelihoods because there are automated packages to do all the steps conveniently in R/other programming languages.

#prior probabilties:
summary(mush_data)
 class    cap.shape cap.surface   cap.color    bruises       odor     
 e:4208   b: 452    f:2320      n      :2284   f:4748   n      :3528  
 p:3916   c:   4    g:   4      g      :1840   t:3376   f      :2160  
          f:3152    s:2556      e      :1500            s      : 576  
          k: 828    y:3244      y      :1072            y      : 576  
          s:  32                w      :1040            a      : 400  
          x:3656                b      : 168            l      : 400  
                                (Other): 220            (Other): 484  
 gill.attachment gill.spacing gill.size   gill.color   stalk.shape stalk.root
 a: 210          c:6812       b:5612    b      :1728   e:3516      ?:2480    
 f:7914          w:1312       n:2512    p      :1492   t:4608      b:3776    
                                        w      :1202               c: 556    
                                        n      :1048               e:1120    
                                        g      : 752               r: 192    
                                        h      : 732                         
                                        (Other):1170                         
 stalk.surface.above.ring stalk.surface.below.ring stalk.color.above.ring
 f: 552                   f: 600                   w      :4464          
 k:2372                   k:2304                   p      :1872          
 s:5176                   s:4936                   g      : 576          
 y:  24                   y: 284                   n      : 448          
                                                   b      : 432          
                                                   o      : 192          
                                                   (Other): 140          
 stalk.color.below.ring veil.type veil.color ring.number ring.type
 w      :4384           p:8124    n:  96     n:  36      e:2776   
 p      :1872                     o:  96     o:7488      f:  48   
 g      : 576                     w:7924     t: 600      l:1296   
 n      : 512                     y:   8                 n:  36   
 b      : 432                                            p:3968   
 o      : 192                                                     
 (Other): 156                                                     
 spore.print.color population habitat 
 w      :2388      a: 384     d:3148  
 n      :1968      c: 340     g:2148  
 k      :1872      n: 400     l: 832  
 h      :1632      s:1248     m: 292  
 r      :  72      v:4040     p:1144  
 b      :  48      y:1712     u: 368  
 (Other): 144                 w: 192  
#subset data into poisonous and edible
poisonous <- mush_data[mush_data$class=="p",]
edible <- mush_data[mush_data$class=="e",]

#calculate event probabilities for cap.shape for each class
##poisonous
table(poisonous$cap.shape)

   b    c    f    k    s    x 
  48    4 1556  600    0 1708 
##edible
table(edible$cap.shape)

   b    c    f    k    s    x 
 404    0 1596  228   32 1948 

Additional notes relevant to training machine learning algorithms.

Before actually carrying out this algorithm in R on our data, we need to discuss some topics that are relevant to machine learning in general when using data to train and evaluate a model.

Another Issue: The issue of overfitting

If you’ve ever heard anyone talking about machine learning, you’ve probably heard of the term “overfitting.” Overfitting can be summed up by any situation where an algorithm learns from the training data … a bit too well. For example, if there are certain patterns in the data you have collected for your training data that are just flukes due to sampling error and do not accurately represent the actual relationships found in the wild, the algorithm you’ve trained will falsely apply these incorrect patterns to new classification problems, leading to inaccurate results. One way to test if your algorithm has overfit your data is to simply collect more labeled data that were not used to train your algorithm and see how well your trained algorithm performs in correctly classifying the new “test” data.

One way to assess overfitting: Splitting data into training vs test data sets (i.e. the “Hold out” method)

Of course, oftentimes the easier alternative to going out to collect new data is to just pretend that a subset of the data that is already collected is your “test” data set. By randomly partitioning your data into a “training set” and a “test set,” you can assess your algorithm’s ability to perform on previously unseen data to determine if it has overfit your training data. A common split is 20% test data and 80% training data.

An additional problem: Unlucky splits

Of course, the word “randomly” above may trigger warning bells, since the whole problem we are trying to avoid in the first place is sampling error! Therefore, by randomly partitioning our training and test sets, we run into the potential problem of unlucky splits in which our training set looks nothing like our test set, despite both having come from the same larger group of data. In this case, using our test set to evaluate the ability of our algorithm to perform on previously unseen data would be inaccurate.

Cross-validation to avoid unlucky splits

In order to avoid the above problem, we can employ cross-validation or ‘k-fold cross-validation’. Cross-validation involves randomly splitting our data into ‘k’ equal-sized groups. In each iteration of the process, one of the groups is used as the test set and the rest are used as the training set. The algorithm is trained on the training set and tested on the test set. Then, in the next iteration selects a different group to be used as the test set and rest are used as the training set. The process is repeated until each unique group has been used as the test set in an iteration. For example, in 5-fold cross-validation, we would split the dataset randomly into 5 groups, and run 5 iterations of training and testing. See the graphical depiction below.

https://miro.medium.com/max/1400/1*rgba1BIOUys7wQcXcL4U5A.png

Therefore, k-fold cross-validation is an effective way of comparing different models to select the one with the best performance for your data and problem.

Naive Bayes classifier in R

Finally, after discussing the theoretical basis behind the classifier and the necessity for training/test sets and cross-validation, we will be able to implement our algorithm with 5-fold cross-validation in R. Since others have written neat programs to automate most of the calculations for us, you will find that the implementation is actually quite simple to carry out.

#Note: a seed has been set already at the beginning of this document, which allows us to reproducibly recreate the random splits that we perform in this code chunk.

#load two libraries in R that are popular for implementing machine learning algorithms
library(e1071)
library(caret)
Loading required package: lattice
library(pROC) #for getting AUC
Type 'citation("pROC")' for a citation.

Attaching package: 'pROC'
The following objects are masked from 'package:stats':

    cov, smooth, var
#randomly subset our data into 5 groups (folds_i is a vector that contains the indices of observations included in each group)
n_folds <- 5
folds_i <- sample(rep(1:n_folds, length.out = nrow(mush_data)))
table(folds_i)
folds_i
   1    2    3    4    5 
1625 1625 1625 1625 1624 
#iterate through training and testing a naiveBayes classifier for all 5 cross-validation folds and print the results
for(k in 1:n_folds){
  #select which group is our test data and define our training/testing data
  test_i <- which(folds_i == k)
  train_data <- mush_data[-test_i, ]
  test_data <- mush_data[test_i, ]

  #train classifier
  classifier_nb <- naiveBayes(train_data[,-1], train_data$class)

  #use classifier to predict classes of test data
  nb_pred_probabilities <- predict(classifier_nb, newdata = test_data, type = "raw")

  nb_pred_classes <- predict(classifier_nb, newdata = test_data, type = "class")

  #assess the accuracy of the classifier on the test data using a confusion matrix
  print(confusionMatrix(nb_pred_classes,test_data$class))

  #assess the accuracy of the classifier on the test data using a ROC (function defined above)
  ROC <- roc(predictor=nb_pred_probabilities[,1],
                 response=test_data$class)
  print(ROC$auc)
  plot(ROC)
}
Confusion Matrix and Statistics

          Reference
Prediction   e   p
         e 836  91
         p   2 696
                                          
               Accuracy : 0.9428          
                 95% CI : (0.9303, 0.9536)
    No Information Rate : 0.5157          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.885           
                                          
 Mcnemar's Test P-Value : < 2.2e-16       
                                          
            Sensitivity : 0.9976          
            Specificity : 0.8844          
         Pos Pred Value : 0.9018          
         Neg Pred Value : 0.9971          
             Prevalence : 0.5157          
         Detection Rate : 0.5145          
   Detection Prevalence : 0.5705          
      Balanced Accuracy : 0.9410          
                                          
       'Positive' Class : e               
                                          
Setting levels: control = e, case = p
Setting direction: controls > cases
Area under the curve: 0.997
Confusion Matrix and Statistics

          Reference
Prediction   e   p
         e 847 100
         p   4 674
                                         
               Accuracy : 0.936          
                 95% CI : (0.923, 0.9474)
    No Information Rate : 0.5237         
    P-Value [Acc > NIR] : < 2.2e-16      
                                         
                  Kappa : 0.871          
                                         
 Mcnemar's Test P-Value : < 2.2e-16      
                                         
            Sensitivity : 0.9953         
            Specificity : 0.8708         
         Pos Pred Value : 0.8944         
         Neg Pred Value : 0.9941         
             Prevalence : 0.5237         
         Detection Rate : 0.5212         
   Detection Prevalence : 0.5828         
      Balanced Accuracy : 0.9331         
                                         
       'Positive' Class : e              
                                         
Setting levels: control = e, case = p
Setting direction: controls > cases

Version Author Date
61e3892 Anthony Hung 2020-06-23
ecd1380 Anthony Hung 2020-05-09
Area under the curve: 0.995
Confusion Matrix and Statistics

          Reference
Prediction   e   p
         e 831  97
         p  10 687
                                         
               Accuracy : 0.9342         
                 95% CI : (0.921, 0.9457)
    No Information Rate : 0.5175         
    P-Value [Acc > NIR] : < 2.2e-16      
                                         
                  Kappa : 0.8676         
                                         
 Mcnemar's Test P-Value : < 2.2e-16      
                                         
            Sensitivity : 0.9881         
            Specificity : 0.8763         
         Pos Pred Value : 0.8955         
         Neg Pred Value : 0.9857         
             Prevalence : 0.5175         
         Detection Rate : 0.5114         
   Detection Prevalence : 0.5711         
      Balanced Accuracy : 0.9322         
                                         
       'Positive' Class : e              
                                         
Setting levels: control = e, case = p
Setting direction: controls > cases

Version Author Date
61e3892 Anthony Hung 2020-06-23
ecd1380 Anthony Hung 2020-05-09
Area under the curve: 0.9939
Confusion Matrix and Statistics

          Reference
Prediction   e   p
         e 856  90
         p   7 672
                                          
               Accuracy : 0.9403          
                 95% CI : (0.9277, 0.9513)
    No Information Rate : 0.5311          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.8794          
                                          
 Mcnemar's Test P-Value : < 2.2e-16       
                                          
            Sensitivity : 0.9919          
            Specificity : 0.8819          
         Pos Pred Value : 0.9049          
         Neg Pred Value : 0.9897          
             Prevalence : 0.5311          
         Detection Rate : 0.5268          
   Detection Prevalence : 0.5822          
      Balanced Accuracy : 0.9369          
                                          
       'Positive' Class : e               
                                          
Setting levels: control = e, case = p
Setting direction: controls > cases

Version Author Date
61e3892 Anthony Hung 2020-06-23
ecd1380 Anthony Hung 2020-05-09
Area under the curve: 0.9953
Confusion Matrix and Statistics

          Reference
Prediction   e   p
         e 810  74
         p   5 735
                                          
               Accuracy : 0.9514          
                 95% CI : (0.9397, 0.9613)
    No Information Rate : 0.5018          
    P-Value [Acc > NIR] : <2e-16          
                                          
                  Kappa : 0.9027          
                                          
 Mcnemar's Test P-Value : 2e-14           
                                          
            Sensitivity : 0.9939          
            Specificity : 0.9085          
         Pos Pred Value : 0.9163          
         Neg Pred Value : 0.9932          
             Prevalence : 0.5018          
         Detection Rate : 0.4988          
   Detection Prevalence : 0.5443          
      Balanced Accuracy : 0.9512          
                                          
       'Positive' Class : e               
                                          
Setting levels: control = e, case = p
Setting direction: controls > cases

Version Author Date
61e3892 Anthony Hung 2020-06-23
ecd1380 Anthony Hung 2020-05-09
Area under the curve: 0.9976

Version Author Date
61e3892 Anthony Hung 2020-06-23
ecd1380 Anthony Hung 2020-05-09

We can appreciate that the algorithm performs fairly well in each iteration (>90% accuracy), meaning we can be pretty confident that it is not overfitting our data.

To better understand what the algorithm has determined (i.e. the individual event probabilities and prior probabilities that have been estimated), we can examine the output for one of the iterations of our cross validation (treating group 1 as our test set). Do you notice any patterns that jump out at you? Which variables do you think are most informative in terms of differentiating between edible and poisonous mushrooms?

#select which group is our test data and define our training/testing data
test_i <- which(folds_i == 1)
train_data <- mush_data[-test_i, ]
test_data <- mush_data[test_i, ]

#train classifier
classifier_nb <- naiveBayes(train_data[,-1], train_data$class)

#use classifier to predict classes of test data
nb_pred <- predict(classifier_nb, type = 'class', newdata = test_data)

#assess the accuracy of the classifier on the test data using a confusion matrix
print(confusionMatrix(nb_pred,test_data$class))
Confusion Matrix and Statistics

          Reference
Prediction   e   p
         e 836  91
         p   2 696
                                          
               Accuracy : 0.9428          
                 95% CI : (0.9303, 0.9536)
    No Information Rate : 0.5157          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.885           
                                          
 Mcnemar's Test P-Value : < 2.2e-16       
                                          
            Sensitivity : 0.9976          
            Specificity : 0.8844          
         Pos Pred Value : 0.9018          
         Neg Pred Value : 0.9971          
             Prevalence : 0.5157          
         Detection Rate : 0.5145          
   Detection Prevalence : 0.5705          
      Balanced Accuracy : 0.9410          
                                          
       'Positive' Class : e               
                                          
#print a summary of the classifier
print(classifier_nb)

Naive Bayes Classifier for Discrete Predictors

Call:
naiveBayes.default(x = train_data[, -1], y = train_data$class)

A-priori probabilities:
train_data$class
        e         p 
0.5185413 0.4814587 

Conditional probabilities:
                cap.shape
train_data$class           b           c           f           k           s
               e 0.097032641 0.000000000 0.381899110 0.055489614 0.007715134
               p 0.012144455 0.001278364 0.395973154 0.153403643 0.000000000
                cap.shape
train_data$class           x
               e 0.457863501
               p 0.437200384

                cap.surface
train_data$class           f           g           s           y
               e 0.374777448 0.000000000 0.273590504 0.351632047
               p 0.197826782 0.001278364 0.356983062 0.443911793

                cap.color
train_data$class           b           c           e           g           n
               e 0.012462908 0.007418398 0.148664688 0.243620178 0.298516320
               p 0.029721956 0.003195909 0.224033237 0.210930010 0.259827421
                cap.color
train_data$class           p           r           u           w           y
               e 0.013649852 0.003857567 0.004154303 0.172403561 0.095252226
               p 0.022371365 0.000000000 0.000000000 0.079258549 0.170661553

                bruises
train_data$class         f         t
               e 0.3548961 0.6451039
               p 0.8443592 0.1556408

                odor
train_data$class          a          c          f          l          m
               e 0.09169139 0.00000000 0.00000000 0.09495549 0.00000000
               p 0.00000000 0.05017578 0.55193353 0.00000000 0.01022691
                odor
train_data$class          n          p          s          y
               e 0.81335312 0.00000000 0.00000000 0.00000000
               p 0.02972196 0.06295941 0.14988814 0.14509428

                gill.attachment
train_data$class           a           f
               e 0.045994065 0.954005935
               p 0.005113455 0.994886545

                gill.spacing
train_data$class          c          w
               e 0.70593472 0.29406528
               p 0.96963886 0.03036114

                gill.size
train_data$class          b          n
               e 0.92818991 0.07181009
               p 0.43560243 0.56439757

                gill.color
train_data$class           b           e           g           h           k
               e 0.000000000 0.024925816 0.060534125 0.048961424 0.082492582
               p 0.440396293 0.000000000 0.131032279 0.134547779 0.014062001
                gill.color
train_data$class           n           o           p           r           u
               e 0.217804154 0.016023739 0.200890208 0.000000000 0.105341246
               p 0.028124001 0.000000000 0.163630553 0.006072228 0.012464046
                gill.color
train_data$class           w           y
               e 0.227596439 0.015430267
               p 0.063598594 0.006072228

                stalk.shape
train_data$class         e         t
               e 0.3878338 0.6121662
               p 0.4880153 0.5119847

                stalk.root
train_data$class          ?          b          c          e          r
               e 0.17833828 0.44985163 0.11869436 0.20949555 0.04362018
               p 0.44646852 0.47778843 0.01278364 0.06295941 0.00000000

                stalk.surface.above.ring
train_data$class           f           k           s           y
               e 0.100296736 0.035905045 0.859940653 0.003857567
               p 0.034196229 0.571108981 0.392138063 0.002556727

                stalk.surface.below.ring
train_data$class          f          k          s          y
               e 0.11186944 0.03620178 0.80445104 0.04747774
               p 0.03323746 0.55640780 0.39149888 0.01885586

                stalk.color.above.ring
train_data$class           b           c           e           g           n
               e 0.000000000 0.000000000 0.024035608 0.137388724 0.003857567
               p 0.107062959 0.010226910 0.000000000 0.000000000 0.111217641
                stalk.color.above.ring
train_data$class           o           p           w           y
               e 0.045994065 0.129970326 0.658753709 0.000000000
               p 0.000000000 0.342921061 0.426014701 0.002556727

                stalk.color.below.ring
train_data$class           b           c           e           g           n
               e 0.000000000 0.000000000 0.024925816 0.139465875 0.016023739
               p 0.110258869 0.010226910 0.000000000 0.000000000 0.112496005
                stalk.color.below.ring
train_data$class           o           p           w           y
               e 0.045994065 0.129080119 0.644510386 0.000000000
               p 0.000000000 0.340044743 0.421860019 0.005113455

                veil.type
train_data$class p
               e 1
               p 1

                veil.color
train_data$class           n           o           w           y
               e 0.024332344 0.021661721 0.954005935 0.000000000
               p 0.000000000 0.000000000 0.997443273 0.002556727

                ring.number
train_data$class          n          o          t
               e 0.00000000 0.86795252 0.13204748
               p 0.01022691 0.97091723 0.01885586

                ring.type
train_data$class          e          f          l          n          p
               e 0.24718101 0.01186944 0.00000000 0.00000000 0.74094955
               p 0.44902525 0.00000000 0.33493129 0.01022691 0.20581655

                spore.print.color
train_data$class          b          h          k          n          o
               e 0.01038576 0.01186944 0.38427300 0.41305638 0.01127596
               p 0.00000000 0.40651965 0.05496964 0.05816555 0.00000000
                spore.print.color
train_data$class          r          u          w          y
               e 0.00000000 0.01246291 0.14421365 0.01246291
               p 0.01885586 0.00000000 0.46148929 0.00000000

                population
train_data$class          a          c          n          s          v
               e 0.09376855 0.07210682 0.09436202 0.21038576 0.28011869
               p 0.00000000 0.01502077 0.00000000 0.09395973 0.72962608
                population
train_data$class          y
               e 0.24925816
               p 0.16139342

                habitat
train_data$class           d           g           l           m           p
               e 0.440059347 0.337091988 0.058160237 0.061424332 0.030860534
               p 0.322786833 0.187919463 0.150527325 0.008948546 0.259827421
                habitat
train_data$class           u           w
               e 0.023442136 0.048961424
               p 0.069990412 0.000000000

sessionInfo()
R version 3.6.3 (2020-02-29)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS Catalina 10.15.5

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] grid      stats     graphics  grDevices utils     datasets  methods  
[8] base     

other attached packages:
[1] pROC_1.16.2     caret_6.0-86    lattice_0.20-38 e1071_1.7-3    
[5] cowplot_1.0.0   ggplot2_3.3.0  

loaded via a namespace (and not attached):
 [1] tidyselect_1.0.0     xfun_0.12            reshape2_1.4.3      
 [4] purrr_0.3.4          splines_3.6.3        colorspace_1.4-1    
 [7] vctrs_0.2.4          generics_0.0.2       stats4_3.6.3        
[10] htmltools_0.4.0      yaml_2.2.1           prodlim_2019.11.13  
[13] survival_3.1-8       rlang_0.4.5          ModelMetrics_1.2.2.2
[16] later_1.0.0          pillar_1.4.3         glue_1.4.0          
[19] withr_2.1.2          foreach_1.4.7        lifecycle_0.2.0     
[22] plyr_1.8.5           lava_1.6.6           stringr_1.4.0       
[25] timeDate_3043.102    munsell_0.5.0        gtable_0.3.0        
[28] workflowr_1.5.0      recipes_0.1.11       codetools_0.2-16    
[31] evaluate_0.14        labeling_0.3         knitr_1.26          
[34] httpuv_1.5.2         class_7.3-15         Rcpp_1.0.4.6        
[37] promises_1.1.0       scales_1.1.0         backports_1.1.6     
[40] ipred_0.9-9          farver_2.0.3         fs_1.3.1            
[43] digest_0.6.25        stringi_1.4.5        dplyr_0.8.5         
[46] rprojroot_1.3-2      tools_3.6.3          magrittr_1.5        
[49] tibble_3.0.1         crayon_1.3.4         whisker_0.4         
[52] pkgconfig_2.0.3      MASS_7.3-51.5        ellipsis_0.3.0      
[55] Matrix_1.2-18        data.table_1.12.8    lubridate_1.7.4     
[58] gower_0.2.1          assertthat_0.2.1     rmarkdown_1.18      
[61] iterators_1.0.12     R6_2.4.1             rpart_4.1-15        
[64] nnet_7.3-12          nlme_3.1-144         git2r_0.26.1        
[67] compiler_3.6.3