Я тренирую модель и сохраняю ее, используя:
saver = tf.train.Saver()
saver.save(session, './my_model_name')
Кроме контрольно-пропускной пункт Файл, который просто содержит указатели на самые последние контрольные точки модели, создает следующие 3 файла в текущем пути:
Интересно, что содержится в каждом из этих файлов?
Я хотел бы загрузить эту модель в C ++ и выполнить вывод. label_image Пример загружает модель из одного .п.н. использование файла ReadBinaryProto()
, Интересно, как я могу загрузить его из этих 3 файлов. Что такое C ++ эквивалент следующего?
new_saver = tf.train.import_meta_graph('./my_model_name.meta')
new_saver.restore(session, './my_model_name')
В настоящее время я борюсь с этим сам, я обнаружил, что это не очень просто сделать в настоящее время. Два наиболее часто цитируемых урока на эту тему:
https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f#.goxwm1e5j
а также
https://medium.com/@hamedmp/exporting-trained-tensorflow-models-to-c-the-right-way-cf24b609d183#.g1gak956i
Эквивалент
new_saver = tf.train.import_meta_graph('./my_model_name.meta')
new_saver.restore(session, './my_model_name')
Просто
Status load_graph_status = LoadGraph(graph_path, &session);
Предполагая, что вы «заморозили график» (использовал скрипт, объединяющий файл графика со значениями контрольных точек).
Также смотрите обсуждение здесь: Tensorflow Различные способы экспорта и запуска графа в C ++
То, что создает ваша заставка, называется «Контрольная точка V2» и было введено в TF 0.12.
У меня получилось довольно неплохо (хотя документы по части C ++ ужасны, так что мне потребовался целый день, чтобы решить). Некоторые люди предлагают преобразование всех переменных в константы или же замораживание графика, но ничего из этого на самом деле не нужно.
Python часть (сохранение)
with tf.Session() as sess:
tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model')
Если вы создаете Saver
с tf.trainable_variables()
, вы можете сэкономить себе немного головной боли и места для хранения. Но, возможно, некоторым более сложным моделям нужно сохранить все данные, а затем убрать этот аргумент Saver
Просто убедитесь, что вы создаете Saver
после ваш график создан. Также очень разумно дать всем переменным / слоям уникальные имена, иначе вы можете столкнуться с различными проблемами.
C ++ часть (вывод)
Обратите внимание, что checkpointPath
не путь к какому-либо из существующих файлов, просто их общий префикс. Если вы ошибочно указали путь к .index
файл, TF не скажет вам, что это не так, но он умрет во время вывода из-за неинициализированных переменных.
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
using namespace std;
using namespace tensorflow;
...
// set up your input paths
const string pathToGraph = "models/my-model.meta"const string checkpointPath = "models/my-model";
...
auto session = NewSession(SessionOptions());
if (session == nullptr) {
throw runtime_error("Could not create Tensorflow session.");
}
Status status;
// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}
// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
throw runtime_error("Error creating graph: " + status.ToString());
}
// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
{{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
{},
{graph_def.saver_def().restore_op_name()},
nullptr);
if (!status.ok()) {
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}
// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);
Для полноты вот эквивалент Python:
Вывод в Python
with tf.Session() as sess:
saver = tf.train.import_meta_graph('models/my-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('models/'))
outputTensors = sess.run(outputOps, feed_dict=feedDict)