Source code for wavesim.domain

import torch
from abc import ABCMeta, abstractmethod


[docs] class Domain(metaclass=ABCMeta): """Base class for all simulation domains This base class defines the interface for operations that are common for all simulation types, and for MultiDomain. todo: the design using slots minimizes memory use, but it is a suboptimal design because it mixes mutable and immutable state. This design should be revisited so that the Domain is immutable, and the code that runs the algorithms performs the memory management. """ def __init__(self, pixel_size: float, shape, device): self.pixel_size = pixel_size self.scale = None self.shift = None self.shape = shape self.device = device
[docs] @abstractmethod def add_source(self, slot, weight: float): pass
[docs] @abstractmethod def clear(self, slot): """Clears the data in the specified slot""" pass
[docs] @abstractmethod def get(self, slot: int, copy=False): """Returns the data in the specified slot. :param slot: slot from which to return the data :param copy: if True, returns a copy of the data. Otherwise, may return the original data possible. Note that this data may be overwritten by the next call to domain. """ pass
[docs] @abstractmethod def set(self, slot, data): """Copy the date into the specified slot""" pass
[docs] @abstractmethod def inner_product(self, slot_a, slot_b): """Computes the inner product of two data vectors Note: The vectors may be represented as multidimensional arrays, but these arrays must be contiguous for this operation to work. Although it would be possible to use flatten(), this would create a copy when the array is not contiguous, causing a hidden performance hit. """ pass
[docs] @abstractmethod def medium(self, slot_in, slot_out, mnum): """Applies the operator 1-Vscat.""" pass
[docs] @abstractmethod def mix(self, weight_a, slot_a, weight_b, slot_b, slot_out): """Mixes two data arrays and stores the result in the specified slot""" pass
[docs] @abstractmethod def propagator(self, slot_in, slot_out): """Applies the operator (L+1)^-1 x. """ pass
[docs] @abstractmethod def inverse_propagator(self, slot_in, slot_out): """Applies the operator (L+1) x . This operation is not needed for the Wavesim algorithm, but is provided for testing purposes, and can be used to evaluate the residue of the solution. """ pass
[docs] @abstractmethod def set_source(self, source): """Sets the source term for this domain.""" pass
[docs] @abstractmethod def create_empty_vdot(self): """Create an empty tensor for the Vdot tensor""" pass
[docs] def coordinates_f(self, dim): """Returns the Fourier-space coordinates along the specified dimension""" shapes = [[-1, 1, 1], [1, -1, 1], [1, 1, -1]] return (2 * torch.pi * torch.fft.fftfreq(self.shape[dim], self.pixel_size, device=self.device, dtype=torch.float64)).reshape(shapes[dim]).to(torch.complex64)
[docs] def coordinates(self, dim, type: str = 'linear'): """Returns the real-space coordinates along the specified dimension, starting at 0""" shapes = [[-1, 1, 1], [1, -1, 1], [1, 1, -1]] x = torch.arange(self.shape[dim], device=self.device, dtype=torch.float64) * self.pixel_size if type == 'periodic': x -= self.pixel_size * (self.shape[dim] // 2) x = torch.fft.ifftshift(x) # todo: or fftshift? elif type == 'centered': x -= self.pixel_size * (self.shape[dim] // 2) elif type == 'linear': pass else: raise ValueError(f"Unknown type {type}") return x.reshape(shapes[dim])