]> Dogcows Code - chaz/yoink/blob - src/Moof/cml/matvec/matvec_mul.h
extreme refactoring
[chaz/yoink] / src / Moof / cml / matvec / matvec_mul.h
1 /* -*- C++ -*- ------------------------------------------------------------
2
3 Copyright (c) 2007 Jesse Anders and Demian Nave http://cmldev.net/
4
5 The Configurable Math Library (CML) is distributed under the terms of the
6 Boost Software License, v1.0 (see cml/LICENSE for details).
7
8 *-----------------------------------------------------------------------*/
9 /** @file
10 * @brief Multiply a matrix and a vector.
11 *
12 * @todo Implement smarter temporary generation.
13 *
14 * @todo Does it make sense to put mat-vec multiplication as a node into the
15 * expression tree?
16 *
17 * @internal This does not need to return an expression type, since the
18 * temporary generation for the matrix result is handled automatically by the
19 * compiler. i.e. when used in an expression, the result is automatically
20 * included in the expression tree as a temporary by the compiler.
21 */
22
23 #ifndef matvec_mul_h
24 #define matvec_mul_h
25
26 #include <cml/core/cml_meta.h>
27 #include <cml/vector/vector_expr.h>
28 #include <cml/matrix/matrix_expr.h>
29 #include <cml/matvec/matvec_promotions.h>
30
31 /* This is used below to create a more meaningful compile-time error when
32 * mat-vec mul is not provided with the right arguments:
33 */
34 struct mvmul_expects_one_matrix_and_one_vector_arg_error;
35 struct mvmul_expects_one_vector_and_one_matrix_arg_error;
36
37 namespace cml {
38 namespace detail {
39
40 /* For choosing the proper multiplication order: */
41 typedef true_type mul_Ax;
42 typedef false_type mul_xA;
43
44 /** Compute y = A*x. */
45 template<typename LeftT, typename RightT> inline
46 typename et::MatVecPromote<
47 typename et::ExprTraits<LeftT>::result_type,
48 typename et::ExprTraits<RightT>::result_type
49 >::temporary_type
50 mul(const LeftT& A, const RightT& x, mul_Ax)
51 {
52 /* Shorthand: */
53 typedef et::ExprTraits<LeftT> left_traits;
54 typedef et::ExprTraits<RightT> right_traits;
55 typedef typename left_traits::result_tag left_result;
56 typedef typename right_traits::result_tag right_result;
57
58 /* mul()[A*x] requires a matrix and a vector expression: */
59 CML_STATIC_REQUIRE_M(
60 (same_type<left_result, et::matrix_result_tag>::is_true
61 && same_type<right_result, et::vector_result_tag>::is_true),
62 mvmul_expects_one_matrix_and_one_vector_arg_error);
63 /* Note: parens are required here so that the preprocessor ignores the
64 * commas.
65 */
66
67 /* Get result type: */
68 typedef typename et::MatVecPromote<
69 typename left_traits::result_type,
70 typename right_traits::result_type
71 >::temporary_type result_type;
72
73 /* Record size type: */
74 typedef typename result_type::size_tag size_tag;
75
76 /* Check the size: */
77 size_t N = et::CheckedSize(A, x, size_tag());
78
79 /* Initialize the new vector: */
80 result_type y; cml::et::detail::Resize(y, N);
81
82 /* Compute y = A*x: */
83 typedef typename result_type::value_type sum_type;
84 for(size_t i = 0; i < N; ++i) {
85 /* XXX This should be unrolled. */
86 sum_type sum(A(i,0)*x[0]);
87 for(size_t k = 1; k < x.size(); ++k) {
88 sum += (A(i,k)*x[k]);
89 }
90 y[i] = sum;
91 }
92
93 return y;
94 }
95
96 /** Compute y = x*A. */
97 template<typename LeftT, typename RightT> inline
98 typename et::MatVecPromote<
99 typename et::ExprTraits<LeftT>::result_type,
100 typename et::ExprTraits<RightT>::result_type
101 >::temporary_type
102 mul(const LeftT& x, const RightT& A, mul_xA)
103 {
104 /* Shorthand: */
105 typedef et::ExprTraits<LeftT> left_traits;
106 typedef et::ExprTraits<RightT> right_traits;
107 typedef typename left_traits::result_tag left_result;
108 typedef typename right_traits::result_tag right_result;
109
110 /* mul()[x*A] requires a vector and a matrix expression: */
111 CML_STATIC_REQUIRE_M(
112 (same_type<left_result, et::vector_result_tag>::is_true
113 && same_type<right_result, et::matrix_result_tag>::is_true),
114 mvmul_expects_one_vector_and_one_matrix_arg_error);
115 /* Note: parens are required here so that the preprocessor ignores the
116 * commas.
117 */
118
119 /* Get result type: */
120 typedef typename et::MatVecPromote<
121 typename left_traits::result_type,
122 typename right_traits::result_type
123 >::temporary_type result_type;
124
125 /* Record size type: */
126 typedef typename result_type::size_tag size_tag;
127
128 /* Check the size: */
129 size_t N = et::CheckedSize(x, A, size_tag());
130
131 /* Initialize the new vector: */
132 result_type y; cml::et::detail::Resize(y, N);
133
134 /* Compute y = x*A: */
135 typedef typename result_type::value_type sum_type;
136 for(size_t i = 0; i < N; ++i) {
137 /* XXX This should be unrolled. */
138 sum_type sum(x[0]*A(0,i));
139 for(size_t k = 1; k < x.size(); ++k) {
140 sum += (x[k]*A(k,i));
141 }
142 y[i] = sum;
143 }
144
145 return y;
146 }
147
148 } // namespace detail
149
150
151 /** operator*() for a matrix and a vector. */
152 template<typename E1, class AT1, typename BO, class L,
153 typename E2, class AT2>
154 inline typename et::MatVecPromote<
155 matrix<E1,AT1,BO,L>, vector<E2,AT2>
156 >::temporary_type
157 operator*(const matrix<E1,AT1,BO,L>& left,
158 const vector<E2,AT2>& right)
159 {
160 return detail::mul(left,right,detail::mul_Ax());
161 }
162
163 /** operator*() for a matrix and a VectorXpr. */
164 template<typename E, class AT, class L, typename BO, typename XprT>
165 inline typename et::MatVecPromote<
166 matrix<E,AT,BO,L>, typename XprT::result_type
167 >::temporary_type
168 operator*(const matrix<E,AT,BO,L>& left,
169 const et::VectorXpr<XprT>& right)
170 {
171 /* Generate a temporary, and compute the right-hand expression: */
172 typename et::VectorXpr<XprT>::temporary_type right_tmp;
173 cml::et::detail::Resize(right_tmp,right.size());
174 right_tmp = right;
175
176 return detail::mul(left,right_tmp,detail::mul_Ax());
177 }
178
179 /** operator*() for a MatrixXpr and a vector. */
180 template<typename XprT, typename E, class AT>
181 inline typename et::MatVecPromote<
182 typename XprT::result_type, vector<E,AT>
183 >::temporary_type
184 operator*(const et::MatrixXpr<XprT>& left,
185 const vector<E,AT>& right)
186 {
187 /* Generate a temporary, and compute the left-hand expression: */
188 typename et::MatrixXpr<XprT>::temporary_type left_tmp;
189 cml::et::detail::Resize(left_tmp,left.rows(),left.cols());
190 left_tmp = left;
191
192 return detail::mul(left_tmp,right,detail::mul_Ax());
193 }
194
195 /** operator*() for a MatrixXpr and a VectorXpr. */
196 template<typename XprT1, typename XprT2>
197 inline typename et::MatVecPromote<
198 typename XprT1::result_type, typename XprT2::result_type
199 >::temporary_type
200 operator*(const et::MatrixXpr<XprT1>& left,
201 const et::VectorXpr<XprT2>& right)
202 {
203 /* Generate a temporary, and compute the left-hand expression: */
204 typename et::MatrixXpr<XprT1>::temporary_type left_tmp;
205 cml::et::detail::Resize(left_tmp,left.rows(),left.cols());
206 left_tmp = left;
207
208 /* Generate a temporary, and compute the right-hand expression: */
209 typename et::VectorXpr<XprT2>::temporary_type right_tmp;
210 cml::et::detail::Resize(right_tmp,right.size());
211 right_tmp = right;
212
213 return detail::mul(left_tmp,right_tmp,detail::mul_Ax());
214 }
215
216 /** operator*() for a vector and a matrix. */
217 template<typename E1, class AT1, typename E2, class AT2, typename BO, class L>
218 inline typename et::MatVecPromote<
219 vector<E1,AT1>, matrix<E2,AT2,BO,L>
220 >::temporary_type
221 operator*(const vector<E1,AT1>& left,
222 const matrix<E2,AT2,BO,L>& right)
223 {
224 return detail::mul(left,right,detail::mul_xA());
225 }
226
227 /** operator*() for a vector and a MatrixXpr. */
228 template<typename XprT, typename E, class AT>
229 inline typename et::MatVecPromote<
230 typename XprT::result_type, vector<E,AT>
231 >::temporary_type
232 operator*(const vector<E,AT>& left,
233 const et::MatrixXpr<XprT>& right)
234 {
235 /* Generate a temporary, and compute the right-hand expression: */
236 typename et::MatrixXpr<XprT>::temporary_type right_tmp;
237 cml::et::detail::Resize(right_tmp,right.rows(),right.cols());
238 right_tmp = right;
239
240 return detail::mul(left,right_tmp,detail::mul_xA());
241 }
242
243 /** operator*() for a VectorXpr and a matrix. */
244 template<typename XprT, typename E, class AT, typename BO, class L>
245 inline typename et::MatVecPromote<
246 typename XprT::result_type, matrix<E,AT,BO,L>
247 >::temporary_type
248 operator*(const et::VectorXpr<XprT>& left,
249 const matrix<E,AT,BO,L>& right)
250 {
251 /* Generate a temporary, and compute the left-hand expression: */
252 typename et::VectorXpr<XprT>::temporary_type left_tmp;
253 cml::et::detail::Resize(left_tmp,left.size());
254 left_tmp = left;
255
256 return detail::mul(left_tmp,right,detail::mul_xA());
257 }
258
259 /** operator*() for a VectorXpr and a MatrixXpr. */
260 template<typename XprT1, typename XprT2>
261 inline typename et::MatVecPromote<
262 typename XprT1::result_type, typename XprT2::result_type
263 >::temporary_type
264 operator*(const et::VectorXpr<XprT1>& left,
265 const et::MatrixXpr<XprT2>& right)
266 {
267 /* Generate a temporary, and compute the left-hand expression: */
268 typename et::VectorXpr<XprT1>::temporary_type left_tmp;
269 cml::et::detail::Resize(left_tmp,left.size());
270 left_tmp = left;
271
272 /* Generate a temporary, and compute the right-hand expression: */
273 typename et::MatrixXpr<XprT2>::temporary_type right_tmp;
274 cml::et::detail::Resize(right_tmp,right.rows(),right.cols());
275 right_tmp = right;
276
277 return detail::mul(left_tmp,right_tmp,detail::mul_xA());
278 }
279
280 } // namespace cml
281
282 #endif
283
284 // -------------------------------------------------------------------------
285 // vim:ft=cpp
This page took 0.043415 seconds and 4 git commands to generate.