octave-maintainers
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

Full matrix inversion for triangular and


From: David Bateman
Subject: Full matrix inversion for triangular and
Date: Wed, 06 Dec 2006 01:57:59 +0100
User-agent: Thunderbird 1.5.0.7 (X11/20060921)

The attached patch adds probing of the matrix type and special casing of
positive definite and triangular matrices to the Matrix and
ComplexMatrix classes inverse methods and adjusts the Finv and xpow
functions appropriately. As an example consider

N=1000;
A = 10*eye(N) + randn(N,N);
t0 = cputime(); BA = inv(A); t(1) = cputime() - t0; ## Normal case
P = A*A';
t0 = cputime(); BP = inv(P); t(2) = cputime() - t0; ## PD case
U = triu(A);
t0 = cputime(); BU = inv(U); t(3) = cputime() - t0; ## Upper Triangular
L = tril(A);
t0 = cputime(); BL = inv(L); t(4) = cputime() - t0; ## Lower Triangular
t

which returns

   7.7648   7.7618   3.8974   4.0144

for 2.9.9, and for 2.9.9+ with the attached patch returns

   7.8098   3.7634   1.3958   1.4018

for significant gains in speed. Note that this essentially makes cholinv
redundant, though it appears that the same is not so for chol2inv. Consider

t0 = cputime(); BP = cholinv(P); cputime() - t0
ans =  3.7494

R = chol(P); t0 = cputime(); BP = chol2inv(R); cputime() - t0
ans =  2.6386
t0 = cputime(); BP = inv(R); BP = BP * BP'; cputime() - t0
ans =  5.3012

Now to check correctness, I did

N = 100;
n = 10;
for i=1:n
  A = 10*eye(N) + randn(N,N);
  P = A*A';
  assert(A*inv(A), eye(N), 1e-10);
endfor

for i=1:n
  A = 10*eye(N) + randn(N,N);
  U = triu(A);
  assert(U*inv(U), eye(N), 1e-10);
endfor

for i=1:n
  A = 10*eye(N) + randn(N,N);
  L = tril(A);
  assert(L*inv(L), eye(N), 1e-10);
endfor

which ran through without assert signaling any failures. This patch
excludes the toeplitz solver code I previously sent (cf.
http://www.cae.wisc.edu/pipermail/octave-maintainers/2006-November/001354.html).
 I have had second thoughts about the usefulness of the toeplitz solver
code, in that as Paul points out, if the matrix is toeplitz we shouldn't
form the matrix at all in most cases and the solver can then work
directly from the vectors making up te toeplitz matrix for a gain in
both speed and memory. It would probably be better to have a specialized
function for this case (probably in octave-forge), though perhaps Gorazd
would like to comment on that..

D.


*** ./liboctave/CMatrix.cc.orig6        2006-12-06 01:45:56.283165269 +0100
--- ./liboctave/CMatrix.cc      2006-12-06 00:37:47.849374855 +0100
***************
*** 41,46 ****
--- 41,47 ----
  #include "CmplxDET.h"
  #include "CmplxSCHUR.h"
  #include "CmplxSVD.h"
+ #include "CmplxCHOL.h"
  #include "f77-fcn.h"
  #include "lo-error.h"
  #include "lo-ieee.h"
***************
*** 142,147 ****
--- 143,155 ----
                             F77_CHAR_ARG_LEN_DECL);
  
    F77_RET_T
+   F77_FUNC (ztrtri, ZTRTRI) (F77_CONST_CHAR_ARG_DECL, 
F77_CONST_CHAR_ARG_DECL, 
+                            const octave_idx_type&, const Complex*, 
+                            const octave_idx_type&, octave_idx_type& 
+                            F77_CHAR_ARG_LEN_DECL
+                            F77_CHAR_ARG_LEN_DECL);
+ 
+   F77_RET_T
    F77_FUNC (ztrcon, ZTRCON) (F77_CONST_CHAR_ARG_DECL, 
F77_CONST_CHAR_ARG_DECL, 
                             F77_CONST_CHAR_ARG_DECL, const octave_idx_type&, 
                             const Complex*, const octave_idx_type&, double&,
***************
*** 959,977 ****
  {
    octave_idx_type info;
    double rcond;
!   return inverse (info, rcond, 0, 0);
  }
  
  ComplexMatrix
! ComplexMatrix::inverse (octave_idx_type& info) const
  {
    double rcond;
!   return inverse (info, rcond, 0, 0);
  }
  
  ComplexMatrix
! ComplexMatrix::inverse (octave_idx_type& info, double& rcond, int force, 
!                       int calc_cond) const
  {
    ComplexMatrix retval;
  
--- 967,1060 ----
  {
    octave_idx_type info;
    double rcond;
!   MatrixType mattype (*this);
!   return inverse (mattype, info, rcond, 0, 0);
! }
! 
! ComplexMatrix
! ComplexMatrix::inverse (MatrixType &mattype) const
! {
!   octave_idx_type info;
!   double rcond;
!   return inverse (mattype, info, rcond, 0, 0);
  }
  
  ComplexMatrix
! ComplexMatrix::inverse (MatrixType &mattype, octave_idx_type& info) const
  {
    double rcond;
!   return inverse (mattype, info, rcond, 0, 0);
  }
  
  ComplexMatrix
! ComplexMatrix::tinverse (MatrixType &mattype, octave_idx_type& info,
!                        double& rcond, int force, int calc_cond) const
! {
!   ComplexMatrix retval;
! 
!   octave_idx_type nr = rows ();
!   octave_idx_type nc = cols ();
! 
!   if (nr != nc || nr == 0 || nc == 0)
!     (*current_liboctave_error_handler) ("inverse requires square matrix");
!   else
!     {
!       int typ = mattype.type ();
!       char uplo = (typ == MatrixType::Lower ? 'L' : 'U');
!       char udiag = 'N';
!       retval = *this;
!       Complex *tmp_data = retval.fortran_vec ();
! 
!       F77_XFCN (ztrtri, ZTRTRI, (F77_CONST_CHAR_ARG2 (&uplo, 1),
!                                F77_CONST_CHAR_ARG2 (&udiag, 1),
!                                nr, tmp_data, nr, info 
!                                F77_CHAR_ARG_LEN (1)
!                                F77_CHAR_ARG_LEN (1)));
! 
!       if (f77_exception_encountered)
!       (*current_liboctave_error_handler) ("unrecoverable error in ztrtri");
!       else
!       {
!         // Throw-away extra info LAPACK gives so as to not change output.
!         rcond = 0.0;
!         if (info != 0) 
!           info = -1;
!         else if (calc_cond) 
!           {
!             octave_idx_type ztrcon_info = 0;
!             char job = '1';
! 
!             OCTAVE_LOCAL_BUFFER (Complex, cwork, 2 * nr);
!             OCTAVE_LOCAL_BUFFER (double, rwork, nr);
! 
!             F77_XFCN (ztrcon, ZTRCON, (F77_CONST_CHAR_ARG2 (&job, 1),
!                                        F77_CONST_CHAR_ARG2 (&uplo, 1),
!                                        F77_CONST_CHAR_ARG2 (&udiag, 1),
!                                        nr, tmp_data, nr, rcond, 
!                                        cwork, rwork, ztrcon_info 
!                                        F77_CHAR_ARG_LEN (1)
!                                        F77_CHAR_ARG_LEN (1)
!                                        F77_CHAR_ARG_LEN (1)));
! 
!             if (f77_exception_encountered)
!               (*current_liboctave_error_handler) 
!                 ("unrecoverable error in ztrcon");
! 
!             if (ztrcon_info != 0) 
!               info = -1;
!           }
!       }
! 
!       if (info == -1 && ! force)
!       retval = *this; // Restore matrix contents.
!     }
! 
!   return retval;
! }
! 
! ComplexMatrix
! ComplexMatrix::finverse (MatrixType &mattype, octave_idx_type& info,
!                        double& rcond, int force, int calc_cond) const
  {
    ComplexMatrix retval;
  
***************
*** 1062,1073 ****
--- 1145,1189 ----
                info = -1;
            }
        }
+ 
+       if (info != 0)
+       mattype.mark_as_rectangular();
      }
    
    return retval;
  }
  
  ComplexMatrix
+ ComplexMatrix::inverse (MatrixType &mattype, octave_idx_type& info,
+                       double& rcond, int force, int calc_cond) const
+ {
+   int typ = mattype.type (false);
+   ComplexMatrix ret;
+ 
+   if (typ == MatrixType::Unknown)
+     typ = mattype.type (*this);
+ 
+   if (typ == MatrixType::Upper || typ == MatrixType::Lower)
+     ret = tinverse (mattype, info, rcond, force, calc_cond);
+   else if (typ != MatrixType::Rectangular)
+     {
+       if (mattype.is_hermitian ())
+       {
+         ComplexCHOL chol (*this, info);
+         if (info == 0)
+           ret = chol.inverse ();
+         else
+           mattype.mark_as_unsymmetric ();
+       }
+ 
+       if (!mattype.is_hermitian ())
+       ret = finverse(mattype, info, rcond, force, calc_cond);
+     }
+ 
+   return ret;
+ }
+ 
+ ComplexMatrix
  ComplexMatrix::pseudo_inverse (double tol) const
  {
    ComplexMatrix retval;
*** ./liboctave/CMatrix.h.orig6 2006-12-06 01:45:56.284165222 +0100
--- ./liboctave/CMatrix.h       2006-12-06 00:28:42.085166364 +0100
***************
*** 135,143 ****
  
    ComplexColumnVector column (octave_idx_type i) const;
  
    ComplexMatrix inverse (void) const;
!   ComplexMatrix inverse (octave_idx_type& info) const;
!   ComplexMatrix inverse (octave_idx_type& info, double& rcond, int force = 0,
                         int calc_cond = 1) const;
  
    ComplexMatrix pseudo_inverse (double tol = 0.0) const;
--- 135,153 ----
  
    ComplexColumnVector column (octave_idx_type i) const;
  
+ private:
+   ComplexMatrix tinverse (MatrixType &mattype, octave_idx_type& info,
+                         double& rcond, int force, int calc_cond) const;
+ 
+   ComplexMatrix finverse (MatrixType &mattype, octave_idx_type& info,
+                         double& rcond, int force, int calc_cond) const;
+ 
+ public:
    ComplexMatrix inverse (void) const;
!   ComplexMatrix inverse (MatrixType &mattype) const;
!   ComplexMatrix inverse (MatrixType &mattype, octave_idx_type& info) const;
!   ComplexMatrix inverse (MatrixType &mattype, octave_idx_type& info,
!                        double& rcond, int force = 0, 
                         int calc_cond = 1) const;
  
    ComplexMatrix pseudo_inverse (double tol = 0.0) const;
*** ./liboctave/dMatrix.h.orig6 2006-12-06 01:45:56.284165222 +0100
--- ./liboctave/dMatrix.h       2006-12-06 00:28:29.140778085 +0100
***************
*** 107,116 ****
  
    ColumnVector column (octave_idx_type i) const;
  
    Matrix inverse (void) const;
!   Matrix inverse (octave_idx_type& info) const;
!   Matrix inverse (octave_idx_type& info, double& rcond, int force = 0, 
!                 int calc_cond = 1) const;
  
    Matrix pseudo_inverse (double tol = 0.0) const;
  
--- 107,125 ----
  
    ColumnVector column (octave_idx_type i) const;
  
+ private:
+   Matrix tinverse (MatrixType &mattype, octave_idx_type& info, double& rcond, 
+                  int force, int calc_cond) const;
+ 
+   Matrix finverse (MatrixType &mattype, octave_idx_type& info, double& rcond, 
+                  int force, int calc_cond) const;
+ 
+ public:
    Matrix inverse (void) const;
!   Matrix inverse (MatrixType &mattype) const;
!   Matrix inverse (MatrixType &mattype, octave_idx_type& info) const;
!   Matrix inverse (MatrixType &mattype, octave_idx_type& info, double& rcond,
!                 int force = 0, int calc_cond = 1) const;
  
    Matrix pseudo_inverse (double tol = 0.0) const;
  
*** ./liboctave/dMatrix.cc.orig6        2006-12-06 01:45:56.286165128 +0100
--- ./liboctave/dMatrix.cc      2006-12-06 00:40:32.059614673 +0100
***************
*** 37,42 ****
--- 37,43 ----
  #include "dbleDET.h"
  #include "dbleSCHUR.h"
  #include "dbleSVD.h"
+ #include "dbleCHOL.h"
  #include "f77-fcn.h"
  #include "lo-error.h"
  #include "lo-ieee.h"
***************
*** 138,143 ****
--- 139,150 ----
                             F77_CHAR_ARG_LEN_DECL);
  
    F77_RET_T
+   F77_FUNC (dtrtri, DTRTRI) (F77_CONST_CHAR_ARG_DECL, 
F77_CONST_CHAR_ARG_DECL, 
+                            const octave_idx_type&, const double*, 
+                            const octave_idx_type&, octave_idx_type&
+                            F77_CHAR_ARG_LEN_DECL
+                            F77_CHAR_ARG_LEN_DECL);
+   F77_RET_T
    F77_FUNC (dtrcon, DTRCON) (F77_CONST_CHAR_ARG_DECL, 
F77_CONST_CHAR_ARG_DECL, 
                             F77_CONST_CHAR_ARG_DECL, const octave_idx_type&, 
                             const double*, const octave_idx_type&, double&,
***************
*** 628,645 ****
  {
    octave_idx_type info;
    double rcond;
!   return inverse (info, rcond, 0, 0);
  }
  
  Matrix
! Matrix::inverse (octave_idx_type& info) const
  {
    double rcond;
!   return inverse (info, rcond, 0, 0);
  }
  
  Matrix
! Matrix::inverse (octave_idx_type& info, double& rcond, int force, int 
calc_cond) const
  {
    Matrix retval;
  
--- 635,729 ----
  {
    octave_idx_type info;
    double rcond;
!   MatrixType mattype (*this);
!   return inverse (mattype, info, rcond, 0, 0);
! }
! 
! Matrix
! Matrix::inverse (MatrixType& mattype) const
! {
!   octave_idx_type info;
!   double rcond;
!   return inverse (mattype, info, rcond, 0, 0);
  }
  
  Matrix
! Matrix::inverse (MatrixType &mattype, octave_idx_type& info) const
  {
    double rcond;
!   return inverse (mattype, info, rcond, 0, 0);
  }
  
  Matrix
! Matrix::tinverse (MatrixType &mattype, octave_idx_type& info, double& rcond, 
!                 int force, int calc_cond) const
! {
!   Matrix retval;
! 
!   octave_idx_type nr = rows ();
!   octave_idx_type nc = cols ();
! 
!   if (nr != nc || nr == 0 || nc == 0)
!     (*current_liboctave_error_handler) ("inverse requires square matrix");
!   else
!     {
!       int typ = mattype.type ();
!       char uplo = (typ == MatrixType::Lower ? 'L' : 'U');
!       char udiag = 'N';
!       retval = *this;
!       double *tmp_data = retval.fortran_vec ();
! 
!       F77_XFCN (dtrtri, DTRTRI, (F77_CONST_CHAR_ARG2 (&uplo, 1),
!                                F77_CONST_CHAR_ARG2 (&udiag, 1),
!                                nr, tmp_data, nr, info 
!                                F77_CHAR_ARG_LEN (1)
!                                F77_CHAR_ARG_LEN (1)));
! 
!       if (f77_exception_encountered)
!       (*current_liboctave_error_handler) ("unrecoverable error in dtrtri");
!       else
!       {
!         // Throw-away extra info LAPACK gives so as to not change output.
!         rcond = 0.0;
!         if (info != 0) 
!           info = -1;
!         else if (calc_cond) 
!           {
!             octave_idx_type dtrcon_info = 0;
!             char job = '1';
! 
!             OCTAVE_LOCAL_BUFFER (double, work, 3 * nr);
!             OCTAVE_LOCAL_BUFFER (octave_idx_type, iwork, nr);
! 
!             F77_XFCN (dtrcon, DTRCON, (F77_CONST_CHAR_ARG2 (&job, 1),
!                                        F77_CONST_CHAR_ARG2 (&uplo, 1),
!                                        F77_CONST_CHAR_ARG2 (&udiag, 1),
!                                        nr, tmp_data, nr, rcond, 
!                                        work, iwork, dtrcon_info 
!                                        F77_CHAR_ARG_LEN (1)
!                                        F77_CHAR_ARG_LEN (1)
!                                        F77_CHAR_ARG_LEN (1)));
! 
!             if (f77_exception_encountered)
!               (*current_liboctave_error_handler) 
!                 ("unrecoverable error in dtrcon");
! 
!             if (dtrcon_info != 0) 
!               info = -1;
!           }
!       }
! 
!       if (info == -1 && ! force)
!       retval = *this; // Restore matrix contents.
!     }
! 
!   return retval;
! }
! 
! 
! Matrix
! Matrix::finverse (MatrixType &mattype, octave_idx_type& info, double& rcond, 
!                 int force, int calc_cond) const
  {
    Matrix retval;
  
***************
*** 730,741 ****
--- 814,858 ----
                info = -1;
            }
        }
+ 
+       if (info != 0)
+       mattype.mark_as_rectangular();
      }
  
    return retval;
  }
  
  Matrix
+ Matrix::inverse (MatrixType &mattype, octave_idx_type& info, double& rcond, 
+                int force, int calc_cond) const
+ {
+   int typ = mattype.type (false);
+   Matrix ret;
+ 
+   if (typ == MatrixType::Unknown)
+     typ = mattype.type (*this);
+ 
+   if (typ == MatrixType::Upper || typ == MatrixType::Lower)
+     ret = tinverse (mattype, info, rcond, force, calc_cond);
+   else if (typ != MatrixType::Rectangular)
+     {
+       if (mattype.is_hermitian ())
+       {
+         CHOL chol (*this, info);
+         if (info == 0)
+           ret = chol.inverse ();
+         else
+           mattype.mark_as_unsymmetric ();
+       }
+ 
+       if (!mattype.is_hermitian ())
+       ret = finverse(mattype, info, rcond, force, calc_cond);
+     }
+ 
+   return ret;
+ }
+ 
+ Matrix
  Matrix::pseudo_inverse (double tol) const
  {
    SVD result (*this, SVD::economy);
*** ./src/xpow.cc.orig6 2006-12-06 01:42:35.007677067 +0100
--- ./src/xpow.cc       2006-12-06 00:23:04.933099363 +0100
***************
*** 192,199 ****
  
                  octave_idx_type info;
                  double rcond = 0.0;
  
!                 atmp = a.inverse (info, rcond, 1);
  
                  if (info == -1)
                    warning ("inverse: matrix singular to machine\
--- 192,200 ----
  
                  octave_idx_type info;
                  double rcond = 0.0;
+                 MatrixType mattype (a);
  
!                 atmp = a.inverse (mattype, info, rcond, 1);
  
                  if (info == -1)
                    warning ("inverse: matrix singular to machine\
***************
*** 388,395 ****
  
                  octave_idx_type info;
                  double rcond = 0.0;
  
!                 atmp = a.inverse (info, rcond, 1);
  
                  if (info == -1)
                    warning ("inverse: matrix singular to machine\
--- 389,397 ----
  
                  octave_idx_type info;
                  double rcond = 0.0;
+                 MatrixType mattype (a);
  
!                 atmp = a.inverse (mattype, info, rcond, 1);
  
                  if (info == -1)
                    warning ("inverse: matrix singular to machine\
*** ./src/DLD-FUNCTIONS/inv.cc.orig6    2006-12-06 01:42:49.121010104 +0100
--- ./src/DLD-FUNCTIONS/inv.cc  2006-12-06 01:07:56.926882187 +0100
***************
*** 74,83 ****
  
        if (! error_state)
        {
          octave_idx_type info;
          double rcond = 0.0;
  
!         Matrix result = m.inverse (info, rcond, 1);
  
          if (nargout > 1)
            retval(1) = rcond;
--- 74,87 ----
  
        if (! error_state)
        {
+         MatrixType mattyp = args(0).matrix_type ();
+ 
          octave_idx_type info;
          double rcond = 0.0;
  
!         Matrix result = m.inverse (mattyp, info, rcond, 1);
! 
!         args(0).matrix_type (mattyp);
  
          if (nargout > 1)
            retval(1) = rcond;
***************
*** 97,106 ****
  
        if (! error_state)
        {
          octave_idx_type info;
          double rcond = 0.0;
  
!         ComplexMatrix result = m.inverse (info, rcond, 1);
  
          if (nargout > 1)
            retval(1) = rcond;
--- 101,114 ----
  
        if (! error_state)
        {
+         MatrixType mattyp = args(0).matrix_type ();
+ 
          octave_idx_type info;
          double rcond = 0.0;
  
!         ComplexMatrix result = m.inverse (mattyp, info, rcond, 1);
! 
!         args(0).matrix_type (mattyp);
  
          if (nargout > 1)
            retval(1) = rcond;
2006-12-06  David Bateman  <address@hidden>

        * dMatrix.cc (finverse): Old inverse method renamed inverse.
        (tinverse): New method for triangular matrices.
        (inverse): New function with matrix type probing.
        * dMatrix.h (finverse, tinverse, inverse): New and modified
        declarations.
        * CMatrix.cc: ditto.
        * CMatrix.h: ditto.

2006-12-06  David Bateman  <address@hidden>

        * xpow.cc (xpow (const Matrix&, double)): Add matrix type probing
        to matrix inverse.
        (xpow (const ComplexMatrix&, double)): ditto.
        * DLD-FUNCTIONS/inv.cc (Finv): Add matrix type probing.

reply via email to

[Prev in Thread] Current Thread [Next in Thread]