Я использую libsvm версии 3.16. Я прошел обучение в Matlab и создал модель. Теперь я хотел бы сохранить эту модель на диск и загрузить ее в мою программу на C ++. До сих пор я нашел следующие альтернативы:
Таким образом, оба эти варианта не являются удовлетворительными,
У кого-нибудь есть идея?
Вариант 1 на самом деле довольно разумный. Если вы сохраните модель в C-формате libsvm через matlab, то работать с моделью в C / C ++ будет просто, используя функции, предоставляемые libsvm. Попытка работать с данными в формате Matlab в C ++, вероятно, будет намного сложнее.
main
Функция в «svm-Forext.c» (находится в корневом каталоге пакета libsvm), вероятно, имеет большую часть того, что вам нужно:
if((model=svm_load_model(argv[i+1]))==0)
{
fprintf(stderr,"can't open model file %s\n",argv[i+1]);
exit(1);
}
Предсказать метку, например x
используя модель, вы можете запустить
int predict_label = svm_predict(model,x);
Самая хитрая часть этого будет состоять в том, чтобы перевести ваши данные в формат libsvm (если ваши данные не в формате текстового файла libsvm, и в этом случае вы можете просто использовать predict
функция в «svm-Forext.C»).
Вектор libsvm, x
, это массив struct svm_node
это представляет собой редкий массив данных. Каждый svm_node имеет индекс и значение, а вектор должен заканчиваться индексом, установленным в -1. Например, для кодирования вектора [0,1,0,5]
Вы могли бы сделать следующее:
struct svm_node *x = (struct svm_node *) malloc(3*sizeof(struct svm_node));
x[0].index=2; //NOTE: libsvm indices start at 1
x[0].value=1.0;
x[1].index=4;
x[1].value=5.0;
x[2].index=-1;
Для типов SVM, отличных от классификатора (C_SVC), посмотрите на predict
функция в «svm-Forext.C».
Мое решение состояло в том, чтобы переучиться на C ++, потому что я не мог найти хороший способ напрямую сохранить модель. Вот мой код Вам нужно будет адаптировать его и немного почистить. Самое большое изменение, которое вы должны сделать, это не трудно кодировать svm_parameter
ценности, как я сделал. Вам также придется заменить FilePath
с std::string
, Я копирую, вставляю и делаю небольшие правки здесь, в SO, поэтому форматирование не будет идеальным:
Используется так:
auto targetsPath = FilePath("targets.txt");
auto observationsPath = FilePath("observations.txt");
auto targetsMat = MatlabMatrixFileReader::Read(targetsPath, ',');
auto observationsMat = MatlabMatrixFileReader::Read(observationsPath, ',');
auto v = MiscVector::ConvertVecOfVecToVec(targetsMat);
auto model = SupportVectorRegressionModel{ observationsMat, v };
std::vector<double> observation{ { // 32 feature observation
0.883575729725847,0.919446119013878,0.95359403450317,
0.968233630936732,0.91891307107125,0.887897763183844,
0.937588566544751,0.920582702918882,0.888864454119387,
0.890066735260163,0.87911085669864,0.903745573664995,
0.861069296586979,0.838606194934074,0.856376230548304,
0.863011311537075,0.807688936997926,0.740434984165146,
0.738498042748759,0.736410940165691,0.697228384912424,
0.608527698289016,0.632994967880269,0.66935784966765,
0.647761430696238,0.745961037635717,0.560761134660957,
0.545498063585615,0.590854855113663,0.486827902942118,
0.187128866890822,- 0.0746523069562551
} };
double prediction = model.Predict(observation);
miscvector.h
static vector<double> ConvertVecOfVecToVec(const vector<vector<double>> &mat)
{
vector<double> targetsVec;
targetsVec.reserve(mat.size());
for (size_t i = 0; i < mat.size(); i++)
{
targetsVec.push_back(mat[i][0]);
}
return targetsVec;
}
libsvmtargetobjectconvertor.h
#pragma once
#include "machinelearning.h"
struct svm_node;
class LibSvmTargetObservationConvertor
{
public:
svm_node ** LibSvmTargetObservationConvertor::ConvertObservations(const vector<MlObservation> &observations, size_t numFeatures) const
{
svm_node **svmObservations = (svm_node **)malloc(sizeof(svm_node *) * observations.size());
for (size_t rowI = 0; rowI < observations.size(); rowI++)
{
svm_node *row = (svm_node *)malloc(sizeof(svm_node) * numFeatures);
for (size_t colI = 0; colI < numFeatures; colI++)
{
row[colI].index = colI;
row[colI].value = observations[rowI][colI];
}
row[numFeatures].index = -1; // apparently needed
svmObservations[rowI] = row;
}
return svmObservations;
}
svm_node* LibSvmTargetObservationConvertor::ConvertMatToSvmNode(const MlObservation &observation) const
{
size_t numFeatures = observation.size();
svm_node *obsNode = (svm_node *)malloc(sizeof(svm_node) * numFeatures);
for (size_t rowI = 0; rowI < numFeatures; rowI++)
{
obsNode[rowI].index = rowI;
obsNode[rowI].value = observation[rowI];
}
obsNode[numFeatures].index = -1; // apparently needed
return obsNode;
}
};
machinelearning.h
#pragma once
#include <vector>
using std::vector;
using MlObservation = vector<double>;
using MlTarget = double;
//machinelearningmodel.h
#pragma once
#include <vector>
#include "machinelearning.h"class MachineLearningModel
{
public:
virtual ~MachineLearningModel() {}
virtual double Predict(const MlObservation &observation) const = 0;
};
matlabmatrixfilereader.h
#pragma once
#include <vector>
using std::vector;
class FilePath;
// Matrix created with command:
// dlmwrite('my_matrix.txt', somematrix, 'delimiter', ',', 'precision', 15);
// In these files, each row is a matrix row. Commas separate elements on a row.
// There is no space at the end of a row. There is a blank line at the bottom of the file.
// File format:
// 0.4,0.7,0.8
// 0.9,0.3,0.5
// etc.
static class MatlabMatrixFileReader
{
public:
static vector<vector<double>> Read(const FilePath &asciiFilePath, char delimiter)
{
vector<vector<double>> values;
vector<double> valueline;
std::ifstream fin(asciiFilePath.Path());
string item, line;
while (getline(fin, line))
{
std::istringstream in(line);
while (getline(in, item, delimiter))
{
valueline.push_back(atof(item.c_str()));
}
values.push_back(valueline);
valueline.clear();
}
fin.close();
return values;
}
};
supportvectorregressionmodel.h
#pragma once
#include <vector>
using std::vector;
#include "machinelearningmodel.h"
#include "svm.h" // libsvm
class FilePath;
class SupportVectorRegressionModel : public MachineLearningModel
{
public:
SupportVectorRegressionModel::~SupportVectorRegressionModel()
{
svm_free_model_content(model_);
svm_destroy_param(¶m_);
svm_free_and_destroy_model(&model_);
}
SupportVectorRegressionModel::SupportVectorRegressionModel(const vector<MlObservation>& observations, const vector<MlTarget>& targets)
{
// assumes all observations have same number of features
size_t numFeatures = observations[0].size();
//setup targets
//auto v = ConvertVecOfVecToVec(targetsMat);
double *targetsPtr = const_cast<double *>(&targets[0]); // why aren't the targets const?
LibSvmTargetObservationConvertor conv;
svm_node **observationsPtr = conv.ConvertObservations(observations, numFeatures);
// setup observations
//svm_node **observations = BuildObservations(observationsMat, numFeatures);
// setup problem
svm_problem problem;
problem.l = targets.size();
problem.y = targetsPtr;
problem.x = observationsPtr;
// specific to out training sets
// TODO: This is hard coded.
// Bust out these values for use in constructor
param_.C = 0.4; // cost
param_.svm_type = 4; // SVR
param_.kernel_type = 2; // radial
param_.nu = 0.6; // SVR nu
// These values are the defaults used in the Matlab version
// as found in svm_model_matlab.c
param_.gamma = 1.0 / (double)numFeatures;
param_.coef0 = 0;
param_.cache_size = 100; // in MB
param_.shrinking = 1;
param_.probability = 0;
param_.degree = 3;
param_.eps = 1e-3;
param_.p = 0.1;
param_.shrinking = 1;
param_.probability = 0;
param_.nr_weight = 0;
param_.weight_label = NULL;
param_.weight = NULL;
// suppress command line output
svm_set_print_string_function([](auto c) {});
model_ = svm_train(&problem, ¶m_);
}
double SupportVectorRegressionModel::Predict(const vector<double>& observation) const
{
LibSvmTargetObservationConvertor conv;
svm_node *obsNode = conv.ConvertMatToSvmNode(observation);
double prediction = svm_predict(model_, obsNode);
return prediction;
}
SupportVectorRegressionModel::SupportVectorRegressionModel(const FilePath & modelFile)
{
model_ = svm_load_model(modelFile.Path().c_str());
}
private:
svm_model *model_;
svm_parameter param_;
};