自用

clc;
clear;
close all;
warning off;
addpath(genpath(pwd));
rng('default');
 
% 数据集加载
Dataset = imageDatastore('Photograph', 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
%数据集划分
[Training_Dataset, Validation_Dataset, Testing_Dataset] = splitEachLabel(Dataset, 0.8, 0.1, 0.1);
 
% 加载预训练的 GoogLeNet
load googlenet.mat
analyzeNetwork(net);
 
% 调整输入层大小
Input_Layer_Size = net.Layers(1).InputSize(1:2);
 
% 调整数据集大小
Resized_Training_Dataset = augmentedImageDatastore(Input_Layer_Size, Training_Dataset);
Resized_Validation_Dataset = augmentedImageDatastore(Input_Layer_Size, Validation_Dataset);
Resized_Testing_Dataset = augmentedImageDatastore(Input_Layer_Size, Testing_Dataset);
 
% 修改网络架构
Feature_Learner = net.Layers(142).Name;
Output_Classifier = net.Layers(144).Name;
 
Number_of_Classes = numel(categories(Training_Dataset.Labels));
 
New_Feature_Learner = fullyConnectedLayer(Number_of_Classes, ...
'Name', 'Coal Feature Learner', ...
'WeightLearnRateFactor', 10, ...
'BiasLearnRateFactor', 10);
 
New_Classifier_Layer = classificationLayer('Name', 'Coal Classifier');
 
Network_Architecture = layerGraph(net);
 
New_Network = replaceLayer(Network_Architecture, Feature_Learner, New_Feature_Learner);
New_Network = replaceLayer(New_Network, Output_Classifier, New_Classifier_Layer);
 
%深度学习网络分析器界面
analyzeNetwork(New_Network);
 
% 训练深度学习网络
maxEpochs = 1;
Minibatch_Size = 8;
Validation_Frequency = floor(numel(Resized_Training_Dataset.Files) / Minibatch_Size);
%启用训练进程图
Training_Options = trainingOptions('sgdm', ...
'MiniBatchSize', Minibatch_Size, ...
'MaxEpochs', maxEpochs, ...
'InitialLearnRate', 1e-3, ...
'Shuffle', 'every-epoch', ...
'ValidationData', Resized_Validation_Dataset, ...
'ValidationFrequency', Validation_Frequency, ...
'Verbose', false, ...
'Plots', 'training-progress');
 
net = trainNetwork(Resized_Training_Dataset, New_Network, Training_Options);
 
% 保存训练好的深度学习网络
save gnet.mat net;
 
% 提取特征
featureLayer = 'pool5-7x7_s1'; % 选择一个合适的特征提取层
features = activations(net, Resized_Training_Dataset, featureLayer, 'MiniBatchSize', 16);
 
% 初始化标签数组
labels = Training_Dataset.Labels; % 从原始数据集中获取标签
 
% 将特征和标签转换为适合 SVM 的格式
% 压缩特征维度
features = squeeze(mean(features, [2, 3]));
% 将标签转换为 categorical 类型
Y = categorical(labels);
X = features; % 特征
 
% 训练 SVM 分类器
SVMModel = fitcecoc(X, Y, 'Learners', 'svm', 'Coding', 'onevsall');
 
% 保存 SVM 分类器
save svmModel.mat SVMModel;
disp('SVM 分类器训练完成并已保存到 svmModel.mat 文件中。');
------------------------------------------------------------------------------------------
clc; clear; close all; warning off;
addpath(genpath(pwd));
rng('default');
 
% 加载模型
load gnet.mat; % 确保网络输入层为 224x224
load svmModel.mat;
 
% 数据集加载与预处理
Dataset = imageDatastore('Photograph', 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[Training_Dataset, Validation_Dataset, Testing_Dataset] = splitEachLabel(Dataset, 0.8, 0.1, 0.1);
 
% ------------------ 关键修正:添加图像尺寸调整 ------------------
inputSize = net.Layers(1).InputSize(1:2); % 获取网络输入尺寸 [224 224]
 
% 创建调整后的验证集和测试集
Resized_Val = augmentedImageDatastore(inputSize, Validation_Dataset, 'ColorPreprocessing', 'gray2rgb');
Resized_Test = augmentedImageDatastore(inputSize, Testing_Dataset, 'ColorPreprocessing', 'gray2rgb');
% -----------------------------------------------------------
 
% 验证集测试
featureLayer = 'pool5-7x7_s1'; % 确保层名正确
 
% 使用调整后的数据集
% 提取特征
features = activations(net, Resized_Val, featureLayer, 'MiniBatchSize', 16, 'OutputAs', 'channels');
features = squeeze(mean(features, [1 2]))'; % 压缩特征维度
disp(size(features)); % 查看特征矩阵的尺寸
labels = Validation_Dataset.Labels;
 
% 如果尺寸不符合要求,进行重新调整
features = reshape(features, [], size(features, 3))'; % 重新调整为适合 SVM 输入的形状
 
% 预测与显示
Predicted_Label = predict(SVMModel, features);
disp(['验证集准确率: ', num2str(mean(Predicted_Label == labels) * 100), '%']);
 
 
showResults(Testing_Dataset, Predicted_Label, '测试集识别结果');
 
% ------------------ 可视化函数 ------------------
function showResults(dataset, labels, titleStr)
index = randperm(numel(dataset.Files), 16);
figure('Name', titleStr, 'NumberTitle','off');
for i = 1:16
subplot(4,4,i);
I = readimage(dataset, index(i));
imshow(I);
title(string(labels(index(i))));
end
end
posted @ 2025-05-25 17:24  SaulGoodman1  阅读(19)  评论(0)    收藏  举报