Supporting a New Model
Learn how to add a new model.
We'll cover the following...
We'll cover the following...
Supporting a new model
Let’s see what the code for the AutoMPG regressor module looks like.
from typing import TYPE_CHECKING, Dict, Tuplefrom sklearn.linear_model import LinearRegressionfrom sklearn.metrics import mean_squared_error, r2_scorefrom joblib import load, dumpif TYPE_CHECKING:import loggingimport pandas as pdfrom omegaconf import DictConfigfrom ml_pipeline.mixins.reporting_mixin import ReportingMixinfrom ml_pipeline.mixins.training_mixin import TrainingMixinfrom ml_pipeline.model import Modelclass AutoMPGRegressor(TrainingMixin, Model, ReportingMixin):def __init__(self,model_params: "DictConfig",training_params: "DictConfig",artifact_dir: str,logger: "logging.Logger" = None,) -> None:self.model = LinearRegression(**model_params)self.training_params = training_paramsself.artifact_dir = artifact_dirself.logger = loggerdef load(self, model_path: str) -> None:self.model = load(model_path)def _encode_train_data(self, X: "pd.DataFrame" = None, y: "pd.Series" = None) -> Tuple["pd.DataFrame", "pd.Series"]:# in this example, we don't do any encodingreturn X, ydef _encode_test_data(self, X: "pd.DataFrame" = None, y: "pd.Series" = None) -> Tuple["pd.DataFrame", "pd.Series"]:# in this example, we don't do any encodingreturn X, ydef _compute_metrics(self, y_true: "pd.Series", y_pred: "pd.Series") -> Dict:self.metrics = {}self.metrics["mean_squared_error"] = mean_squared_error(y_true, y_pred)self.metrics["r2_score"] = r2_score(y_true, y_pred)def create_report(self) -> None:self.save_metrics()def save(self) -> None:filename = f"{self.artifact_dir}/model.joblib"dump(self.model, filename)self.logger.debug(f"Saved {filename}.")def predict(self, X: "pd.DataFrame") -> int:return self.model.predict(X)
If we compare this to the iris classifier module, we see similarities, mainly because, like IrisClassifier, this class derives from TrainingMixin, Model, and ReportingMixin. We also see several differences:
- We see that this model uses the - LinearRegressionmodule from- scikit-learninstead of the- LogisticRegressionmodule we used for iris classification in line 26.
- The iris classification is a regression problem, so it required related metrics, such as accuracy and confusion matrix. However, this project computes regression-related metrics, such as mean squared error and the coefficient of determination. ...