Skip to content

slisemap.local_models

Module that contains the built-in alternatives for local white box models.

These functions can also be used as templates for implementing your own.

local_predict(X, B, local_model)

Get individual predictions when every data item has a separate model.

Parameters:

Name Type Description Default
X Tensor

Data matrix [n, m].

required
B Tensor

Coefficient matrix [n, q].

required
local_model Callable[[Tensor, Tensor], Tensor]

Prediction function: [1, m], [1, q] -> [1, 1, o].

required

Returns:

Type Description
Tensor

Matrix of local predictions [n, o].

Source code in slisemap/local_models.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def local_predict(
    X: torch.Tensor,
    B: torch.Tensor,
    local_model: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> torch.Tensor:
    """Get individual predictions when every data item has a separate model.

    Args:
        X: Data matrix [n, m].
        B: Coefficient matrix [n, q].
        local_model: Prediction function: [1, m], [1, q] -> [1, 1, o].

    Returns:
        Matrix of local predictions [n, o].
    """
    n = X.shape[0]
    _assert(n == B.shape[0], "X and B must have the same number of rows", local_predict)
    y = local_model(X[:1, :], B[:1, :])[0, 0, ...]
    Y = torch.empty((n, *y.shape), dtype=y.dtype, device=y.device)
    Y[0, ...] = y
    for i in range(1, n):
        Y[i, ...] = local_model(X[i : i + 1, :], B[i : i + 1, :])[0, 0, ...]
    return Y

ALocalModel

Bases: ABC

Abstract class for gathering all the functions needed for a local model (predict, loss, coefficients).

Source code in slisemap/local_models.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class ALocalModel(ABC):
    """Abstract class for gathering all the functions needed for a local model (predict, loss, coefficients)."""

    @staticmethod
    @abstractmethod
    def predict(X: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
        """Prediction function.

        Args:
            X: Data matrix.
            B: Coefficient matrix.

        Returns:
            Y: Prediction matrix.
        """
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def loss(Ytilde: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
        """Loss function.

        Args:
            Ytilde: Prediction matrix.
            Y: Target matrix

        Returns:
            L: Loss matrix.
        """
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def coefficients(
        X: Union[torch.Tensor, np.ndarray],
        Y: Union[torch.Tensor, np.ndarray],
        intercept: bool,
    ) -> int:
        """Get for the number of columns of B.

        Args:
            X: Data matrix.
            Y: Target matrix.
            intercept: Add intercept.

        Returns:
            Number of columns.
        """
        raise NotImplementedError

    @staticmethod
    def regularisation(
        X: torch.Tensor,
        Y: torch.Tensor,
        B: torch.Tensor,
        Z: torch.Tensor,
        Ytilde: torch.Tensor,
    ) -> torch.Tensor:
        """Regularisation function.

        Args:
            X: Data matrix.
            Y: Target matrix.
            B: Coefficient matrix.
            Z: Embedding matrix.
            Ytilde: Prediction matrix.

        Returns:
            Additional loss term.
        """
        return 0.0

predict(X, B) abstractmethod staticmethod

Prediction function.

Parameters:

Name Type Description Default
X Tensor

Data matrix.

required
B Tensor

Coefficient matrix.

required

Returns:

Name Type Description
Y Tensor

Prediction matrix.

Source code in slisemap/local_models.py
44
45
46
47
48
49
50
51
52
53
54
55
56
@staticmethod
@abstractmethod
def predict(X: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """Prediction function.

    Args:
        X: Data matrix.
        B: Coefficient matrix.

    Returns:
        Y: Prediction matrix.
    """
    raise NotImplementedError

loss(Ytilde, Y) abstractmethod staticmethod

Loss function.

Parameters:

Name Type Description Default
Ytilde Tensor

Prediction matrix.

required
Y Tensor

Target matrix

required

Returns:

Name Type Description
L Tensor

Loss matrix.

Source code in slisemap/local_models.py
58
59
60
61
62
63
64
65
66
67
68
69
70
@staticmethod
@abstractmethod
def loss(Ytilde: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    """Loss function.

    Args:
        Ytilde: Prediction matrix.
        Y: Target matrix

    Returns:
        L: Loss matrix.
    """
    raise NotImplementedError

coefficients(X, Y, intercept) abstractmethod staticmethod

Get for the number of columns of B.

Parameters:

Name Type Description Default
X Union[Tensor, ndarray]

Data matrix.

required
Y Union[Tensor, ndarray]

Target matrix.

required
intercept bool

Add intercept.

required

Returns:

Type Description
int

Number of columns.

Source code in slisemap/local_models.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@staticmethod
@abstractmethod
def coefficients(
    X: Union[torch.Tensor, np.ndarray],
    Y: Union[torch.Tensor, np.ndarray],
    intercept: bool,
) -> int:
    """Get for the number of columns of B.

    Args:
        X: Data matrix.
        Y: Target matrix.
        intercept: Add intercept.

    Returns:
        Number of columns.
    """
    raise NotImplementedError

regularisation(X, Y, B, Z, Ytilde) staticmethod

Regularisation function.

Parameters:

Name Type Description Default
X Tensor

Data matrix.

required
Y Tensor

Target matrix.

required
B Tensor

Coefficient matrix.

required
Z Tensor

Embedding matrix.

required
Ytilde Tensor

Prediction matrix.

required

Returns:

Type Description
Tensor

Additional loss term.

Source code in slisemap/local_models.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@staticmethod
def regularisation(
    X: torch.Tensor,
    Y: torch.Tensor,
    B: torch.Tensor,
    Z: torch.Tensor,
    Ytilde: torch.Tensor,
) -> torch.Tensor:
    """Regularisation function.

    Args:
        X: Data matrix.
        Y: Target matrix.
        B: Coefficient matrix.
        Z: Embedding matrix.
        Ytilde: Prediction matrix.

    Returns:
        Additional loss term.
    """
    return 0.0

linear_regression(X, B)

Prediction function for (multiple) linear regression.

Parameters:

Name Type Description Default
X Tensor

Data matrix [n_x, m].

required
B Tensor

Coefficient Matrix [n_b, m * p].

required

Returns:

Type Description
Tensor

Prediction tensor [n_b, n_x, p]

Source code in slisemap/local_models.py
114
115
116
117
118
119
120
121
122
123
124
125
def linear_regression(X: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """Prediction function for (multiple) linear regression.

    Args:
        X: Data matrix [n_x, m].
        B: Coefficient Matrix [n_b, m * p].

    Returns:
        Prediction tensor [n_b, n_x, p]
    """
    # return (B @ X.T)[:, :, None] # Only for single linear regression
    return (B.view(B.shape[0], -1, X.shape[1]) @ X.T).transpose(1, 2)

multiple_linear_regression(X, B)

Prediction function for multiple linear regression. DEPRECATED.

Parameters:

Name Type Description Default
X Tensor

Data matrix [n_x, m].

required
B Tensor

Coefficient Matrix [n_b, m*p].

required

Returns:

Type Description
Tensor

Prediction tensor [n_b, n_x, p]

Deprecated

1.4: In favour of a combined linear_regression

Source code in slisemap/local_models.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def multiple_linear_regression(X: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """Prediction function for multiple linear regression. **DEPRECATED**.

    Args:
        X: Data matrix [n_x, m].
        B: Coefficient Matrix [n_b, m*p].

    Returns:
        Prediction tensor [n_b, n_x, p]

    Deprecated:
        1.4: In favour of a combined `linear_regression`
    """
    _deprecated(multiple_linear_regression, linear_regression)
    return linear_regression(X, B)

linear_regression_loss(Ytilde, Y, B=None)

Least squares loss function for (multiple) linear regresson.

Parameters:

Name Type Description Default
Ytilde Tensor

Predicted values [n_b, n_x, p].

required
Y Tensor

Ground truth values [n_x, p].

required
B Optional[Tensor]

Coefficient matrix. Deprecated. Defaults to None.

None

Returns:

Type Description
Tensor

Loss values [n_b, n_x].

Deprecated

1.6: B

Source code in slisemap/local_models.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def linear_regression_loss(
    Ytilde: torch.Tensor, Y: torch.Tensor, B: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """Least squares loss function for (multiple) linear regresson.

    Args:
        Ytilde: Predicted values [n_b, n_x, p].
        Y: Ground truth values [n_x, p].
        B: Coefficient matrix. **Deprecated**. Defaults to None.

    Returns:
        Loss values [n_b, n_x].

    Deprecated:
        1.6: B
    """
    return ((Ytilde - Y.expand(Ytilde.shape)) ** 2).sum(dim=-1)

linear_regression_coefficients(X, Y, intercept=False)

Get the number of coefficients for a (multiple) linear regression.

Parameters:

Name Type Description Default
X Union[Tensor, ndarray]

Data matrix.

required
Y Union[Tensor, ndarray]

Target matrix.

required
intercept bool

Add an (additional) intercept to X. Defaults to False.

False

Returns:

Type Description
int

Number of coefficients (columns of B).

Source code in slisemap/local_models.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def linear_regression_coefficients(
    X: Union[torch.Tensor, np.ndarray],
    Y: Union[torch.Tensor, np.ndarray],
    intercept: bool = False,
) -> int:
    """Get the number of coefficients for a (multiple) linear regression.

    Args:
        X: Data matrix.
        Y: Target matrix.
        intercept: Add an (additional) intercept to X. Defaults to False.

    Returns:
        Number of coefficients (columns of B).
    """
    return (X.shape[1] + intercept) * (1 if len(Y.shape) < 2 else Y.shape[1])

LinearRegression

Bases: ALocalModel

A class that contains all the functions needed for linear regression.

Source code in slisemap/local_models.py
182
183
184
185
186
187
class LinearRegression(ALocalModel):
    """A class that contains all the functions needed for linear regression."""

    predict = linear_regression
    loss = linear_regression_loss
    coefficients = linear_regression_coefficients

absolute_error(Ytilde, Y, B=None)

Absolute error function for (multiple) linear regresson.

Parameters:

Name Type Description Default
Ytilde Tensor

Predicted values [n_b, n_x, p].

required
Y Tensor

Ground truth values [n_x, p].

required
B Optional[Tensor]

Coefficient matrix. Deprecated. Defaults to None.

None

Returns:

Type Description
Tensor

Loss values [n_b, n_x].

Deprecated

1.6: B

Source code in slisemap/local_models.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def absolute_error(
    Ytilde: torch.Tensor, Y: torch.Tensor, B: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """Absolute error function for (multiple) linear regresson.

    Args:
        Ytilde: Predicted values [n_b, n_x, p].
        Y: Ground truth values [n_x, p].
        B: Coefficient matrix. **Deprecated**. Defaults to None.

    Returns:
        Loss values [n_b, n_x].

    Deprecated:
        1.6: B
    """
    return torch.abs(Ytilde - Y.expand(Ytilde.shape)).sum(dim=-1)

LinearAbsoluteRegression

Bases: ALocalModel

A class that contains all the functions needed for linear regression with absolute errors.

Source code in slisemap/local_models.py
209
210
211
212
213
214
class LinearAbsoluteRegression(ALocalModel):
    """A class that contains all the functions needed for linear regression with absolute errors."""

    predict = linear_regression
    loss = absolute_error
    coefficients = linear_regression_coefficients

logistic_regression(X, B)

Prediction function for (multinomial) logistic regression.

Note that the number of coefficients is m * (p-1) due to the normalisation of softmax.

Parameters:

Name Type Description Default
X Tensor

Data matrix [n_x, m].

required
B Tensor

Coefficient Matrix [n_b, m*(p-1)].

required

Returns:

Type Description
Tensor

Prediction tensor [n_b, n_x, p]

Source code in slisemap/local_models.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def logistic_regression(X: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """Prediction function for (multinomial) logistic regression.

    Note that the number of coefficients is `m * (p-1)` due to the normalisation of softmax.

    Args:
        X: Data matrix [n_x, m].
        B: Coefficient Matrix [n_b, m*(p-1)].

    Returns:
        Prediction tensor [n_b, n_x, p]
    """
    n_x, m = X.shape
    n_b, o = B.shape
    p = 1 + torch.div(o, m, rounding_mode="trunc")
    a = torch.zeros([n_b, n_x, p], device=B.device, dtype=B.dtype)
    for i in range(p - 1):
        a[:, :, i] = B[:, (i * m) : ((i + 1) * m)] @ X.T
    return softmax(a, 2)

logistic_regression_loss(Ytilde, Y, B=None)

Squared Hellinger distance function for (multinomial) logistic regression.

Parameters:

Name Type Description Default
Ytilde Tensor

Predicted values [n_b, n_x, p].

required
Y Tensor

Ground truth values [n_x, p].

required
B Optional[Tensor]

Coefficient matrix. Deprecated. Defaults to None.

None

Returns:

Type Description
Tensor

Loss values [n_b, n_x].

Deprecated

1.6: B

Source code in slisemap/local_models.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def logistic_regression_loss(
    Ytilde: torch.Tensor, Y: torch.Tensor, B: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """Squared Hellinger distance function for (multinomial) logistic regression.

    Args:
        Ytilde: Predicted values [n_b, n_x, p].
        Y: Ground truth values [n_x, p].
        B: Coefficient matrix. **Deprecated**. Defaults to None.

    Returns:
        Loss values [n_b, n_x].

    Deprecated:
        1.6: B
    """
    _assert_no_trace(
        lambda: (
            Ytilde.shape[-1] <= Y.shape[-1],
            f"Too few columns in Y: {Y.shape[-1]} < {Ytilde.shape[-1]}",
        ),
        logistic_regression_loss,
    )
    return ((Ytilde.sqrt() - Y.sqrt().expand(Ytilde.shape)) ** 2).sum(dim=-1) * 0.5

logistic_regression_coefficients(X, Y, intercept=False)

Get the number of coefficients for a (multinomial) logistic regression.

Parameters:

Name Type Description Default
X Union[Tensor, ndarray]

Data matrix.

required
Y Union[Tensor, ndarray]

Target matrix.

required
intercept bool

Add an (additional) intercept to X. Defaults to False.

False

Returns:

Type Description
int

Number of coefficients (columns of B).

Source code in slisemap/local_models.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def logistic_regression_coefficients(
    X: Union[torch.Tensor, np.ndarray],
    Y: Union[torch.Tensor, np.ndarray],
    intercept: bool = False,
) -> int:
    """Get the number of coefficients for a (multinomial) logistic regression.

    Args:
        X: Data matrix.
        Y: Target matrix.
        intercept: Add an (additional) intercept to X. Defaults to False.

    Returns:
        Number of coefficients (columns of B).
    """
    return (X.shape[1] + intercept) * max(1, Y.shape[1] - 1)

LogisticRegression

Bases: ALocalModel

A class that contains all the functions needed for logistic regression.

Source code in slisemap/local_models.py
282
283
284
285
286
287
class LogisticRegression(ALocalModel):
    """A class that contains all the functions needed for logistic regression."""

    predict = logistic_regression
    loss = logistic_regression_loss
    coefficients = logistic_regression_coefficients

logistic_regression_log(X, B)

Prediction function for (multinomial) logistic regression that returns the log of the prediction.

Note that the number of coefficients is m * (p-1) due to the normalisation of softmax.

Parameters:

Name Type Description Default
X Tensor

Data matrix [n_x, m].

required
B Tensor

Coefficient Matrix [n_b, m*(p-1)].

required

Returns:

Type Description
Tensor

Prediction tensor [n_b, n_x, p]

Source code in slisemap/local_models.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def logistic_regression_log(X: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """Prediction function for (multinomial) logistic regression that returns the **log of the prediction**.

    Note that the number of coefficients is `m * (p-1)` due to the normalisation of softmax.

    Args:
        X: Data matrix [n_x, m].
        B: Coefficient Matrix [n_b, m*(p-1)].

    Returns:
        Prediction tensor [n_b, n_x, p]
    """
    n_x, m = X.shape
    n_b, o = B.shape
    p = 1 + torch.div(o, m, rounding_mode="trunc")
    a = torch.zeros([n_b, n_x, p], device=B.device, dtype=B.dtype)
    for i in range(p - 1):
        a[:, :, i] = B[:, (i * m) : ((i + 1) * m)] @ X.T
    return a - torch.logsumexp(a, 2, True)

logistic_regression_log_loss(Ytilde, Y, B=None)

Cross entropy loss function for (multinomial) logistic regression.

Note that this loss function expects Ytilde to be the log of the predicted probabilities.

Parameters:

Name Type Description Default
Ytilde Tensor

Predicted logits [n_b, n_x, p].

required
Y Tensor

Ground truth values [n_x, p].

required
B Optional[Tensor]

Coefficient matrix. Deprecated. Defaults to None.

None

Returns:

Type Description
Tensor

Loss values [n_b, n_x].

Deprecated

1.6: B

Source code in slisemap/local_models.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def logistic_regression_log_loss(
    Ytilde: torch.Tensor, Y: torch.Tensor, B: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """Cross entropy loss function for (multinomial) logistic regression.

    Note that this loss function expects `Ytilde` to be the **log of the predicted probabilities**.

    Args:
        Ytilde: Predicted logits [n_b, n_x, p].
        Y: Ground truth values [n_x, p].
        B: Coefficient matrix. **Deprecated**. Defaults to None.

    Returns:
        Loss values [n_b, n_x].

    Deprecated:
        1.6: B
    """
    _assert_no_trace(
        lambda: (
            Ytilde.shape[-1] <= Y.shape[-1],
            f"Too few columns in Y: {Y.shape[-1]} < {Ytilde.shape[-1]}",
        ),
        logistic_regression_loss,
    )
    return torch.sum(-Y * Ytilde - (1 - Y) * torch.log1p(-torch.exp(Ytilde)), -1)

LogisticLogRegression

Bases: ALocalModel

A class that contains all the functions needed for logistic regression.

The predictions are in log-space rather than probabilities for numerical stability.

Source code in slisemap/local_models.py
339
340
341
342
343
344
345
346
347
class LogisticLogRegression(ALocalModel):
    """A class that contains all the functions needed for logistic regression.

    The predictions are in log-space rather than probabilities for numerical stability.
    """

    predict = logistic_regression_log
    loss = logistic_regression_log_loss
    coefficients = logistic_regression_coefficients

identify_local_model(local_model, local_loss=None, coefficients=None, regularisation=None)

Identify the "predict", "loss", and "coefficients" functions for a local model.

Parameters:

Name Type Description Default
local_model Union[LocalModelCollection, CallableLike[predict]]

A instance/subclass of ALocalModel, a predict function, or a tuple of functions.

required
local_loss Optional[CallableLike[loss]]

A loss function or None if it is part of local_model. Defaults to None.

None
coefficients Union[None, int, CallableLike[coefficients]]

The number of coefficients, or a function giving that number, or None if it is part of local_model. Defaults to None.

None
regularisation Union[None, CallableLike[regularisation]]

Additional regularisation function. Defaults to None.

None

Returns:

Name Type Description
predict Callable

"prediction" function (takes X and B and returns predicted Y for every X and B combination).

loss Callable

"loss" function (takes predicted Y and real Y and returns the loss).

coefficients Callable

"coefficients" function (takes X and Y and returns the number of coefficients for B).

regularisation Callable

"regularisation" function (takes X, Y, B, Z, and, Ytilde and returns an additional loss scalar).

Source code in slisemap/local_models.py
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def identify_local_model(
    local_model: Union[LocalModelCollection, CallableLike[ALocalModel.predict]],
    local_loss: Optional[CallableLike[ALocalModel.loss]] = None,
    coefficients: Union[None, int, CallableLike[ALocalModel.coefficients]] = None,
    regularisation: Union[None, CallableLike[ALocalModel.regularisation]] = None,
) -> Tuple[Callable, Callable, Callable, Callable]:
    """Identify the "predict", "loss", and "coefficients" functions for a local model.

    Args:
        local_model: A instance/subclass of `ALocalModel`, a predict function, or a tuple of functions.
        local_loss: A loss function or None if it is part of `local_model`. Defaults to None.
        coefficients: The number of coefficients, or a function giving that number, or None if it is part of `local_model`. Defaults to None.
        regularisation: Additional regularisation function. Defaults to None.

    Returns:
        predict: "prediction" function (takes X and B and returns predicted Y for every X and B combination).
        loss: "loss" function (takes predicted Y and real Y and returns the loss).
        coefficients: "coefficients" function (takes X and Y and returns the number of coefficients for B).
        regularisation: "regularisation" function (takes X, Y, B, Z, and, Ytilde and returns an additional loss scalar).
    """
    pred_fn = None
    loss_fn = None
    coef_fn = linear_regression_coefficients
    regu_fn = ALocalModel.regularisation
    if isinstance(local_model, ALocalModel) or (
        isinstance(local_model, type) and issubclass(local_model, ALocalModel)
    ):
        pred_fn = local_model.predict
        loss_fn = local_model.loss
        coef_fn = local_model.coefficients
        regu_fn = local_model.regularisation
    elif callable(local_model):
        pred_fn = local_model
        if local_model in (linear_regression, multiple_linear_regression):
            loss_fn = linear_regression_loss
            coef_fn = linear_regression_coefficients
        elif local_model == logistic_regression:
            loss_fn = logistic_regression_loss
            coef_fn = logistic_regression_coefficients
        elif local_model == logistic_regression_log:
            loss_fn = logistic_regression_log_loss
            coef_fn = logistic_regression_coefficients
        else:
            loss_fn = local_loss
            coef_fn = coefficients
    elif isinstance(local_model, Sequence):
        pred_fn = local_model[0]
        loss_fn = local_model[1] if len(local_model) > 1 else loss_fn
        coef_fn = local_model[2] if len(local_model) > 2 else coef_fn
        regu_fn = local_model[3] if len(local_model) > 3 else regu_fn
    else:
        _warn("Could not identity the local model", identify_local_model)
        pred_fn = local_model
    if local_loss is not None:
        loss_fn = local_loss
    if coefficients is not None:
        coef_fn = coefficients
    if regularisation is not None:
        regu_fn = regularisation
    if isinstance(coef_fn, int):
        i_coef = coef_fn
        coef_fn = lambda X, Y: i_coef  # noqa: E731
    _assert(pred_fn is not None, "`local_model` function missing")
    _assert(loss_fn is not None, "`local_loss` function missing")
    return pred_fn, loss_fn, coef_fn, regu_fn