libtorch入门例程

libtorch C++版可以直接在官网下载。自己学习如果没有合适的显卡可以选择下载CPU版的。下面是官网链接:

下载后就可以把开发包包含到VS的项目中使用。注意libtorch官网提供的Release/Debug的开发包,Debug版的程序用Debug版的库,Release版的程序用Release版的库,不能混用。另外libtorch更新很快是用最新的C++版本写的,需要在编译器设置中设置合适的C++语言版本。比如我的是libtorch1.13.1,它只能在C++14版本下使用,C++11或C++17都不行。在使用时如果编译报错有很多“std不明确的符号”,可用的改正方法是:打开项目属性→属性→C/C++→语言→符合模式→改为“否”。

下面给出一个可以运行的代码。我的测试环境是VS2017(C++14)和libtorch1.13.1。如果在你的编译器下能正常编译运行,那么说明libtorch是正常的。此代码的功能是拟合函数${ z=3x+y+2.5 }$。请注意torch::Tensor loss = lossFunc(predict, c);这一句,predict和c的位置不能反过来。

#include "torch/all.h"

int main()
{
    torch::nn::Linear linear(2, 1);

    /* 30个样本。在这里是一行一个样本 */
    at::Tensor b = torch::rand({ 30, 2 });
    at::Tensor c = torch::zeros({ 30, 1 });
    for (int i = 0; i < 30; i++)
    {
        c[i] = 3 * b[i][0] + b[i][1] + 2.5f;
    }

    cout << b << endl;
    cout << c << endl;

    /* 训练过程 */
    torch::optim::SGD optim(linear->parameters(), torch::optim::SGDOptions(0.01));
    torch::nn::MSELoss lossFunc;
    linear->train();
    for (int i = 0; i < 10000; i++)
    {
        torch::Tensor predict = linear(b);
        torch::Tensor loss = lossFunc(predict, c);
        optim.zero_grad();
        loss.backward();
        optim.step();
        if (i % 1000 == 0)
        {
            /* 每1000次循环输出一次损失函数值 */
            cout << "LOOP:" << i << ",LOSS=" << loss.item() << endl;
        }
    }
    /* 输出训练之后的网络参数 */
    cout << linear->parameters() << endl;

    /* 做个测试 */
    at::Tensor x = torch::tensor({ 1.5f, 2.0f });
    at::Tensor y = linear(x);
    cout << "3*1.5+1*2+2.5=" << y.item();

    return 0;
}

输出内容是:

 0.0341  0.6551
 0.9524  0.1005
 0.3764  0.5524
 0.8860  0.6767
 0.6554  0.9601
 0.7736  0.0955
 0.4260  0.3402
 0.1248  0.1497
 0.2288  0.2765
 0.4508  0.6151
 0.1954  0.0717
 0.5392  0.5821
 0.8622  0.2375
 0.9371  0.0668
 0.6593  0.2563
 0.1854  0.8515
 0.1299  0.4341
 0.8148  0.6432
 0.7303  0.0794
 0.6853  0.5018
 0.7687  0.8698
 0.6909  0.7306
 0.8921  0.8072
 0.6477  0.0745
 0.5048  0.8875
 0.6906  0.4306
 0.7410  0.6294
 0.0095  0.8609
 0.0862  0.8630
 0.6828  0.5330
[ CPUFloatType{30,2} ]
 3.2576
 5.4577
 4.1815
 5.8348
 5.4263
 4.9163
 4.1181
 3.0239
 3.4628
 4.4676
 3.1578
 4.6996
 5.3240
 5.3781
 4.7343
 3.9076
 3.3239
 5.5875
 4.7702
 5.0576
 5.6757
 5.3035
 5.9835
 4.5176
 4.9018
 5.0025
 5.3523
 3.3893
 3.6218
 5.0814
[ CPUFloatType{30,1} ]
LOOP:0,LOSS=33.1978
LOOP:1000,LOSS=0.0120936
LOOP:2000,LOSS=0.00164465
LOOP:3000,LOSS=0.000271623
LOOP:4000,LOSS=4.59804e-05
LOOP:5000,LOSS=7.80881e-06
LOOP:6000,LOSS=1.33087e-06
LOOP:7000,LOSS=2.33062e-07
LOOP:8000,LOSS=4.16838e-08
LOOP:9000,LOSS=8.05691e-09
 2.9999  0.9999
[ CPUFloatType{1,2} ]  2.5001
[ CPUFloatType{1} ]
3*1.5+1*2+2.5=8.99974

 

posted @ 2024-03-09 13:18  兜尼完  阅读(22)  评论(0编辑  收藏  举报