tf拟合
https://files.cnblogs.com/files/chinasoft/tf.js-demo-v2.rar?t=1656483198
<script src = "tf.min.js"> </script>
<script>
/* 根据身高推测体重 */
//把数据处理成符合模型要求的格式
function getData() {
//学习数据
const heights = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
const weights = [3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23];
//验证数据
const testh = [100, 101, 102, 103, 104, 105, 106];
const testw = [201, 203, 205, 207, 208, 210, 212];
//归一化数据
const inputs = tf.tensor(heights);//.sub(150).div(50);
const labels = tf.tensor(weights);//.sub(40).div(60);
const xs = tf.tensor(testh);//.;//sub(150).div(50);
const ys = tf.tensor(testw);//.sub(40).div(60);
// //绘制图表
// tfvis.render.scatterplot(
// { name: '身高体重' },
// //x轴身高,y轴体重
// { values: heights.map((x, i) => ({ x, y: weights[i] })) },
// //设置x轴范围,设置y轴范围
// { xAxisDomain: [140, 200], yAxisDomain: [40, 110] }
// );
return { inputs, labels, xs, ys };
}
async function run(){
const { inputs, labels, xs, ys } = getData();
//设置连续模型
const model = tf.sequential();
//设置全连接层
model.add(tf.layers.dense({
units: 1,
inputShape: [1]
}));
// model.add(tf.layers.dense({
// units: 1
// }));
//设置损失函数,优化函数学习速率为0.1
model.compile({
loss: tf.losses.meanSquaredError,
optimizer: tf.train.adam(0.1)
});
await model.fit(inputs, labels, {
batchSize: 1,
epochs: 20,
//设置验证集
validationData: [xs, ys],
// callbacks: tfvis.show.fitCallbacks(
// { name: '训练过程' },
// ['loss', 'val_loss', 'acc', 'val_acc'],
// { callbacks: ['onEpochEnd'] }
// )
callbacks:function(){
console.log("1");
}
});
//对身高180的体重进行推测
// let res = model.predict(tf.tensor([180]).sub(150).div(50));
// console.log(res.mul(60).add(40).dataSync()[0]);
let res = model.predict(tf.tensor([180]));
console.log(res.dataSync()[0]);
//保存模型
window.download = async () => {
await model.save('downloads://my-model');
}
}
run();
</script>

浙公网安备 33010602011771号