/* -*- C++ -*- ------------------------------------------------------------ Copyright (c) 2007 Jesse Anders and Demian Nave http://cmldev.net/ The Configurable Math Library (CML) is distributed under the terms of the Boost Software License, v1.0 (see cml/LICENSE for details). *-----------------------------------------------------------------------*/ /** @file * @brief Multiply a matrix and a vector. * * @todo Implement smarter temporary generation. * * @todo Does it make sense to put mat-vec multiplication as a node into the * expression tree? * * @internal This does not need to return an expression type, since the * temporary generation for the matrix result is handled automatically by the * compiler. i.e. when used in an expression, the result is automatically * included in the expression tree as a temporary by the compiler. */ #ifndef matvec_mul_h #define matvec_mul_h #include #include #include #include /* This is used below to create a more meaningful compile-time error when * mat-vec mul is not provided with the right arguments: */ struct mvmul_expects_one_matrix_and_one_vector_arg_error; struct mvmul_expects_one_vector_and_one_matrix_arg_error; namespace cml { namespace detail { /* For choosing the proper multiplication order: */ typedef true_type mul_Ax; typedef false_type mul_xA; /** Compute y = A*x. */ template inline typename et::MatVecPromote< typename et::ExprTraits::result_type, typename et::ExprTraits::result_type >::temporary_type mul(const LeftT& A, const RightT& x, mul_Ax) { /* Shorthand: */ typedef et::ExprTraits left_traits; typedef et::ExprTraits right_traits; typedef typename left_traits::result_tag left_result; typedef typename right_traits::result_tag right_result; /* mul()[A*x] requires a matrix and a vector expression: */ CML_STATIC_REQUIRE_M( (same_type::is_true && same_type::is_true), mvmul_expects_one_matrix_and_one_vector_arg_error); /* Note: parens are required here so that the preprocessor ignores the * commas. */ /* Get result type: */ typedef typename et::MatVecPromote< typename left_traits::result_type, typename right_traits::result_type >::temporary_type result_type; /* Record size type: */ typedef typename result_type::size_tag size_tag; /* Check the size: */ size_t N = et::CheckedSize(A, x, size_tag()); /* Initialize the new vector: */ result_type y; cml::et::detail::Resize(y, N); /* Compute y = A*x: */ typedef typename result_type::value_type sum_type; for(size_t i = 0; i < N; ++i) { /* XXX This should be unrolled. */ sum_type sum(A(i,0)*x[0]); for(size_t k = 1; k < x.size(); ++k) { sum += (A(i,k)*x[k]); } y[i] = sum; } return y; } /** Compute y = x*A. */ template inline typename et::MatVecPromote< typename et::ExprTraits::result_type, typename et::ExprTraits::result_type >::temporary_type mul(const LeftT& x, const RightT& A, mul_xA) { /* Shorthand: */ typedef et::ExprTraits left_traits; typedef et::ExprTraits right_traits; typedef typename left_traits::result_tag left_result; typedef typename right_traits::result_tag right_result; /* mul()[x*A] requires a vector and a matrix expression: */ CML_STATIC_REQUIRE_M( (same_type::is_true && same_type::is_true), mvmul_expects_one_vector_and_one_matrix_arg_error); /* Note: parens are required here so that the preprocessor ignores the * commas. */ /* Get result type: */ typedef typename et::MatVecPromote< typename left_traits::result_type, typename right_traits::result_type >::temporary_type result_type; /* Record size type: */ typedef typename result_type::size_tag size_tag; /* Check the size: */ size_t N = et::CheckedSize(x, A, size_tag()); /* Initialize the new vector: */ result_type y; cml::et::detail::Resize(y, N); /* Compute y = x*A: */ typedef typename result_type::value_type sum_type; for(size_t i = 0; i < N; ++i) { /* XXX This should be unrolled. */ sum_type sum(x[0]*A(0,i)); for(size_t k = 1; k < x.size(); ++k) { sum += (x[k]*A(k,i)); } y[i] = sum; } return y; } } // namespace detail /** operator*() for a matrix and a vector. */ template inline typename et::MatVecPromote< matrix, vector >::temporary_type operator*(const matrix& left, const vector& right) { return detail::mul(left,right,detail::mul_Ax()); } /** operator*() for a matrix and a VectorXpr. */ template inline typename et::MatVecPromote< matrix, typename XprT::result_type >::temporary_type operator*(const matrix& left, const et::VectorXpr& right) { /* Generate a temporary, and compute the right-hand expression: */ typename et::VectorXpr::temporary_type right_tmp; cml::et::detail::Resize(right_tmp,right.size()); right_tmp = right; return detail::mul(left,right_tmp,detail::mul_Ax()); } /** operator*() for a MatrixXpr and a vector. */ template inline typename et::MatVecPromote< typename XprT::result_type, vector >::temporary_type operator*(const et::MatrixXpr& left, const vector& right) { /* Generate a temporary, and compute the left-hand expression: */ typename et::MatrixXpr::temporary_type left_tmp; cml::et::detail::Resize(left_tmp,left.rows(),left.cols()); left_tmp = left; return detail::mul(left_tmp,right,detail::mul_Ax()); } /** operator*() for a MatrixXpr and a VectorXpr. */ template inline typename et::MatVecPromote< typename XprT1::result_type, typename XprT2::result_type >::temporary_type operator*(const et::MatrixXpr& left, const et::VectorXpr& right) { /* Generate a temporary, and compute the left-hand expression: */ typename et::MatrixXpr::temporary_type left_tmp; cml::et::detail::Resize(left_tmp,left.rows(),left.cols()); left_tmp = left; /* Generate a temporary, and compute the right-hand expression: */ typename et::VectorXpr::temporary_type right_tmp; cml::et::detail::Resize(right_tmp,right.size()); right_tmp = right; return detail::mul(left_tmp,right_tmp,detail::mul_Ax()); } /** operator*() for a vector and a matrix. */ template inline typename et::MatVecPromote< vector, matrix >::temporary_type operator*(const vector& left, const matrix& right) { return detail::mul(left,right,detail::mul_xA()); } /** operator*() for a vector and a MatrixXpr. */ template inline typename et::MatVecPromote< typename XprT::result_type, vector >::temporary_type operator*(const vector& left, const et::MatrixXpr& right) { /* Generate a temporary, and compute the right-hand expression: */ typename et::MatrixXpr::temporary_type right_tmp; cml::et::detail::Resize(right_tmp,right.rows(),right.cols()); right_tmp = right; return detail::mul(left,right_tmp,detail::mul_xA()); } /** operator*() for a VectorXpr and a matrix. */ template inline typename et::MatVecPromote< typename XprT::result_type, matrix >::temporary_type operator*(const et::VectorXpr& left, const matrix& right) { /* Generate a temporary, and compute the left-hand expression: */ typename et::VectorXpr::temporary_type left_tmp; cml::et::detail::Resize(left_tmp,left.size()); left_tmp = left; return detail::mul(left_tmp,right,detail::mul_xA()); } /** operator*() for a VectorXpr and a MatrixXpr. */ template inline typename et::MatVecPromote< typename XprT1::result_type, typename XprT2::result_type >::temporary_type operator*(const et::VectorXpr& left, const et::MatrixXpr& right) { /* Generate a temporary, and compute the left-hand expression: */ typename et::VectorXpr::temporary_type left_tmp; cml::et::detail::Resize(left_tmp,left.size()); left_tmp = left; /* Generate a temporary, and compute the right-hand expression: */ typename et::MatrixXpr::temporary_type right_tmp; cml::et::detail::Resize(right_tmp,right.rows(),right.cols()); right_tmp = right; return detail::mul(left_tmp,right_tmp,detail::mul_xA()); } } // namespace cml #endif // ------------------------------------------------------------------------- // vim:ft=cpp