免费注册 查看新帖 |

Chinaunix

  平台 论坛 博客 文库
最近访问板块 发新帖
查看: 3743 | 回复: 0
打印 上一主题 下一主题

我实现一个随机数生成器。 [复制链接]

论坛徽章:
0
跳转到指定楼层
1 [收藏(0)] [报告]
发表于 2007-04-19 08:06 |只看该作者 |倒序浏览
我做机器学习方面的研究,因此非常需要高质量的(伪)随机数生成器,下面这个是我参考"Numerical Recipes in C (NRIC)"和一些其它资料实现的一个C++ Wrapper for random number generator. 当然网上有很多类似的代码,不过我觉得这个应付常见的需求足够了,至少基本满足自己的需求。
它特点就是简单高效,但是随机数质量高,基于48位伪随机数生成引擎,实现了比较常见的binomial, cauchy, exponential, gamma(erlang), gaussian, uniform integers, uniform, poisson, 还有一个数学函数log(Gamma(x))。
因为参考了一些其它资料,我不知道该怎么发布这个东西,不过我有NRIC的正版授权,因此自己用应该是完全合法的。
使用的时候比较简单,比如需要生成200个正态分布的随机数:

  1. double x[200];
  2. Rand engine; // default to use current time as random seed
  3. for (int i = 0; i < 200; ++i)
  4.     x[i] = engine.gasdev();
复制代码


// rand.h

  1. // -*- C++ -*-
  2. #ifndef RAND_H
  3. #define RAND_H

  4. #include <cmath>
  5. #include <cstdlib>

  6. class Rand {
  7.     static const double PI;
  8.     unsigned short xsubi[3];
  9.     void set_seed(long seed = 0);

  10. public:
  11.     Rand(long seed = 0) {set_seed(seed);}
  12.     static double log_gamma(double xx); // returns log(Gamma(xx))

  13.     // supported random deviates
  14.     double bnldev(double pp, int n); // binomial
  15.     double caudev();                 // cauchy
  16.     double caudev(double location, double scale);
  17.     double expdev();                 // exponential
  18.     double expdev(double intensity);
  19.     double gamdev(int shape);        // gamma(erlang)
  20.     double gamdev(int shape, double intensity);
  21.     double gasdev();                 // gaussian
  22.     double gasdev(double mean, double std);
  23.     int    intdev(int end_point);    // uniform integeters [0, end_point)
  24.     int    intdev(int start_point, int end_point);
  25.     double poidev(double mean);      // poisson
  26.     double unidev();                 // uniform [0, 1)
  27.     double unidev(double start_point, double end_point);
  28. };

  29. inline double Rand::caudev()
  30. {
  31.     return ::tan(PI * (unidev() - 0.5));
  32. }

  33. inline double Rand::caudev(double location, double scale)
  34. {
  35.     return location + scale * caudev();
  36. }

  37. inline double Rand::expdev(double intensity)
  38. {
  39.     return expdev() / intensity;
  40. }

  41. inline double Rand::gamdev(int shape, double intensity)
  42. {
  43.     return gamdev(shape) / intensity;
  44. }

  45. inline double Rand::gasdev(double mean, double std)
  46. {
  47.     return mean + std * gasdev();
  48. }

  49. inline int Rand::intdev(int end_point)
  50. {
  51.     return static_cast<int>(::floor(unidev() * end_point));
  52. }

  53. inline int Rand::intdev(int start_point, int end_point)
  54. {
  55.     return start_point + intdev(end_point - start_point);
  56. }

  57. inline double Rand::unidev()
  58. {
  59.     return ::erand48(xsubi);
  60. }

  61. inline double Rand::unidev(double start_point, double end_point)
  62. {
  63.     return start_point + (end_point - start_point) * unidev();
  64. }

  65. #endif
复制代码


// rand.cc

  1. #include "rand.h"
  2. #include <ctime>

  3. const double Rand::PI = 3.14159265358979324;

  4. void Rand::set_seed(long seed)
  5. {
  6.     long rand_seed = seed;

  7.     if (seed == 0)
  8.         rand_seed = static_cast<long>(::time(0));
  9.     for (int i = 0; i < 3; ++i) {
  10.         xsubi[i] = rand_seed & 0xFFFF;
  11.         rand_seed >>= 16;
  12.     }
  13.     // the first few numbers are somewhat related with the seed
  14.     erand48(xsubi);
  15.     erand48(xsubi);
  16.     erand48(xsubi);
  17.     erand48(xsubi);
  18. }

  19. double Rand::log_gamma(double xx)
  20. {
  21.     double x, y, tmp, ser;
  22.     static double cof[6] = {76.18009172947146,
  23.                             -86.50532032941677,
  24.                             24.01409824083091,
  25.                             -1.231739572450155,
  26.                             0.1208650973866179e-2,
  27.                             -0.5395239384953e-5};
  28.     y = x = xx;
  29.     tmp = x + 5.5;
  30.     tmp -= (x + 0.5) * ::log(tmp);
  31.     ser = 1.000000000190015;
  32.     for (int j = 0; j <= 5; ++j)
  33.         ser += cof[j] / ++y;
  34.     return -tmp + ::log(2.5066282746310005 * ser / x);
  35. }

  36. double Rand::bnldev(double pp, int n)
  37. {
  38.     int j;
  39.     static int nold = -1;
  40.     double am, em, g, angle, p, bnl, sq, t, y;
  41.     static double pold = -1.0, pc, plog, pclog, en, oldg;

  42.     p = (pp <= 0.5 ? pp : 1.0 - pp);
  43.     am = n * p;
  44.     if (n < 25) {
  45.         bnl = 0.0;
  46.         for (j = 1; j <= n; ++j)
  47.             if (unidev() < p)
  48.                 ++bnl;
  49.     } else if (am < 1.0) {
  50.         g = ::exp(-am);
  51.         t = 1.0;
  52.         for (j = 0; j <= n; ++j) {
  53.             t *= unidev();
  54.             if (t < g)
  55.                 break;
  56.         }
  57.         bnl = (j <= n ? j : n);
  58.     } else {
  59.         if (n != nold) {
  60.             en = n;
  61.             oldg = log_gamma(en + 1.0);
  62.             nold = n;
  63.         } if (p != pold) {
  64.             pc = 1.0 - p;
  65.             plog = ::log(p);
  66.             pclog = ::log(pc);
  67.             pold = p;
  68.         }
  69.         sq = ::sqrt(2.0 * am * pc);
  70.         do {
  71.             do {
  72.                 angle = PI * unidev();
  73.                 y = ::tan(angle);
  74.                 em = sq * y + am;
  75.             } while (em < 0.0 || em >= (en + 1.0));
  76.             em = ::floor(em);
  77.             t = 1.2 * sq * (1.0 + y * y) *
  78.                 ::exp(oldg - log_gamma(em + 1.0) -
  79.                       log_gamma(en - em + 1.0) +
  80.                       em * plog + (en - em) * pclog);
  81.         } while (unidev() > t);
  82.         bnl = em;
  83.     }
  84.     if (p != pp)
  85.         bnl = n - bnl;
  86.     return bnl;
  87. }

  88. double Rand::expdev()
  89. {
  90.     double dum;

  91.     do {
  92.         dum = unidev();
  93.     } while (dum == 0.0);
  94.     return -::log(dum);
  95. }

  96. double Rand::gamdev(int shape)
  97. {
  98.     double am, e, s, v1, v2, x, y;

  99.     if (shape < 6) {
  100.         x = 1.0;
  101.         for (int j = 1; j <= shape; ++j)
  102.             x *= unidev();
  103.         x = -::log(x);
  104.     } else {
  105.         do {
  106.             do {
  107.                 do {
  108.                     v1 = unidev();
  109.                     v2 = 2.0 * unidev() - 1.0;
  110.                 } while (v1 * v1 + v2 * v2 > 1.0);
  111.                 y = v2 / v1;
  112.                 am = shape - 1;
  113.                 s = sqrt(2.0 * am + 1.0);
  114.                 x = s * y + am;
  115.             } while (x <= 0.0);
  116.             e = (1.0 + y * y) * ::exp(am * ::log(x / am) - s * y);
  117.         } while (unidev() > e);
  118.     }
  119.     return x;
  120. }

  121. double Rand::gasdev()
  122. {
  123.     static int iset = 0;
  124.     static double gset;
  125.     double fac, rsq, v1, v2;

  126.     if (iset == 0) {
  127.         do {
  128.             v1 = 2.0 * unidev() - 1.0;
  129.             v2 = 2.0 * unidev() - 1.0;
  130.             rsq = v1 * v1 + v2 * v2;
  131.         } while (rsq >= 1.0 || rsq == 0.0);
  132.         fac = sqrt(-2.0 * ::log(rsq) / rsq);
  133.         gset = v1 * fac;
  134.         iset = 1;
  135.         return v2 * fac;
  136.     }
  137.     iset = 0;
  138.     return gset;
  139. }

  140. double Rand::poidev(double mean)
  141. {
  142.     static double sq, alxm, g, oldm = -1.0;
  143.     double em, t, y;

  144.     if (mean < 12.0) {
  145.         if (mean != oldm) {
  146.             oldm = mean;
  147.             g = ::exp(-mean);
  148.         }
  149.         em = -1;
  150.         t = 1.0;
  151.         do {
  152.             ++em;
  153.             t *= unidev();
  154.         } while (t > g);
  155.     } else {
  156.         if (mean != oldm) {
  157.             oldm = mean;
  158.             sq = sqrt(2.0 * mean);
  159.             alxm = ::log(mean);
  160.             g = mean * alxm - log_gamma(mean + 1.0);
  161.         }
  162.         do {
  163.             do {
  164.                 y = ::tan(PI * unidev());
  165.                 em = sq * y + mean;
  166.             } while (em < 0.0);
  167.             em = ::floor(em);
  168.             t = 0.9 * (1.0 + y * y) *
  169.                 ::exp(em * alxm - log_gamma(em + 1.0) - g);
  170.         } while (unidev() > t);
  171.     }
  172.     return em;
  173. }
复制代码
您需要登录后才可以回帖 登录 | 注册

本版积分规则 发表回复

  

北京盛拓优讯信息技术有限公司. 版权所有 京ICP备16024965号-6 北京市公安局海淀分局网监中心备案编号:11010802020122 niuxiaotong@pcpop.com 17352615567
未成年举报专区
中国互联网协会会员  联系我们:huangweiwei@itpub.net
感谢所有关心和支持过ChinaUnix的朋友们 转载本站内容请注明原作者名及出处

清除 Cookies - ChinaUnix - Archiver - WAP - TOP