免费注册 查看新帖 |

Chinaunix

  平台 论坛 博客 文库
最近访问板块 发新帖
查看: 2636 | 回复: 2
打印 上一主题 下一主题

矩阵乘法的 template class 该怎么声明? [复制链接]

论坛徽章:
1
2015年迎新春徽章
日期:2015-03-04 09:49:45
跳转到指定楼层
1 [收藏(0)] [报告]
发表于 2007-10-18 12:43 |只看该作者 |倒序浏览
我的矩阵是这样定义的:

  1. template <class Numeric, int rows, int cols>
  2. class Matrix {
  3.   ...
  4. };
复制代码

如果要重载一个 * 运算符, 这个重载的运算符的声明该怎么写? 这样写对吗?

  1. template <class Numeric, int rows, int cols>
  2. class Matrix {
  3. public:
  4.   template <int cols2>
  5.   Matrix<Numeric, rows, cols2>& operator*(Matrix<Numeric, cols, cols2>&);
  6. };
复制代码

[ 本帖最后由 koolcoy 于 2007-10-18 12:45 编辑 ]

论坛徽章:
1
2015年迎新春徽章
日期:2015-03-04 09:49:45
2 [报告]
发表于 2007-10-18 13:31 |只看该作者
我是这么写的, 但是有几个问题:
matrix.h:

  1. #ifndef MATRIX_H__
  2. #define MATRIX_H__

  3. #include <iostream>

  4. template <class Numeric, int rows, int cols>
  5. class Matrix {
  6. public:
  7.         Numeric elements[rows][cols];
  8.         Matrix(Numeric arr[rows][cols]) {
  9.                 for (int i = 0; i < rows; ++i) {
  10.                         for (int j = 0; j < cols; ++j) {
  11.                                 elements[i][j] = arr[i][j];
  12.                         }
  13.                 }
  14.         }

  15.         template <int cols2>
  16.         const Matrix<Numeric, rows, cols2>&
  17.         operator*(Matrix<Numeric, cols, cols2>& m) {
  18.                 Numeric arr[rows][cols2];
  19.                 for (int i = 0; i < rows; ++i) {
  20.                         for (int j = 0; j < cols2; ++j) {
  21.                                 Numeric sum = 0;
  22.                                 for (int k = 0; k < cols; ++k) {
  23.                                         sum += elements[i][k] * m.elements[k][j];
  24.                                 }
  25.                                 arr[i][j] = sum;
  26.                         }
  27.                 }
  28.                 return Matrix<Numeric, rows, cols2>(arr);
  29.         }

  30.         void dump() {
  31.                 for (int i = 0; i < rows; ++i) {
  32.                         for (int j = 0; j < cols; ++j) {
  33.                                 std::cout << elements[i][j] << ' ';
  34.                         }
  35.                         std::cout << std::endl;
  36.                 }
  37.         }
  38. };

  39. #endif
复制代码

test.cpp

  1. #include "matrix.h"
  2. #include <iostream>

  3. using namespace std;

  4. int main() {
  5.         int arr[2][3] = {{1,2,3},{4,5,6}};
  6.         int arr2[3][2] = {{1,2},{3,4},{5,6}};
  7.         Matrix<int, 2, 3> m1(arr);
  8.         Matrix<int, 3, 2> m2(arr2);
  9.         m1.dump();
  10.         m2.dump();
  11.         Matrix<int, 2, 2> m3 = m1 * m2;
  12.         m3.dump();
  13. }
复制代码


1. 如果我把elements这个成员声明成private的话, 代码编译不过去。
2. return Matrix<Numeric, rows, cols2>(arr); 这一行会报warning:
  matrix.h:31: warning: returning reference to temporary
为什么? 怎么修改?

论坛徽章:
0
3 [报告]
发表于 2008-03-31 12:09 |只看该作者
#ifndef MATRIX_H__
#define MATRIX_H__

#include <iostream>

template <class Numeric, int rows, int cols>
class Matrix {
public:
        Numeric elements[rows][cols];
        Matrix(Numeric arr[rows][cols]) {
                for (int i = 0; i < rows; ++i) {
                        for (int j = 0; j < cols; ++j) {
                                elements[j] = arr[j];
                        }
                }
        }

        template <int cols2>
        const Matrix<Numeric, rows, cols2>&
        operator*(Matrix<Numeric, cols, cols2>& m) {
                Numeric arr[rows][cols2];
                for (int i = 0; i < rows; ++i) {
                        for (int j = 0; j < cols2; ++j) {
                                Numeric sum = 0;
                                for (int k = 0; k < cols; ++k) {
                                        sum += elements[k] * m.elements[k][j];
                                }
                                arr[j] = sum;
                        }
                }
                return Matrix<Numeric, rows, cols2>(arr);
        }

        void dump() {
                for (int i = 0; i < rows; ++i) {
                        for (int j = 0; j < cols; ++j) {
                                std::cout << elements[j] << ' ';
                        }
                        std::cout << std::endl;
                }
        }
};

#endif
上面这个算是声明还是定义呀 ??
您需要登录后才可以回帖 登录 | 注册

本版积分规则 发表回复

  

北京盛拓优讯信息技术有限公司. 版权所有 京ICP备16024965号-6 北京市公安局海淀分局网监中心备案编号:11010802020122 niuxiaotong@pcpop.com 17352615567
未成年举报专区
中国互联网协会会员  联系我们:huangweiwei@itpub.net
感谢所有关心和支持过ChinaUnix的朋友们 转载本站内容请注明原作者名及出处

清除 Cookies - ChinaUnix - Archiver - WAP - TOP