忘记密码   免费注册 查看新帖 |

ChinaUnix.net

  平台 论坛 博客 认证专区 大话IT 徽章 文库 自测 下载 频道自动化运维 虚拟化 储存备份 C/C++ PHP MySQL 嵌入式 Linux系统
最近访问板块 发新帖
查看: 1044 | 回复: 2

[C++] 矩阵乘 [复制链接]

论坛徽章:
23
狮子座
日期:2013-12-31 10:48:0015-16赛季CBA联赛之吉林
日期:2016-04-18 14:43:1015-16赛季CBA联赛之北控
日期:2016-05-18 15:01:4415-16赛季CBA联赛之上海
日期:2016-06-22 18:00:1315-16赛季CBA联赛之八一
日期:2016-06-25 11:02:2215-16赛季CBA联赛之佛山
日期:2016-08-17 22:48:2615-16赛季CBA联赛之福建
日期:2016-12-27 22:39:272016科比退役纪念章
日期:2017-02-08 23:49:4315-16赛季CBA联赛之八一
日期:2017-02-16 01:05:3415-16赛季CBA联赛之山东
日期:2017-02-22 15:34:5615-16赛季CBA联赛之四川
日期:2016-01-17 18:38:3715-16赛季CBA联赛之广夏
日期:2016-01-05 20:02:21
发表于 2018-02-01 00:32 |显示全部楼层
矩阵乘      

  1. #include <iostream>
  2. #include <glog/logging.h>

  3. #include "ATen/ATen.h"

  4. #include "torch/csrc/autograd/variable.h"
  5. #include "torch/csrc/assertions.h"
  6. #include "torch/csrc/autograd/generated/VariableType.h"
  7. #include "torch/csrc/autograd/generated/Functions.h"
  8. #include "torch/csrc/autograd/functions/accumulate_grad.h"
  9. #include "torch/csrc/autograd/functions/tensor.h"

  10. using namespace at;

  11. using namespace torch::autograd;

  12. //void test_tensor(Type & type);

  13. int main(int argc, char** argv)
  14. {
  15.     //Initialize Google's logging library.
  16.     //google::InitGoogleLogging(argv[0]);
  17.     //FLAGS_log_dir = "./log";

  18.     google::InitGoogleLogging(argv[0]);
  19.     google::LogToStderr();

  20.     LOG(INFO) << argv[0];

  21.     //test_tensor(CPU(kFloat));
  22.     //test_tensor(CPU(kDouble));

  23.     at::Type  &mv_type = at::CPU(at::kDouble);

  24.     //at::Tensor   ta = mv_type.rand({3,5});
  25.     //at::Tensor   tb = mv_type.rand({3,5});

  26.     double data01[] = { 1.0, 2.0, 3.0,
  27.                         4.0, 5.0, 6.3};

  28.     double data02[] = { 2.0, 3.0, 5.0,
  29.                         4.0, 2.5, 1.2};

  30.     at::Tensor  ta = mv_type.tensorFromBlob(data01, {2,3});
  31.     at::Tensor  tb = mv_type.tensorFromBlob(data02, {3,2});

  32.     at::Tensor  tc = mv_type.zeros({2,3});

  33.     std::cout << std::endl;

  34.     std::cout << "Tensor ta :  \n" << ta << "\n" << std::endl;
  35.     std::cout << "Tensor tb :  \n" << tb << "\n" << std::endl;

  36.     Variable  va(ta);
  37.     Variable  vb(tb);
  38.     //Variable  vc(tc);

  39.     //vc = va * vb;
  40.     //vc = va.mul(vb);
  41.     //vc = va.div(vb);
  42.     auto  vc = va.matmul(vb);

  43.     //at::Type  &va_type = at::CPU(at::kDouble);
  44.    
  45.     std::cout << "Variable vc :  \n" << vc << "\n" << std::endl;


  46.     return 0;
  47. }
复制代码





论坛徽章:
0
发表于 2018-02-01 16:16 |显示全部楼层
毕竟不是学这个的,果然看不懂,有助解就好了

论坛徽章:
5
数据库技术版块每日发帖之星
日期:2015-11-27 06:20:00程序设计版块每日发帖之星
日期:2015-12-01 06:20:00每日论坛发贴之星
日期:2015-12-01 06:20:0015-16赛季CBA联赛之佛山
日期:2017-03-26 23:38:0315-16赛季CBA联赛之江苏
日期:2017-07-17 10:08:44
发表于 2018-02-01 17:41 |显示全部楼层
十分 的 佩服
您需要登录后才可以回帖 登录 | 注册

本版积分规则

  

北京盛拓优讯信息技术有限公司. 版权所有 京ICP备16024965号 北京市公安局海淀分局网监中心备案编号:11010802020122
广播电视节目制作经营许可证(京) 字第1234号 中国互联网协会会员  联系我们:
感谢所有关心和支持过ChinaUnix的朋友们 转载本站内容请注明原作者名及出处

清除 Cookies - ChinaUnix - Archiver - WAP - TOP