summaryrefslogtreecommitdiff
path: root/passes
diff options
context:
space:
mode:
authorClifford Wolf <clifford@clifford.at>2015-09-24 22:16:37 +0200
committerClifford Wolf <clifford@clifford.at>2015-09-24 22:16:37 +0200
commitec92c8965960fa814c3663e987bc2a7eb80965e5 (patch)
tree5c884c7730593fa80d17cafa8276ff0b851a8ff0 /passes
parent69071bbc5f67057fc5902a43b13c39acf63c8bd6 (diff)
Added pivoting to qwp solver
Diffstat (limited to 'passes')
-rw-r--r--passes/cmds/qwp.cc57
1 files changed, 43 insertions, 14 deletions
diff --git a/passes/cmds/qwp.cc b/passes/cmds/qwp.cc
index eb4c10a7..f76de326 100644
--- a/passes/cmds/qwp.cc
+++ b/passes/cmds/qwp.cc
@@ -255,33 +255,62 @@ struct QwpWorker
// (least squares fit for "A*x = y")
//
// Using gaussian elimination to get M := [Id x]
- // (no pivoting, so let's hope for the best..)
- // eliminate to upper triangular matrix
+ vector<int> pivot_cache;
+ vector<int> queue;
+
+ for (int i = 0; i < N; i++)
+ queue.push_back(i);
+
+ // gaussian elimination
for (int i = 0; i < N; i++)
{
+ // find best row
+ int best_row = queue.front();
+ int best_row_queue_idx = 0;
+ double best_row_absval = 0;
+
+ for (int k = 0; k < GetSize(queue); k++) {
+ int row = queue[k];
+ double absval = fabs(M[i + row*N1]);
+ if (absval > best_row_absval) {
+ best_row = row;
+ best_row_queue_idx = k;
+ best_row_absval = absval;
+ }
+ }
+
+ int row = best_row;
+ pivot_cache.push_back(row);
+
+ queue[best_row_queue_idx] = queue.back();
+ queue.pop_back();
+
// normalize row
- for (int j = i+1; j < N+1; j++)
- M[(N+1)*i + j] /= M[(N+1)*i + i];
- M[(N+1)*i + i] = 1.0;
+ for (int k = i+1; k < N1; k++)
+ M[k + row*N1] /= M[i + row*N1];
+ M[i + row*N1] = 1.0;
// elimination
- for (int j = i+1; j < N; j++) {
- double d = M[(N+1)*j + i];
- for (int k = 0; k < N+1; k++)
- if (k > i)
- M[(N+1)*j + k] -= d*M[(N+1)*i + k];
- else
- M[(N+1)*j + k] = 0.0;
+ for (int other_row : queue) {
+ double d = M[i + other_row*N1];
+ for (int k = i+1; k < N1; k++)
+ M[k + other_row*N1] -= d*M[k + row*N1];
+ M[i + other_row*N1] = 0.0;
}
}
+ log_assert(queue.empty());
+ log_assert(GetSize(pivot_cache) == N);
+
// back substitution
for (int i = N-1; i >= 0; i--)
for (int j = i+1; j < N; j++)
{
- M[(N+1)*i + N] -= M[(N+1)*i + j] * M[(N+1)*j + N];
- M[(N+1)*i + j] = 0.0;
+ int row = pivot_cache[i];
+ int other_row = pivot_cache[j];
+ M[N + row*N1] -= M[j + row*N1] * M[N + other_row*N1];
+ M[j + row*N1] = 0.0;
}
#ifdef LOG_MATRICES