BitwiseXOR для Tensorflow

Мой проект запрашивает новый слой, которому нужен новый оператор Tensor для вычисления bitwiseXOR между вводом x и константой Key k.
Например. x = 4 (битовая форма: 100), k = 7 (111), побитовое значение XOR (x, k) ожидается как 3 (011).

Насколько я знаю, Tensor имеет только оператор LogicXOR для типа bool. К счастью, Tensorflow обладает расширенной способностью иметь новый Op. Тем не менее, я прочитал документ в https://www.tensorflow.org/extend/adding_an_op, Я могу понять основную идею, но это далеко от реализации, возможно, из-за недостатка знаний C ++. Любые предложения по внедрению нового оператора будут полезны. Тогда я смогу использовать эту новую опцию Тензор для создания новых слоев.

0

Решение

Если вы не хотите реализовывать свой собственный оператор C ++, попробуйте tf.py_func который позволяет вам определить функцию Python, которая работает на numpy массивов, а затем используется в качестве операции Tensorflow на графике.

Для вашей проблемы вы можете использовать numpybitwise_xor ():

import tensorflow as tf
import numpy as np

t1 = tf.constant([2,4,6], dtype=tf.int64)
t2 = tf.constant([1,3,5], dtype=tf.int64)

t_xor = tf.py_func(np.bitwise_xor, [t1, t2], tf.int64, stateful=False)

with tf.Session() as sess:
val = sess.run(t_xor)
print(val)

который печатает [3,7,3] как и ожидалось.

Пожалуйста, позаботьтесь об известных ограничениях этой функции (взято по ссылке выше):

Нотабене tf.py_func() операция имеет следующие известные ограничения:

  • Тело функции (т.е. func) не будет сериализовано в
    GraphDef, Следовательно, вы не должны использовать эту функцию, если вам нужно
    сериализовать вашу модель и восстановить ее в другой среде.
  • Операция должна выполняться в том же адресном пространстве, что и программа Python
    что вызывает tf.py_func(), Если вы используете распределенный TensorFlow, вы
    должен запустить tf.train.Server в том же процессе, что и программа,
    звонки tf.py_func() и вы должны прикрепить созданную операцию к устройству
    на этом сервере (например, используя with tf.device():).
0

Другие решения

Других решений пока нет …

По вопросам рекламы [email protected]