X-Git-Url: https://git.dogcows.com/gitweb?a=blobdiff_plain;ds=sidebyside;f=src%2Fcml%2Fmatvec%2Fmatvec_mul.h;fp=src%2Fcml%2Fmatvec%2Fmatvec_mul.h;h=5f5ea92e0d4c18ea7c1bf6cac25ea8c6783912cb;hb=6b0a0d0efafe34d48ab344fca3b479553bd4e62c;hp=0000000000000000000000000000000000000000;hpb=85783316365181491a3e3c0c63659972477cebba;p=chaz%2Fyoink diff --git a/src/cml/matvec/matvec_mul.h b/src/cml/matvec/matvec_mul.h new file mode 100644 index 0000000..5f5ea92 --- /dev/null +++ b/src/cml/matvec/matvec_mul.h @@ -0,0 +1,285 @@ +/* -*- C++ -*- ------------------------------------------------------------ + +Copyright (c) 2007 Jesse Anders and Demian Nave http://cmldev.net/ + +The Configurable Math Library (CML) is distributed under the terms of the +Boost Software License, v1.0 (see cml/LICENSE for details). + + *-----------------------------------------------------------------------*/ +/** @file + * @brief Multiply a matrix and a vector. + * + * @todo Implement smarter temporary generation. + * + * @todo Does it make sense to put mat-vec multiplication as a node into the + * expression tree? + * + * @internal This does not need to return an expression type, since the + * temporary generation for the matrix result is handled automatically by the + * compiler. i.e. when used in an expression, the result is automatically + * included in the expression tree as a temporary by the compiler. + */ + +#ifndef matvec_mul_h +#define matvec_mul_h + +#include +#include +#include +#include + +/* This is used below to create a more meaningful compile-time error when + * mat-vec mul is not provided with the right arguments: + */ +struct mvmul_expects_one_matrix_and_one_vector_arg_error; +struct mvmul_expects_one_vector_and_one_matrix_arg_error; + +namespace cml { +namespace detail { + +/* For choosing the proper multiplication order: */ +typedef true_type mul_Ax; +typedef false_type mul_xA; + +/** Compute y = A*x. */ +template inline +typename et::MatVecPromote< + typename et::ExprTraits::result_type, + typename et::ExprTraits::result_type +>::temporary_type +mul(const LeftT& A, const RightT& x, mul_Ax) +{ + /* Shorthand: */ + typedef et::ExprTraits left_traits; + typedef et::ExprTraits right_traits; + typedef typename left_traits::result_tag left_result; + typedef typename right_traits::result_tag right_result; + + /* mul()[A*x] requires a matrix and a vector expression: */ + CML_STATIC_REQUIRE_M( + (same_type::is_true + && same_type::is_true), + mvmul_expects_one_matrix_and_one_vector_arg_error); + /* Note: parens are required here so that the preprocessor ignores the + * commas. + */ + + /* Get result type: */ + typedef typename et::MatVecPromote< + typename left_traits::result_type, + typename right_traits::result_type + >::temporary_type result_type; + + /* Record size type: */ + typedef typename result_type::size_tag size_tag; + + /* Check the size: */ + size_t N = et::CheckedSize(A, x, size_tag()); + + /* Initialize the new vector: */ + result_type y; cml::et::detail::Resize(y, N); + + /* Compute y = A*x: */ + typedef typename result_type::value_type sum_type; + for(size_t i = 0; i < N; ++i) { + /* XXX This should be unrolled. */ + sum_type sum(A(i,0)*x[0]); + for(size_t k = 1; k < x.size(); ++k) { + sum += (A(i,k)*x[k]); + } + y[i] = sum; + } + + return y; +} + +/** Compute y = x*A. */ +template inline +typename et::MatVecPromote< + typename et::ExprTraits::result_type, + typename et::ExprTraits::result_type +>::temporary_type +mul(const LeftT& x, const RightT& A, mul_xA) +{ + /* Shorthand: */ + typedef et::ExprTraits left_traits; + typedef et::ExprTraits right_traits; + typedef typename left_traits::result_tag left_result; + typedef typename right_traits::result_tag right_result; + + /* mul()[x*A] requires a vector and a matrix expression: */ + CML_STATIC_REQUIRE_M( + (same_type::is_true + && same_type::is_true), + mvmul_expects_one_vector_and_one_matrix_arg_error); + /* Note: parens are required here so that the preprocessor ignores the + * commas. + */ + + /* Get result type: */ + typedef typename et::MatVecPromote< + typename left_traits::result_type, + typename right_traits::result_type + >::temporary_type result_type; + + /* Record size type: */ + typedef typename result_type::size_tag size_tag; + + /* Check the size: */ + size_t N = et::CheckedSize(x, A, size_tag()); + + /* Initialize the new vector: */ + result_type y; cml::et::detail::Resize(y, N); + + /* Compute y = x*A: */ + typedef typename result_type::value_type sum_type; + for(size_t i = 0; i < N; ++i) { + /* XXX This should be unrolled. */ + sum_type sum(x[0]*A(0,i)); + for(size_t k = 1; k < x.size(); ++k) { + sum += (x[k]*A(k,i)); + } + y[i] = sum; + } + + return y; +} + +} // namespace detail + + +/** operator*() for a matrix and a vector. */ +template +inline typename et::MatVecPromote< + matrix, vector +>::temporary_type +operator*(const matrix& left, + const vector& right) +{ + return detail::mul(left,right,detail::mul_Ax()); +} + +/** operator*() for a matrix and a VectorXpr. */ +template +inline typename et::MatVecPromote< + matrix, typename XprT::result_type +>::temporary_type +operator*(const matrix& left, + const et::VectorXpr& right) +{ + /* Generate a temporary, and compute the right-hand expression: */ + typename et::VectorXpr::temporary_type right_tmp; + cml::et::detail::Resize(right_tmp,right.size()); + right_tmp = right; + + return detail::mul(left,right_tmp,detail::mul_Ax()); +} + +/** operator*() for a MatrixXpr and a vector. */ +template +inline typename et::MatVecPromote< + typename XprT::result_type, vector +>::temporary_type +operator*(const et::MatrixXpr& left, + const vector& right) +{ + /* Generate a temporary, and compute the left-hand expression: */ + typename et::MatrixXpr::temporary_type left_tmp; + cml::et::detail::Resize(left_tmp,left.rows(),left.cols()); + left_tmp = left; + + return detail::mul(left_tmp,right,detail::mul_Ax()); +} + +/** operator*() for a MatrixXpr and a VectorXpr. */ +template +inline typename et::MatVecPromote< + typename XprT1::result_type, typename XprT2::result_type +>::temporary_type +operator*(const et::MatrixXpr& left, + const et::VectorXpr& right) +{ + /* Generate a temporary, and compute the left-hand expression: */ + typename et::MatrixXpr::temporary_type left_tmp; + cml::et::detail::Resize(left_tmp,left.rows(),left.cols()); + left_tmp = left; + + /* Generate a temporary, and compute the right-hand expression: */ + typename et::VectorXpr::temporary_type right_tmp; + cml::et::detail::Resize(right_tmp,right.size()); + right_tmp = right; + + return detail::mul(left_tmp,right_tmp,detail::mul_Ax()); +} + +/** operator*() for a vector and a matrix. */ +template +inline typename et::MatVecPromote< + vector, matrix +>::temporary_type +operator*(const vector& left, + const matrix& right) +{ + return detail::mul(left,right,detail::mul_xA()); +} + +/** operator*() for a vector and a MatrixXpr. */ +template +inline typename et::MatVecPromote< + typename XprT::result_type, vector +>::temporary_type +operator*(const vector& left, + const et::MatrixXpr& right) +{ + /* Generate a temporary, and compute the right-hand expression: */ + typename et::MatrixXpr::temporary_type right_tmp; + cml::et::detail::Resize(right_tmp,right.rows(),right.cols()); + right_tmp = right; + + return detail::mul(left,right_tmp,detail::mul_xA()); +} + +/** operator*() for a VectorXpr and a matrix. */ +template +inline typename et::MatVecPromote< + typename XprT::result_type, matrix +>::temporary_type +operator*(const et::VectorXpr& left, + const matrix& right) +{ + /* Generate a temporary, and compute the left-hand expression: */ + typename et::VectorXpr::temporary_type left_tmp; + cml::et::detail::Resize(left_tmp,left.size()); + left_tmp = left; + + return detail::mul(left_tmp,right,detail::mul_xA()); +} + +/** operator*() for a VectorXpr and a MatrixXpr. */ +template +inline typename et::MatVecPromote< + typename XprT1::result_type, typename XprT2::result_type +>::temporary_type +operator*(const et::VectorXpr& left, + const et::MatrixXpr& right) +{ + /* Generate a temporary, and compute the left-hand expression: */ + typename et::VectorXpr::temporary_type left_tmp; + cml::et::detail::Resize(left_tmp,left.size()); + left_tmp = left; + + /* Generate a temporary, and compute the right-hand expression: */ + typename et::MatrixXpr::temporary_type right_tmp; + cml::et::detail::Resize(right_tmp,right.rows(),right.cols()); + right_tmp = right; + + return detail::mul(left_tmp,right_tmp,detail::mul_xA()); +} + +} // namespace cml + +#endif + +// ------------------------------------------------------------------------- +// vim:ft=cpp