Churn Prediction Using Machine Learning


With the advancement of internet in the last 10 years, it makes more company tend to give subscription-based service rather than perpetual license software. Why? Because it gives the company more sustain income rather than one time income. And on the customer side, it also gives us much cheaper service rather than have to spend a lot upfront.

As the company getting more user on their subscription-based service, they need to figure out how to make the user stay since continuous income is their main goal in this kind of service. One of the ways to do so, is to predict when subscriber tend to stop subscription, or we call it “churn”. So, in this article I will try to predict user churn based on user activity on app using machine learning model.

If you want to check my source code, please


The resource I use in this course is Sparkify dataset provided by Udacity. Sparkify is fiction music streaming app made by Udacity team. The dataset contains 12 GB of users' activity in JSON format with each row represent one user event at the time.

It has a lot of columns but below is the main data column:

  1. userId
    unique ID for each user
  2. page
    It shows what page user visit and also means as their activity
  3. ts
    timestamp when the event happens
  4. userAgent
    userAgent string that contains what device the user connects with
  5. song
    the song names
  6. artist
    artist of the song in the event
  7. sessionId
    session ID until the user logout
  8. itemInSession
    number of events in one session

Sparkify has two service level, the first one is free version when you can hear music but needs to hear advertisement every few songs and the second is paid version when you pay to hear music without advertisement.

Churn Definition

From the dataset above, we need to define what churn is. Since page name also describe what the users is doing at the time, I found that there are two pages’ name called “Submit Upgrade” and “Submit Downgrade”. I think both pages describe when user upgrade the service into paid version and downgrade into free version respectively.

Dataset Summary

How many registered users in the service?
The service has 12.082 active user

What artist with the most played song?
the first is Kings of Leon, second is Coldplay and the third Florence + The Machine.

What device users use?
the most used is Windows and the second one is MacOS

How many paid users currently?
since one user can upgrade and downgrade multiple time, the only reasonable way to count this is to subscript the number of “Submit Upgrade” event, 12082, with “Submit Downgrade” event, 3881, then we got 8201 paid users.

Data Preparation

The first thing I do is to check for event created by unregistered user. We will remove this data since we are only interested in registered user data to predict churn.

Second, remove event that are produced by user that are on the given time event is not on premium plan. Also, this event is not useful for our churn prediction case.

Third, I convert userAgent string value into the device the user use since the raw data string contain not usefull information for us to make this data categorized. This will be used to predict whether device preferences influence user churn.

Feature Extraction

In my opinion, feature extraction is the most important thing to do in this part. Why? because it will determine how good our model later on. So, extracting the right feature is the key.

Based on Sparkify dataset, I extract the features below:

  1. Number of songs played during subscription phase
  2. How long is user on subscription phase
  3. Number of songs played per day on subscription phase
  4. Number of songs added into playlist
  5. Number of “Thumbs Up” and “Thumb Down” given
  6. Number of sessions in subscription phase
  7. Average time user spent on one session
  8. What device user use
  9. Number of errors happened
  10. Number of friends added
  11. Number of “Cancellation Confirmation” page visited
  12. Number of unique songs and artists heard on subscription phase

Please note, “subscription phase” is term I use in this case to define phase when user upgrades the service until it downgrades, if they churn, or until the maximum timestamp event on the dataset if the user doesn’t churn.

First, how I impute the null value? For categorical column, I user one-hot encoding to give better result. If user is not having that category on given column I will fill with False value.

Second, do we need to normalize the data? Since in this case I decide not to use regression-based machine learning model, then normalization won’t give us much different. Then, the reason I don’t use regression-based ML model will be explained later.

Dataset with feature extracted

Model Training

First, what model algorithm we will use in this problem? For classification, there are a lot of algorithm to solve it, but I will use decision-tree-based algorithm. Why? First, compared to regression-based algorithm, it can give much better model since it can find the “sweet spot” on our features value. Second, compared to deep learning model, decision-tree-based model still can give us better causality.

In this case I will try to train the model using random forest and gradient boosted tree algorithms. And do evaluation on both and see which model perform better.

Second, I split the dataset into 9:1 ratio. the 90% dataset will be used for training and validation purpose and the rest 10% for testing.

I also do hyperparameter tuning for this model to find the best combination. In this problem, I tune number of trees and tree depth parameter. Then choose the best model based on area under ROC metric.

After training, both random forest and GBT algorithms have the best model with 25 trees and 7 tree-depth. It also shows us the most important features for the model. The features are subscription duration, average number of songs played per day, number of sessions on subscription duration and number of “Cancel Confirmation” page visited.

Number of users distribution on subscription duration

It is in line with the graphic above showing people who churn mostly have shorter subscription phase duration. And users who don’t tend to be more loyal.

Model Evaluation

In this case, I will use recall score and f1 score as the most important metric since it tells us how good we predict churning user and model prediction overall. But I will also show accuracy, precision, and area under ROC to make you have clearer vision

After training, we need to evaluate how good both our models are in predicting churn on test dataset. The result can be seen on the images below.

evaluation metrics from random forest model
evaluation metrics form gradient boosted tree model
confusion matrix from both of the model

As you can see above, the metric shows 77.3% and 77.9% of accuracy for GBT and RF model, the number looked like it’s quite good but if you see the confusion matrix it’s clearly dominated by the number of true negative result. But when you see the f1 score, it’s just 63.3% and65.2%. Which is ok but we still can improve it.

On churn prediction problem, we are more concerned on true positive compared to false negative result (recall score) rather than true negative result since we want to prevent user from churning. So, we need more “aggressive” model to predict true positive.

Weighted Model Training and Evaluation

To train weighted model, first I need to modify training dataset to have weight column. In this case we will give negative label 0.7 weight and positive label 1 weight.

Dataset with weight column

After a little change in model algorithm to calculate the weight, we get different result. Each model has 20 trees and 7 tree depth. The most importance feature still the same as before but with different value.

Weighted RF model evaluation metrics
Weighted GBT model evaluation metrics
weighted model confusion matrix

After retraining it gives as much better result. All the metrics on both models are improved as the picture above compared to model before. Especially, we get big improvement on recall score. If we see on confusion matrix, we are clearly predicting churn more aggressively based on increased number of true positive and false positive and decreased number of false negative and true negative.

Between random forest and gradient boosted tree, it shows that gradient boosted tree has better result on unweighted and weighted training. So, it would be better if we deploy gradient boosted tree model to predict churn on production environment.

So, why these weighted models give better result than before? Because the weight column I gave before model training actually do as “punisher” during training for 0 label row. It will make model training more focused on 1 label data instead of 0 label data. And this exactly make the model more “aggressive” to predict churn.

Model Inference

If we deploy our best model, weighted GBT model, and have 10.000 active paid plan user, then we will have:

  1. 3.164 users will churn, 2.406 users will be predicted churning and 758 users will not
  2. 6.836 users will stay on paid plan, while 1.318 users will be predicted wrongly as churning

Let’s talk about money

Maybe some of you say “Hey, what does that number even mean?”. To make it more understandable and you get the idea why we need to retrain with weighted label, let’s talk about money.

Let’s say you are the CEO on Sparkify. To prevent users from churning you decide to give 40% subscription fee discount. And let’s assume the subscription fee is 10 USD. Then, every user predicted as positive, you will give them discount. Let’s look how this model affect your company income.

Prediction influence on 40% discount

First, on true positive group, we give them 40% discount and let’s assume all of them take the discount then we will save 6 USD per user. 6 USD is compared if we don’t give them discount and we will lose 10 USD.

Second, on false positive group, since they are actually loyal to our service yet we label them as churning user and give them discount, we basically lose 4 USD for every user in this group.

Third, on false negative group, we lost 10 USD per user since they will stop using our paid service.

Then we do simple math using weighted GBT model, 365 * 6 + (-10 * 115) + (-4 * 200) = 240. It means that we are still gaining 240 USD even though we are giving 40% discount on predicted churn. If we use unweighted GBT model, we will lose 440 USD while giving predicted churn 40% discount

Gain and loss comparison for each model


Let’s take a look to all the steps:

  1. Dataset used in this project is Sparkify dataset produced by Udacity team
  2. We try to predict churn user based on given dataset
  3. Extract the feature from data to make model much better
  4. After training we found out that subscription duration, songs played per day, session count and “cancellation confirmation” page visit is the top most important feature
  5. On model evaluation, first we get not so good result but after we give label weight then we get much better result.
  6. After we talk in money, we see how different the result is if we know what to tune in our model

Future Improvements

  1. The 0.7 value on weighted dataset I use is just a random number I picked. On the future we need to tune this parameter too.
  2. Try with more algorithm and grid on hyper parameter tuning. I have limited resource when doing this case. So, if I have more resource that’s the first thing I will do.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store