OpenCV 3.4: логистическая регрессия C ++ (мультикласс)

Я пытаюсь построить мультиклассовый логистический классификатор в C ++. Я использую для этого ML библиотеки OpenCV. Вот код:

void LRtrain(cv::Ptr<cv::ml::LogisticRegression>& lr, cv::Mat& trainMat, cv::Mat& trainLabels, cv::Mat& testResponse, cv::Mat& testMat) {

//Dataset must be changed to CV_32F
trainLabels.convertTo(trainLabels, CV_32F);
trainMat.convertTo(trainMat, CV_32F);
testResponse.convertTo(testResponse, CV_32F);
testMat.convertTo(testMat, CV_32F);

lr->setLearningRate(0.0001);
lr->setIterations(100);
lr->setRegularization(cv::ml::LogisticRegression::REG_L2);
lr->setTrainMethod(cv::ml::LogisticRegression::BATCH);
lr->setMiniBatchSize(1);

cv::Ptr<cv::ml::TrainData> tData = cv::ml::TrainData::create(trainMat, cv::ml::SampleTypes::ROW_SAMPLE,
trainLabels);

//train
lr->train(tData);

lr->save("lr.yml");
lr->predict(testMat, testResponse);

std::cout << lr->get_learnt_thetas().size() << std::endl;

std::cout << testResponse << std::endl;

getLRParam(lr);
}

void SVMevaluate(cv::Mat& testResponse, float& count, float& accuracy, cv::Mat& testLabels) {
cv::Mat confusionMatrix = cv::Mat::zeros(10,10, CV_32S);
for (int i = 0; i < testResponse.rows; i++) {
std::cout << "Test Response , TestLabels " << testResponse.at<int>(i, 0) << " " << testLabels.at<int>(i, 0) << std::endl;
if (testResponse.at<int>(i, 0) == testLabels.at<int>(i,0)) count = count + 1;
confusionMatrix.at<int>(testLabels.at<int>(i, 0)+1, testResponse.at<int>(i, 0)+1) += 1;
}
accuracy = (count / testResponse.rows) * 100;
std::cout << confusionMatrix << std::endl;
}

У меня 8 классов, но ответов я получаю в основном 2. Я не получаю метки, классифицированные как 3,4 или более. Я был бы признателен, если бы кто-нибудь мог дать мне некоторое представление о том, в чем может быть ошибка или почему я получаю ответы о явной бинарной логистической регрессии. (Матрица learnnt_theta — это матрица 513×8, которая, по-видимому, подходит для регистрации мультикласса).

Вот сгенерированная Матрица Путаницы:

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0;
0, 0, 0, 0, 0, 0, 0, 0, 0, 0;
0, 0, 3, 257, 0, 0, 0, 0, 0, 0;
0, 0, 0, 228, 0, 0, 0, 0, 0, 0;
0, 0, 2, 158, 0, 0, 0, 0, 0, 0;
0, 0, 0, 208, 0, 0, 0, 0, 0, 0;
0, 0, 0, 274, 0, 0, 0, 0, 0, 0;
0, 0, 1, 309, 0, 0, 0, 0, 0, 0;
0, 0, 0, 192, 0, 0, 0, 0, 0, 0;
0, 0, 0, 256, 0, 0, 0, 0, 0, 0]
the accuracy is: 12.2352

Спасибо вам всем! 🙂

0

Решение

Задача ещё не решена.

Другие решения

Других решений пока нет …

По вопросам рекламы [email protected]