Linear regression is one of the most popular techniques for modelling a linear relationship between a dependent and one or more independent variables. Moreover, it is the origin of many machine learning algorithms. In “An introduction to Statistical Learning,” the authors claim that “the importance of having a good understanding of linear regression before studying more complex learning methods cannot be overstated.”
Simple linear regression is pretty straightforward. We assume a linear relationship between the quantitative response Y and the predictor variable X. There are two coefficients in this model: the intercept and the slope. The intercept is the value of your prediction when the predictor X is zero. The slope is the marginal effect of increasing X by one unit. Truth be told, if you’re interested in all the mathematical details of linear regression (which I strongly recommend learning about), get an econometrics book. In this tutorial, I will briefly explain doing linear regression with Scikit-Learn, a popular machine learning package which is available in Python.
import pandas as pd import numpy as np from sklearn.linear_model import LinearRegression from sklearn.model_selection import train_test_split from sklearn.metrics import mean_squared_error, r2_score import matplotlib.pyplot as plt
You can download the famous mpg dataset from the UCI Machine Learning Repository, or just google “mpg.csv.” Using pandas, you can quickly read in the CSV into a DataFrame.
df = pd.read_csv('mpg.csv')
Next up, we will clean the dataset and remove the missing values. Using pandas, we replace question marks with NaNs and remove these rows.
df = df.replace('?', np.nan) df = df.dropna()
In the following part, for educational purposes, we’ll drop some columns that I don’t think we need in our regression model. These columns are the model name, the geographical origin and the year that the model was built. Finally, the ‘mpg’ column is dropped in the X variable and set as the target in the Y variable.
df = df.drop(['name','origin','model_year'], axis=1) X = df.drop('mpg', axis=1) y = df[['mpg']]
Finally, we’ll split the dataset into a train set and a test set. Scikit-learn has a very straightforward train_test_split function for that.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
Finally, we can start building the regression model. First, let’s try a model with only one variable. We want to predict the mileage per gallon by looking at the horsepower of a car.
reg = LinearRegression() reg.fit(X_train[['horsepower']], y_train)
Now that we have our model, we can check how well it performs. In the first instance, we run the model on our test set. Some good evaluation metrics for linear regression are mean squared error and the R² score.
y_predicted = reg.predict(X_test[['horsepower']]) print("Mean squared error: %.2f" % mean_squared_error(y_test, y_predicted)) print('R²: %.2f' % r2_score(y_test, y_predicted))
We get a model with a mean squared error of 28.66 and an R² of 0.59. That’s … okay. But we can do better, right? Of course we can! Let’s add more variables to the model. What about using weight and cylinders? That makes sense, right?
reg = LinearRegression() reg.fit(X_train[['horsepower','weight','cylinders']], y_train) y_predicted = reg.predict(X_test[['horsepower','weight','cylinders']]) print("Mean squared error: %.2f" % mean_squared_error(y_test, y_predicted)) print('R²: %.2f' % r2_score(y_test, y_predicted))
By using a model with three variables instead of one, we get to a model with a mean squared error of 19.12 and an R² score of 0.72. That’s definitely a nice improvement!
By plotting the true Y values of our test set and the Y values that our model predicts, we can inspect its performance visually.
fig, ax = plt.subplots() ax.scatter(y_test, y_predicted) ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=4) ax.set_xlabel('measured') ax.set_ylabel('predicted') plt.show()
A visual inspection of the performance usually reveals some interesting findings. In this case, we can see that there seems to be some non-linearity in our data. Although this is out of the scope of this article, a proper next step would be to transform the data — e.g. take the logarithm of the mpg values.
Keep in mind that linear regression is, computationally, a very efficient algorithm, which makes it ideal for some quick insights into the relationships in your data. However, in terms of accuracy, linear regression will rarely be your final choice.