libtorch使用model.forward报std::runtime_error错误

1、原因

模型向GPU拷贝时发生异常

	model = torch::jit::load(ptFile);
	if (isHalf)
	{
		model.to(torch::kHalf);
	}
	model.to(device);//GPU版异常,可能模型并没有完全放到GPU上

2、解决方法

model = torch::jit::load(ptFile, torch::kCUDA);

参考:https://github.com/pytorch/pytorch/issues/19302

posted @ 2024-05-03 09:01  珠峰上吹泡泡  阅读(291)  评论(0)    收藏  举报