139 integer,
intent(in) :: n
140 integer,
intent(in) :: max_iter
141 class(pc_t),
optional,
intent(in),
target :: M
142 real(kind=rp),
optional,
intent(in) :: rel_tol
143 real(kind=rp),
optional,
intent(in) :: abs_tol
144 logical,
optional,
intent(in) :: monitor
145 type(device_ident_t),
target :: M_ident
147 integer(c_size_t) :: z_size
160 call device_map(this%w, this%w_d, n)
161 call device_map(this%r, this%r_d, n)
163 allocate(this%c(this%m_restart))
164 allocate(this%s(this%m_restart))
165 allocate(this%gam(this%m_restart + 1))
166 call device_map(this%c, this%c_d, this%m_restart)
167 call device_map(this%s, this%s_d, this%m_restart)
168 call device_map(this%gam, this%gam_d, this%m_restart+1)
170 allocate(this%z(n, this%m_restart))
171 allocate(this%v(n, this%m_restart))
172 allocate(this%h(this%m_restart, this%m_restart))
173 allocate(this%z_d(this%m_restart))
174 allocate(this%v_d(this%m_restart))
175 allocate(this%h_d(this%m_restart))
176 do i = 1, this%m_restart
177 this%z_d(i) = c_null_ptr
178 call device_map(this%z(:,i), this%z_d(i), n)
180 this%v_d(i) = c_null_ptr
181 call device_map(this%v(:,i), this%v_d(i), n)
183 this%h_d(i) = c_null_ptr
184 call device_map(this%h(:,i), this%h_d(i), this%m_restart)
187 z_size = c_sizeof(c_null_ptr) * (this%m_restart)
188 call device_alloc(this%z_d_d, z_size)
189 call device_alloc(this%v_d_d, z_size)
190 call device_alloc(this%h_d_d, z_size)
191 ptr = c_loc(this%z_d)
192 call device_memcpy(ptr, this%z_d_d, z_size, &
193 host_to_device, sync = .false.)
194 ptr = c_loc(this%v_d)
195 call device_memcpy(ptr, this%v_d_d, z_size, &
196 host_to_device, sync = .false.)
197 ptr = c_loc(this%h_d)
198 call device_memcpy(ptr, this%h_d_d, z_size, &
199 host_to_device, sync = .false.)
202 if (
present(rel_tol) .and.
present(abs_tol) .and.
present(monitor))
then
203 call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
204 else if (
present(rel_tol) .and.
present(abs_tol))
then
205 call this%ksp_init(max_iter, rel_tol, abs_tol)
206 else if (
present(monitor) .and.
present(abs_tol))
then
207 call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
208 else if (
present(rel_tol) .and.
present(monitor))
then
209 call this%ksp_init(max_iter, rel_tol, monitor = monitor)
210 else if (
present(rel_tol))
then
211 call this%ksp_init(max_iter, rel_tol = rel_tol)
212 else if (
present(abs_tol))
then
213 call this%ksp_init(max_iter, abs_tol = abs_tol)
214 else if (
present(monitor))
then
215 call this%ksp_init(max_iter, monitor = monitor)
217 call this%ksp_init(max_iter)
220 call device_event_create(this%gs_event, 2)
231 if (
allocated(this%w))
then
232 if (c_associated(this%w_d))
then
233 call device_unmap(this%w, this%w_d)
238 if (
allocated(this%c))
then
239 if (c_associated(this%c_d))
then
240 call device_unmap(this%c, this%c_d)
245 if (
allocated(this%r))
then
246 if (c_associated(this%r_d))
then
247 call device_unmap(this%r, this%r_d)
252 if (
allocated(this%z))
then
253 if (
allocated(this%z_d))
then
254 do i = 1, this%m_restart
255 if (c_associated(this%z_d(i)))
then
256 call device_unmap(this%z(:,i), this%z_d(i))
263 if (
allocated(this%h))
then
264 if (
allocated(this%h_d))
then
265 do i = 1, this%m_restart
266 if (c_associated(this%h_d(i)))
then
267 call device_unmap(this%h(:,i), this%h_d(i))
274 if (
allocated(this%v))
then
275 if (
allocated(this%v_d))
then
276 do i = 1, this%m_restart
277 if (c_associated(this%v_d(i)))
then
278 call device_unmap(this%v(:,i), this%v_d(i))
285 if (
allocated(this%s))
then
286 if (c_associated(this%s_d))
then
287 call device_unmap(this%s, this%s_d)
291 if (
allocated(this%gam))
then
292 if (c_associated(this%gam_d))
then
293 call device_unmap(this%gam, this%gam_d)
298 if (c_associated(this%z_d_d))
then
299 call device_free(this%z_d_d)
301 if (c_associated(this%v_d_d))
then
302 call device_free(this%v_d_d)
304 if (c_associated(this%h_d_d))
then
305 call device_free(this%h_d_d)
310 if (c_associated(this%gs_event))
then
311 call device_event_destroy(this%gs_event)
320 class(ax_t),
intent(in) :: ax
321 type(field_t),
intent(inout) :: x
322 integer,
intent(in) :: n
323 real(kind=rp),
dimension(n),
intent(in) :: f
324 type(coef_t),
intent(inout) :: coef
325 type(bc_list_t),
intent(inout) :: blst
326 type(gs_t),
intent(inout) :: gs_h
327 type(ksp_monitor_t) :: ksp_results
328 integer,
optional,
intent(in) :: niter
329 integer :: iter, max_iter
331 real(kind=rp) :: rnorm, alpha, temp, lr, alpha2, norm_fac
335 f_d = device_get_ptr(f)
341 if (
present(niter))
then
344 max_iter = this%max_iter
347 associate(w => this%w, c => this%c, r => this%r, z => this%z, h => this%h, &
348 v => this%v, s => this%s, gam => this%gam, v_d => this%v_d, &
349 w_d => this%w_d, r_d => this%r_d, h_d => this%h_d, &
350 v_d_d => this%v_d_d, x_d => x%x_d, z_d_d => this%z_d_d, &
353 norm_fac = 1.0_rp / sqrt(coef%volume)
354 call rzero(gam, this%m_restart + 1)
355 call rone(s, this%m_restart)
356 call rone(c, this%m_restart)
357 call rzero(h, this%m_restart * this%m_restart)
358 call device_rzero(x%x_d, n)
359 call device_rzero(this%gam_d, this%m_restart + 1)
360 call device_rone(this%s_d, this%m_restart)
361 call device_rone(this%c_d, this%m_restart)
363 call rzero(this%h, this%m_restart**2)
368 call this%monitor_start(
'GMRES')
369 do while (.not. conv .and. iter .lt. max_iter)
371 if (iter .eq. 0)
then
372 call device_copy(r_d, f_d, n)
374 call device_copy(r_d, f_d, n)
375 call ax%compute(w, x%x, coef, x%msh, x%Xh)
376 call gs_h%op(w, n, gs_op_add, this%gs_event)
377 call device_event_sync(this%gs_event)
378 call blst%apply_scalar(w, n)
379 call device_sub2(r_d, w_d, n)
382 gam(1) = sqrt(device_glsc3(r_d, r_d, coef%mult_d, n))
383 if (iter .eq. 0)
then
384 ksp_results%res_start = gam(1) * norm_fac
387 if (abscmp(gam(1), 0.0_rp))
exit
390 temp = 1.0_rp / gam(1)
391 call device_cmult2(v_d(1), r_d, temp, n)
392 do j = 1, this%m_restart
395 call this%M%solve(z(1,j), v(1,j), n)
397 call ax%compute(w, z(1,j), coef, x%msh, x%Xh)
398 call gs_h%op(w, n, gs_op_add, this%gs_event)
399 call device_event_sync(this%gs_event)
400 call blst%apply_scalar(w, n)
402 if (neko_bcknd_opencl .eq. 1 .or. neko_bcknd_metal .eq. 1)
then
404 h(i,j) = device_glsc3(w_d, v_d(i), coef%mult_d, n)
406 call device_add2s2(w_d, v_d(i), -h(i,j), n)
408 alpha2 = device_glsc3(w_d, w_d, coef%mult_d, n)
411 call device_glsc3_many(h(1,j), w_d, v_d_d, coef%mult_d, j, n)
413 call device_memcpy(h(:,j), h_d(j), j, &
414 host_to_device, sync = .false.)
424 h(i,j) = c(i)*temp + s(i) * h(i+1,j)
425 h(i+1,j) = -s(i)*temp + c(i) * h(i+1,j)
429 if (abscmp(alpha, 0.0_rp))
then
434 lr = sqrt(h(j,j) * h(j,j) + alpha2)
439 call device_memcpy(h(:,j), h_d(j), j, &
440 host_to_device, sync = .false.)
441 gam(j+1) = -s(j) * gam(j)
442 gam(j) = c(j) * gam(j)
444 rnorm = abs(gam(j+1)) * norm_fac
445 call this%monitor_iter(iter, rnorm)
446 if (rnorm .lt. this%abs_tol)
then
451 if (iter + 1 .gt. max_iter)
exit
453 if (j .lt. this%m_restart)
then
454 temp = 1.0_rp / alpha
455 call device_cmult2(v_d(j+1), w_d, temp, n)
460 j = min(j, this%m_restart)
464 temp = temp - h(k,i) * c(i)
469 if (neko_bcknd_opencl .eq. 1 .or. neko_bcknd_metal .eq. 1)
then
471 call device_add2s2(x_d, this%z_d(i), c(i), n)
474 call device_memcpy(c, c_d, j, host_to_device, sync = .false.)
475 call device_add2s2_many(x_d, z_d_d, c_d, j, n)
480 call this%monitor_stop()
481 ksp_results%res_final = rnorm
482 ksp_results%iter = iter
483 ksp_results%converged = this%is_converged(iter, rnorm)
489 n, coef, blstx, blsty, blstz, gs_h, niter)
result(ksp_results)
491 class(ax_t),
intent(in) :: ax
492 type(field_t),
intent(inout) :: x
493 type(field_t),
intent(inout) :: y
494 type(field_t),
intent(inout) :: z
495 integer,
intent(in) :: n
496 real(kind=rp),
dimension(n),
intent(in) :: fx
497 real(kind=rp),
dimension(n),
intent(in) :: fy
498 real(kind=rp),
dimension(n),
intent(in) :: fz
499 type(coef_t),
intent(inout) :: coef
500 type(bc_list_t),
intent(inout) :: blstx
501 type(bc_list_t),
intent(inout) :: blsty
502 type(bc_list_t),
intent(inout) :: blstz
503 type(gs_t),
intent(inout) :: gs_h
504 type(ksp_monitor_t),
dimension(3) :: ksp_results
505 integer,
optional,
intent(in) :: niter
507 ksp_results(1) = this%solve(ax, x, fx, n, coef, blstx, gs_h, niter)
508 ksp_results(2) = this%solve(ax, y, fy, n, coef, blsty, gs_h, niter)
509 ksp_results(3) = this%solve(ax, z, fz, n, coef, blstz, gs_h, niter)