WarpX
NewtonSolver.H
Go to the documentation of this file.
1 /* Copyright 2024 Justin Angus
2  *
3  * This file is part of WarpX.
4  *
5  * License: BSD-3-Clause-LBNL
6  */
7 #ifndef NEWTON_SOLVER_H_
8 #define NEWTON_SOLVER_H_
9 
10 #include "NonlinearSolver.H"
11 #include "JacobianFunctionMF.H"
12 
13 #include <AMReX_GMRES.H>
14 #include <AMReX_ParmParse.H>
15 #include "Utils/TextMsg.H"
16 
17 #include <vector>
18 
26 template<class Vec, class Ops>
27 class NewtonSolver : public NonlinearSolver<Vec,Ops>
28 {
29 public:
30 
31  NewtonSolver<Vec,Ops>() = default;
32 
33  ~NewtonSolver<Vec,Ops>() override = default;
34 
35  // Prohibit Move and Copy operations
36  NewtonSolver(const NewtonSolver&) = delete;
37  NewtonSolver& operator=(const NewtonSolver&) = delete;
38  NewtonSolver(NewtonSolver&&) noexcept = delete;
39  NewtonSolver& operator=(NewtonSolver&&) noexcept = delete;
40 
41  void Define ( const Vec& a_U,
42  Ops* a_ops ) override;
43 
44  void Solve ( Vec& a_U,
45  const Vec& a_b,
46  amrex::Real a_time,
47  amrex::Real a_dt ) const override;
48 
49  void GetSolverParams ( amrex::Real& a_rtol,
50  amrex::Real& a_atol,
51  int& a_maxits ) override
52  {
53  a_rtol = m_rtol;
54  a_atol = m_atol;
55  a_maxits = m_maxits;
56  }
57 
58  inline void CurTime ( amrex::Real a_time ) const
59  {
60  m_cur_time = a_time;
61  m_linear_function->curTime( a_time );
62  }
63 
64  inline void CurTimeStep ( amrex::Real a_dt ) const
65  {
66  m_dt = a_dt;
67  m_linear_function->curTimeStep( a_dt );
68  }
69 
70  void PrintParams () const override
71  {
72  amrex::Print() << "Newton verbose: " << (this->m_verbose?"true":"false") << std::endl;
73  amrex::Print() << "Newton max iterations: " << m_maxits << std::endl;
74  amrex::Print() << "Newton relative tolerance: " << m_rtol << std::endl;
75  amrex::Print() << "Newton absolute tolerance: " << m_atol << std::endl;
76  amrex::Print() << "Newton require convergence: " << (m_require_convergence?"true":"false") << std::endl;
77  amrex::Print() << "GMRES verbose: " << m_gmres_verbose_int << std::endl;
78  amrex::Print() << "GMRES restart length: " << m_gmres_restart_length << std::endl;
79  amrex::Print() << "GMRES max iterations: " << m_gmres_maxits << std::endl;
80  amrex::Print() << "GMRES relative tolerance: " << m_gmres_rtol << std::endl;
81  amrex::Print() << "GMRES absolute tolerance: " << m_gmres_atol << std::endl;
82  }
83 
84 private:
85 
89  mutable Vec m_dU, m_F, m_R;
90 
94  Ops* m_ops = nullptr;
95 
99  bool m_require_convergence = true;
100 
104  amrex::Real m_rtol = 1.0e-6;
105 
109  amrex::Real m_atol = 0.;
110 
114  int m_maxits = 100;
115 
119  amrex::Real m_gmres_rtol = 1.0e-4;
120 
124  amrex::Real m_gmres_atol = 0.;
125 
129  int m_gmres_maxits = 1000;
130 
135 
140 
141  mutable amrex::Real m_cur_time, m_dt;
142  mutable bool m_update_pc = false;
143  mutable bool m_update_pc_init = false;
144 
149  std::unique_ptr<JacobianFunctionMF<Vec,Ops>> m_linear_function;
150 
154  std::unique_ptr<amrex::GMRES<Vec,JacobianFunctionMF<Vec,Ops>>> m_linear_solver;
155 
156  void ParseParameters ();
157 
161  void EvalResidual ( Vec& a_F,
162  const Vec& a_U,
163  const Vec& a_b,
164  amrex::Real a_time,
165  amrex::Real a_dt,
166  int a_iter ) const;
167 
168 };
169 
170 template <class Vec, class Ops>
171 void NewtonSolver<Vec,Ops>::Define ( const Vec& a_U,
172  Ops* a_ops )
173 {
175  !this->m_is_defined,
176  "Newton nonlinear solver object is already defined!");
177 
178  ParseParameters();
179 
180  m_dU.Define(a_U);
181  m_F.Define(a_U); // residual function F(U) = U - b - R(U) = 0
182  m_R.Define(a_U); // right hand side function R(U)
183 
184  m_ops = a_ops;
185 
186  m_linear_function = std::make_unique<JacobianFunctionMF<Vec,Ops>>();
187  m_linear_function->define(m_F, m_ops);
188 
189  m_linear_solver = std::make_unique<amrex::GMRES<Vec,JacobianFunctionMF<Vec,Ops>>>();
190  m_linear_solver->define(*m_linear_function);
191  m_linear_solver->setVerbose( m_gmres_verbose_int );
192  m_linear_solver->setRestartLength( m_gmres_restart_length );
193  m_linear_solver->setMaxIters( m_gmres_maxits );
194 
195  this->m_is_defined = true;
196 
197 }
198 
199 template <class Vec, class Ops>
201 {
202  const amrex::ParmParse pp_newton("newton");
203  pp_newton.query("verbose", this->m_verbose);
204  pp_newton.query("absolute_tolerance", m_atol);
205  pp_newton.query("relative_tolerance", m_rtol);
206  pp_newton.query("max_iterations", m_maxits);
207  pp_newton.query("require_convergence", m_require_convergence);
208 
209  const amrex::ParmParse pp_gmres("gmres");
210  pp_gmres.query("verbose_int", m_gmres_verbose_int);
211  pp_gmres.query("restart_length", m_gmres_restart_length);
212  pp_gmres.query("absolute_tolerance", m_gmres_atol);
213  pp_gmres.query("relative_tolerance", m_gmres_rtol);
214  pp_gmres.query("max_iterations", m_gmres_maxits);
215 }
216 
217 template <class Vec, class Ops>
219  const Vec& a_b,
220  amrex::Real a_time,
221  amrex::Real a_dt ) const
222 {
223  BL_PROFILE("NewtonSolver::Solve()");
225  this->m_is_defined,
226  "NewtonSolver::Solve() called on undefined object");
227  using namespace amrex::literals;
228 
229  //
230  // Newton routine to solve nonlinear equation of form:
231  // F(U) = U - b - R(U) = 0
232  //
233 
234  CurTime(a_time);
235  CurTimeStep(a_dt);
236 
237  amrex::Real norm_abs = 0.;
238  amrex::Real norm0 = 1._rt;
239  amrex::Real norm_rel = 0.;
240 
241  int iter;
242  for (iter = 0; iter < m_maxits;) {
243 
244  // Compute residual: F(U) = U - b - R(U)
245  EvalResidual(m_F, a_U, a_b, a_time, a_dt, iter);
246 
247  // Compute norm of the residual
248  norm_abs = m_F.norm2();
249  if (iter == 0) {
250  if (norm_abs > 0.) { norm0 = norm_abs; }
251  else { norm0 = 1._rt; }
252  }
253  norm_rel = norm_abs/norm0;
254 
255  // Check for convergence criteria
256  if (this->m_verbose || iter == m_maxits) {
257  amrex::Print() << "Newton: iteration = " << std::setw(3) << iter << ", norm = "
258  << std::scientific << std::setprecision(5) << norm_abs << " (abs.), "
259  << std::scientific << std::setprecision(5) << norm_rel << " (rel.)" << "\n";
260  }
261 
262  if (norm_abs < m_rtol) {
263  amrex::Print() << "Newton: exiting at iteration = " << std::setw(3) << iter
264  << ". Satisfied absolute tolerance " << m_atol << std::endl;
265  break;
266  }
267 
268  if (norm_rel < m_rtol) {
269  amrex::Print() << "Newton: exiting at iteration = " << std::setw(3) << iter
270  << ". Satisfied relative tolerance " << m_rtol << std::endl;
271  break;
272  }
273 
274  if (norm_abs > 100._rt*norm0) {
275  amrex::Print() << "Newton: exiting at iteration = " << std::setw(3) << iter
276  << ". SOLVER DIVERGED! relative tolerance = " << m_rtol << std::endl;
277  std::stringstream convergenceMsg;
278  convergenceMsg << "Newton: exiting at iteration " << std::setw(3) << iter <<
279  ". SOLVER DIVERGED! absolute norm = " << norm_abs <<
280  " has increased by 100X from that after first iteration.";
281  WARPX_ABORT_WITH_MESSAGE(convergenceMsg.str());
282  }
283 
284  // Solve linear system for Newton step [Jac]*dU = F
285  m_dU.zero();
286  m_linear_solver->solve( m_dU, m_F, m_gmres_rtol, m_gmres_atol );
287 
288  // Update solution
289  a_U -= m_dU;
290 
291  iter++;
292  if (iter >= m_maxits) {
293  amrex::Print() << "Newton: exiting at iter = " << std::setw(3) << iter
294  << ". Maximum iteration reached: iter = " << m_maxits << std::endl;
295  break;
296  }
297 
298  }
299 
300  if (m_rtol > 0. && iter == m_maxits) {
301  std::stringstream convergenceMsg;
302  convergenceMsg << "Newton solver failed to converge after " << iter <<
303  " iterations. Relative norm is " << norm_rel <<
304  " and the relative tolerance is " << m_rtol <<
305  ". Absolute norm is " << norm_abs <<
306  " and the absolute tolerance is " << m_atol;
307  if (this->m_verbose) { amrex::Print() << convergenceMsg.str() << std::endl; }
308  if (m_require_convergence) {
309  WARPX_ABORT_WITH_MESSAGE(convergenceMsg.str());
310  } else {
311  ablastr::warn_manager::WMRecordWarning("NewtonSolver", convergenceMsg.str());
312  }
313  }
314 
315 }
316 
317 template <class Vec, class Ops>
319  const Vec& a_U,
320  const Vec& a_b,
321  amrex::Real a_time,
322  amrex::Real a_dt,
323  int a_iter ) const
324 {
325 
326  m_ops->ComputeRHS( m_R, a_U, a_time, a_dt, a_iter, false );
327 
328  // set base U and R(U) for matrix-free Jacobian action calculation
329  m_linear_function->setBaseSolution(a_U);
330  m_linear_function->setBaseRHS(m_R);
331 
332  // update preconditioner
333  if (m_update_pc || m_update_pc_init) {
334  m_linear_function->updatePreCondMat(a_U);
335  }
336  m_update_pc_init = false;
337 
338  // Compute residual: F(U) = U - b - R(U)
339  a_F.Copy(a_U);
340  a_F -= m_R;
341  a_F -= a_b;
342 
343 }
344 
345 #endif
#define BL_PROFILE(a)
#define WARPX_ABORT_WITH_MESSAGE(MSG)
Definition: TextMsg.H:15
#define WARPX_ALWAYS_ASSERT_WITH_MESSAGE(EX, MSG)
Definition: TextMsg.H:13
Newton method to solve nonlinear equation of form: F(U) = U - b - R(U) = 0. U is the solution vector,...
Definition: NewtonSolver.H:28
Vec m_dU
Intermediate Vec containers used by the solver.
Definition: NewtonSolver.H:89
int m_maxits
Maximum iterations for the Newton solver.
Definition: NewtonSolver.H:114
int m_gmres_maxits
Maximum iterations for GMRES.
Definition: NewtonSolver.H:129
void Define(const Vec &a_U, Ops *a_ops) override
Read user-provided parameters that control the nonlinear solver. Allocate intermediate data container...
Definition: NewtonSolver.H:171
void GetSolverParams(amrex::Real &a_rtol, amrex::Real &a_atol, int &a_maxits) override
Return the convergence parameters used by the nonlinear solver.
Definition: NewtonSolver.H:49
void ParseParameters()
Definition: NewtonSolver.H:200
void PrintParams() const override
Print parameters used by the nonlinear solver.
Definition: NewtonSolver.H:70
std::unique_ptr< JacobianFunctionMF< Vec, Ops > > m_linear_function
The linear function used by GMRES to compute A*v. In the contect of JFNK, A = dF/dU (i....
Definition: NewtonSolver.H:149
amrex::Real m_gmres_atol
Absolute tolerance for GMRES.
Definition: NewtonSolver.H:124
amrex::Real m_atol
Absolute tolerance for the Newton solver.
Definition: NewtonSolver.H:109
amrex::Real m_gmres_rtol
Relative tolerance for GMRES.
Definition: NewtonSolver.H:119
bool m_require_convergence
Flag to determine whether convergence is required.
Definition: NewtonSolver.H:99
NewtonSolver(const NewtonSolver &)=delete
amrex::Real m_cur_time
Definition: NewtonSolver.H:141
std::unique_ptr< amrex::GMRES< Vec, JacobianFunctionMF< Vec, Ops > > > m_linear_solver
The linear solver (GMRES) object.
Definition: NewtonSolver.H:154
void EvalResidual(Vec &a_F, const Vec &a_U, const Vec &a_b, amrex::Real a_time, amrex::Real a_dt, int a_iter) const
Compute the nonlinear residual: F(U) = U - b - R(U).
Definition: NewtonSolver.H:318
int m_gmres_verbose_int
Verbosity flag for GMRES.
Definition: NewtonSolver.H:134
NewtonSolver(NewtonSolver &&) noexcept=delete
bool m_update_pc
Definition: NewtonSolver.H:142
Vec m_F
Definition: NewtonSolver.H:89
bool m_update_pc_init
Definition: NewtonSolver.H:143
NewtonSolver & operator=(const NewtonSolver &)=delete
Vec m_R
Definition: NewtonSolver.H:89
amrex::Real m_dt
Definition: NewtonSolver.H:141
void CurTime(amrex::Real a_time) const
Definition: NewtonSolver.H:58
void CurTimeStep(amrex::Real a_dt) const
Definition: NewtonSolver.H:64
int m_gmres_restart_length
Restart iteration for GMRES.
Definition: NewtonSolver.H:139
void Solve(Vec &a_U, const Vec &a_b, amrex::Real a_time, amrex::Real a_dt) const override
Solve the specified nonlinear equation for U. Picard: U = b + R(U). Newton: F(U) = U - b - R(U) = 0.
Definition: NewtonSolver.H:218
amrex::Real m_rtol
Relative tolerance for the Newton solver.
Definition: NewtonSolver.H:104
Ops * m_ops
Pointer to Ops class.
Definition: NewtonSolver.H:94
Top-level class for the nonlinear solver.
Definition: NonlinearSolver.H:28
bool m_verbose
Definition: NonlinearSolver.H:83
int query(const char *name, bool &ref, int ival=FIRST) const
void WMRecordWarning(const std::string &topic, const std::string &text, const WarnPriority &priority=WarnPriority::medium)
Helper function to abbreviate the call to WarnManager::GetInstance().RecordWarning (recording a warni...
Definition: WarnManager.cpp:318