Как сохранить модель PyTorch, если я хочу, чтобы она была загружена модулем OpenCV dnn

Я тренирую простую модель классификации с помощью PyTorch и загружаю ее с помощью opencv3.3, но она выдает исключение и говорит

Ошибка OpenCV: функция / функция не реализована (неподдерживаемый тип Lua) в файле readObject
/home/ramsus/Qt/3rdLibs/opencv/modules/dnn/src/torch/torch_importer.cpp,
линия 797
/home/ramsus/Qt/3rdLibs/opencv/modules/dnn/src/torch/torch_importer.cpp:797:
ошибка: (-213) Неподдерживаемый тип Lua в функции readObject

Определение модели

class conv_block(nn.Module):
def __init__(self, in_filter, out_filter, kernel):
super(conv_block, self).__init__()

self.conv1 = nn.Conv2d(in_filter, out_filter, kernel, 1, (kernel - 1)//2)
self.batchnorm = nn.BatchNorm2d(out_filter)
self.maxpool = nn.MaxPool2d(2, 2)

def forward(self, x):
x = self.conv1(x)
x = self.batchnorm(x)
x = F.relu(x)
x = self.maxpool(x)

return x

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()

self.conv1 = conv_block(3, 6, 3)
self.conv2 = conv_block(6, 16, 3)
self.fc1 = nn.Linear(16 * 8 * 8, 120)
self.bn1 = nn.BatchNorm1d(120)
self.fc2 = nn.Linear(120, 84)
self.bn2 = nn.BatchNorm1d(84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size()[0], -1)
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
x = self.fc3(x)
return x

Эта модель использует только Conv2d, ReLU, BatchNorm2d, MaxPool2d и линейный слой, все слои поддерживаются opencv3.3

Я сохраняю это с помощью state_dict

torch.save(net.state_dict(), 'cifar10_model')

Загрузите его на C ++ как

std::string const model_file("/home/some_folder/cifar10_model");

std::cout<<"read net from torch"<<std::endl;
dnn::Net net = dnn::readNetFromTorch(model_file);

Я думаю, что я сохраняю модель неверным способом, как правильно сохранить модель PyTorch для загрузки с использованием OpenCV? Спасибо

Редактировать :

Я использую другой способ сохранить модель, но она также не может быть загружена

torch.save(net, 'cifar10_model.net')

Это ошибка? Или я делаю что-то не так?

2

Решение

Я нашел ответ, opencv3.3 не поддерживает PyTorch (https://github.com/pytorch/pytorch) но пыторь (https://github.com/hughperkins/pytorch), это большой сюрприз, я никогда не знаю, что существует еще одна версия pytorch (выглядит как мертвый проект, долгое время не обновлялась), я надеюсь, что они могли бы упомянуть, какой pytorch они поддерживают в вики.

2

Другие решения

Других решений пока нет …

По вопросам рекламы [email protected]