Файлы моделей TensorFlow 0.12

Я тренирую модель и сохраняю ее, используя:

saver = tf.train.Saver()
saver.save(session, './my_model_name')

Кроме контрольно-пропускной пункт Файл, который просто содержит указатели на самые последние контрольные точки модели, создает следующие 3 файла в текущем пути:

  1. my_model_name.meta
  2. my_model_name.index
  3. my_model_name.data-00000-из-00001

Интересно, что содержится в каждом из этих файлов?

Я хотел бы загрузить эту модель в C ++ и выполнить вывод. label_image Пример загружает модель из одного .п.н. использование файла ReadBinaryProto(), Интересно, как я могу загрузить его из этих 3 файлов. Что такое C ++ эквивалент следующего?

new_saver = tf.train.import_meta_graph('./my_model_name.meta')
new_saver.restore(session, './my_model_name')

3

Решение

В настоящее время я борюсь с этим сам, я обнаружил, что это не очень просто сделать в настоящее время. Два наиболее часто цитируемых урока на эту тему:
https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f#.goxwm1e5j
а также
https://medium.com/@hamedmp/exporting-trained-tensorflow-models-to-c-the-right-way-cf24b609d183#.g1gak956i

Эквивалент

new_saver = tf.train.import_meta_graph('./my_model_name.meta')
new_saver.restore(session, './my_model_name')

Просто

Status load_graph_status = LoadGraph(graph_path, &session);

Предполагая, что вы «заморозили график» (использовал скрипт, объединяющий файл графика со значениями контрольных точек).
Также смотрите обсуждение здесь: Tensorflow Различные способы экспорта и запуска графа в C ++

3

Другие решения

То, что создает ваша заставка, называется «Контрольная точка V2» и было введено в TF 0.12.

У меня получилось довольно неплохо (хотя документы по части C ++ ужасны, так что мне потребовался целый день, чтобы решить). Некоторые люди предлагают преобразование всех переменных в константы или же замораживание графика, но ничего из этого на самом деле не нужно.

Python часть (сохранение)

with tf.Session() as sess:
tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model')

Если вы создаете Saver с tf.trainable_variables(), вы можете сэкономить себе немного головной боли и места для хранения. Но, возможно, некоторым более сложным моделям нужно сохранить все данные, а затем убрать этот аргумент SaverПросто убедитесь, что вы создаете Saver после ваш график создан. Также очень разумно дать всем переменным / слоям уникальные имена, иначе вы можете столкнуться с различными проблемами.

C ++ часть (вывод)

Обратите внимание, что checkpointPath не путь к какому-либо из существующих файлов, просто их общий префикс. Если вы ошибочно указали путь к .index файл, TF не скажет вам, что это не так, но он умрет во время вывода из-за неинициализированных переменных.

#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>

using namespace std;
using namespace tensorflow;

...
// set up your input paths
const string pathToGraph = "models/my-model.meta"const string checkpointPath = "models/my-model";
...

auto session = NewSession(SessionOptions());
if (session == nullptr) {
throw runtime_error("Could not create Tensorflow session.");
}

Status status;

// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}

// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
throw runtime_error("Error creating graph: " + status.ToString());
}

// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
{{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
{},
{graph_def.saver_def().restore_op_name()},
nullptr);
if (!status.ok()) {
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}

// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);

Для полноты вот эквивалент Python:

Вывод в Python

with tf.Session() as sess:
saver = tf.train.import_meta_graph('models/my-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('models/'))
outputTensors = sess.run(outputOps, feed_dict=feedDict)
6

По вопросам рекламы [email protected]