Introduction to Linear Regression

Linear regression

Linear regression is an approach to modeling a linear relationship between a dependent variable (Y) and one or more independent variables (X). This approach tries to draw a line through the observations in a way that minimizes the vertical distance between the observations to the regression line.

The equation above is a simple linear regression for a single independent variable, where Y is the predicted value and X is the independent variable. The parameter a is the intercept where the regression line crosses the y axis (the value of Y when X = 0), and parameter b is the coefficient of X.

Let’s imagine a hypothetical scenario where an employee’s yearly salary (independent variable) increases linearly with their number of years of experience (dependent variable). The following graph shows this relationship in a one-dimensional linear regression model. By following the regression line, we can easily predict the employee’s salary if we know the number of years they have worked in the industry.

Salary vs. Experience
Salary vs. Experience

Remember that a linear regression model can only be employed to predict a continuous variable. There are still a ton of use cases where linear regression models suit very well. For example, we could use it to predict future customer spending, identify effective marketing channels, predict sales at a store, etc.

Machine learning steps

The key steps involved in building machine-learning models are:

  1. Data collection

  2. Data preparation

  3. Exploratory data analysis (EDA)

  4. Feature engineering

  5. Training a model

  6. Evaluating the model

  7. Model deployment

Machine learning workflow
Machine learning workflow

In the following section, we'll build a linear regression model to predict sales numbers based on advertising spending.

Predicting sales from advertising spending

Let’s say we are a fashion retailer and have different marketing channels (TV, radio, and newspaper) to spread the word about our business name and what we offer. In a broad sense, we are aware of marketing effects on sales growth, but it is not clear which channel performs better than the others. As our first step, we’d like to know how TV advertising affects the sales figures and if they have a linear relationship.

Data exploration

Let’s start by loading the historical advertising data. The dataset includes weekly sales and spending on each marketing channel. Let’s look at the code in action.

Press + to interact
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
df_advertising = pd.read_csv('advertising.csv', index_col=0)
print('\nDetail about numerical columns')
print('\nCount of empty records in the dataset')

Once a dataset is imported, we would highly suggest taking a look at the dataset from different perspectives.


  • Line 7 loads the dataset using the Pandas read_csv() function.

  • Line 12 summarizes the numerical columns using the Pandas describe() method.

  • Line 15 inspects the empty values in each column using the isna().sum() method.

Before we build the model, let’s take a look at the relationship between TV ads and sales through a scatterplot.

Press + to interact
plt.xlabel('TV ads spending')
sns.scatterplot(data=df_advertising, x='TV', y='sales')
plt.savefig('output/output.png', dpi=800)

It looks like the sales numbers increase as the TV ad spending increases.

Training the model

Now, let’s model this relationship with linear regression.

Press + to interact
from sklearn.linear_model import LinearRegression
lr = LinearRegression()[['TV']], df_advertising['sales'])


  • In line 3, we build a linear regression object, then we train the model.

  • In line 4, we use the fit() method of the model.

The fit() method of a linear regression model requires two parameters:

  • The training data X

  • The target label y

In our case, we use the TV advertisement spending as training data and the sales numbers as the output label.

Model evaluation

After fitting our model, it’s ready to make predictions. To test the model performance, we can use the predict() method that the scikit-learn library offers.

Press + to interact
predictions = lr.predict(df_advertising[['TV']])
plt.xlabel('TV ads spending')
sns.scatterplot(data=df_advertising, x='TV', y='sales')
plt.plot(df_advertising['TV'], predictions, color='r')
plt.savefig('output/output.png', dpi=800)


  • In line 1, the code estimates the sales number against TV ads spending.

  • In lines 3–7, we plot the sales prediction along with the previous scatterplot to understand the model performance.

Play around with the entire code in the interactive Jupyter Notebook.

Please login to launch live app!