Solution Review: Performing a Multiclass Classification
Explore how to use ML.NET for multiclass classification by modifying prediction methods to output a single predicted label. Understand the process of refining your model to focus on the most likely category in supervised learning tasks using ML.NET.
We'll cover the following...
We'll cover the following...
Our complete solution is presented in the playground below:
// This file was auto-generated by ML.NET Model Builder.
using Microsoft.ML;
using Microsoft.ML.Data;
using System;
using System.Linq;
using System.IO;
using System.Collections.Generic;
namespace MulticlassModelDemo.ConsoleApp
{
public partial class MulticlassModelDemo
{
/// <summary>
/// model input class for MulticlassModelDemo.
/// </summary>
#region model input class
public class ModelInput
{
[LoadColumn(0)]
[ColumnName(@"Area")]
public string Area { get; set; }
[LoadColumn(1)]
[ColumnName(@"Title")]
public string Title { get; set; }
[LoadColumn(2)]
[ColumnName(@"Description")]
public string Description { get; set; }
}
#endregion
/// <summary>
/// model output class for MulticlassModelDemo.
/// </summary>
#region model output class
public class ModelOutput
{
[ColumnName(@"Area")]
public uint Area { get; set; }
[ColumnName(@"Title")]
public float[] Title { get; set; }
[ColumnName(@"Description")]
public float[] Description { get; set; }
[ColumnName(@"Features")]
public float[] Features { get; set; }
[ColumnName(@"PredictedLabel")]
public string PredictedLabel { get; set; }
[ColumnName(@"Score")]
public float[] Score { get; set; }
}
#endregion
private static string MLNetModelPath = Path.GetFullPath("/models/MulticlassModelDemo.mlnet");
public static readonly Lazy<PredictionEngine<ModelInput, ModelOutput>> PredictEngine = new Lazy<PredictionEngine<ModelInput, ModelOutput>>(() => CreatePredictEngine(), true);
private static PredictionEngine<ModelInput, ModelOutput> CreatePredictEngine()
{
var mlContext = new MLContext();
ITransformer mlModel = mlContext.Model.Load(MLNetModelPath, out var _);
return mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);
}
/// <summary>
/// Use this method to predict scores for all possible labels.
/// </summary>
/// <param name="input">model input.</param>
/// <returns><seealso cref=" ModelOutput"/></returns>
public static IOrderedEnumerable<KeyValuePair<string, float>> PredictAllLabels(ModelInput input)
{
var predEngine = PredictEngine.Value;
var result = predEngine.Predict(input);
return GetSortedScoresWithLabels(result);
}
/// <summary>
/// Map the unlabeled result score array to the predicted label names.
/// </summary>
/// <param name="result">Prediction to get the labeled scores from.</param>
/// <returns>Ordered list of label and score.</returns>
/// <exception cref="Exception"></exception>
public static IOrderedEnumerable<KeyValuePair<string, float>> GetSortedScoresWithLabels(ModelOutput result)
{
var unlabeledScores = result.Score;
var labelNames = GetLabels(result);
Dictionary<string, float> labledScores = new Dictionary<string, float>();
for (int i = 0; i < labelNames.Count(); i++)
{
// Map the names to the predicted result score array
var labelName = labelNames.ElementAt(i);
labledScores.Add(labelName.ToString(), unlabeledScores[i]);
}
return labledScores.OrderByDescending(c => c.Value);
}
/// <summary>
/// Get the ordered label names.
/// </summary>
/// <param name="result">Predicted result to get the labels from.</param>
/// <returns>List of labels.</returns>
/// <exception cref="Exception"></exception>
private static IEnumerable<string> GetLabels(ModelOutput result)
{
var schema = PredictEngine.Value.OutputSchema;
var labelColumn = schema.GetColumnOrNull("Area");
if (labelColumn == null)
{
throw new Exception("Area column not found. Make sure the name searched for matches the name in the schema.");
}
// Key values contains an ordered array of the possible labels. This allows us to map the results to the correct label value.
var keyNames = new VBuffer<ReadOnlyMemory<char>>();
labelColumn.Value.GetKeyValues(ref keyNames);
return keyNames.DenseValues().Select(x => x.ToString());
}
/// <summary>
/// Use this method to predict on <see cref="ModelInput"/>.
/// </summary>
/// <param name="input">model input.</param>
/// <returns><seealso cref=" ModelOutput"/></returns>
public static ModelOutput Predict(ModelInput input)
{
var predEngine = PredictEngine.Value;
return predEngine.Predict(input);
}
}
}
Complete multiclass classification solution
Solving the challenge
Since ...