IT++ Logo

ls_solve.cpp

Go to the documentation of this file.
00001 
00030 #ifndef _MSC_VER
00031 #  include <itpp/config.h>
00032 #else
00033 #  include <itpp/config_msvc.h>
00034 #endif
00035 
00036 #if defined(HAVE_LAPACK)
00037 #  include <itpp/base/algebra/lapack.h>
00038 #endif
00039 
00040 #include <itpp/base/algebra/ls_solve.h>
00041 
00042 
00043 namespace itpp
00044 {
00045 
00046 // ----------- ls_solve_chol -----------------------------------------------------------
00047 
00048 #if defined(HAVE_LAPACK)
00049 
00050 bool ls_solve_chol(const mat &A, const vec &b, vec &x)
00051 {
00052   int n, lda, ldb, nrhs, info;
00053   n = lda = ldb = A.rows();
00054   nrhs = 1;
00055   char uplo = 'U';
00056 
00057   it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00058   it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00059 
00060   ivec ipiv(n);
00061   x = b;
00062   mat Chol = A;
00063 
00064   dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info);
00065 
00066   return (info == 0);
00067 }
00068 
00069 
00070 bool ls_solve_chol(const mat &A, const mat &B, mat &X)
00071 {
00072   int n, lda, ldb, nrhs, info;
00073   n = lda = ldb = A.rows();
00074   nrhs = B.cols();
00075   char uplo = 'U';
00076 
00077   it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00078   it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00079 
00080   ivec ipiv(n);
00081   X = B;
00082   mat Chol = A;
00083 
00084   dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info);
00085 
00086   return (info == 0);
00087 }
00088 
00089 bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x)
00090 {
00091   int n, lda, ldb, nrhs, info;
00092   n = lda = ldb = A.rows();
00093   nrhs = 1;
00094   char uplo = 'U';
00095 
00096   it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00097   it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00098 
00099   ivec ipiv(n);
00100   x = b;
00101   cmat Chol = A;
00102 
00103   zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info);
00104 
00105   return (info == 0);
00106 }
00107 
00108 bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X)
00109 {
00110   int n, lda, ldb, nrhs, info;
00111   n = lda = ldb = A.rows();
00112   nrhs = B.cols();
00113   char uplo = 'U';
00114 
00115   it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00116   it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00117 
00118   ivec ipiv(n);
00119   X = B;
00120   cmat Chol = A;
00121 
00122   zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info);
00123 
00124   return (info == 0);
00125 }
00126 
00127 #else
00128 
00129 bool ls_solve_chol(const mat &A, const vec &b, vec &x)
00130 {
00131   it_error("LAPACK library is needed to use ls_solve_chol() function");
00132   return false;
00133 }
00134 
00135 bool ls_solve_chol(const mat &A, const mat &B, mat &X)
00136 {
00137   it_error("LAPACK library is needed to use ls_solve_chol() function");
00138   return false;
00139 }
00140 
00141 bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x)
00142 {
00143   it_error("LAPACK library is needed to use ls_solve_chol() function");
00144   return false;
00145 }
00146 
00147 bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X)
00148 {
00149   it_error("LAPACK library is needed to use ls_solve_chol() function");
00150   return false;
00151 }
00152 
00153 #endif // HAVE_LAPACK
00154 
00155 vec ls_solve_chol(const mat &A, const vec &b)
00156 {
00157   vec x;
00158   bool info;
00159   info = ls_solve_chol(A, b, x);
00160   it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00161   return x;
00162 }
00163 
00164 mat ls_solve_chol(const mat &A, const mat &B)
00165 {
00166   mat X;
00167   bool info;
00168   info = ls_solve_chol(A, B, X);
00169   it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00170   return X;
00171 }
00172 
00173 cvec ls_solve_chol(const cmat &A, const cvec &b)
00174 {
00175   cvec x;
00176   bool info;
00177   info = ls_solve_chol(A, b, x);
00178   it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00179   return x;
00180 }
00181 
00182 cmat ls_solve_chol(const cmat &A, const cmat &B)
00183 {
00184   cmat X;
00185   bool info;
00186   info = ls_solve_chol(A, B, X);
00187   it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00188   return X;
00189 }
00190 
00191 
00192 // --------- ls_solve ---------------------------------------------------------------
00193 #if defined(HAVE_LAPACK)
00194 
00195 bool ls_solve(const mat &A, const vec &b, vec &x)
00196 {
00197   int n, lda, ldb, nrhs, info;
00198   n = lda = ldb = A.rows();
00199   nrhs = 1;
00200 
00201   it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00202   it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00203 
00204   ivec ipiv(n);
00205   x = b;
00206   mat LU = A;
00207 
00208   dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info);
00209 
00210   return (info == 0);
00211 }
00212 
00213 bool ls_solve(const mat &A, const mat &B, mat &X)
00214 {
00215   int n, lda, ldb, nrhs, info;
00216   n = lda = ldb = A.rows();
00217   nrhs = B.cols();
00218 
00219   it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00220   it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00221 
00222   ivec ipiv(n);
00223   X = B;
00224   mat LU = A;
00225 
00226   dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info);
00227 
00228   return (info == 0);
00229 }
00230 
00231 bool ls_solve(const cmat &A, const cvec &b, cvec &x)
00232 {
00233   int n, lda, ldb, nrhs, info;
00234   n = lda = ldb = A.rows();
00235   nrhs = 1;
00236 
00237   it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00238   it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00239 
00240   ivec ipiv(n);
00241   x = b;
00242   cmat LU = A;
00243 
00244   zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info);
00245 
00246   return (info == 0);
00247 }
00248 
00249 bool ls_solve(const cmat &A, const cmat &B, cmat &X)
00250 {
00251   int n, lda, ldb, nrhs, info;
00252   n = lda = ldb = A.rows();
00253   nrhs = B.cols();
00254 
00255   it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00256   it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00257 
00258   ivec ipiv(n);
00259   X = B;
00260   cmat LU = A;
00261 
00262   zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info);
00263 
00264   return (info == 0);
00265 }
00266 
00267 #else
00268 
00269 bool ls_solve(const mat &A, const vec &b, vec &x)
00270 {
00271   it_error("LAPACK library is needed to use ls_solve() function");
00272   return false;
00273 }
00274 
00275 bool ls_solve(const mat &A, const mat &B, mat &X)
00276 {
00277   it_error("LAPACK library is needed to use ls_solve() function");
00278   return false;
00279 }
00280 
00281 bool ls_solve(const cmat &A, const cvec &b, cvec &x)
00282 {
00283   it_error("LAPACK library is needed to use ls_solve() function");
00284   return false;
00285 }
00286 
00287 bool ls_solve(const cmat &A, const cmat &B, cmat &X)
00288 {
00289   it_error("LAPACK library is needed to use ls_solve() function");
00290   return false;
00291 }
00292 
00293 #endif // HAVE_LAPACK
00294 
00295 vec ls_solve(const mat &A, const vec &b)
00296 {
00297   vec x;
00298   bool info;
00299   info = ls_solve(A, b, x);
00300   it_assert_debug(info, "ls_solve: Failed solving the system");
00301   return x;
00302 }
00303 
00304 mat ls_solve(const mat &A, const mat &B)
00305 {
00306   mat X;
00307   bool info;
00308   info = ls_solve(A, B, X);
00309   it_assert_debug(info, "ls_solve: Failed solving the system");
00310   return X;
00311 }
00312 
00313 cvec ls_solve(const cmat &A, const cvec &b)
00314 {
00315   cvec x;
00316   bool info;
00317   info = ls_solve(A, b, x);
00318   it_assert_debug(info, "ls_solve: Failed solving the system");
00319   return x;
00320 }
00321 
00322 cmat ls_solve(const cmat &A, const cmat &B)
00323 {
00324   cmat X;
00325   bool info;
00326   info = ls_solve(A, B, X);
00327   it_assert_debug(info, "ls_solve: Failed solving the system");
00328   return X;
00329 }
00330 
00331 
00332 // ----------------- ls_solve_od ------------------------------------------------------------------
00333 #if defined(HAVE_LAPACK)
00334 
00335 bool ls_solve_od(const mat &A, const vec &b, vec &x)
00336 {
00337   int m, n, lda, ldb, nrhs, lwork, info;
00338   char trans = 'N';
00339   m = lda = ldb = A.rows();
00340   n = A.cols();
00341   nrhs = 1;
00342   lwork = n + std::max(m, nrhs);
00343 
00344   it_assert_debug(m >= n, "The system is under-determined!");
00345   it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00346 
00347   vec work(lwork);
00348   x = b;
00349   mat QR = A;
00350 
00351   dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00352   x.set_size(n, true);
00353 
00354   return (info == 0);
00355 }
00356 
00357 bool ls_solve_od(const mat &A, const mat &B, mat &X)
00358 {
00359   int m, n, lda, ldb, nrhs, lwork, info;
00360   char trans = 'N';
00361   m = lda = ldb = A.rows();
00362   n = A.cols();
00363   nrhs = B.cols();
00364   lwork = n + std::max(m, nrhs);
00365 
00366   it_assert_debug(m >= n, "The system is under-determined!");
00367   it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00368 
00369   vec work(lwork);
00370   X = B;
00371   mat QR = A;
00372 
00373   dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00374   X.set_size(n, nrhs, true);
00375 
00376   return (info == 0);
00377 }
00378 
00379 bool ls_solve_od(const cmat &A, const cvec &b, cvec &x)
00380 {
00381   int m, n, lda, ldb, nrhs, lwork, info;
00382   char trans = 'N';
00383   m = lda = ldb = A.rows();
00384   n = A.cols();
00385   nrhs = 1;
00386   lwork = n + std::max(m, nrhs);
00387 
00388   it_assert_debug(m >= n, "The system is under-determined!");
00389   it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00390 
00391   cvec work(lwork);
00392   x = b;
00393   cmat QR = A;
00394 
00395   zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00396   x.set_size(n, true);
00397 
00398   return (info == 0);
00399 }
00400 
00401 bool ls_solve_od(const cmat &A, const cmat &B, cmat &X)
00402 {
00403   int m, n, lda, ldb, nrhs, lwork, info;
00404   char trans = 'N';
00405   m = lda = ldb = A.rows();
00406   n = A.cols();
00407   nrhs = B.cols();
00408   lwork = n + std::max(m, nrhs);
00409 
00410   it_assert_debug(m >= n, "The system is under-determined!");
00411   it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00412 
00413   cvec work(lwork);
00414   X = B;
00415   cmat QR = A;
00416 
00417   zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00418   X.set_size(n, nrhs, true);
00419 
00420   return (info == 0);
00421 }
00422 
00423 #else
00424 
00425 bool ls_solve_od(const mat &A, const vec &b, vec &x)
00426 {
00427   it_error("LAPACK library is needed to use ls_solve_od() function");
00428   return false;
00429 }
00430 
00431 bool ls_solve_od(const mat &A, const mat &B, mat &X)
00432 {
00433   it_error("LAPACK library is needed to use ls_solve_od() function");
00434   return false;
00435 }
00436 
00437 bool ls_solve_od(const cmat &A, const cvec &b, cvec &x)
00438 {
00439   it_error("LAPACK library is needed to use ls_solve_od() function");
00440   return false;
00441 }
00442 
00443 bool ls_solve_od(const cmat &A, const cmat &B, cmat &X)
00444 {
00445   it_error("LAPACK library is needed to use ls_solve_od() function");
00446   return false;
00447 }
00448 
00449 #endif // HAVE_LAPACK
00450 
00451 vec ls_solve_od(const mat &A, const vec &b)
00452 {
00453   vec x;
00454   bool info;
00455   info = ls_solve_od(A, b, x);
00456   it_assert_debug(info, "ls_solve_od: Failed solving the system");
00457   return x;
00458 }
00459 
00460 mat ls_solve_od(const mat &A, const mat &B)
00461 {
00462   mat X;
00463   bool info;
00464   info = ls_solve_od(A, B, X);
00465   it_assert_debug(info, "ls_solve_od: Failed solving the system");
00466   return X;
00467 }
00468 
00469 cvec ls_solve_od(const cmat &A, const cvec &b)
00470 {
00471   cvec x;
00472   bool info;
00473   info = ls_solve_od(A, b, x);
00474   it_assert_debug(info, "ls_solve_od: Failed solving the system");
00475   return x;
00476 }
00477 
00478 cmat ls_solve_od(const cmat &A, const cmat &B)
00479 {
00480   cmat X;
00481   bool info;
00482   info = ls_solve_od(A, B, X);
00483   it_assert_debug(info, "ls_solve_od: Failed solving the system");
00484   return X;
00485 }
00486 
00487 // ------------------- ls_solve_ud -----------------------------------------------------------
00488 #if defined(HAVE_LAPACK)
00489 
00490 bool ls_solve_ud(const mat &A, const vec &b, vec &x)
00491 {
00492   int m, n, lda, ldb, nrhs, lwork, info;
00493   char trans = 'N';
00494   m = lda = A.rows();
00495   n = A.cols();
00496   ldb = n;
00497   nrhs = 1;
00498   lwork = m + std::max(n, nrhs);
00499 
00500   it_assert_debug(m < n, "The system is over-determined!");
00501   it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00502 
00503   vec work(lwork);
00504   x = b;
00505   x.set_size(n, true);
00506   mat QR = A;
00507 
00508   dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00509 
00510   return (info == 0);
00511 }
00512 
00513 bool ls_solve_ud(const mat &A, const mat &B, mat &X)
00514 {
00515   int m, n, lda, ldb, nrhs, lwork, info;
00516   char trans = 'N';
00517   m = lda = A.rows();
00518   n = A.cols();
00519   ldb = n;
00520   nrhs = B.cols();
00521   lwork = m + std::max(n, nrhs);
00522 
00523   it_assert_debug(m < n, "The system is over-determined!");
00524   it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00525 
00526   vec work(lwork);
00527   X = B;
00528   X.set_size(n, std::max(m, nrhs), true);
00529   mat QR = A;
00530 
00531   dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00532   X.set_size(n, nrhs, true);
00533 
00534   return (info == 0);
00535 }
00536 
00537 bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x)
00538 {
00539   int m, n, lda, ldb, nrhs, lwork, info;
00540   char trans = 'N';
00541   m = lda = A.rows();
00542   n = A.cols();
00543   ldb = n;
00544   nrhs = 1;
00545   lwork = m + std::max(n, nrhs);
00546 
00547   it_assert_debug(m < n, "The system is over-determined!");
00548   it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00549 
00550   cvec work(lwork);
00551   x = b;
00552   x.set_size(n, true);
00553   cmat QR = A;
00554 
00555   zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00556 
00557   return (info == 0);
00558 }
00559 
00560 bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X)
00561 {
00562   int m, n, lda, ldb, nrhs, lwork, info;
00563   char trans = 'N';
00564   m = lda = A.rows();
00565   n = A.cols();
00566   ldb = n;
00567   nrhs = B.cols();
00568   lwork = m + std::max(n, nrhs);
00569 
00570   it_assert_debug(m < n, "The system is over-determined!");
00571   it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00572 
00573   cvec work(lwork);
00574   X = B;
00575   X.set_size(n, std::max(m, nrhs), true);
00576   cmat QR = A;
00577 
00578   zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00579   X.set_size(n, nrhs, true);
00580 
00581   return (info == 0);
00582 }
00583 
00584 #else
00585 
00586 bool ls_solve_ud(const mat &A, const vec &b, vec &x)
00587 {
00588   it_error("LAPACK library is needed to use ls_solve_ud() function");
00589   return false;
00590 }
00591 
00592 bool ls_solve_ud(const mat &A, const mat &B, mat &X)
00593 {
00594   it_error("LAPACK library is needed to use ls_solve_ud() function");
00595   return false;
00596 }
00597 
00598 bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x)
00599 {
00600   it_error("LAPACK library is needed to use ls_solve_ud() function");
00601   return false;
00602 }
00603 
00604 bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X)
00605 {
00606   it_error("LAPACK library is needed to use ls_solve_ud() function");
00607   return false;
00608 }
00609 
00610 #endif // HAVE_LAPACK
00611 
00612 
00613 vec ls_solve_ud(const mat &A, const vec &b)
00614 {
00615   vec x;
00616   bool info;
00617   info = ls_solve_ud(A, b, x);
00618   it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00619   return x;
00620 }
00621 
00622 mat ls_solve_ud(const mat &A, const mat &B)
00623 {
00624   mat X;
00625   bool info;
00626   info = ls_solve_ud(A, B, X);
00627   it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00628   return X;
00629 }
00630 
00631 cvec ls_solve_ud(const cmat &A, const cvec &b)
00632 {
00633   cvec x;
00634   bool info;
00635   info = ls_solve_ud(A, b, x);
00636   it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00637   return x;
00638 }
00639 
00640 cmat ls_solve_ud(const cmat &A, const cmat &B)
00641 {
00642   cmat X;
00643   bool info;
00644   info = ls_solve_ud(A, B, X);
00645   it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00646   return X;
00647 }
00648 
00649 
00650 // ---------------------- backslash -----------------------------------------
00651 
00652 bool backslash(const mat &A, const vec &b, vec &x)
00653 {
00654   int m = A.rows(), n = A.cols();
00655   bool info;
00656 
00657   if (m == n)
00658     info = ls_solve(A, b, x);
00659   else if (m > n)
00660     info = ls_solve_od(A, b, x);
00661   else
00662     info = ls_solve_ud(A, b, x);
00663 
00664   return info;
00665 }
00666 
00667 
00668 vec backslash(const mat &A, const vec &b)
00669 {
00670   vec x;
00671   bool info;
00672   info = backslash(A, b, x);
00673   it_assert_debug(info, "backslash(): solution was not found");
00674   return x;
00675 }
00676 
00677 
00678 bool backslash(const mat &A, const mat &B, mat &X)
00679 {
00680   int m = A.rows(), n = A.cols();
00681   bool info;
00682 
00683   if (m == n)
00684     info = ls_solve(A, B, X);
00685   else if (m > n)
00686     info = ls_solve_od(A, B, X);
00687   else
00688     info = ls_solve_ud(A, B, X);
00689 
00690   return info;
00691 }
00692 
00693 
00694 mat backslash(const mat &A, const mat &B)
00695 {
00696   mat X;
00697   bool info;
00698   info = backslash(A, B, X);
00699   it_assert_debug(info, "backslash(): solution was not found");
00700   return X;
00701 }
00702 
00703 
00704 bool backslash(const cmat &A, const cvec &b, cvec &x)
00705 {
00706   int m = A.rows(), n = A.cols();
00707   bool info;
00708 
00709   if (m == n)
00710     info = ls_solve(A, b, x);
00711   else if (m > n)
00712     info = ls_solve_od(A, b, x);
00713   else
00714     info = ls_solve_ud(A, b, x);
00715 
00716   return info;
00717 }
00718 
00719 
00720 cvec backslash(const cmat &A, const cvec &b)
00721 {
00722   cvec x;
00723   bool info;
00724   info = backslash(A, b, x);
00725   it_assert_debug(info, "backslash(): solution was not found");
00726   return x;
00727 }
00728 
00729 
00730 bool backslash(const cmat &A, const cmat &B, cmat &X)
00731 {
00732   int m = A.rows(), n = A.cols();
00733   bool info;
00734 
00735   if (m == n)
00736     info = ls_solve(A, B, X);
00737   else if (m > n)
00738     info = ls_solve_od(A, B, X);
00739   else
00740     info = ls_solve_ud(A, B, X);
00741 
00742   return info;
00743 }
00744 
00745 cmat backslash(const cmat &A, const cmat &B)
00746 {
00747   cmat X;
00748   bool info;
00749   info = backslash(A, B, X);
00750   it_assert_debug(info, "backslash(): solution was not found");
00751   return X;
00752 }
00753 
00754 
00755 // --------------------------------------------------------------------------
00756 
00757 vec forward_substitution(const mat &L, const vec &b)
00758 {
00759   int n = L.rows();
00760   vec x(n);
00761 
00762   forward_substitution(L, b, x);
00763 
00764   return x;
00765 }
00766 
00767 void forward_substitution(const mat &L, const vec &b, vec &x)
00768 {
00769   it_assert(L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size(),
00770             "forward_substitution: dimension mismatch");
00771   int n = L.rows(), i, j;
00772   double temp;
00773 
00774   x(0) = b(0) / L(0, 0);
00775   for (i = 1;i < n;i++) {
00776     // Should be: x(i)=((b(i)-L(i,i,0,i-1)*x(0,i-1))/L(i,i))(0); but this is to slow.
00777     //i_pos=i*L._row_offset();
00778     temp = 0;
00779     for (j = 0; j < i; j++) {
00780       temp += L._elem(i, j) * x(j);
00781       //temp+=L._data()[i_pos+j]*x(j);
00782     }
00783     x(i) = (b(i) - temp) / L._elem(i, i);
00784     //x(i)=(b(i)-temp)/L._data()[i_pos+i];
00785   }
00786 }
00787 
00788 vec forward_substitution(const mat &L, int p, const vec &b)
00789 {
00790   int n = L.rows();
00791   vec x(n);
00792 
00793   forward_substitution(L, p, b, x);
00794 
00795   return x;
00796 }
00797 
00798 void forward_substitution(const mat &L, int p, const vec &b, vec &x)
00799 {
00800   it_assert(L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size() && p <= L.rows() / 2,
00801             "forward_substitution: dimension mismatch");
00802   int n = L.rows(), i, j;
00803 
00804   x = b;
00805 
00806   for (j = 0;j < n;j++) {
00807     x(j) /= L(j, j);
00808     for (i = j + 1;i < std::min(j + p + 1, n);i++) {
00809       x(i) -= L(i, j) * x(j);
00810     }
00811   }
00812 }
00813 
00814 vec backward_substitution(const mat &U, const vec &b)
00815 {
00816   vec x(U.rows());
00817   backward_substitution(U, b, x);
00818 
00819   return x;
00820 }
00821 
00822 void backward_substitution(const mat &U, const vec &b, vec &x)
00823 {
00824   it_assert(U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size(),
00825             "backward_substitution: dimension mismatch");
00826   int n = U.rows(), i, j;
00827   double temp;
00828 
00829   x(n - 1) = b(n - 1) / U(n - 1, n - 1);
00830   for (i = n - 2; i >= 0; i--) {
00831     // Should be: x(i)=((b(i)-U(i,i,i+1,n-1)*x(i+1,n-1))/U(i,i))(0); but this is too slow.
00832     temp = 0;
00833     //i_pos=i*U._row_offset();
00834     for (j = i + 1; j < n; j++) {
00835       temp += U._elem(i, j) * x(j);
00836       //temp+=U._data()[i_pos+j]*x(j);
00837     }
00838     x(i) = (b(i) - temp) / U._elem(i, i);
00839     //x(i)=(b(i)-temp)/U._data()[i_pos+i];
00840   }
00841 }
00842 
00843 vec backward_substitution(const mat &U, int q, const vec &b)
00844 {
00845   vec x(U.rows());
00846   backward_substitution(U, q, b, x);
00847 
00848   return x;
00849 }
00850 
00851 void backward_substitution(const mat &U, int q, const vec &b, vec &x)
00852 {
00853   it_assert(U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size() && q <= U.rows() / 2,
00854             "backward_substitution: dimension mismatch");
00855   int n = U.rows(), i, j;
00856 
00857   x = b;
00858 
00859   for (j = n - 1; j >= 0; j--) {
00860     x(j) /= U(j, j);
00861     for (i = std::max(0, j - q); i < j; i++) {
00862       x(i) -= U(i, j) * x(j);
00863     }
00864   }
00865 }
00866 
00867 } // namespace itpp
SourceForge Logo

Generated on Fri Jul 25 12:42:57 2008 for IT++ by Doxygen 1.5.4