Solution Review: Retraining a Machine Learning Model
Explore the process of retraining and saving machine learning models using ML.NET. Understand how to load data, train, and save models, and work with ONNX and TensorFlow formats to deploy models effectively.
We'll cover the following...
We'll cover the following...
The complete solution is available in the following playground:
using Microsoft.ML;
using Microsoft.ML.Data;
using System;
using System.Linq;
using System.IO;
using System.Collections.Generic;
namespace ModelSavingExample.ConsoleApp
{
public partial class ModelSavingExample
{
/// <summary>
/// model input class for ModelSavingExample.
/// </summary>
#region model input class
public class ModelInput
{
[LoadColumn(0)]
[ColumnName(@"col0")]
public string Col0 { get; set; }
[LoadColumn(1)]
[ColumnName(@"col1")]
public float Col1 { get; set; }
}
#endregion
/// <summary>
/// model output class for ModelSavingExample.
/// </summary>
#region model output class
public class ModelOutput
{
[ColumnName(@"col0")]
public float[] Col0 { get; set; }
[ColumnName(@"col1")]
public uint Col1 { get; set; }
[ColumnName(@"Features")]
public float[] Features { get; set; }
[ColumnName(@"PredictedLabel")]
public float PredictedLabel { get; set; }
[ColumnName(@"Score")]
public float[] Score { get; set; }
}
#endregion
private static string MLNetModelPath = Path.GetFullPath("ModelSavingExample.mlnet");
public static readonly Lazy<PredictionEngine<ModelInput, ModelOutput>> PredictEngine = new Lazy<PredictionEngine<ModelInput, ModelOutput>>(() => CreatePredictEngine(), true);
private static PredictionEngine<ModelInput, ModelOutput> CreatePredictEngine()
{
var mlContext = new MLContext();
ITransformer mlModel = mlContext.Model.Load(MLNetModelPath, out var _);
return mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);
}
public static IOrderedEnumerable<KeyValuePair<string, float>> PredictAllLabels(ModelInput input)
{
var predEngine = PredictEngine.Value;
var result = predEngine.Predict(input);
return GetSortedScoresWithLabels(result);
}
public static IOrderedEnumerable<KeyValuePair<string, float>> GetSortedScoresWithLabels(ModelOutput result)
{
var unlabeledScores = result.Score;
var labelNames = GetLabels(result);
Dictionary<string, float> labledScores = new Dictionary<string, float>();
for (int i = 0; i < labelNames.Count(); i++)
{
// Map the names to the predicted result score array
var labelName = labelNames.ElementAt(i);
labledScores.Add(labelName.ToString(), unlabeledScores[i]);
}
return labledScores.OrderByDescending(c => c.Value);
}
private static IEnumerable<string> GetLabels(ModelOutput result)
{
var schema = PredictEngine.Value.OutputSchema;
var labelColumn = schema.GetColumnOrNull("col1");
if (labelColumn == null)
{
throw new Exception("col1 column not found. Make sure the name searched for matches the name in the schema.");
}
// Key values contains an ordered array of the possible labels. This allows us to map the results to the correct label value.
var keyNames = new VBuffer<float>();
labelColumn.Value.GetKeyValues(ref keyNames);
return keyNames.DenseValues().Select(x => x.ToString());
}
public static ModelOutput Predict(ModelInput input)
{
var predEngine = PredictEngine.Value;
return predEngine.Predict(input);
}
}
}
Playground showing the complete solution with the code to train the model
Solving the challenge
First, we insert some code into the Train() method inside the ...