Я пытаюсь построить мультиклассовый логистический классификатор в C ++. Я использую для этого ML библиотеки OpenCV. Вот код:
void LRtrain(cv::Ptr<cv::ml::LogisticRegression>& lr, cv::Mat& trainMat, cv::Mat& trainLabels, cv::Mat& testResponse, cv::Mat& testMat) {
//Dataset must be changed to CV_32F
trainLabels.convertTo(trainLabels, CV_32F);
trainMat.convertTo(trainMat, CV_32F);
testResponse.convertTo(testResponse, CV_32F);
testMat.convertTo(testMat, CV_32F);
lr->setLearningRate(0.0001);
lr->setIterations(100);
lr->setRegularization(cv::ml::LogisticRegression::REG_L2);
lr->setTrainMethod(cv::ml::LogisticRegression::BATCH);
lr->setMiniBatchSize(1);
cv::Ptr<cv::ml::TrainData> tData = cv::ml::TrainData::create(trainMat, cv::ml::SampleTypes::ROW_SAMPLE,
trainLabels);
//train
lr->train(tData);
lr->save("lr.yml");
lr->predict(testMat, testResponse);
std::cout << lr->get_learnt_thetas().size() << std::endl;
std::cout << testResponse << std::endl;
getLRParam(lr);
}
void SVMevaluate(cv::Mat& testResponse, float& count, float& accuracy, cv::Mat& testLabels) {
cv::Mat confusionMatrix = cv::Mat::zeros(10,10, CV_32S);
for (int i = 0; i < testResponse.rows; i++) {
std::cout << "Test Response , TestLabels " << testResponse.at<int>(i, 0) << " " << testLabels.at<int>(i, 0) << std::endl;
if (testResponse.at<int>(i, 0) == testLabels.at<int>(i,0)) count = count + 1;
confusionMatrix.at<int>(testLabels.at<int>(i, 0)+1, testResponse.at<int>(i, 0)+1) += 1;
}
accuracy = (count / testResponse.rows) * 100;
std::cout << confusionMatrix << std::endl;
}
У меня 8 классов, но ответов я получаю в основном 2. Я не получаю метки, классифицированные как 3,4 или более. Я был бы признателен, если бы кто-нибудь мог дать мне некоторое представление о том, в чем может быть ошибка или почему я получаю ответы о явной бинарной логистической регрессии. (Матрица learnnt_theta — это матрица 513×8, которая, по-видимому, подходит для регистрации мультикласса).
Вот сгенерированная Матрица Путаницы:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0;
0, 0, 0, 0, 0, 0, 0, 0, 0, 0;
0, 0, 3, 257, 0, 0, 0, 0, 0, 0;
0, 0, 0, 228, 0, 0, 0, 0, 0, 0;
0, 0, 2, 158, 0, 0, 0, 0, 0, 0;
0, 0, 0, 208, 0, 0, 0, 0, 0, 0;
0, 0, 0, 274, 0, 0, 0, 0, 0, 0;
0, 0, 1, 309, 0, 0, 0, 0, 0, 0;
0, 0, 0, 192, 0, 0, 0, 0, 0, 0;
0, 0, 0, 256, 0, 0, 0, 0, 0, 0]
the accuracy is: 12.2352
Спасибо вам всем! 🙂
Задача ещё не решена.
Других решений пока нет …