ergo
mat_utils.h
Go to the documentation of this file.
1/* Ergo, version 3.8.2, a program for linear scaling electronic structure
2 * calculations.
3 * Copyright (C) 2023 Elias Rudberg, Emanuel H. Rubensson, Pawel Salek,
4 * and Anastasia Kruchinina.
5 *
6 * This program is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program. If not, see <http://www.gnu.org/licenses/>.
18 *
19 * Primary academic reference:
20 * Ergo: An open-source program for linear-scaling electronic structure
21 * calculations,
22 * Elias Rudberg, Emanuel H. Rubensson, Pawel Salek, and Anastasia
23 * Kruchinina,
24 * SoftwareX 7, 107 (2018),
25 * <http://dx.doi.org/10.1016/j.softx.2018.03.005>
26 *
27 * For further information about Ergo, see <http://www.ergoscf.org>.
28 */
29
35
36#ifndef MAT_UTILS_HEADER
37#define MAT_UTILS_HEADER
38#include "Interval.h"
39#include "matrix_proxy.h"
40#include <random>
41namespace mat {
42
43 template<class RandomIt>
44 static void
45 random_shuffle(RandomIt first, RandomIt last) {
46#if 1
47 /* Doing this instead of the deprecated std::random_shuffle() --
48 see
49 https://meetingcpp.com/blog/items/stdrandom_shuffle-is-deprecated.html */
50 std::random_device rng;
51 std::mt19937 urng(rng());
52 std::shuffle(first, last, urng);
53#else
54 /* Old behavior using deprecated std::random_shuffle() */
55 std::random_shuffle(first, last);
56#endif
57 }
58
59 template<typename Tmatrix, typename Treal>
60 struct DiffMatrix {
61 typedef typename Tmatrix::VectorType VectorType;
62 void getCols(SizesAndBlocks & colsCopy) const {
63 A.getCols(colsCopy);
64 }
65 int get_nrows() const {
66 assert( A.get_nrows() == B.get_nrows() );
67 return A.get_nrows();
68 }
69 Treal frob() const {
70 return Tmatrix::frob_diff(A, B);
71 }
72 void quickEuclBounds(Treal & euclLowerBound,
73 Treal & euclUpperBound) const {
74 Treal frobTmp = frob();
75 euclLowerBound = frobTmp / template_blas_sqrt( (Treal)get_nrows() );
76 euclUpperBound = frobTmp;
77 }
78
79 Tmatrix const & A;
80 Tmatrix const & B;
81 DiffMatrix(Tmatrix const & A_, Tmatrix const & B_)
82 : A(A_), B(B_) {}
83 template<typename Tvector>
84 void matVecProd(Tvector & y, Tvector const & x) const {
85 Tvector tmp(y);
86 tmp = (Treal)-1.0 * B * x; // -B * x
87 y = (Treal)1.0 * A * x; // A * x
88 y += (Treal)1.0 * tmp; // A * x - B * x => (A - B) * x
89 }
90 };
91
92
93 // ATAMatrix AT*A
94 template<typename Tmatrix, typename Treal>
95 struct ATAMatrix {
96 typedef typename Tmatrix::VectorType VectorType;
97 Tmatrix const & A;
98 explicit ATAMatrix(Tmatrix const & A_)
99 : A(A_) {}
100 void getCols(SizesAndBlocks & colsCopy) const {
101 A.getRows(colsCopy);
102 }
103 void quickEuclBounds(Treal & euclLowerBound,
104 Treal & euclUpperBound) const {
105 Treal frobA = A.frob();
106 euclLowerBound = 0;
107 euclUpperBound = frobA * frobA;
108 }
109
110 // y = AT*A*x
111 template<typename Tvector>
112 void matVecProd(Tvector & y, Tvector const & x) const {
113 y = x;
114 y = A * y;
115 y = transpose(A) * y;
116 }
117 // Number of rows of A^T * A is the number of columns of A
118 int get_nrows() const { return A.get_ncols(); }
119 };
120
121
122 template<typename Tmatrix, typename Tmatrix2, typename Treal>
124 typedef typename Tmatrix::VectorType VectorType;
125 void getCols(SizesAndBlocks & colsCopy) const {
126 A.getCols(colsCopy);
127 }
128 int get_nrows() const {
129 assert( A.get_nrows() == Z.get_nrows() );
130 return A.get_nrows();
131 }
132 void quickEuclBounds(Treal & euclLowerBound,
133 Treal & euclUpperBound) const {
134 Treal frobA = A.frob();
135 Treal frobZ = Z.frob();
136 euclLowerBound = 0;
137 euclUpperBound = frobA * frobZ * frobZ;
138 }
139
140 Tmatrix const & A;
141 Tmatrix2 const & Z;
142 TripleMatrix(Tmatrix const & A_, Tmatrix2 const & Z_)
143 : A(A_), Z(Z_) {}
144 void matVecProd(VectorType & y, VectorType const & x) const {
145 VectorType tmp(x);
146 tmp = Z * tmp; // Z * x
147 y = (Treal)1.0 * A * tmp; // A * Z * x
148 y = transpose(Z) * y; // Z^T * A * Z * x
149 }
150 };
151
152
153 template<typename Tmatrix, typename Tmatrix2, typename Treal>
155 typedef typename Tmatrix::VectorType VectorType;
156 void getCols(SizesAndBlocks & colsCopy) const {
157 E.getRows(colsCopy);
158 }
159 int get_nrows() const {
160 return E.get_ncols();
161 }
162 void quickEuclBounds(Treal & euclLowerBound,
163 Treal & euclUpperBound) const {
164 Treal frobA = A.frob();
165 Treal frobZ = Zt.frob();
166 Treal frobE = E.frob();
167 euclLowerBound = 0;
168 euclUpperBound = frobA * frobE * frobE + 2 * frobA * frobE * frobZ;
169 }
170
171 Tmatrix const & A;
172 Tmatrix2 const & Zt;
173 Tmatrix2 const & E;
174
175 CongrTransErrorMatrix(Tmatrix const & A_,
176 Tmatrix2 const & Z_,
177 Tmatrix2 const & E_)
178 : A(A_), Zt(Z_), E(E_) {}
179 void matVecProd(VectorType & y, VectorType const & x) const {
180
181 VectorType tmp(x);
182 tmp = E * tmp; // E * x
183 y = (Treal)-1.0 * A * tmp; // -A * E * x
184 y = transpose(E) * y; // -E^T * A * E * x
185
186 VectorType tmp1;
187 tmp = x;
188 tmp = Zt * tmp; // Zt * x
189 tmp1 = (Treal)1.0 * A * tmp; // A * Zt * x
190 tmp1 = transpose(E) * tmp1; // E^T * A * Zt * x
191 y += (Treal)1.0 * tmp1;
192
193 tmp = x;
194 tmp = E * tmp; // E * x
195 tmp1 = (Treal)1.0 * A * tmp; // A * E * x
196 tmp1 = transpose(Zt) * tmp1; // Zt^T * A * E * x
197 y += (Treal)1.0 * tmp1;
198 }
199 };
200
201
202
203} /* end namespace mat */
204#endif
Interval class.
Describes dimensions of matrix and its blocks on all levels.
Definition SizesAndBlocks.h:45
Proxy structs used by the matrix API.
Definition allocate.cc:39
static void random_shuffle(RandomIt first, RandomIt last)
Definition mat_utils.h:45
Xtrans< TX > transpose(TX const &A)
Transposition.
Definition matrix_proxy.h:131
Tmatrix const & A
Definition mat_utils.h:97
int get_nrows() const
Definition mat_utils.h:118
void matVecProd(Tvector &y, Tvector const &x) const
Definition mat_utils.h:112
ATAMatrix(Tmatrix const &A_)
Definition mat_utils.h:98
Tmatrix::VectorType VectorType
Definition mat_utils.h:96
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition mat_utils.h:103
void getCols(SizesAndBlocks &colsCopy) const
Definition mat_utils.h:100
void getCols(SizesAndBlocks &colsCopy) const
Definition mat_utils.h:156
Tmatrix2 const & Zt
Definition mat_utils.h:172
Tmatrix const & A
Definition mat_utils.h:171
Tmatrix2 const & E
Definition mat_utils.h:173
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition mat_utils.h:162
int get_nrows() const
Definition mat_utils.h:159
CongrTransErrorMatrix(Tmatrix const &A_, Tmatrix2 const &Z_, Tmatrix2 const &E_)
Definition mat_utils.h:175
void matVecProd(VectorType &y, VectorType const &x) const
Definition mat_utils.h:179
Tmatrix::VectorType VectorType
Definition mat_utils.h:155
Treal frob() const
Definition mat_utils.h:69
DiffMatrix(Tmatrix const &A_, Tmatrix const &B_)
Definition mat_utils.h:81
Tmatrix const & A
Definition mat_utils.h:79
Tmatrix::VectorType VectorType
Definition mat_utils.h:61
void getCols(SizesAndBlocks &colsCopy) const
Definition mat_utils.h:62
Tmatrix const & B
Definition mat_utils.h:80
int get_nrows() const
Definition mat_utils.h:65
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition mat_utils.h:72
void matVecProd(Tvector &y, Tvector const &x) const
Definition mat_utils.h:84
void getCols(SizesAndBlocks &colsCopy) const
Definition mat_utils.h:125
TripleMatrix(Tmatrix const &A_, Tmatrix2 const &Z_)
Definition mat_utils.h:142
Tmatrix const & A
Definition mat_utils.h:140
void matVecProd(VectorType &y, VectorType const &x) const
Definition mat_utils.h:144
int get_nrows() const
Definition mat_utils.h:128
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition mat_utils.h:132
Tmatrix::VectorType VectorType
Definition mat_utils.h:124
Tmatrix2 const & Z
Definition mat_utils.h:141
Treal template_blas_sqrt(Treal x)