8#include <kernel/solver/linesearch.hpp>
49 template<
typename Functional_,
typename Filter_>
58 typedef typename Functional_::DataType
DataType;
86 explicit MQCLinesearch(Functional_& functional, Filter_& filter,
bool keep_iterates =
false) :
87 BaseClass(
"MQC-LS", functional, filter, keep_iterates),
112 Functional_& functional, Filter_& filter) :
113 BaseClass(
"MQC-LS", section_name, section, functional, filter),
129 return "MQCLinesearch";
203 Statistics::add_solver_expression(std::make_shared<ExpressionStartSolve>(this->
name()));
205 static constexpr DataType extrapolation_width =
DataType(4);
252 DataType alpha_hi(0);
261 bool interval_known(
false);
263 bool min_in_interval(
false);
266 bool drive_to_bndry(
false);
291 vec_sol.axpy(this->
_vec_pn, alpha);
345 interval_known =
true;
351 if(!interval_known && (fval <= fval_lo)
364 alpha, fval_m, df_m, alpha_lo, fval_lo_m, df_lo_m,
365 alpha_hi, fval_hi_m, df_hi_m, min_in_interval, drive_to_bndry);
375 _polynomial_fit(alpha, fval, df, alpha_lo, fval_lo, df_lo, alpha_hi, fval_hi, df_hi, min_in_interval,
385 alpha = alpha_lo +
DataType(0.5)*(alpha_hi - alpha_lo);
411 vec_sol.axpy(this->
_vec_pn, alpha_lo);
420 Statistics::add_solver_expression(std::make_shared<ExpressionEndSolve>(this->
name(), status, this->
get_num_iter()));
474 bool& min_in_interval,
bool& drive_to_bndry)
484 alpha_c =
_argmin_cubic(alpha_lo, fval_lo, df_lo, alpha, fval, df);
487 min_in_interval =
true;
488 drive_to_bndry =
true;
497 alpha_new = (alpha_q + alpha_c)/
DataType(2);
505 alpha_c =
_argmin_cubic(alpha, fval, df, alpha_lo, fval_lo, df_lo);
508 min_in_interval =
true;
509 drive_to_bndry =
false;
513 Math::abs(alpha - alpha_c) >
Math::abs(alpha - alpha_q) ? alpha_new = alpha_c : alpha_new = alpha_q;
523 drive_to_bndry =
true;
531 min_in_interval ? alpha_new = alpha_c : alpha_new = alpha_q;
535 min_in_interval ? alpha_new = alpha_q : alpha_new = alpha_c;
544 drive_to_bndry =
false;
550 alpha_c =
_argmin_cubic(alpha, fval , df, alpha_hi, fval_hi, df_hi);
584 if(min_in_interval && drive_to_bndry)
588 if(alpha_lo < alpha_hi)
589 alpha_new =
Math::min(alpha_lo +
DataType(0.66)*(alpha_hi - alpha_lo), alpha_new);
591 alpha_new =
Math::max(alpha_lo +
DataType(0.66)*(alpha_hi - alpha_lo), alpha_new);
657 if(interpolate_derivative)
658 alpha += df_lo /(df_lo - df_hi) * (alpha_hi - alpha_lo);
661 alpha += df_lo/( (fval_hi - fval_lo)/( alpha_hi - alpha_lo) - df_lo)/
DataType(2)*(alpha_lo - alpha_hi);
702 DataType d1 =
DataType(3)*(fval_lo - fval_hi)/(alpha_hi - alpha_lo) + df_lo + df_hi;
725 df_hi*df_lo >
DataType(0) ? q = d2 +(df_hi-df_lo) + d2 : q = d2 - df_lo + d2 + df_hi;
730 alpha += (alpha_hi - alpha_lo)*(p/q);
763 if(alpha_lo == alpha_hi)
768 DataType d1 =
DataType(3)*(fval_hi - fval_lo)/(alpha_lo - alpha_hi) + df_hi + df_lo;
777 DataType q(d2 +(df_hi - df_lo) + d2);
788 alpha_c += p/q*(alpha_hi - alpha_lo);
815 template<
typename Functional_,
typename Filter_>
817 Functional_& functional, Filter_& filter,
bool keep_iterates =
false)
819 return std::make_shared<MQCLinesearch<Functional_, Filter_>>(functional, filter, keep_iterates);
840 template<
typename Functional_,
typename Filter_>
842 const String& section_name,
const PropertyMap* section, Functional_& functional, Filter_& filter)
844 return std::make_shared<MQCLinesearch<Functional_, Filter_>> (section_name, section, functional, filter);
A class organizing a tree of key-value pairs.
Helper class for iteration statistics collection.
Index get_num_iter() const
Returns number of performed iterations.
Status _status
current status of the solver
Index _num_iter
number of performed iterations
virtual void plot_summary() const
Plot a summary of the last solver run.
VectorType _vec_grad
Gradient vector.
Functional_ & _functional
The (nonlinear) functional.
VectorType _vec_initial_sol
Initial solution.
virtual void trim_func_grad(DataType &func)
Trims the function value and gradient according to some threshold.
DataType _fval_0
Initial functional value.
DataType _tol_curvature
Tolerance for sufficient decrease in the norm of the gradient (Wolfe conditions)
virtual Status _startup(DataType &alpha, DataType &fval, DataType &delta, const VectorType &vec_sol, const VectorType &vec_dir)
Performs the startup of the iteration.
VectorType _vec_pn
descend direction vector, normalized for better numerical stability
virtual void reset()
Resets various member variables in case the solver is reused.
DataType _delta_0
Initial <vec_dir, vec_grad>
DataType _fval_min
Functional functional value.
Filter_ & _filter
The filter to be applied to the functional's gradient.
DataType _tol_decrease
Tolerance for sufficient decrease in the functional value (Wolfe conditions)
DataType _alpha_0
Initial line search parameter.
DataType _alpha_min
Line search parameter.
DataType _tol_step
Tolerance for the update step.
VectorType _vec_tmp
temporary vector
virtual Status _check_convergence(const DataType fval, const DataType df, const DataType alpha)
Performs the line search convergence checks using the strong Wolfe conditions.
Mixed quadratic-cubic line search.
DataType _alpha_hard_max
Hard maximum for the step length.
virtual String name() const override
Returns a descriptive string.
MQCLinesearch(Functional_ &functional, Filter_ &filter, bool keep_iterates=false)
Standard constructor.
DataType _argmin_cubic_case_3(const DataType alpha_lo, const DataType fval_lo, const DataType df_lo, const DataType alpha_hi, const DataType fval_hi, DataType df_hi) const
Computes the minimum of a cubic interpolation polynomial.
DataType _alpha_soft_max
Lower bound of the interval of uncertainty.
Functional_::DataType DataType
Underlying floating point type.
MQCLinesearch(const String §ion_name, const PropertyMap *section, Functional_ &functional, Filter_ &filter)
Constructor using a PropertyMap.
DataType _argmin_quadratic(const DataType alpha_lo, const DataType fval_lo, const DataType df_lo, const DataType alpha_hi, const DataType fval_hi, const DataType df_hi, const bool interpolate_derivative) const
Computes the minimum of a quadratic interpolation polynomial.
DataType _alpha_soft_min
Upper bound of the interval of uncertainty.
virtual Status correct(VectorType &vec_sol, const VectorType &vec_dir) override
Applies the solver, making use of an initial guess.
Filter_ FilterType
Filter type to be applied to the gradient of the functional.
DataType _argmin_cubic(DataType alpha_lo, DataType fval_lo, DataType df_lo, DataType alpha_hi, DataType fval_hi, DataType df_hi) const
Computes the minimum of a cubic interpolation polynomial.
DataType _alpha_hard_min
Hard minimum for the step length.
Linesearch< Functional_, Filter_ > BaseClass
Our base class.
virtual Status apply(VectorType &vec_cor, const VectorType &vec_dir) override
Applies the solver, setting the initial guess to zero.
Status _polynomial_fit(DataType &alpha, DataType &fval, DataType &df, DataType &alpha_lo, DataType &fval_lo, DataType &df_lo, DataType &alpha_hi, DataType &fval_hi, DataType &df_hi, bool &min_in_interval, bool &drive_to_bndry)
The great magick trick to find a minimum of a 1d function.
Functional_::VectorTypeR VectorType
Input vector type for the functional's gradient.
void _clamp_step(DataType &alpha_new) const
Enforces hard and soft step limits, adjusting the soft limits if necessary.
virtual Status _apply_intern(VectorType &vec_sol, const VectorType &vec_dir)
Internal function: Applies the solver.
virtual void reset() override
Resets various member variables in case the solver is reused.
String class implementation.
T_ sqrt(T_ x)
Returns the square-root of a value.
T_ abs(T_ x)
Returns the absolute value.
T_ pow(T_ x, T_ y)
Returns x raised to the power of y.
T_ sqr(T_ x)
Returns the square of a value.
T_ min(T_ a, T_ b)
Returns the minimum of two values.
T_ signum(T_ x)
Returns the sign of a value.
T_ max(T_ a, T_ b)
Returns the maximum of two values.
Status
Solver status return codes enumeration.
@ success
solving successful (convergence criterion fulfilled)
@ progress
continue iteration (internal use only)
@ stagnated
solver stagnated (stagnation criterion fulfilled)
std::shared_ptr< MQCLinesearch< Functional_, Filter_ > > new_mqc_linesearch(Functional_ &functional, Filter_ &filter, bool keep_iterates=false)
Creates a new MQCLinesearch object.