FEAT 3
Finite Element Analysis Toolbox
Loading...
Searching...
No Matches
linesearch.hpp
1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
5
6#pragma once
8#include <kernel/solver/iterative.hpp>
9
10#include <deque>
11
12namespace FEAT
13{
14 namespace Solver
15 {
32 template<typename Functional_, typename Filter_>
33 class Linesearch : public IterativeSolver<typename Functional_::VectorTypeR>
34 {
35 public:
37 typedef Filter_ FilterType;
39 typedef typename Functional_::VectorTypeR VectorType;
41 typedef typename Functional_::DataType DataType;
44
45 protected:
47 // Note that this cannot be const, as the functional saves its state and thus changes
48 Functional_& _functional;
50 Filter_& _filter;
51
60
67
78
88
89 public:
91 std::deque<VectorType>* iterates;
92
93 public:
110 explicit Linesearch(const String& plot_name, Functional_& functional, Filter_& filter,
111 bool keep_iterates = false) :
112 BaseClass(plot_name),
113 _functional(functional),
114 _filter(filter),
115 _fval_min(Math::huge<DataType>()),
116 _fval_0(Math::huge<DataType>()),
117 _trim_threshold(Math::huge<DataType>()),
118 _alpha_0(1),
119 _alpha_min(0),
120 _delta_0(Math::huge<DataType>()),
121 _norm_dir(0),
122 _norm_sol(0),
124 _tol_decrease(DataType(1e-3)),
125 _tol_step(DataType(Math::pow(Math::eps<DataType>(), DataType(0.85)))),
126 _dir_scaling(false),
127 iterates(nullptr)
128 {
129 // set communicator by functional (same interface as matrix)
130 this->_set_comm_by_matrix(functional);
131
132 if(keep_iterates)
133 {
134 iterates = new std::deque<VectorType>;
135 }
136 }
137
157 explicit Linesearch(const String& plot_name, const String& section_name, const PropertyMap* section,
158 Functional_& functional, Filter_& filter) :
159 BaseClass(plot_name, section_name, section),
160 _functional(functional),
161 _filter(filter),
162 _fval_min(Math::huge<DataType>()),
163 _fval_0(Math::huge<DataType>()),
164 _trim_threshold(Math::huge<DataType>()),
165 _alpha_0(1),
167 _delta_0(Math::huge<DataType>()),
170 _tol_curvature(0.3),
171 _tol_decrease(1e-3),
172 _tol_step(DataType(Math::pow(Math::eps<DataType>(), DataType(0.85)))),
173 _dir_scaling(false),
174 iterates(nullptr)
175 {
176 // Check if we have to keep the iterates
177 auto keep_iterates_p = section->query("keep_iterates");
178 if(keep_iterates_p.second)
179 {
180 iterates = new std::deque<VectorType>;
181 }
182
183 auto tol_curvature_p = section->query("tol_curvature");
184 if(tol_curvature_p.second)
185 {
186 set_tol_curvature(DataType(std::stod(tol_curvature_p.first)));
187 }
188
189 auto tol_decrease_p = section->query("tol_decrease");
190 if(tol_decrease_p.second)
191 {
192 set_tol_decrease(DataType(std::stod(tol_decrease_p.first)));
193 }
194 // Check if we have to keep the iterates
195 auto tol_step_p = section->query("tol_step");
196 if(tol_step_p.second)
197 {
198 set_tol_step(DataType(std::stod(tol_step_p.first)));
199 }
200 }
201
205 virtual ~Linesearch()
206 {
207 if(iterates != nullptr)
208 {
209 delete iterates;
210 }
211 }
212
213 virtual String get_summary() const override
214 {
215 String msg(this->get_plot_name()+ ": its: "+stringify(this->get_num_iter())
216 +" ("+ stringify(this->get_status())+")"
217 +", evals: "+stringify(_functional.get_num_func_evals())+" (func) "
218 + stringify(_functional.get_num_grad_evals()) + " (grad) "
219 + stringify(_functional.get_num_hess_evals()) + " (hess)"
220 +" last step: "+stringify_fp_sci(_alpha_min)+"\n");
221 msg +=this->get_plot_name()+": fval: "+stringify_fp_sci(_fval_0)
223 + ", factor "+stringify_fp_sci(_fval_min/_fval_0)+"\n";
224 msg += this->get_plot_name() +": <dir, grad>: "+stringify_fp_sci(this->_def_init)
225 + " -> "+stringify_fp_sci(this->_def_cur)
226 + ", factor " +stringify_fp_sci(this->_def_cur/this->_def_init);
227 return msg;
228 }
229
233 void set_tol_curvature(DataType tol_curvature)
234 {
235 XASSERT(tol_curvature > DataType(0));
236
237 _tol_curvature = tol_curvature;
238 }
239
243 void set_tol_decrease(DataType tol_decrease)
244 {
245 XASSERT(tol_decrease > DataType(0));
246
247 _tol_decrease = tol_decrease;
248 }
249
254 void set_tol_step(DataType tol_step)
255 {
256 XASSERT(tol_step > DataType(0));
257
258 _tol_step = tol_step;
259 }
260
269 void set_dir_scaling(const bool b)
270 {
271 _dir_scaling = b;
272 }
273
275 virtual void init_symbolic() override
276 {
277 // Create temporary vectors
278 _vec_initial_sol = this->_functional.create_vector_r();
279 _vec_tmp = this->_functional.create_vector_r();
280 _vec_pn = this->_functional.create_vector_r();
281 _vec_grad = this->_functional.create_vector_r();
283 }
284
286 virtual void done_symbolic() override
287 {
288 // Clear temporary vectors
289 _vec_initial_sol.clear();
290 _vec_pn.clear();
291 _vec_tmp.clear();
292 _vec_grad.clear();
294 }
295
299 virtual void reset()
300 {
301 _fval_min = Math::huge<DataType>();
302 _fval_0 = Math::huge<DataType>();
303 _trim_threshold = Math::huge<DataType>();
304 _alpha_0 = DataType(1),
305 _alpha_min = DataType(0);
306 _delta_0 = Math::huge<DataType>();
307 _norm_dir = DataType(0);
308 _norm_sol = DataType(0);
309
310 if(iterates != nullptr)
311 {
312 iterates->clear();
313 }
314 }
315
334 virtual DataType get_rel_update() const
335 {
336 return Math::abs(this->_alpha_min);
337 }
338
345 {
346 return _fval_min;
347 }
348
360 {
361 _fval_0 = f0;
362 if(_trim_threshold == Math::huge<DataType>())
363 {
365 }
366 }
367
378 void get_defect_from_grad(VectorType& vec_def) const
379 {
380 vec_def.scale(this->_vec_grad, DataType(-1));
381 }
382
392 void set_grad_from_defect(const VectorType& vec_def)
393 {
394 this->_vec_grad.scale(vec_def, DataType(-1));
395 }
396
405 {
406 return _vec_initial_sol;
407 }
408
430 virtual void trim_func_grad(DataType& func)
431 {
432 if(func > _trim_threshold)
433 {
434 func = _trim_threshold;
435 this->_vec_grad.format(DataType(0));
436 }
437 }
438
439 protected:
440
464 virtual Status _startup(DataType& alpha, DataType& fval, DataType& delta, const VectorType& vec_sol, const VectorType& vec_dir)
465 {
466 Status status(Status::progress);
467
468 this->_num_iter = Index(0);
469 this->_vec_initial_sol.copy(vec_sol);
470
471 this->_vec_pn.copy(vec_dir);
472 this->_norm_dir = this->_vec_pn.norm2();
473
474 // First scale so that all entries are |.| < 1
475 this->_vec_pn.scale(this->_vec_pn, DataType(1)/this->_vec_pn.max_abs_element());
476 // Now scale so that vec_pn.norm2() == 1. Note that this will only be approximately true
477 this->_vec_pn.scale(this->_vec_pn, DataType(1)/this->_vec_pn.norm2());
478
479 // It is critical that fval_0 was set from outside using set_initial_fval!
480 this->_fval_min = this->_fval_0;
481 this->_delta_0 = this->_vec_pn.dot(this->_vec_grad);
482
483 if(this->_delta_0 > DataType(0))
484 {
485 XABORTM("Search direction is not a descent direction: " +stringify_fp_sci(this->_delta_0));
486 }
487
488 // Compute initial defect. We want to minimize d^T * grad(_functional)
489 this->_def_init = Math::abs(this->_delta_0);
490
491 // Norm of the initial guess
492 this->_norm_sol = vec_sol.norm2();
493
494 this->_alpha_min = DataType(0);
495
496 // plot?
497 if(this->_plot_iter())
498 {
499 String msg = this->_plot_name
500 + ": " + stringify(this->_num_iter).pad_front(this->_iter_digits)
501 + " : " + stringify_fp_sci(this->_def_init)
502 + " : " + stringify_fp_sci(this->_fval_0)
503 + ", ||dir|| = " + stringify_fp_sci(this->_norm_dir);
504 this->_print_line(msg);
505 }
506
507 if(!Math::isfinite(this->_fval_0) || !Math::isfinite(this->_delta_0) || !Math::isfinite(this->_norm_dir))
508 {
509 status = Status::aborted;
510 }
511
512 // Set initial step size
513 alpha = _alpha_0;
514 // Other intitial values
515 fval = _fval_0;
516 delta = _delta_0;
517
518 return status;
519 }
520
543 virtual Status _check_convergence(const DataType fval, const DataType df, const DataType alpha)
544 {
545 Status status(Status::progress);
546
547 this->_def_cur = Math::abs(df);
548
549 Statistics::add_solver_expression(
550 std::make_shared<ExpressionDefect>(this->name(), this->_def_cur, this->get_num_iter()));
551
552 // plot?
553 if(this->_plot_iter())
554 {
555 String msg = this->_plot_name
556 + ": " + stringify(this->_num_iter).pad_front(this->_iter_digits)
557 + " : " + stringify_fp_sci(this->_def_cur)
558 + " / " + stringify_fp_sci(this->_def_cur / this->_def_init)
559 + " : " + stringify_fp_sci(fval) + " : " + stringify_fp_sci(alpha);
560 this->_print_line(msg);
561 }
562
563 // ensure that the defect is neither NaN nor infinity
564 if(!Math::isfinite(this->_def_cur))
565 {
566 status = Status::aborted;
567 }
568
569 // is diverged?
570 if(this->is_diverged())
571 {
572 status = Status::diverged;
573 }
574
575 // If the maximum number of iterations was performed, return the iterate for the best step so far
576 if(this->_num_iter >= this->_max_iter)
577 {
578 status = Status::max_iter;
579 }
580
581 // If the strong Wolfe conditions hold, we are successful
582 if((fval < this->_fval_0 +_tol_decrease*alpha*_delta_0) && (Math::abs(df) < -_tol_curvature*_delta_0))
583 {
584 status = Status::success;
585 }
586
587 return status;
588 }
589
590 }; // class Linesearch
591 } // namespace Solver
592} // namespace FEAT
#define XABORTM(msg)
Abortion macro definition with custom message.
Definition: assertion.hpp:192
#define XASSERT(expr)
Assertion macro definition.
Definition: assertion.hpp:262
FEAT Kernel base header.
A class organizing a tree of key-value pairs.
std::pair< String, bool > query(String key_path) const
Queries a value by its key path.
Abstract base-class for iterative solvers.
Definition: iterative.hpp:198
String get_plot_name() const
Returns the plot name of the solver.
Definition: iterative.hpp:523
Index _iter_digits
iteration count digits for plotting
Definition: iterative.hpp:243
Index get_num_iter() const
Returns number of performed iterations.
Definition: iterative.hpp:462
void _set_comm_by_matrix(const Matrix_ &matrix)
Sets the communicator for the solver from a matrix.
Definition: iterative.hpp:680
void _print_line(const String &line) const
Prints a line.
Definition: iterative.hpp:752
Index _num_iter
number of performed iterations
Definition: iterative.hpp:231
bool _plot_iter(Status st=Status::progress) const
Plot the current iteration?
Definition: iterative.hpp:720
Linesearch base class.
Definition: linesearch.hpp:34
virtual ~Linesearch()
Virtual destructor.
Definition: linesearch.hpp:205
VectorType _vec_grad
Gradient vector.
Definition: linesearch.hpp:53
void set_tol_curvature(DataType tol_curvature)
Sets the tolerance for the sufficient decrease in curvature.
Definition: linesearch.hpp:233
virtual void init_symbolic() override
Symbolic initialization method.
Definition: linesearch.hpp:275
DataType _norm_dir
The 2-norm of the search direction.
Definition: linesearch.hpp:75
Functional_ & _functional
The (nonlinear) functional.
Definition: linesearch.hpp:48
IterativeSolver< typename Functional_::VectorTypeR > BaseClass
Our base class.
Definition: linesearch.hpp:43
VectorType _vec_initial_sol
Initial solution.
Definition: linesearch.hpp:55
void set_grad_from_defect(const VectorType &vec_def)
Gets a defect vector from the final gradient.
Definition: linesearch.hpp:392
void get_defect_from_grad(VectorType &vec_def) const
Sets the initial gradient from a defect vector.
Definition: linesearch.hpp:378
virtual void trim_func_grad(DataType &func)
Trims the function value and gradient according to some threshold.
Definition: linesearch.hpp:430
Functional_::VectorTypeR VectorType
Input vector type for the functional's gradient.
Definition: linesearch.hpp:39
DataType _trim_threshold
Threshold for trimming function value and gradient.
Definition: linesearch.hpp:66
DataType _fval_0
Initial functional value.
Definition: linesearch.hpp:64
void set_initial_fval(DataType f0)
Sets the intitial functional value.
Definition: linesearch.hpp:359
void set_tol_decrease(DataType tol_decrease)
Sets the tolerance for the sufficient decrease in functional value.
Definition: linesearch.hpp:243
const VectorType & get_initial_sol() const
Gets the initial solution the linesearch started with.
Definition: linesearch.hpp:404
DataType _tol_curvature
Tolerance for sufficient decrease in the norm of the gradient (Wolfe conditions)
Definition: linesearch.hpp:80
void set_tol_step(DataType tol_step)
Sets the step length tolerance.
Definition: linesearch.hpp:254
DataType _norm_sol
The 2-norm of the iterate.
Definition: linesearch.hpp:77
virtual void done_symbolic() override
Symbolic finalization method.
Definition: linesearch.hpp:286
virtual Status _startup(DataType &alpha, DataType &fval, DataType &delta, const VectorType &vec_sol, const VectorType &vec_dir)
Performs the startup of the iteration.
Definition: linesearch.hpp:464
VectorType _vec_pn
descend direction vector, normalized for better numerical stability
Definition: linesearch.hpp:59
Linesearch(const String &plot_name, Functional_ &functional, Filter_ &filter, bool keep_iterates=false)
Standard constructor.
Definition: linesearch.hpp:110
virtual DataType get_rel_update() const
Get the relative update of the solver application.
Definition: linesearch.hpp:334
std::deque< VectorType > * iterates
For debugging purposes, it is possible to save all iterates to this.
Definition: linesearch.hpp:91
virtual void reset()
Resets various member variables in case the solver is reused.
Definition: linesearch.hpp:299
DataType _delta_0
Initial <vec_dir, vec_grad>
Definition: linesearch.hpp:73
DataType get_final_fval() const
Gets the functional value of the last iteration.
Definition: linesearch.hpp:344
DataType _fval_min
Functional functional value.
Definition: linesearch.hpp:62
Filter_ FilterType
Filter type to be applied to the gradient of the functional.
Definition: linesearch.hpp:37
virtual String get_summary() const override
Returns a summary string.
Definition: linesearch.hpp:213
Filter_ & _filter
The filter to be applied to the functional's gradient.
Definition: linesearch.hpp:50
DataType _tol_decrease
Tolerance for sufficient decrease in the functional value (Wolfe conditions)
Definition: linesearch.hpp:82
DataType _alpha_0
Initial line search parameter.
Definition: linesearch.hpp:69
DataType _alpha_min
Line search parameter.
Definition: linesearch.hpp:71
void set_dir_scaling(const bool b)
Determines if search direction scaling is to be used.
Definition: linesearch.hpp:269
DataType _tol_step
Tolerance for the update step.
Definition: linesearch.hpp:84
Functional_::DataType DataType
Underlying floating point type.
Definition: linesearch.hpp:41
VectorType _vec_tmp
temporary vector
Definition: linesearch.hpp:57
virtual Status _check_convergence(const DataType fval, const DataType df, const DataType alpha)
Performs the line search convergence checks using the strong Wolfe conditions.
Definition: linesearch.hpp:543
Linesearch(const String &plot_name, const String &section_name, const PropertyMap *section, Functional_ &functional, Filter_ &filter)
Constructor using a PropertyMap.
Definition: linesearch.hpp:157
virtual void init_symbolic()
Symbolic initialization method.
Definition: base.hpp:227
virtual void done_symbolic()
Symbolic finalization method.
Definition: base.hpp:255
virtual String name() const =0
Returns a descriptive string.
String class implementation.
Definition: string.hpp:46
String pad_front(size_type len, char c=' ') const
Pads the front of the string up to a desired length.
Definition: string.hpp:392
T_ abs(T_ x)
Returns the absolute value.
Definition: math.hpp:275
bool isfinite(T_ x)
Checks whether a value is finite.
Status
Solver status return codes enumeration.
Definition: base.hpp:47
@ success
solving successful (convergence criterion fulfilled)
@ progress
continue iteration (internal use only)
@ max_iter
solver reached maximum iterations
@ diverged
solver diverged (divergence criterion fulfilled)
@ aborted
premature abort (solver aborted due to internal errors or preconditioner failure)
FEAT namespace.
Definition: adjactor.hpp:12
String stringify(const T_ &item)
Converts an item into a String.
Definition: string.hpp:944
String stringify_fp_sci(DataType_ value, int precision=0, int width=0, bool sign=false)
Prints a floating point value to a string in scientific notation.
Definition: string.hpp:1088
std::uint64_t Index
Index data type.