]> Dogcows Code - chaz/yoink/blob - src/cml/matrix/matrix_mul.h
fixes for newer versions of g++
[chaz/yoink] / src / cml / matrix / matrix_mul.h
1 /* -*- C++ -*- ------------------------------------------------------------
2
3 Copyright (c) 2007 Jesse Anders and Demian Nave http://cmldev.net/
4
5 The Configurable Math Library (CML) is distributed under the terms of the
6 Boost Software License, v1.0 (see cml/LICENSE for details).
7
8 *-----------------------------------------------------------------------*/
9 /** @file
10 * @brief Multiply two matrices.
11 *
12 * @todo Does it make sense to put mat-mat multiplication as a node into the
13 * expression tree?
14 *
15 * @internal This does not need to return an expression type, since the
16 * temporary generation for the matrix result is handled automatically by the
17 * compiler. i.e. when used in an expression, the result is automatically
18 * included in the expression tree as a temporary by the compiler.
19 */
20
21 #ifndef matrix_mul_h
22 #define matrix_mul_h
23
24 #include <cml/et/size_checking.h>
25 #include <cml/matrix/matrix_expr.h>
26
27 /* This is used below to create a more meaningful compile-time error when
28 * mul is not provided with matrix or MatrixExpr arguments:
29 */
30 struct mul_expects_matrix_args_error;
31
32 /* This is used below to create a more meaningful compile-time error when
33 * fixed-size arguments to mul() have the wrong size:
34 */
35 struct mul_expressions_have_wrong_size_error;
36
37 namespace cml {
38 namespace detail {
39
40 /** Verify the sizes of the argument matrices for matrix multiplication.
41 *
42 * @returns a matrix_size containing the size of the resulting matrix.
43 */
44 template<typename LeftT, typename RightT> inline matrix_size
45 MatMulCheckedSize(const LeftT&, const RightT&, fixed_size_tag)
46 {
47 CML_STATIC_REQUIRE_M(
48 ((size_t)LeftT::array_cols == (size_t)RightT::array_rows),
49 mul_expressions_have_wrong_size_error);
50 return matrix_size(LeftT::array_rows,RightT::array_cols);
51 }
52
53 /** Verify the sizes of the argument matrices for matrix multiplication.
54 *
55 * @returns a matrix_size containing the size of the resulting matrix.
56 */
57 template<typename LeftT, typename RightT> inline matrix_size
58 MatMulCheckedSize(const LeftT& left, const RightT& right, dynamic_size_tag)
59 {
60 matrix_size left_N = left.size(), right_N = right.size();
61 et::GetCheckedSize<LeftT,RightT,dynamic_size_tag>()
62 .equal_or_fail(left_N.second, right_N.first); /* cols,rows */
63 return matrix_size(left_N.first, right_N.second); /* rows,cols */
64 }
65
66
67 /** Matrix multiplication.
68 *
69 * Computes C = A x B (O(N^3), non-blocked algorithm).
70 */
71 template<class LeftT, class RightT>
72 inline typename et::MatrixPromote<
73 typename et::ExprTraits<LeftT>::result_type,
74 typename et::ExprTraits<RightT>::result_type
75 >::temporary_type
76 mul(const LeftT& left, const RightT& right)
77 {
78 /* Shorthand: */
79 typedef et::ExprTraits<LeftT> left_traits;
80 typedef et::ExprTraits<RightT> right_traits;
81 typedef typename left_traits::result_type left_result;
82 typedef typename right_traits::result_type right_result;
83
84 /* First, require matrix expressions: */
85 CML_STATIC_REQUIRE_M(
86 (et::MatrixExpressions<LeftT,RightT>::is_true),
87 mul_expects_matrix_args_error);
88 /* Note: parens are required here so that the preprocessor ignores the
89 * commas.
90 */
91
92 /* Deduce size type to ensure that a run-time check is performed if
93 * necessary:
94 */
95 typedef typename et::MatrixPromote<
96 typename left_traits::result_type,
97 typename right_traits::result_type
98 >::type result_type;
99 typedef typename result_type::size_tag size_tag;
100
101 /* Require that left has the same number of columns as right has rows.
102 * This automatically checks fixed-size matrices at compile time, and
103 * throws at run-time if the sizes don't match:
104 */
105 matrix_size N = detail::MatMulCheckedSize(left, right, size_tag());
106
107 /* Create an array with the right size (resize() is a no-op for
108 * fixed-size matrices):
109 */
110 result_type C;
111 cml::et::detail::Resize(C, N);
112
113 /* XXX Specialize this for fixed-size matrices: */
114 typedef typename result_type::value_type value_type;
115 for(size_t i = 0; i < left.rows(); ++i) { /* rows */
116 for(size_t j = 0; j < right.cols(); ++j) { /* cols */
117 value_type sum(left(i,0)*right(0,j));
118 for(size_t k = 1; k < right.rows(); ++k) {
119 sum += (left(i,k)*right(k,j));
120 }
121 C(i,j) = sum;
122 }
123 }
124
125 return C;
126 }
127
128 } // namespace detail
129
130
131 /** operator*() for two matrices. */
132 template<typename E1, class AT1, typename L1,
133 typename E2, class AT2, typename L2,
134 typename BO>
135 inline typename et::MatrixPromote<
136 matrix<E1,AT1,BO,L1>, matrix<E2,AT2,BO,L2>
137 >::temporary_type
138 operator*(const matrix<E1,AT1,BO,L1>& left,
139 const matrix<E2,AT2,BO,L2>& right)
140 {
141 return detail::mul(left,right);
142 }
143
144 /** operator*() for a matrix and a MatrixXpr. */
145 template<typename E, class AT, typename BO, typename L, typename XprT>
146 inline typename et::MatrixPromote<
147 matrix<E,AT,BO,L>, typename XprT::result_type
148 >::temporary_type
149 operator*(const matrix<E,AT,BO,L>& left,
150 const et::MatrixXpr<XprT>& right)
151 {
152 /* Generate a temporary, and compute the right-hand expression: */
153 typedef typename et::MatrixXpr<XprT>::temporary_type expr_tmp;
154 expr_tmp tmp;
155 cml::et::detail::Resize(tmp,right.rows(),right.cols());
156 tmp = right;
157
158 return detail::mul(left,tmp);
159 }
160
161 /** operator*() for a MatrixXpr and a matrix. */
162 template<typename XprT, typename E, class AT, typename BO, typename L>
163 inline typename et::MatrixPromote<
164 typename XprT::result_type , matrix<E,AT,BO,L>
165 >::temporary_type
166 operator*(const et::MatrixXpr<XprT>& left,
167 const matrix<E,AT,BO,L>& right)
168 {
169 /* Generate a temporary, and compute the left-hand expression: */
170 typedef typename et::MatrixXpr<XprT>::temporary_type expr_tmp;
171 expr_tmp tmp;
172 cml::et::detail::Resize(tmp,left.rows(),left.cols());
173 tmp = left;
174
175 return detail::mul(tmp,right);
176 }
177
178 /** operator*() for two MatrixXpr's. */
179 template<typename XprT1, typename XprT2>
180 inline typename et::MatrixPromote<
181 typename XprT1::result_type, typename XprT2::result_type
182 >::temporary_type
183 operator*(const et::MatrixXpr<XprT1>& left,
184 const et::MatrixXpr<XprT2>& right)
185 {
186 /* Generate temporaries and compute expressions: */
187 typedef typename et::MatrixXpr<XprT1>::temporary_type left_tmp;
188 left_tmp ltmp;
189 cml::et::detail::Resize(ltmp,left.rows(),left.cols());
190 ltmp = left;
191
192 typedef typename et::MatrixXpr<XprT2>::temporary_type right_tmp;
193 right_tmp rtmp;
194 cml::et::detail::Resize(rtmp,right.rows(),right.cols());
195 rtmp = right;
196
197 return detail::mul(ltmp,rtmp);
198 }
199
200 } // namespace cml
201
202 #endif
203
204 // -------------------------------------------------------------------------
205 // vim:ft=cpp
This page took 0.040023 seconds and 4 git commands to generate.