C ++ обратное автоматическое дифференцирование с графом

Я пытаюсь сделать обратный режим автоматического дифференцирования в C ++.

Идея, которую я придумал, состоит в том, что каждая переменная, которая является результатом операции с одной или двумя другими переменными, будет сохранять градиенты в векторе.

Это код:

class Var {
private:
double value;
char character;
std::vector<std::pair<double, const Var*> > children;

public:
Var(const double& _value=0, const char& _character='_') : value(_value), character(_character) {};
void set_character(const char& character){ this->character = character; }

// computes the derivative of the current object with respect to 'var'
double gradient(Var* var) const{
if(this==var){
return 1.0;
}

double sum=0.0;
for(auto& pair : children){
// std::cout << "(" << this->character << " -> " <<  pair.second->character << ", " << this << " -> " << pair.second << ", weight=" << pair.first << ")" << std::endl;
sum += pair.first*pair.second->gradient(var);
}
return sum;
}

friend Var operator+(const Var& l, const Var& r){
Var result(l.value+r.value);
result.children.push_back(std::make_pair(1.0, &l));
result.children.push_back(std::make_pair(1.0, &r));
return result;
}

friend Var operator*(const Var& l, const Var& r){
Var result(l.value*r.value);
result.children.push_back(std::make_pair(r.value, &l));
result.children.push_back(std::make_pair(l.value, &r));
return result;
}

friend std::ostream& operator<<(std::ostream& os, const Var& var){
os << var.value;
return os;
}
};

Я попытался запустить код следующим образом:

int main(int argc, char const *argv[]) {
Var x(5,'x'), y(6,'y'), z(7,'z');

Var k = z + x*y;
k.set_character('k');

std::cout << "k = " << k << std::endl;
std::cout << "∂k/∂x = " << k.gradient(&x) << std::endl;
std::cout << "∂k/∂y = " << k.gradient(&y) << std::endl;
std::cout << "∂k/∂z = " << k.gradient(&z) << std::endl;

return 0;
}

Вычислительный граф, который должен быть построен, выглядит следующим образом:

       x(5)   y(6)              z(7)
\     /                 /
∂w/∂x=y  \   /  ∂w/∂y=x        /
\ /                 /
w=x*y               /
\               /  ∂k/∂z=1
\             /
∂k/∂w=1  \           /
\_________/
|
k=w+z

Тогда, если я хочу рассчитать ∂k/∂x например, я должен умножить градиенты по краям и суммировать результат для каждого ребра. Это делается рекурсивно double gradient(Var* var) const, Так что я ∂k/∂x = ∂k/∂w * ∂w/∂x + ∂k/∂z * ∂z/∂x,

Если у меня есть промежуточный расчет, такой как x*y здесь что-то идет не так. когда std::cout здесь есть комментарий:

k = 37
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂x = 0
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂y = 5
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂z = 1

Он печатает, какая переменная связана с какой, затем их адреса и вес соединения (который должен быть градиентом).

Проблема в weight=0 между x и промежуточная переменная, которая содержит результат x*y (и который я обозначил как w в моем графике).
Я понятия не имею, почему этот ноль, а не другой вес, связанный с y,

Еще одна вещь, которую я заметил, это то, что если вы переключите линии в operator* вот так :

result.children.push_back(std::make_pair(1.0, &r));
result.children.push_back(std::make_pair(1.0, &l));

Тогда это y соединения, которые отменяются.

Заранее благодарю за любую помощь.

4

Решение

Линия:

Var k = z + x*y;

Вызовы operator*, который возвращает Var временный, который затем используется для r аргумент operator+где pair хранит адрес временный. После завершения строки k дети включают указатель на место временного имел был, но его больше не существует.


Хотя это не защищает от вышеуказанной ошибки, вы можете создать предполагаемое поведение, избегая неназванного временного …

Var xy = x * y;
xy.set_character('*');
Var k = z + xy;
k.set_character('k');

…с которой ваша программа выдает:

k = 37
∂k/∂x = 6
∂k/∂y = 5
∂k/∂z = 1

Лучшим решением может быть захват детей по значению.


В качестве общего совета по обнаружению таких ошибок … когда ваша программа, кажется, делает что-то необъяснимое (и / или происходит сбой), попробуйте запустить ее под детектором ошибок памяти, таким как Valgrind. Для вашего кода отчет начинается с:

==22137== Invalid read of size 8
==22137==    at 0x1090EA: Var::gradient(Var*) const (in /home/median/so/deriv)
==22137==    by 0x109109: Var::gradient(Var*) const (in /home/median/so/deriv)
==22137==    by 0x108E12: main (in /home/median/so/deriv)
==22137==  Address 0x5b82cd0 is 0 bytes inside a block of size 32 free'd
==22137==    at 0x4C3123B: operator delete(void*) (in /usr/lib/valgrind/vgpreload_memcheck-amd64-linux.so)
==22137==    by 0x109FC1: __gnu_cxx::new_allocator<std::pair<double, Var const*> >::deallocate(std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
==22137==    by 0x109CDD: std::allocator_traits<std::allocator<std::pair<double, Var const*> > >::deallocate(std::allocator<std::pair<double, Var const*> >&, std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
==22137==    by 0x109963: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_deallocate(std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
==22137==    by 0x1097BC: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::~_Vector_base() (in /home/median/so/deriv)
==22137==    by 0x1095EA: std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::~vector() (in /home/median/so/deriv)
==22137==    by 0x109161: Var::~Var() (in /home/median/so/deriv)
==22137==    by 0x108D95: main (in /home/median/so/deriv)
==22137==  Block was alloc'd at
==22137==    at 0x4C3017F: operator new(unsigned long) (in /usr/lib/valgrind/vgpreload_memcheck-amd64-linux.so)
==22137==    by 0x10A153: __gnu_cxx::new_allocator<std::pair<double, Var const*> >::allocate(unsigned long, void const*) (in /home/median/so/deriv)
==22137==    by 0x10A060: std::allocator_traits<std::allocator<std::pair<double, Var const*> > >::allocate(std::allocator<std::pair<double, Var const*> >&, unsigned long) (in /home/median/so/deriv)
==22137==    by 0x109F03: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_allocate(unsigned long) (in /home/median/so/deriv)
==22137==    by 0x109A8D: void std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_realloc_insert<std::pair<double, Var const*> >(__gnu_cxx::__normal_iterator<std::pair<double, Var const*>*, std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > > >, std::pair<double, Var const*>&&) (in /home/median/so/deriv)
==22137==    by 0x1098CF: void std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::emplace_back<std::pair<double, Var const*> >(std::pair<double, Var const*>&&) (in /home/median/so/deriv)
==22137==    by 0x10973F: std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::push_back(std::pair<double, Var const*>&&) (in /home/median/so/deriv)
==22137==    by 0x109520: operator*(Var const&, Var const&) (in /home/median/so/deriv)
==22137==    by 0x108D6F: main (in /home/median/so/deriv)

Другой способ поймать это — добавить запись в деструктор, чтобы вы знали, когда адреса объектов, упомянутые в вашей записи, больше не действительны.

4

Другие решения

Других решений пока нет …

По вопросам рекламы ammmcru@yandex.ru
Adblock
detector