Breast Cancer Survival Prediction with Machine Learning

Breast cancer is one of the types of cancer that starts in the breast. It occurs in women, but men can get breast cancer too. It is the second leading cause of death in women. As the use of data in healthcare is very common today, we can use machine learning to predict whether a patient will survive a deadly disease like breast cancer or not. So if you want to learn how to predict the survival of a breast cancer patient, this article is for you. In this article, I will take you through the task of breast cancer survival prediction with machine learning using Python.

Breast Cancer Survival Prediction with Machine Learning

You have a dataset of over 400 breast cancer patients who underwent surgery for the treatment of breast cancer. Below is the information of all columns in the dataset:

  1. Patient_ID: ID of the patient
  2. Age: Age of the patient
  3. Gender: Gender of the patient
  4. Protein1, Protein2, Protein3, Protein4: expression levels
  5. Tumor_Stage: Breast cancer stage of the patient
  6. Histology: Infiltrating Ductal Carcinoma, Infiltration Lobular Carcinoma, Mucinous Carcinoma
  7. ER status: Positive/Negative
  8. PR status: Positive/Negative
  9. HER2 status: Positive/Negative
  10. Surgery_type: Lumpectomy, Simple Mastectomy, Modified Radical Mastectomy, Other
  11. DateofSurgery: The date of Surgery
  12. DateofLast_Visit: The date of the last visit of the patient
  13. Patient_Status: Alive/Dead

So by using this dataset, our task is to predict whether a breast cancer patient will survive or not after the surgery.

I hope you have an overview of the dataset we are using for the task of breast cancer survival prediction. This dataset was collected from Kaggle. You can download this dataset from here. Now, in the section below, I will walk you through the task of predicting breast cancer survival with machine learning using Python.

Breast Cancer Survival Prediction using Python

I will start the task of breast cancer survival prediction by importing the necessary Python libraries and the dataset we need:

import pandas as pd
import numpy as np
import plotly.express as px
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

data = pd.read_csv("BRCA.csv")
print(data.head())
     Patient_ID   Age  Gender  Protein1  Protein2  Protein3  Protein4  \
0  TCGA-D8-A1XD  36.0  FEMALE  0.080353   0.42638   0.54715  0.273680   
1  TCGA-EW-A1OX  43.0  FEMALE -0.420320   0.57807   0.61447 -0.031505   
2  TCGA-A8-A079  69.0  FEMALE  0.213980   1.31140  -0.32747 -0.234260   
3  TCGA-D8-A1XR  56.0  FEMALE  0.345090  -0.21147  -0.19304  0.124270   
4  TCGA-BH-A0BF  56.0  FEMALE  0.221550   1.90680   0.52045 -0.311990   

  Tumour_Stage                      Histology ER status PR status HER2 status  \
0          III  Infiltrating Ductal Carcinoma  Positive  Positive    Negative   
1           II             Mucinous Carcinoma  Positive  Positive    Negative   
2          III  Infiltrating Ductal Carcinoma  Positive  Positive    Negative   
3           II  Infiltrating Ductal Carcinoma  Positive  Positive    Negative   
4           II  Infiltrating Ductal Carcinoma  Positive  Positive    Negative   

                  Surgery_type Date_of_Surgery Date_of_Last_Visit  \
0  Modified Radical Mastectomy       15-Jan-17          19-Jun-17   
1                   Lumpectomy       26-Apr-17          09-Nov-18   
2                        Other       08-Sep-17          09-Jun-18   
3  Modified Radical Mastectomy       25-Jan-17          12-Jul-17   
4                        Other       06-May-17          27-Jun-19   

  Patient_Status  
0          Alive  
1           Dead  
2          Alive  
3          Alive  
4           Dead

Let’s have a look at whether the columns of this dataset contains any null values or not:

print(data.isnull().sum())
Patient_ID             7
Age                    7
Gender                 7
Protein1               7
Protein2               7
Protein3               7
Protein4               7
Tumour_Stage           7
Histology              7
ER status              7
PR status              7
HER2 status            7
Surgery_type           7
Date_of_Surgery        7
Date_of_Last_Visit    24
Patient_Status        20
dtype: int64

So this dataset has some null values in each column, I will drop these null values:

data = data.dropna()

Now let’s have a look at the insights about the columns of this data:

data.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 317 entries, 0 to 333
Data columns (total 16 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   Patient_ID          317 non-null    object 
 1   Age                 317 non-null    float64
 2   Gender              317 non-null    object 
 3   Protein1            317 non-null    float64
 4   Protein2            317 non-null    float64
 5   Protein3            317 non-null    float64
 6   Protein4            317 non-null    float64
 7   Tumour_Stage        317 non-null    object 
 8   Histology           317 non-null    object 
 9   ER status           317 non-null    object 
 10  PR status           317 non-null    object 
 11  HER2 status         317 non-null    object 
 12  Surgery_type        317 non-null    object 
 13  Date_of_Surgery     317 non-null    object 
 14  Date_of_Last_Visit  317 non-null    object 
 15  Patient_Status      317 non-null    object 
dtypes: float64(5), object(11)
memory usage: 42.1+ KB

Breast cancer is mostly found in females, so let’s have a look at the Gender column to see how many females and males are there:

print(data.Gender.value_counts())
FEMALE    313
MALE        4
Name: Gender, dtype: int64

As expected, the proportion of females is more than males in the gender column. Now let’s have a look at the stage of tumour of the patients:

# Tumour Stage
stage = data["Tumour_Stage"].value_counts()
transactions = stage.index
quantity = stage.values

figure = px.pie(data, 
             values=quantity, 
             names=transactions,hole = 0.5, 
             title="Tumour Stages of Patients")
figure.show()
stage of tumour of the breast cancer patients

So most of the patients are in the second stage. Now let’s have a look at the histology of breast cancer patients. (Histology is a description of a tumour based on how abnormal the cancer cells and tissue look under a microscope and how quickly cancer can grow and spread):

# Histology
histology = data["Histology"].value_counts()
transactions = histology.index
quantity = histology.values
figure = px.pie(data, 
             values=quantity, 
             names=transactions,hole = 0.5, 
             title="Histology of Patients")
figure.show()
histology of breast cancer patients

Now let’s have a look at the values of ER status, PR status, and HER2 status of the patients:

# ER status
print(data["ER status"].value_counts())
# PR status
print(data["PR status"].value_counts())
# HER2 status
print(data["HER2 status"].value_counts())
Positive    317
Name: ER status, dtype: int64
Positive    317
Name: PR status, dtype: int64
Negative    288
Positive     29
Name: HER2 status, dtype: int64

Now let’s have a look at the type of surgeries done to the patients:

# Surgery_type
surgery = data["Surgery_type"].value_counts()
transactions = surgery.index
quantity = surgery.values
figure = px.pie(data, 
             values=quantity, 
             names=transactions,hole = 0.5, 
             title="Type of Surgery of Patients")
figure.show()
breast cancer survival prediction

So we explored the data, the dataset has a lot of categorical features. To use this data to train a machine learning model, we need to transform the values of all the categorical columns. Here is how we can transform values of the categorical features:

data["Tumour_Stage"] = data["Tumour_Stage"].map({"I": 1, "II": 2, "III": 3})
data["Histology"] = data["Histology"].map({"Infiltrating Ductal Carcinoma": 1, 
                                           "Infiltrating Lobular Carcinoma": 2, "Mucinous Carcinoma": 3})
data["ER status"] = data["ER status"].map({"Positive": 1})
data["PR status"] = data["PR status"].map({"Positive": 1})
data["HER2 status"] = data["HER2 status"].map({"Positive": 1, "Negative": 2})
data["Gender"] = data["Gender"].map({"MALE": 0, "FEMALE": 1})
data["Surgery_type"] = data["Surgery_type"].map({"Other": 1, "Modified Radical Mastectomy": 2, 
                                                 "Lumpectomy": 3, "Simple Mastectomy": 4})
print(data.head())
     Patient_ID   Age  Gender  Protein1  Protein2  Protein3  Protein4  \
0  TCGA-D8-A1XD  36.0       1  0.080353   0.42638   0.54715  0.273680   
1  TCGA-EW-A1OX  43.0       1 -0.420320   0.57807   0.61447 -0.031505   
2  TCGA-A8-A079  69.0       1  0.213980   1.31140  -0.32747 -0.234260   
3  TCGA-D8-A1XR  56.0       1  0.345090  -0.21147  -0.19304  0.124270   
4  TCGA-BH-A0BF  56.0       1  0.221550   1.90680   0.52045 -0.311990   

   Tumour_Stage  Histology  ER status  PR status  HER2 status  Surgery_type  \
0             3          1          1          1            2             2   
1             2          3          1          1            2             3   
2             3          1          1          1            2             1   
3             2          1          1          1            2             2   
4             2          1          1          1            2             1   

  Date_of_Surgery Date_of_Last_Visit Patient_Status  
0       15-Jan-17          19-Jun-17          Alive  
1       26-Apr-17          09-Nov-18           Dead  
2       08-Sep-17          09-Jun-18          Alive  
3       25-Jan-17          12-Jul-17          Alive  
4       06-May-17          27-Jun-19           Dead  

Breast Cancer Survival Prediction Model

We can now move on to training a machine learning model to predict the survival of a breast cancer patient. Before training the model, we need to split the data into training and test set:

# Splitting data
x = np.array(data[['Age', 'Gender', 'Protein1', 'Protein2', 'Protein3','Protein4', 
                   'Tumour_Stage', 'Histology', 'ER status', 'PR status', 
                   'HER2 status', 'Surgery_type']])
y = np.array(data[['Patient_Status']])
xtrain, xtest, ytrain, ytest = train_test_split(x, y, test_size=0.10, random_state=42)

Now here’s how we can train a machine learning model:

model = SVC()
model.fit(xtrain, ytrain)

Now let’s input all the features that we have used to train this machine learning model and predict whether a patient will survive from breast cancer or not:

# Prediction
# features = [['Age', 'Gender', 'Protein1', 'Protein2', 'Protein3','Protein4', 'Tumour_Stage', 'Histology', 'ER status', 'PR status', 'HER2 status', 'Surgery_type']]
features = np.array([[36.0, 1, 0.080353, 0.42638, 0.54715, 0.273680, 3, 1, 1, 1, 2, 2,]])
print(model.predict(features))
['Alive']

Summary

So this is how we can use machine learning for the task of breast cancer survival prediction. As the use of data in healthcare is very common today, we can use machine learning to predict whether a patient will survive a deadly disease like breast cancer or not. I hope you liked this article on Breast cancer survival prediction with machine learning using Python. Feel free to ask valuable questions in the comments section below.

Default image
Aman Kharwal

Coder with the ♥️ of a Writer || Data Scientist | Solopreneur | Founder

Articles: 1211

4 Comments

Leave a Reply