Introduction to Linear Regression
Learn how to apply linear regression in marketing analytic use cases.
Linear regression
Linear regression is an approach to modeling a linear relationship between a dependent variable (Y
) and one or more independent variables (X
). This approach tries to draw a line through the observations in a way that minimizes the vertical distance between the observations to the regression line.
The equation above is a simple linear regression for a single independent variable, where Y
is the predicted value and X
is the independent variable. The parameter a
is the intercept where the regression line crosses the y axis (the value of Y
when X = 0
), and parameter b
is the coefficient of X
.
Let’s imagine a hypothetical scenario where an employee’s yearly salary
(independent variable) increases linearly with their number of years of experience
(dependent variable). The following graph shows this relationship in a one-dimensional linear regression model. By following the regression line, we can easily predict the employee’s salary if we know the number of years they have worked in the industry.
Remember that a linear regression model can only be employed to predict a continuous variable. There are still a ton of use cases where linear regression models suit very well. For example, we could use it to predict future customer spending, identify effective marketing channels, predict sales at a store, etc.
Machine learning steps
The key steps involved in building machine-learning models are:
Data collection
Data preparation
Exploratory data analysis (EDA)
Feature engineering
Training a model
Evaluating the model
Model deployment
In the following section, we'll build a linear regression model to predict sales numbers based on advertising spending.
Predicting sales from advertising spending
Let’s say we are a fashion retailer and have different marketing channels (TV, radio, and newspaper) to spread the word about our business name and what we offer. In a broad sense, we are aware of marketing effects on sales growth, but it is not clear which channel performs better than the others. As our first step, we’d like to know how TV advertising affects the sales figures and if they have a linear relationship.
Data exploration
Let’s start by loading the historical advertising data. The dataset includes weekly sales and spending on each marketing channel. Let’s look at the code in action.
import pandas as pdimport numpy as npimport matplotlib.pyplot as pltimport seaborn as snsdf_advertising = pd.read_csv('advertising.csv', index_col=0)print('Dataset')print(df_advertising.head())print('\nDetail about numerical columns')print(df_advertising.describe())print('\nCount of empty records in the dataset')print(df_advertising.isna().sum())
Once a dataset is imported, we would highly suggest taking a look at the dataset from different perspectives.
Explanation
Line 7 loads the dataset using the Pandas
read_csv()
function.Line 12 summarizes the numerical columns using the Pandas
describe()
method.Line 15 inspects the empty values in each column using the
isna().sum()
method.
Before we build the model, let’s take a look at the relationship between TV ads and sales through a scatterplot.
plt.xlabel('TV ads spending')plt.ylabel('Sales')sns.scatterplot(data=df_advertising, x='TV', y='sales')plt.savefig('output/output.png', dpi=800)
It looks like the sales numbers increase as the TV ad spending increases.
Training the model
Now, let’s model this relationship with linear regression.
from sklearn.linear_model import LinearRegressionlr = LinearRegression()lr.fit(df_advertising[['TV']], df_advertising['sales'])
Explanation
In line 3, we build a linear regression object, then we train the model.
In line 4, we use the
fit()
method of the model.
The fit()
method of a linear regression model requires two parameters:
The training data
X
The target label
y
In our case, we use the TV advertisement spending as training data and the sales numbers as the output label.
Model evaluation
After fitting our model, it’s ready to make predictions. To test the model performance, we can use the predict()
method that the scikit-learn library offers.
predictions = lr.predict(df_advertising[['TV']])plt.xlabel('TV ads spending')plt.ylabel('Sales')sns.scatterplot(data=df_advertising, x='TV', y='sales')plt.plot(df_advertising['TV'], predictions, color='r')plt.savefig('output/output.png', dpi=800)
Explanation
In line 1, the code estimates the
sales
number againstTV ads spending
.In lines 3–7, we plot the sales prediction along with the previous scatterplot to understand the model performance.
Play around with the entire code in the interactive Jupyter Notebook.