# Copyright 2018 Max Shinn <maxwell.shinn@yale.edu>
# 2018 Norman Lam <norman.lam@yale.edu>
#
# This file is part of PyDDM, and is available under the MIT license.
# Please see LICENSE.txt in the root directory for more information.
__all__ = ["Drift", "DriftConstant", "DriftLinear"]
import numpy as np
from ..tridiag import TriDiagMatrix
from .base import Dependence
from paranoid import *
from .paranoid_types import Conditions
[docs]@paranoidclass
class Drift(Dependence):
"""Subclass this to specify how drift rate varies with position and time.
This abstract class provides the methods which define a dependence
of drift on x and t. To subclass it, implement get_drift. Since
it inherits from Dependence, subclasses must also assign a `name`
and `required_parameters` (see documentation for Dependence.)
"""
depname = "Drift"
def _uses_t(self):
return self._uses(self.get_drift, "t")
def _uses_x(self):
return self._uses(self.get_drift, "x")
[docs] @accepts(Self, x=NDArray(d=1, t=Number), t=Positive0, dx=Positive, dt=Positive, conditions=Conditions, implicit=Boolean)
@returns(TriDiagMatrix)
@ensures("return.shape == (len(x), len(x))")
def get_matrix(self, x, t, dx, dt, conditions, implicit=False, **kwargs):
"""The drift component of the implicit method diffusion matrix across the domain `x` at time `t`.
`x` should be a length N ndarray of all positions in the grid.
`t` should be the time in seconds at which to calculate drift.
`dt` and `dx` should be the simulations timestep and grid step
`conditions` should be the conditions at which to calculate drift
Returns a sparse NxN matrix as a PyDDM TriDiagMatrix object.
There is generally no need to redefine this method in
subclasses.
"""
drift = self.get_drift(x=x, t=t, dx=dx, dt=dt, conditions=conditions, **kwargs)
D = np.zeros(len(x))
if np.isscalar(drift):
UP = 0.5*dt/dx * drift * np.ones(len(x)-1)
DOWN = -0.5*dt/dx * drift * np.ones(len(x)-1)
else:
UP = 0.5*dt/dx * drift[1:]
DOWN = -0.5*dt/dx * drift[:-1]
if implicit:
D[-1] = UP[-1]
UP[-1] = 0
D[0] = DOWN[0]
DOWN[0] = 0
return TriDiagMatrix(up=UP,
down=DOWN,
diag=D)
# Amount of flux from bound/end points to correct and erred
# response probabilities, due to different parameters.
[docs] @accepts(Self, x_bound=Number, t=Positive0, dx=Positive, dt=Positive, conditions=Conditions)
@returns(Number)
def get_flux(self, x_bound, t, dx, dt, conditions, **kwargs):
"""The drift component of flux across the boundary at position `x_bound` at time `t`.
Flux here is essentially the amount of the mass of the PDF
that is past the boundary point `x_bound`.
There is generally no need to redefine this method in
subclasses.
"""
return 0.5*dt/dx * np.sign(x_bound) * self.get_drift(x=x_bound, t=t, dx=dx, dt=dt, conditions=conditions, **kwargs)
[docs] def get_drift(self, t, x, conditions, **kwargs):
"""Calculate the instantaneous drift rate.
This function must be redefined in subclasses.
It may take several arguments:
- `t` - The time at which drift should be calculated
- `x` - The particle position (or 1-dimensional NDArray of
particle positions) at which drift should be calculated
- `conditions` - A dictionary describing the task conditions
It should return a number or an NDArray (the same as `x`)
indicating the drift rate at that particular time,
position(s), and task conditions.
Definitions of this method in subclasses should only have
arguments for needed variables and should always be followed
by "**kwargs". For example, if the function does not depend
on `t` or `x` but does depend on task conditions, this should
be:
| def get_drift(self, conditions, **kwargs):
Of course, the function would still work properly if `x` were
included as an argument, but this convention allows PyDDM to
automatically select the best simulation methods for the
model.
If a function depends on `x`, it should return a scalar if `x`
is a scalar, or an NDArray of the same size as `x` if `x` is
an NDArray. If the function does not depend on `x`, it should
return a scalar. (The purpose of this is a dramatic speed
increase by using numpy vectorization.)
"""
raise NotImplementedError("Drift model %s invalid: must define the get_drift function" % self.__class__.__name__)
[docs]@paranoidclass
class DriftConstant(Drift):
"""Drift dependence: drift rate is constant throughout the simulation.
Only take one parameter: drift, the constant drift rate.
Note that this is a special case of DriftLinear.
Example usage:
| drift = DriftConstant(drift=0.3)
"""
name = "constant"
required_parameters = ["drift"]
@staticmethod
def _test(v):
assert v.drift in Number()
@staticmethod
def _generate():
yield DriftConstant(drift=0)
yield DriftConstant(drift=1)
yield DriftConstant(drift=-1)
yield DriftConstant(drift=100)
@accepts(Self)
@returns(Number)
def get_drift(self, **kwargs):
return self.drift
[docs]@paranoidclass
class DriftLinear(Drift):
"""Drift dependence: drift rate varies linearly with position and time.
Take three parameters:
- `drift` - The starting drift rate
- `x` - The coefficient by which drift varies with x
- `t` - The coefficient by which drift varies with t
Example usage:
| drift = DriftLinear(drift=0.5, t=0, x=-1) # Leaky integrator
| drift = DriftLinear(drift=0.8, t=0, x=0.4) # Unstable integrator
| drift = DriftLinear(drift=0, t=1, x=0.4) # Urgency function
"""
name = "linear_xt"
required_parameters = ["drift", "x", "t"]
@staticmethod
def _test(v):
assert v.drift in Number()
assert v.x in Number()
assert v.t in Number()
@staticmethod
def _generate():
yield DriftLinear(drift=0, x=0, t=0)
yield DriftLinear(drift=1, x=-1, t=1)
yield DriftLinear(drift=10, x=10, t=10)
yield DriftLinear(drift=1, x=-10, t=-.5)
# We allow this function to accept a vector or a scalar for x,
# because if we use list comprehensions instead of native numpy
# multiplication in the get_matrix function it slows things down by
# around 100x.
@accepts(Self, Or(Number, NDArray(d=1, t=Number)), Positive0)
@returns(Or(Number, NDArray(d=1, t=Number)))
@ensures("np.isscalar(x) <--> np.isscalar(return)")
def get_drift(self, x, t, **kwargs):
return self.drift + self.x*x + self.t*t
def _uses_t(self):
return self.t != 0
def _uses_x(self):
return self.x != 0