免费注册 查看新帖 |

Chinaunix

  平台 论坛 博客 文库
12下一页
最近访问板块 发新帖
查看: 11509 | 回复: 14
打印 上一主题 下一主题

[C++] 一段自动求导的代码,C++er 来看下 [复制链接]

论坛徽章:
1
2015年辞旧岁徽章
日期:2015-03-03 16:54:15
跳转到指定楼层
1 [收藏(0)] [报告]
发表于 2013-04-11 21:01 |只看该作者 |倒序浏览
本帖最后由 lost_templar 于 2013-04-12 17:11 编辑

自己写了做自动数值求导的;

一次导数正常,二次就会出错,有些莫名其妙。

一次求导正确:
  1. double ds( double x )
  2. {
  3.     return std::sin(x);
  4. }
复制代码
求导使用:
  1. derivative<0, double(double)>( ds )( 0.0 );
复制代码
输出 1

二次求导则会出错:
  1.     auto const& dds = derivative<0, double(double)>(ds);

  2.     derivative<0, double(double)>(dds)( 0.0 );
复制代码
依旧输出 1,不知为何


derivitative 实现的大致代码如下:
  1.     template<std::size_t M, typename Dummy_Type> struct derivative;

  2.     template< std::size_t M, typename R, typename... Types >
  3.     struct derivative< M, R(Types...) >
  4.     {
  5.         typedef R                                               return_type;
  6.         typedef std::function< R(Types...) >                    function_type;
  7.         typedef typename type_at< M, Types... >::result_type    result_type;
  8.         typedef std::size_t                                     size_type;

  9.         function_type ff;

  10.         template< typename Function >
  11.         derivative( const Function& ff_ ) : ff( ff_ ) { }

  12.         result_type operator()( Types... vts ) const
  13.         {
  14.             typedef typename stepper_value_type<result_type>::value_type value_type; //fix for complex

  15.             const value_type decrease_step     = 1.6180339887498948482;
  16.             const value_type decrease_step_2   = 2.6180339887498948482;
  17.             const value_type safe_boundary     = 2.6180339887498948482;
  18.             const value_type start_step        = 1.0e-8;
  19.             const result_type x                = value_at<M, Types...>()( vts... );
  20.             const size_type iter_depth         = 64;

  21.             result_type error   = std::numeric_limits<value_type>::max();
  22.             result_type ans     = std::numeric_limits<value_type>::quiet_NaN();
  23.             value_type step     = start_step;

  24.             auto const& lhs_f = [&step]( result_type x ) { return x - step; };
  25.             auto const& rhs_f = [&step]( result_type x ) { return x + step; };

  26.             oscillate_function< M, return_type, Types... > const lhs_of( ff, lhs_f );
  27.             oscillate_function< M, return_type, Types... > const rhs_of( ff, rhs_f );

  28.             //matrix<result_type> a( iter_depth, iter_depth );
  29.             result_type a[iter_depth][iter_depth];

  30.             a[0][0] = ( rhs_of( vts... ) - lhs_of( vts... ) ) / ( step + step );

  31.             //for ( size_type i = 1; i != a.row(); ++i )
  32.             for ( size_type i = 1; i != iter_depth; ++i )
  33.             {
  34.                 step /= decrease_step;
  35.                 a[i][0] = ( rhs_of( vts... ) - lhs_of( vts... ) ) / ( step + step );

  36.                 result_type factor = decrease_step_2;

  37.                 for ( size_type j = 1; j <= i; ++j )
  38.                 {
  39.                     const result_type factor_1 = factor - result_type(1);

  40.                     a[i][j]  = ( a[i][j-1] * factor - a[i-1][j-1] ) / factor_1;

  41.                     factor *= decrease_step_2;

  42.                     const result_type error_so_far = std::max( std::abs(a[i][j]-a[i][j-1]), std::abs(a[i][j]-a[i-1][j-1]) );

  43.                     if ( error > error_so_far )
  44.                     {
  45.                         error = error_so_far;
  46.                         ans = a[i][j];
  47.                     }
  48.                 }

  49.                 if ( std::abs( a[i][i] - a[i-1][i-1] ) >=  safe_boundary * error )
  50.                     return ans;

  51.             }

  52.             return ans;
  53.         }

  54.     };
复制代码
其中 type_at 如下实现:
  1.     template< std::size_t N, typename T, typename... Types >
  2.     struct type_at
  3.     {
  4.         static_assert( N < sizeof...(Types)+1, "dim size exceeds limitation!" );
  5.         typedef typename type_at<N-1, Types...>::result_type result_type;
  6.     };

  7.     template<typename T, typename... Types>
  8.     struct type_at< 0, T, Types...>
  9.     {
  10.         typedef T result_type;
  11.     };
复制代码
那个 matrix 可以用个二维数组代替,应该不难理解。

论坛徽章:
1
2015年辞旧岁徽章
日期:2015-03-03 16:54:15
2 [报告]
发表于 2013-04-11 21:13 |只看该作者
不好意思,漏了 oscillate_function 的定义:
  1.     namespace oscillate_function_private_fdspojiasldkjasasfdioj4asfd4d
  2.     {
  3.     template< typename R, std::size_t Bn, std::size_t N >
  4.     struct oscillate_backward_impl
  5.     {
  6.         template<typename F, typename Type, typename... Types >
  7.         R impl( F F_, Type t, Types... vts ) const
  8.         {
  9.             return oscillate_backward_impl<R, Bn+1, N>().impl( F_, vts..., t );
  10.         }
  11.     };

  12.     template< typename R, std::size_t N >
  13.     struct oscillate_backward_impl<R, N, N>
  14.     {
  15.         template< typename F, typename... Types >
  16.         R impl( F F_, Types... vts ) const
  17.         {
  18.             return F_( vts... );
  19.         }
  20.     };

  21.     template<typename R, std::size_t Fn, std::size_t Bn>
  22.     struct oscillate_forward_impl
  23.     {
  24.         template<typename F, typename f, typename Type, typename... Types>
  25.         R impl( F F_, f f_, Type t, Types... vts ) const
  26.         {
  27.             return oscillate_forward_impl< R, Fn-1, Bn >().impl( F_, f_, vts..., t );
  28.         }
  29.     };

  30.     template<typename R, std::size_t Bn>
  31.     struct oscillate_forward_impl< R, 0, Bn>
  32.     {
  33.         template<typename F, typename f, typename Type, typename... Types>
  34.         R impl( F F_, f f_, Type t, Types... vts ) const
  35.         {
  36.             return oscillate_backward_impl<R, 0, Bn>().impl( F_, vts..., f_(t) );
  37.         }
  38.     };

  39.     }//namespace oscillate_function_private_fdspojiasldkjasasfdioj4asfd4d

  40.     template< std::size_t N, typename R, typename... Types >
  41.     struct oscillate_function
  42.     {
  43.         typedef R return_type;
  44.         typedef typename type_at< N, Types... >::result_type oscillate_type;
  45.         typedef std::function<R(Types...)> function_type;
  46.         typedef std::function<oscillate_type(oscillate_type)> oscillate_function_type;

  47.         static_assert( N < sizeof...(Types), "dim size exceeds arguments limitation." );

  48.         function_type F;
  49.         oscillate_function_type f;

  50.         oscillate_function( const function_type& F_, const oscillate_function_type& f_ ) : F( F_ ), f( f_ ) {}

  51.         return_type operator()( Types... vts ) const
  52.         {
  53.             using namespace oscillate_function_private_fdspojiasldkjasasfdioj4asfd4d;
  54.             return oscillate_forward_impl< R, N, sizeof...(vts)-N-1 >().impl( F, f, vts... );
  55.         }

  56.     };
复制代码

论坛徽章:
1
2015年辞旧岁徽章
日期:2015-03-03 16:54:15
3 [报告]
发表于 2013-04-11 21:15 |只看该作者
本帖最后由 lost_templar 于 2013-04-11 21:28 编辑

好吧,还有个 value_at
  1.     template< std::size_t N, typename T, typename... Types >
  2.     struct value_at
  3.     {
  4.         typedef typename type_at< N, T, Types...>::result_type result_type;

  5.         result_type operator()( T, Types... vts ) const
  6.         {
  7.             return value_at<N-1, Types...>()( vts... );
  8.         }
  9.     };

  10.     template< typename T, typename... Types >
  11.     struct value_at< 0, T, Types... >
  12.     {
  13.         typedef T result_type;

  14.         result_type operator()( T vt, Types... ) const
  15.         {
  16.             return vt;
  17.         }
  18.     };
复制代码
淫呢?.......

论坛徽章:
0
4 [报告]
发表于 2013-04-12 16:12 |只看该作者
给matrix定义吧,用二维数组代替编译不过……

论坛徽章:
0
5 [报告]
发表于 2013-04-12 16:12 |只看该作者
还有stepper_value_type呢。
typedef typename stepper_value_type<result_type>::value_type value_type;
虽然这行用typedef double value_type;直接代替好像问题不大?

论坛徽章:
1
2015年辞旧岁徽章
日期:2015-03-03 16:54:15
6 [报告]
发表于 2013-04-12 16:33 |只看该作者
本帖最后由 lost_templar 于 2013-06-10 05:59 编辑

post deleted.........

论坛徽章:
1
2015年辞旧岁徽章
日期:2015-03-03 16:54:15
7 [报告]
发表于 2013-04-12 16:35 |只看该作者
回复 5# 幻の上帝


    steper_value_type 可以直接用 value_type 代替;

至于矩阵那个,可以修改下两行通过:
  1.             oscillate_function< M, return_type, Types... > const rhs_of( ff, rhs_f );

  2.             //matrix<result_type> a( iter_depth, iter_depth );
  3.             result_type a[iter_depth][iter_depth];

  4.             a[0][0] = ( rhs_of( vts... ) - lhs_of( vts... ) ) / ( step + step );

  5.             //for ( size_type i = 1; i != a.row(); ++i )
  6.             for ( size_type i = 1; i != iter_depth; ++i )
  7.             {
复制代码

论坛徽章:
0
8 [报告]
发表于 2013-04-13 08:44 |只看该作者
似乎迭代一次就在std::abs( a[i][i] - a[i-1][i-1] ) >=  safe_boundary * error后面停了,是预期吗?(不过就算不停结果也一样不对……)
另外matrix\functional下min.hpp和max.hpp在debug时有问题, assert( m.size() ) > 0 ; 看来得是assert( m.size() > 0 );。

论坛徽章:
1
2015年辞旧岁徽章
日期:2015-03-03 16:54:15
9 [报告]
发表于 2013-04-13 20:09 |只看该作者
本帖最后由 lost_templar 于 2013-04-13 20:10 编辑
幻の上帝 发表于 2013-04-13 08:44
似乎迭代一次就在std::abs( a - a ) >=  safe_boundary * error后面停了,是预期吗?(不过就算不停结果也一 ...


停顿是预期的行为;

后边那两个 assert 是我的失误,那两个函数自从写出来就从来没有用过,也没有测试过{:3_190:}

另外我检查过 d(sin x)/dx 的输出,
  1. double ds( double x )
  2. {
  3.     return std::sin(x);
  4. }
复制代码
  1.       for ( double d_start = -0.1; d_start <= 0.1; d_start += 0.001 )
  2.           std::cout << derivative<0, double(double)>( ds )( d_start ) << "\n";
复制代码
是一个预期的 cos 函数:
  1. 0.995004
  2. 0.995104
  3. 0.995202
  4. 0.995299
  5. 0.995396
  6. 0.995491
  7. 0.995585
  8. 0.995679
  9. 0.995771
  10. 0.995862
  11. 0.995953
  12. 0.996042
  13. 0.99613
  14. 0.996218
  15. 0.996304
  16. 0.99639
  17. 0.996474
  18. 0.996557
  19. 0.99664
  20. 0.996721
  21. 0.996802
  22. 0.996881
  23. 0.99696
  24. 0.997037
  25. 0.997113
  26. 0.997189
  27. 0.997263
  28. 0.997337
  29. 0.997409
  30. 0.997481
  31. 0.997551
  32. 0.99762
  33. 0.997689
  34. 0.997756
  35. 0.997823
  36. 0.997888
  37. 0.997953
  38. 0.998016
  39. 0.998079
  40. 0.99814
  41. 0.998201
  42. 0.99826
  43. 0.998318
  44. 0.998376
  45. 0.998432
  46. 0.998488
  47. 0.998542
  48. 0.998596
  49. 0.998648
  50. 0.9987
  51. 0.99875
  52. 0.9988
  53. 0.998848
  54. 0.998896
  55. 0.998942
  56. 0.998988
  57. 0.999032
  58. 0.999076
  59. 0.999118
  60. 0.99916
  61. 0.9992
  62. 0.99924
  63. 0.999278
  64. 0.999316
  65. 0.999352
  66. 0.999388
  67. 0.999422
  68. 0.999456
  69. 0.999488
  70. 0.99952
  71. 0.99955
  72. 0.99958
  73. 0.999608
  74. 0.999636
  75. 0.999662
  76. 0.999688
  77. 0.999712
  78. 0.999736
  79. 0.999758
  80. 0.99978
  81. 0.9998
  82. 0.99982
  83. 0.999838
  84. 0.999856
  85. 0.999872
  86. 0.999888
  87. 0.999902
  88. 0.999916
  89. 0.999928
  90. 0.99994
  91. 0.99995
  92. 0.99996
  93. 0.999968
  94. 0.999976
  95. 0.999982
  96. 0.999987
  97. 0.999992
  98. 0.999995
  99. 0.999998
  100. 0.999999
  101. 1
  102. 0.999999
  103. 0.999998
  104. 0.999995
  105. 0.999992
  106. 0.999987
  107. 0.999982
  108. 0.999976
  109. 0.999968
  110. 0.99996
  111. 0.99995
  112. 0.99994
  113. 0.999928
  114. 0.999916
  115. 0.999902
  116. 0.999888
  117. 0.999872
  118. 0.999856
  119. 0.999838
  120. 0.99982
  121. 0.9998
  122. 0.99978
  123. 0.999758
  124. 0.999736
  125. 0.999712
  126. 0.999688
  127. 0.999662
  128. 0.999636
  129. 0.999608
  130. 0.99958
  131. 0.99955
  132. 0.99952
  133. 0.999488
  134. 0.999456
  135. 0.999422
  136. 0.999388
  137. 0.999352
  138. 0.999316
  139. 0.999278
  140. 0.99924
  141. 0.9992
  142. 0.99916
  143. 0.999118
  144. 0.999076
  145. 0.999032
  146. 0.998988
  147. 0.998942
  148. 0.998896
  149. 0.998848
  150. 0.9988
  151. 0.99875
  152. 0.9987
  153. 0.998648
  154. 0.998596
  155. 0.998542
  156. 0.998488
  157. 0.998432
  158. 0.998376
  159. 0.998318
  160. 0.99826
  161. 0.998201
  162. 0.99814
  163. 0.998079
  164. 0.998016
  165. 0.997953
  166. 0.997888
  167. 0.997823
  168. 0.997756
  169. 0.997689
  170. 0.99762
  171. 0.997551
  172. 0.997481
  173. 0.997409
  174. 0.997337
  175. 0.997263
  176. 0.997189
  177. 0.997113
  178. 0.997037
  179. 0.99696
  180. 0.996881
  181. 0.996802
  182. 0.996721
  183. 0.99664
  184. 0.996557
  185. 0.996474
  186. 0.99639
  187. 0.996304
  188. 0.996218
  189. 0.99613
  190. 0.996042
  191. 0.995953
  192. 0.995862
  193. 0.995771
  194. 0.995679
  195. 0.995585
  196. 0.995491
  197. 0.995396
  198. 0.995299
  199. 0.995202
  200. 0.995104
复制代码

论坛徽章:
1
2015年辞旧岁徽章
日期:2015-03-03 16:54:15
10 [报告]
发表于 2013-04-13 20:38 |只看该作者
回复 8# 幻の上帝


    我觉得可能是问题出在用 derivative<0, double(double)> 这样一个类型的参数来初始化另外一个 derivative<0, double(double)> 这样一个类型的时候:
  1. auto const& dds = derivative<0, double(double)>(ds);
  2. derivative<0, double(double)>(dds)( 0.0 )
复制代码
上边代码的第二句本来预期的行为是调用这个带 template 的 constructor:
  1. struct derivative
  2. {
  3.        function_type ff;

  4.         template< typename Function >
  5.         derivative( const Function& ff_ ) : ff( ff_ ) { }
  6. };
复制代码
但是在实际调用的是一个缺省的隐藏构造函数:
  1. struct derivative
  2. {
  3.     derivative( const derivative& other );
  4. };
复制代码
因此得到错误的结果

您需要登录后才可以回帖 登录 | 注册

本版积分规则 发表回复

  

北京盛拓优讯信息技术有限公司. 版权所有 京ICP备16024965号-6 北京市公安局海淀分局网监中心备案编号:11010802020122 niuxiaotong@pcpop.com 17352615567
未成年举报专区
中国互联网协会会员  联系我们:huangweiwei@itpub.net
感谢所有关心和支持过ChinaUnix的朋友们 转载本站内容请注明原作者名及出处

清除 Cookies - ChinaUnix - Archiver - WAP - TOP