WarpX
PicardSolver.H
Go to the documentation of this file.
1 /* Copyright 2024 Justin Angus
2  *
3  * This file is part of WarpX.
4  *
5  * License: BSD-3-Clause-LBNL
6  */
7 #ifndef PICARD_SOLVER_H_
8 #define PICARD_SOLVER_H_
9 
10 #include "NonlinearSolver.H"
11 
12 #include <AMReX_ParmParse.H>
13 #include "Utils/TextMsg.H"
14 
15 #include <vector>
16 
24 template<class Vec, class Ops>
25 class PicardSolver : public NonlinearSolver<Vec,Ops>
26 {
27 public:
28 
29  PicardSolver<Vec,Ops>() = default;
30 
31  ~PicardSolver<Vec,Ops>() override = default;
32 
33  // Prohibit Move and Copy operations
34  PicardSolver(const PicardSolver&) = delete;
35  PicardSolver& operator=(const PicardSolver&) = delete;
36  PicardSolver(PicardSolver&&) noexcept = delete;
37  PicardSolver& operator=(PicardSolver&&) noexcept = delete;
38 
39  void Define ( const Vec& a_U,
40  Ops* a_ops ) override;
41 
42  void Solve ( Vec& a_U,
43  const Vec& a_b,
44  amrex::Real a_time,
45  amrex::Real a_dt ) const override;
46 
47  void GetSolverParams ( amrex::Real& a_rtol,
48  amrex::Real& a_atol,
49  int& a_maxits ) override
50  {
51  a_rtol = m_rtol;
52  a_atol = m_atol;
53  a_maxits = m_maxits;
54  }
55 
56  void PrintParams () const override
57  {
58  amrex::Print() << "Picard max iterations: " << m_maxits << std::endl;
59  amrex::Print() << "Picard relative tolerance: " << m_rtol << std::endl;
60  amrex::Print() << "Picard absolute tolerance: " << m_atol << std::endl;
61  amrex::Print() << "Picard require convergence: " << (m_require_convergence?"true":"false") << std::endl;
62  }
63 
64 private:
65 
69  mutable Vec m_Usave, m_R;
70 
74  Ops* m_ops = nullptr;
75 
79  bool m_require_convergence = true;
80 
84  amrex::Real m_rtol = 1.0e-6;
85 
89  amrex::Real m_atol = 0.;
90 
94  int m_maxits = 100;
95 
96  void ParseParameters( );
97 
98 };
99 
100 template <class Vec, class Ops>
101 void PicardSolver<Vec,Ops>::Define ( const Vec& a_U,
102  Ops* a_ops )
103 {
105  !this->m_is_defined,
106  "Picard nonlinear solver object is already defined!");
107 
108  ParseParameters();
109 
110  m_Usave.Define(a_U);
111  m_R.Define(a_U);
112 
113  m_ops = a_ops;
114 
115  this->m_is_defined = true;
116 
117 }
118 
119 template <class Vec, class Ops>
121 {
122  const amrex::ParmParse pp_picard("picard");
123  pp_picard.query("verbose", this->m_verbose);
124  pp_picard.query("absolute_tolerance", m_atol);
125  pp_picard.query("relative_tolerance", m_rtol);
126  pp_picard.query("max_iterations", m_maxits);
127  pp_picard.query("require_convergence", m_require_convergence);
128 
129 }
130 
131 template <class Vec, class Ops>
133  const Vec& a_b,
134  amrex::Real a_time,
135  amrex::Real a_dt ) const
136 {
137  BL_PROFILE("PicardSolver::Solve()");
139  this->m_is_defined,
140  "PicardSolver::Solve() called on undefined object");
141  using namespace amrex::literals;
142 
143  //
144  // Picard fixed-point iteration method to solve nonlinear
145  // equation of form: U = b + R(U)
146  //
147 
148  amrex::Real norm_abs = 0.;
149  amrex::Real norm0 = 1._rt;
150  amrex::Real norm_rel = 0.;
151 
152  int iter;
153  for (iter = 0; iter < m_maxits;) {
154 
155  // Save previous state for norm calculation
156  m_Usave.Copy(a_U);
157 
158  // Update the solver state (a_U = a_b + m_R)
159  m_ops->ComputeRHS( m_R, a_U, a_time, a_dt, iter, false );
160  a_U.Copy(a_b);
161  a_U += m_R;
162 
163  // Compute the step norm and update iter
164  m_Usave -= a_U;
165  norm_abs = m_Usave.norm2();
166  if (iter == 0) {
167  if (norm_abs > 0.) { norm0 = norm_abs; }
168  else { norm0 = 1._rt; }
169  }
170  norm_rel = norm_abs/norm0;
171  iter++;
172 
173  // Check for convergence criteria
174  if (this->m_verbose || iter == m_maxits) {
175  amrex::Print() << "Picard: iter = " << std::setw(3) << iter << ", norm = "
176  << std::scientific << std::setprecision(5) << norm_abs << " (abs.), "
177  << std::scientific << std::setprecision(5) << norm_rel << " (rel.)" << "\n";
178  }
179 
180  if (norm_abs < m_atol) {
181  amrex::Print() << "Picard: exiting at iter = " << std::setw(3) << iter
182  << ". Satisfied absolute tolerance " << m_atol << std::endl;
183  break;
184  }
185 
186  if (norm_rel < m_rtol) {
187  amrex::Print() << "Picard: exiting at iter = " << std::setw(3) << iter
188  << ". Satisfied relative tolerance " << m_rtol << std::endl;
189  break;
190  }
191 
192  if (iter >= m_maxits) {
193  amrex::Print() << "Picard: exiting at iter = " << std::setw(3) << iter
194  << ". Maximum iteration reached: iter = " << m_maxits << std::endl;
195  break;
196  }
197 
198  }
199 
200  if (m_rtol > 0. && iter == m_maxits) {
201  std::stringstream convergenceMsg;
202  convergenceMsg << "Picard solver failed to converge after " << iter <<
203  " iterations. Relative norm is " << norm_rel <<
204  " and the relative tolerance is " << m_rtol <<
205  ". Absolute norm is " << norm_abs <<
206  " and the absolute tolerance is " << m_atol;
207  if (this->m_verbose) { amrex::Print() << convergenceMsg.str() << std::endl; }
208  if (m_require_convergence) {
209  WARPX_ABORT_WITH_MESSAGE(convergenceMsg.str());
210  } else {
211  ablastr::warn_manager::WMRecordWarning("PicardSolver", convergenceMsg.str());
212  }
213  }
214 
215 }
216 
217 #endif
#define BL_PROFILE(a)
#define WARPX_ABORT_WITH_MESSAGE(MSG)
Definition: TextMsg.H:15
#define WARPX_ALWAYS_ASSERT_WITH_MESSAGE(EX, MSG)
Definition: TextMsg.H:13
Top-level class for the nonlinear solver.
Definition: NonlinearSolver.H:28
Picard fixed-point iteration method to solve nonlinear equation of form: U = b + R(U)....
Definition: PicardSolver.H:26
void ParseParameters()
Definition: PicardSolver.H:120
PicardSolver(const PicardSolver &)=delete
bool m_require_convergence
Flag to determine whether convergence is required.
Definition: PicardSolver.H:79
Vec m_R
Definition: PicardSolver.H:69
void GetSolverParams(amrex::Real &a_rtol, amrex::Real &a_atol, int &a_maxits) override
Return the convergence parameters used by the nonlinear solver.
Definition: PicardSolver.H:47
amrex::Real m_rtol
Relative tolerance for the Picard nonlinear solver.
Definition: PicardSolver.H:84
Ops * m_ops
Pointer to Ops class.
Definition: PicardSolver.H:74
int m_maxits
Maximum iterations for the Picard nonlinear solver.
Definition: PicardSolver.H:94
amrex::Real m_atol
Absolute tolerance for the Picard nonlinear solver.
Definition: PicardSolver.H:89
PicardSolver & operator=(const PicardSolver &)=delete
Vec m_Usave
Intermediate Vec containers used by the solver.
Definition: PicardSolver.H:69
PicardSolver(PicardSolver &&) noexcept=delete
void Define(const Vec &a_U, Ops *a_ops) override
Read user-provided parameters that control the nonlinear solver. Allocate intermediate data container...
Definition: PicardSolver.H:101
void Solve(Vec &a_U, const Vec &a_b, amrex::Real a_time, amrex::Real a_dt) const override
Solve the specified nonlinear equation for U. Picard: U = b + R(U). Newton: F(U) = U - b - R(U) = 0.
Definition: PicardSolver.H:132
void PrintParams() const override
Print parameters used by the nonlinear solver.
Definition: PicardSolver.H:56
int query(const char *name, bool &ref, int ival=FIRST) const
void WMRecordWarning(const std::string &topic, const std::string &text, const WarnPriority &priority=WarnPriority::medium)
Helper function to abbreviate the call to WarnManager::GetInstance().RecordWarning (recording a warni...
Definition: WarnManager.cpp:318