基于 ONNX 在 ML.NET 中使用 Pytorch 训练的垃圾分类模型

ML.NET 在经典机器学习范畴内,对分类、回归、异常检测等问题开发模型已经有非常棒的表现了,我之前的文章都有过介绍。当然我们希望在更高层次的领域加以使用,例如计算机视觉、自然语言处理和信号处理等等领域。

图像识别是计算机视觉的一类分支,AI研发者们较为熟悉的是使用TensorFlow、Pytorch、Keras、MXNET等框架来训练深度神经网络模型,其中会涉及到CNN(卷积神经网络)、DNN(深度神经网络)的相关算法。

ML.NET 在较早期的版本是无法支持这类研究的,可喜的是最新的版本不但能很好地集成 TensorFlow 的模型做迁移学习,还可以直接导入 DNN 常见预编译模型:AlexNet、ResNet18、ResNet50、ResNet101 实现对图像的分类、识别等。

我特别想推荐的是,ML.NET 最新版本对 ONNX 的支持也是非常强劲,通过 ONNX 可以把众多其他优秀深度学习框架的模型引入到 .NET Core 运行时中,极大地扩充了 .NET 应用在智能认知服务的丰富程度。在 Microsoft Docs 中已经提供了一个基于 ONNX 使用 Tiny YOLOv2 做对象检测的例子。为了展现 ML.NET 在其他框架上的通用性,本文将介绍使用 Pytorch 训练的垃圾分类的模型,基于 ONNX 导入到 ML.NET 中完成预测。

在2019年9月华为云举办了一次人工智能大赛·垃圾分类挑战杯,首次将AI与环保主题结合,展现人工智能技术在生活中的运用。有幸我看到了本次大赛亚军方案的分享,并且在 github 上找到了开源代码,按照 README 说明,我用 Pytorch 训练出了一个模型,并保存为garbage.pt 文件。

生成 ONNX 模型

首先,我使用以下 Pytorch 代码来生成一个garbage.pt 对应的文件,命名为 garbage.onnx

torch_model = torch.load("garbage.pt") # pytorch模型加载
    batch_size = 1  #批处理大小
    input_shape = (3,224,224)   #输入数据

    # # set the model to inference mode
    torch_model.eval()

    x = torch.randn(batch_size, *input_shape, device='cuda')        # 生成张量
    export_onnx_file = "garbage.onnx"                    # 目的ONNX文件名
 
    
    torch.onnx.export(torch_model.module,
                        x,
                        export_onnx_file,
                        input_names=["input"],        # 输入名
                        output_names=["output"]    # 输出名

准备 ML.NET 项目

创建一个 .NET Core 控制台应用,按如下结构创建好合适的目录。assets 目录下的 images 子目录将放置待预测的图片,而 Model 子目录放入前一个步骤生成的 garbage.onnx 模型文件。

ImageNetData 和 ImageNetPrediction 类定义了输入和输出的数据结构。

using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;

namespace GarbageDemo.DataStructures
{
    public class ImageNetData
    {
        [LoadColumn(0)]
        public string ImagePath;

        [LoadColumn(1)]
        public string Label;

        public static IEnumerable<ImageNetData> ReadFromFile(string imageFolder)
        {
            return Directory
               .GetFiles(imageFolder)
               .Where(filePath => Path.GetExtension(filePath) == ".jpg")
               .Select(filePath => new ImageNetData { ImagePath = filePath, Label = Path.GetFileName(filePath) });

        }
    }

    public class ImageNetPrediction : ImageNetData
    {
        public float[] Score;

        public string PredictedLabelValue;
    }
}

 OnnxModelScorer 类定义了 ONNX 模型加载、打分预测的过程。注意 ImageNetModelSettings 的属性和第一步中指定的输入输出字段名要一致。

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms.Onnx;
using Microsoft.ML.Transforms.Image;
using GarbageDemo.DataStructures;

namespace GarbageDemo
{
    class OnnxModelScorer
    {
        private readonly string imagesFolder;
        private readonly string modelLocation;
        private readonly MLContext mlContext;


        public OnnxModelScorer(string imagesFolder, string modelLocation, MLContext mlContext)
        {
            this.imagesFolder = imagesFolder;
            this.modelLocation = modelLocation;
            this.mlContext = mlContext;
        }

        public struct ImageNetSettings
        {
            public const int imageHeight = 224;
            public const int imageWidth = 224;    
            public const float Mean = 127;
            public const float Scale = 1;
            public const bool ChannelsLast = false;
        } 
        
        public struct ImageNetModelSettings
        {
            // input tensor name
            public const string ModelInput = "input";

            // output tensor name
            public const string ModelOutput = "output";
        }

        private ITransformer LoadModel(string modelLocation)
        {
            Console.WriteLine("Read model");
            Console.WriteLine($"Model location: {modelLocation}");
            Console.WriteLine($"Default parameters: image size=({ImageNetSettings.imageWidth},{ImageNetSettings.imageHeight})");

            // Create IDataView from empty list to obtain input data schema
            var data = mlContext.Data.LoadFromEnumerable(new List<ImageNetData>());

            // Define scoring pipeline
            var pipeline = mlContext.Transforms.LoadImages(outputColumnName: ImageNetModelSettings.ModelInput, imageFolder: "", inputColumnName: nameof(ImageNetData.ImagePath))                           
                            .Append(mlContext.Transforms.ResizeImages(outputColumnName: ImageNetModelSettings.ModelInput, 
                                                                        imageWidth: ImageNetSettings.imageWidth, 
                                                                        imageHeight: ImageNetSettings.imageHeight, 
                                                                        inputColumnName: ImageNetModelSettings.ModelInput,
                                                                        resizing: ImageResizingEstimator.ResizingKind.IsoCrop,
                                                                        cropAnchor: ImageResizingEstimator.Anchor.Center
                                                                        ))
                            .Append(mlContext.Transforms.ExtractPixels(outputColumnName: ImageNetModelSettings.ModelInput, interleavePixelColors: ImageNetSettings.ChannelsLast))
                            .Append(mlContext.Transforms.NormalizeGlobalContrast(outputColumnName: ImageNetModelSettings.ModelInput, 
                                                                                 inputColumnName: ImageNetModelSettings.ModelInput, 
                                                                                 ensureZeroMean : true, 
                                                                                 ensureUnitStandardDeviation: true, 
                                                                                 scale: ImageNetSettings.Scale))
                            .Append(mlContext.Transforms.ApplyOnnxModel(modelFile: modelLocation, outputColumnNames: new[] { ImageNetModelSettings.ModelOutput }, inputColumnNames: new[] { ImageNetModelSettings.ModelInput }));

            // Fit scoring pipeline
            var model = pipeline.Fit(data);

            return model;
        }

        private IEnumerable<float[]> PredictDataUsingModel(IDataView testData, ITransformer model)
        {
            Console.WriteLine($"Images location: {imagesFolder}");
            Console.WriteLine("");
            Console.WriteLine("=====Identify the objects in the images=====");
            Console.WriteLine("");

            IDataView scoredData = model.Transform(testData);

            IEnumerable<float[]> probabilities = scoredData.GetColumn<float[]>(ImageNetModelSettings.ModelOutput);

            return probabilities;
        }

        public IEnumerable<float[]> Score(IDataView data)
        {
            var model = LoadModel(modelLocation);

            return PredictDataUsingModel(data, model);
        }
    }
}

Program 类中定义了调用过程,完成预测结果呈现。

using GarbageDemo.DataStructures;
using Microsoft.ML;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;

namespace GarbageDemo
{
    class Program
    {
        static void Main(string[] args)
        {
            var assetsRelativePath = @"../../../assets";
            string assetsPath = GetAbsolutePath(assetsRelativePath);
            var modelFilePath = Path.Combine(assetsPath, "Model", "garbage.onnx");
            var imagesFolder = Path.Combine(assetsPath, "images");// Initialize MLContext
            MLContext mlContext = new MLContext();

            try
            {
                // Load Data
                IEnumerable<ImageNetData> images = ImageNetData.ReadFromFile(imagesFolder);
                IDataView imageDataView = mlContext.Data.LoadFromEnumerable(images);

                // Create instance of model scorer
                var modelScorer = new OnnxModelScorer(imagesFolder, modelFilePath, mlContext);

                // Use model to score data
                IEnumerable<float[]> probabilities = modelScorer.Score(imageDataView);

                int index = 0;
                foreach (var probable in probabilities)
                {
                    var scores = Softmax(probable);

                    var (topResultIndex, topResultScore) = scores.Select((predictedClass, index) => (Index: index, Value: predictedClass))
                        .OrderByDescending(result => result.Value)
                        .First();
                    Console.WriteLine("图片:{3} \r\n 分类{2} {0}:{1}", labels[topResultIndex], topResultScore, topResultIndex, images.ElementAt(index).ImagePath);
                    Console.WriteLine("=============================");
                    index++;
                }

            }
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
            }

            Console.WriteLine("========= End of Process..Hit any Key ========");
            Console.ReadLine();
        }

        public static string GetAbsolutePath(string relativePath)
        {
            FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
            string assemblyFolderPath = _dataRoot.Directory.FullName;

            string fullPath = Path.Combine(assemblyFolderPath, relativePath);

            return fullPath;
        }

        private static float[] Softmax(float[] values)
        {
            var maxVal = values.Max();
            var exp = values.Select(v => Math.Exp(v - maxVal));
            var sumExp = exp.Sum();

            return exp.Select(v => (float)(v / sumExp)).ToArray();
        }

        private static string[] labels = new string[]
        {
            "其他垃圾/一次性快餐盒",
            "其他垃圾/污损塑料",
            "其他垃圾/烟蒂",
            "其他垃圾/牙签",
            "其他垃圾/破碎花盆及碟碗",
            "其他垃圾/竹筷",
            "厨余垃圾/剩饭剩菜",
            "厨余垃圾/大骨头",
            "厨余垃圾/水果果皮",
            "厨余垃圾/水果果肉",
            "厨余垃圾/茶叶渣",
            "厨余垃圾/菜叶菜根",
            "厨余垃圾/蛋壳",
            "厨余垃圾/鱼骨",
            "可回收物/充电宝",
            "可回收物/包",
            "可回收物/化妆品瓶",
            "可回收物/塑料玩具",
            "可回收物/塑料碗盆",
            "可回收物/塑料衣架",
            "可回收物/快递纸袋",
            "可回收物/插头电线",
            "可回收物/旧衣服",
            "可回收物/易拉罐",
            "可回收物/枕头",
            "可回收物/毛绒玩具",
            "可回收物/洗发水瓶",
            "可回收物/玻璃杯",
            "可回收物/皮鞋",
            "可回收物/砧板",
            "可回收物/纸板箱",
            "可回收物/调料瓶",
            "可回收物/酒瓶",
            "可回收物/金属食品罐",
            "可回收物/锅",
            "可回收物/食用油桶",
            "可回收物/饮料瓶",
            "有害垃圾/干电池",
            "有害垃圾/软膏",
            "有害垃圾/过期药物",
            "可回收物/毛巾",
            "可回收物/饮料盒",
            "可回收物/纸袋"
        };

选择一张图片放到 images 目录中,运行结果如下:

有 0.88 的得分说明照片中的物品属于污损塑料,让我们看一下图片真相。

果然是相当准确 ,并且把周边的附属物都过滤掉了。

对于 ML.NET 训练深度神经网络模型支持更复杂的场景是不是更有信心了!

posted on 2020-06-22 13:53  Bean.Hsiang  阅读(2876)  评论(6编辑  收藏  举报