In Machine Learning, the Random Forest algorithm is an ensemble learning method for classification and regression tasks. It is composed of multiple decision trees, where each tree is built on a random subset of the input features and a random subset of the training data. So, if you are new to machine learning and want to know how the Random Forest algorithm works, this article is for you. In this article, I will introduce how the Random Forest algorithm works and how to implement it using Python.
Here’s How Random Forest Algorithm Works
Random Forests are known for their ability to handle high-dimensional data, handle missing values, and perform feature selection. Let’s understand how the Random Forest algorithm works by taking an example of a real-time business problem.
Suppose you own a store and want to predict which products will be popular with your customers. You have a lot of data about your customers, such as their age, gender, and previous purchases. You also have product data such as price, category and features.
Now, to predict a product, you can seek the opinion of a panel of experts. Each expert may have a different point of view depending on their knowledge and experience. For example, one expert may be good at predicting which products are popular with young people, while another expert may be good at predicting which products are popular with women.
The Random Forest algorithm works similarly. It combines the opinions of many “experts”, called decision trees. Each decision tree looks at a subset of data and predicts based on predefined rules. For example, a decision tree might indicate that if a customer is female and has purchased a product in the same category before, she is likely to purchase a new product in that category.
When you have many decision trees, you can predict by taking the average of their predictions. It helps reduce the impact of an individual tree making a mistake. For example, if one tree predicts that a product will be very popular, but most of the other trees predict that it will be average, then the Random Forest algorithm will predict that it will be average.
Thus, the Random Forest algorithm is an ensemble learning method for classification and regression tasks. It is composed of multiple decision trees, where each tree is built on a random subset of the input features and a random subset of the training data. During the learning phase, each decision tree is built by recursively partitioning the data into subsets based on the values of the input features. Split points are chosen to maximize information gain or decrease impurity in each node. The final result of the algorithm is obtained by aggregating the predictions of all the trees, typically taking a majority vote for classification problems or the average for regression problems.
Implementation of Random Forest Algorithm using Python
Now let’s see how to implement the Random Forest algorithm using Python. To implement it using Python, we can use the scikit-learn library in Python, which provides the functionality of implementing all Machine Learning algorithms and concepts using Python.
Let’s first import the necessary Python libraries and create a sample data based on the example we discussed above:
import pandas as pd from sklearn.ensemble import RandomForestClassifier # Define sample data data = {'age': [25, 30, 35, 40, 45, 50, 55, 60], 'gender': ['M', 'F', 'F', 'M', 'F', 'M', 'F', 'M'], 'previous_purchase': ['yes', 'no', 'yes', 'no', 'yes', 'no', 'yes', 'no'], 'category': ['electronics', 'books', 'fashion', 'electronics', 'fashion', 'books', 'electronics', 'fashion'], 'price': [100, 50, 80, 150, 70, 60, 120, 90], 'popular': [1, 0, 1, 0, 1, 0, 1, 0]} data = pd.DataFrame(data) print(data)
age gender previous_purchase category price popular 0 25 M yes electronics 100 1 1 30 F no books 50 0 2 35 F yes fashion 80 1 3 40 M no electronics 150 0 4 45 F yes fashion 70 1 5 50 M no books 60 0 6 55 F yes electronics 120 1 7 60 M no fashion 90 0
Now here’s how to train a Machine Learning model using the Random Forest algorithm:
# Split data into features and target X = data.drop(['popular'], axis=1) y = data['popular'] # One-hot encode categorical columns from sklearn.preprocessing import OneHotEncoder cat_cols = ['gender', 'previous_purchase', 'category'] encoder = OneHotEncoder(drop='first', sparse=False) X_cat = encoder.fit_transform(X[cat_cols]) X_cat_df = pd.DataFrame(X_cat, columns=encoder.get_feature_names_out(cat_cols)) X = pd.concat([X.drop(cat_cols, axis=1), X_cat_df], axis=1) # Train the model rf = RandomForestClassifier(n_estimators=10, random_state=42) rf.fit(X, y)
Now here’s how we can make predictions on new data:
# New Data new_samples = {'age': [25, 30, 35], 'gender': ['F', 'M', 'F'], 'previous_purchase': ['no', 'yes', 'no'], 'category': ['fashion', 'books', 'electronics'], 'price': [80, 50, 120]} # One-hot encode categorical columns in new data new_cat = encoder.transform(new_df[cat_cols]) new_cat_df = pd.DataFrame(new_cat, columns=encoder.get_feature_names_out(cat_cols)) new_df = pd.concat([new_df.drop(cat_cols, axis=1), new_cat_df], axis=1) predictions = rf.predict(new_df) print(predictions)
[0 1 0]
So this is how the Random Forest algorithm works.
Advantages and Disadvantages of Random Forest Algorithm
Advantages:
- Random Forest can handle a large number of input features, including both categorical and numerical features.
- It is less prone to overfitting compared to other machine learning algorithms, such as decision trees.
Disadvantages:
- Random Forest is a black box model, which means that it is difficult to understand how it makes its predictions.
- It may take longer to train compared to other machine learning algorithms, especially on large datasets.
Summary
The Random Forest algorithm is an ensemble learning method for classification and regression tasks. It is composed of multiple decision trees, where each tree is built on a random subset of the input features and a random subset of the training data. I hope you liked this article on how the Random Forest algorithm works and how to implement it using Python. Feel free to ask valuable questions in the comments section below.
Loved how you explained, i am from non tech background but got 90% of the understanding. Thank you. Also as a request if possible, please guide me how I can learn machine learning concepts through coding which im facing a huge challenge.