Photo by Isaac Smith on Unsplash

Linear Regression Using Gradient Descent: Intuition and Implementation

Ali H Khanafer
Published in
9 min readMay 10, 2021

--

This is the second article of a series I’m working on, in which we’ll discuss and define introductory machine learning concepts. If you haven’t read my first article on Data Pre-Processing, make sure to check that out.

In this article, we talk about linear regression, and how it can be used to make predictions when dealing with labeled data.

Let’s get right into it.

Case Study

The best way to understand the concept of linear regression is through an example. So, while reading the rest of this article, imagine yourself in the following scenario:

You’re a data scientist living in Boston. Your friend comes up to you and tells you he’s having trouble selling his house, which he has listed on the market for $500 000 USD. You ask him for the details of the house and notice from his description that the evaluation he’s placed on the property is way too high. As such, you wish to help your friend by writing an algorithm that will look at the current housing market, and predict how much he can sell his house for.

A Few Definitions

Before describing linear regression, it’s important that we understand a few basic concepts:

  1. Dependent vs Independent Variables: A variable that changes depending on the change in some other variable, is called the dependent variable. In our housing scenario, an example of a dependent variable would be the house’s price, which changes depending on the size of the house, the number of bedrooms, etc. Size, number of bedrooms, and number of bathrooms are examples of independent variables. The behavior of these variables doesn’t depend on any other variable.
  2. Supervised vs Unsupervised Learning Algorithms: A supervised learning algorithm is an algorithm that learns from pre-existing, labeled, data, in order to understand its behavior and make predictions about future data. Regression algorithms are supervised learning algorithms, as we’ll see. On the other hand, an unsupervised learning algorithm is one that analyzes an unlabeled dataset, and itself learns and separates the different groups of data into sets of data points that have attributes in common. If you wish to learn more about the differences, here’s a great article on supervised vs unsupervised learning.
  3. Regression vs Classification: Regression is a statistical method used to predict the value of a dependent variable using one or more independent variables. In regression, we aim to find a line that best describes our data, in an attempt to later predict new, incoming, data. Predicting the price of a house is an example of a regression problem. In classification, the goal is to classify data into a discrete number of groups, based on their attributes. For example, we can classify a person with COVID-19 symptoms as a carrier or non-carrier. Here, we classify patients into two different groups.

Linear Regression

Linear regression predicts the value of a continuous dependent variable. The key term here is continuous. In linear regression, we aren’t looking to classify our data into a discrete number of different groups. Instead, our prediction can theoretically take any real number. We use this variant of regression when the relationship between our dependent and independent variable is linear.

Intuition

We’ve all seen the equation of a line, right?

This equation gives us the value of y, for its respective value of x, by defining a line with mas its slope and bas its y-intercept, i.e. where the line intersects with the y-axis. Depending on the values of mand b, our line will change.

For purposes that will become more clear later on, we can express the equation of a line in the following way:

Equation 1: Univariate Linear Regression

Where Theta_1and Theta_0correspond to mand brespectively, and hcorresponds to y. We can also change the way in which we interpret xand h. Instead of thinking of xas an arbitrary real number, think of it as a descriptive feature impacting h. For example, if his the price, then xcould be the number of rooms in the house. The figures below show the effects of changing Theta_0 and Theta_1

Figure 1: m=0.8 and b=2
Figure 2: m=-0.8 and b=2
Figure 3: m=2 and b=0
Figure 4: m=1 and b=2

Which line best describes the behavior of our points? Visually, we can see that it’s either the one drawn in figure 1 or the one drawn in figure 4. With this, we can start making predictions using our graph. For example, figure 4 tells us that a house with eight rooms will cost about $100 000 USD.

But our models don’t understand visuals. So how do we compare one line to the other? How do we know which line equation best describes our data points? For this, we use the mean squared error equation:

Equation 2: Mean Squared Errors

This equation is pretty simple. Using figure 4 as an example, MSE will calculate the mean distance between every red point and the blue line. The larger this mean, the worse our line is at describing the data points. Pretty intuitive right? If our points are far off the line, then it doesn’t properly describe the behavior of our data. Since we’re dealing with linear regression, we can replace Y^_iwith the equation of our line:

Equation 3: Mean Squared Errors for Linear Regression

The idea behind linear regression is no more complicated than determining the values of Theta_0 and Theta_1which will give us a line that best fits our training data i.e. one with the smallest MSE.

However, is a house’s price determined solely by the number of rooms it has? Of course not. We need to look at its size, its location, etc. To account for this, we can generalize equation 1:

Equation 4: Multivariate Linear Regression

Where {x_1, x_2,...,x_j}correspond to different features of the house, and are the inputs received by the equation hin order to come up with a prediction. For simplicity purposes, we’re mostly going to be working with the univariate formula shown in equation 1, but all the concepts can be extended to the multivariate scenario shown in equation 4.

So, how do we decide on what {Theta_1, Theta_2,...,Theta_j} best describes our data? Enter, Gradient Descent.

Gradient Descent

In the case of linear regression, gradient descent is used to find Thetas that will minimize the MSE. Note that this is not its only use case. We can use this algorithm in many other optimization and machine learning problems, but we won’t get into that in this article.

A lot (and I mean a lot) can be said about the gradient descent algorithm. In this article, we’ll touch on the points that are most important to linear regression. If you wish to get a more detailed understanding, have a look at Gradient Descent Algorithm and Its Variants.

Intuition

We want to find the line that best fits the following points:

Figure 5: Boston Housing Prices vs Number of Rooms

To make our graphs simpler to understand, let’s assume that Theta_0 = 0. That leaves us with only Theta_1 to find. The following graph compares the MSE we get depending on an arbitrary choice of Theta_1:

Figure 6: MSE vs Different Choices of Theta_1

Easy to find the best Theta_1 with this graph, isn’t it? We just need to look at the lowest point to realize that our MSE is at its minimum when Theta_1~=1.1. Once again, we need a way to mathematically calculate that value. And that’s what gradient descent does. Here’s the basic idea:

Let’s break this down.

The most important part is the partial derivative. Why are we partially deriving MSE? Think about it this way: Your partial derivative represents the slope at a certain point. This slope can be positive, negative, or zero. When it’s positive, we’re decreasing the value of Theta_j. When it’s negative, Theta_j increases. When it’s zero, it means we’ve reached a minimum and nothing happens to Theta_j. Consider the point (0.8, 4) in figure 6. The slope at that point is negative. If we run one iteration of the gradient descent algorithm, the value Theta_1 will increase and get closer and closer to the minimum.

Alpha is called the learning rate and it represents how large of a step we take towards the minimum. This value can’t be too small, or else your algorithm will run very slowly, but it can’t be too large, otherwise, your algorithm will never terminate.

Implementation

Let’s see how we can use the scikit learn Linear Regression and built-in Boston Housing Dataset classes to find a best-fitting line.

First, we import the libraries we need:

# Scikit learn's built-in Boston Housing dataset
from sklearn.datasets import load_boston
# Library for scikit-learn compatible arrays and matrices
import numpy as np
# Library for plotting nice graphs
import matplotlib.pyplot as plt

Then separate our features (number of rooms) from our target variable (price):

dataset = load_boston() # Loads sklearn's Boston datasetX = dataset.data[:100,5] # Set x_1 as the number of roomsy = dataset.target[:100] # Set h as the house's price

And split them so that 20% of our data is used for testing and the rest for training our model:

# Split data into 20% testing and 80% training
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0)

Our data is now ready for use in our linear regression model. The first thing we need to do is train it using our training set. Remember that linear regression is a supervised learning algorithm, meaning it learns from previous data to predict the value of new, incoming, data:

# Train the model with our training set using linear regreesionfrom sklearn.linear_model import LinearRegressionregressor = LinearRegression()# Run Gradient Descent to get the values of Theta_1 and Theta_0
regressor.fit(X_train.reshape(-1,1),y_train)

We can see the values of Theta_1 and Theta_0 obtained:

print(regressor.coef_) # Theta_1
print(regressor.intercept_) # Theta_0
>> [9.79185794]
>> -38.79201598312946

Let’s see how well our model will perform against our training set by drawing its graph. Please note that this dataset looks at the median price:

plt.scatter(X_test,y_test,color='red')plt.plot(X_test, regressor.predict(X_test.reshape(-1,1)), color='blue')plt.title('Boston Housing Price vs Number of Rooms')plt.xlabel('Number of Rooms')plt.ylabel('Median Price (1000s)')plt.show()
Figure 7: Linear Regression Predictions On Test Data

And finally, we can start making predictions using this line:

regressor.predict([[8]])
>>> array([39.54284755])

So, a house with 8 rooms will have a median price of around $40 000 USD.

Conclusion

In this article, we went through the theory behind linear regression, and how the gradient descent algorithm is used to find the parameters that give us the best fitting line to our data points. We also looked at how we can use Scikit Learn’s Linear Regression class to easily use this model on a dataset of our choice.

Although useful, linear regression is only the beginning of what one can do with regression. Here are some things for you to think about:

  • In all the examples we saw in this article, the correlation between our dependent and independent variables was mostly linear. What happens when this isn’t the case? Can we still use linear regression?
  • We set Theta_0 to zero, to simplify our graphs. How does gradient descent extend to more than one feature (more than one Theta)?
  • Is it a surprise that the graph in figure 6 had a quadratic shape?
  • I suggest you trace out at least one iteration of gradient descent on a paper for a dataset of your choice. This exercise will give you a better understanding of how it works to find the minimum values for Theta .

References

  1. Investopedia Regression Definition
  2. The Coding Train: Linear Regression with Gradient Descent — Intelligence and Learning
  3. Machine Learning Fundamentals (1): Cost functions and gradient descent
  4. Andrew Ng’s Machine Learning Coursera Course
  5. Supervised vs. Unsupervised Learning
  6. Introduction to Probability, Statistics, and Random Processes: Mean Squared Error

--

--

Ali H Khanafer
Geek Culture

Machine Learning Developer @ Kinaxis | I write about theoretical and practical computer science 🤖⚙️