Skip to content

slisemap.utils

Module that contains various useful functions.

softmax_row_kernel(D)

Kernel function that applies softmax on the rows.

Parameters:

Name Type Description Default
D Tensor

Distance matrix.

required

Returns:

Type Description
Tensor

Weight matrix.

Source code in slisemap/utils.py
23
24
25
26
27
28
29
30
31
32
def softmax_row_kernel(D: torch.Tensor) -> torch.Tensor:
    """Kernel function that applies softmax on the rows.

    Args:
        D: Distance matrix.

    Returns:
        Weight matrix.
    """
    return torch.softmax(-D, 1)

softmax_column_kernel(D)

Kernel function that applies softmax on the columns.

Parameters:

Name Type Description Default
D Tensor

Distance matrix.

required

Returns:

Type Description
Tensor

Weight matrix.

Source code in slisemap/utils.py
35
36
37
38
39
40
41
42
43
44
def softmax_column_kernel(D: torch.Tensor) -> torch.Tensor:
    """Kernel function that applies softmax on the columns.

    Args:
        D: Distance matrix.

    Returns:
        Weight matrix.
    """
    return torch.softmax(-D, 0)

squared_distance(A, B)

Distance function that returns the squared euclidean distances.

Parameters:

Name Type Description Default
A Tensor

The first matrix [n1, d].

required
B Tensor

The second matrix [n2, d].

required

Returns:

Type Description
Tensor

Distance matrix [n1, n2].

Source code in slisemap/utils.py
47
48
49
50
51
52
53
54
55
56
57
def squared_distance(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """Distance function that returns the squared euclidean distances.

    Args:
        A: The first matrix [n1, d].
        B: The second matrix [n2, d].

    Returns:
        Distance matrix [n1, n2].
    """
    return torch.sum((A[:, None, ...] - B[None, ...]) ** 2, -1)

SlisemapException

Bases: Exception

Custom Exception type (for filtering).

Source code in slisemap/utils.py
60
61
62
63
class SlisemapException(Exception):  # noqa: N818
    """Custom Exception type (for filtering)."""

    pass

SlisemapWarning

Bases: Warning

Custom Warning type (for filtering).

Source code in slisemap/utils.py
66
67
68
69
class SlisemapWarning(Warning):
    """Custom Warning type (for filtering)."""

    pass

CallableLike

Bases: Generic[_F]

Type annotation for functions matching the signature of a given function.

Source code in slisemap/utils.py
138
139
140
141
142
143
class CallableLike(Generic[_F]):
    """Type annotation for functions matching the signature of a given function."""

    @staticmethod
    def __class_getitem__(fn: _F) -> _F:
        return fn

tonp(x)

Convert a torch.Tensor to a numpy.ndarray.

If x is not a torch.Tensor then np.asarray is used instead.

Parameters:

Name Type Description Default
x Union[Tensor, object]

Input torch.Tensor.

required

Returns:

Type Description
ndarray

Output numpy.ndarray.

Source code in slisemap/utils.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def tonp(x: Union[torch.Tensor, object]) -> np.ndarray:
    """Convert a `torch.Tensor` to a `numpy.ndarray`.

    If `x` is not a `torch.Tensor` then `np.asarray` is used instead.

    Args:
        x: Input `torch.Tensor`.

    Returns:
        Output `numpy.ndarray`.
    """
    if isinstance(x, torch.Tensor):
        return x.cpu().detach().numpy()
    else:
        return np.asarray(x)

CheckConvergence

An object that tries to estimate when an optimisation has converged.

Use it for, e.g., escape+optimisation cycles in Slisemap.

Source code in slisemap/utils.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
class CheckConvergence:
    """An object that tries to estimate when an optimisation has converged.

    Use it for, e.g., escape+optimisation cycles in Slisemap.
    """

    __slots__ = {
        "current": "Current loss value.",
        "best": "Best loss value, so far.",
        "counter": "Number of steps since the best loss value.",
        "patience": "Number of steps allowed without improvement.",
        "optimal": "Cache for storing the state that produced the best loss value.",
        "max_iter": "The maximum number of iterations.",
        "iter": "The current number of iterations.",
        "rel": "Minimum relative error for convergence check",
    }

    def __init__(
        self, patience: float = 3, max_iter: int = 1 << 20, rel: float = 1e-4
    ) -> None:
        """Create a `CheckConvergence` object.

        Args:
            patience: How long should the optimisation continue without improvement. Defaults to 3.
            max_iter: The maximum number of iterations. Defaults to `2**20`.
            rel: Minimum relative error change that is considered an improvement. Defaults to `1e-4`.
        """
        self.current = np.inf
        self.best = np.asarray(np.inf)
        self.counter = 0.0
        self.patience = patience
        self.optimal = None
        self.max_iter = max_iter
        self.iter = 0
        self.rel = rel

    def has_converged(
        self,
        loss: Union[float, Sequence[float], np.ndarray],
        store: Optional[Callable[[], Any]] = None,
        verbose: bool = False,
    ) -> bool:
        """Check if the optimisation has converged.

        If more than one loss value is provided, then only the first one is checked when storing the `optimal_state`.
        The other losses are only used for checking convergence.

        Args:
            loss: The latest loss value(s).
            store: Function that returns the current state for storing in `self.optimal_state`. Defaults to None.
            verbose: Pring debug messages. Defaults to False.

        Returns:
            True if the optimisation has converged.
        """
        self.iter += 1
        loss = np.asarray(loss)
        if np.any(np.isnan(loss)):
            _warn("Loss is `nan`", CheckConvergence.has_converged)
            return True
        if np.any(loss + np.abs(loss) * self.rel < self.best):
            self.counter = 0.0  # Reset the counter if a new best
            if store is not None and loss.item(0) < self.best.item(0):
                self.optimal = store()
            self.best = np.minimum(loss, self.best)
        else:
            # Increase the counter if no improvement
            self.counter += np.mean(self.current <= loss)
        self.current = loss
        if verbose:
            print(
                f"CheckConvergence: patience={self.patience-self.counter:g}/{self.patience:g}   iter={self.iter}/{self.max_iter}"
            )
        return self.counter >= self.patience or self.iter >= self.max_iter

__init__(patience=3, max_iter=1 << 20, rel=0.0001)

Create a CheckConvergence object.

Parameters:

Name Type Description Default
patience float

How long should the optimisation continue without improvement. Defaults to 3.

3
max_iter int

The maximum number of iterations. Defaults to 2**20.

1 << 20
rel float

Minimum relative error change that is considered an improvement. Defaults to 1e-4.

0.0001
Source code in slisemap/utils.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def __init__(
    self, patience: float = 3, max_iter: int = 1 << 20, rel: float = 1e-4
) -> None:
    """Create a `CheckConvergence` object.

    Args:
        patience: How long should the optimisation continue without improvement. Defaults to 3.
        max_iter: The maximum number of iterations. Defaults to `2**20`.
        rel: Minimum relative error change that is considered an improvement. Defaults to `1e-4`.
    """
    self.current = np.inf
    self.best = np.asarray(np.inf)
    self.counter = 0.0
    self.patience = patience
    self.optimal = None
    self.max_iter = max_iter
    self.iter = 0
    self.rel = rel

has_converged(loss, store=None, verbose=False)

Check if the optimisation has converged.

If more than one loss value is provided, then only the first one is checked when storing the optimal_state. The other losses are only used for checking convergence.

Parameters:

Name Type Description Default
loss Union[float, Sequence[float], ndarray]

The latest loss value(s).

required
store Optional[Callable[[], Any]]

Function that returns the current state for storing in self.optimal_state. Defaults to None.

None
verbose bool

Pring debug messages. Defaults to False.

False

Returns:

Type Description
bool

True if the optimisation has converged.

Source code in slisemap/utils.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def has_converged(
    self,
    loss: Union[float, Sequence[float], np.ndarray],
    store: Optional[Callable[[], Any]] = None,
    verbose: bool = False,
) -> bool:
    """Check if the optimisation has converged.

    If more than one loss value is provided, then only the first one is checked when storing the `optimal_state`.
    The other losses are only used for checking convergence.

    Args:
        loss: The latest loss value(s).
        store: Function that returns the current state for storing in `self.optimal_state`. Defaults to None.
        verbose: Pring debug messages. Defaults to False.

    Returns:
        True if the optimisation has converged.
    """
    self.iter += 1
    loss = np.asarray(loss)
    if np.any(np.isnan(loss)):
        _warn("Loss is `nan`", CheckConvergence.has_converged)
        return True
    if np.any(loss + np.abs(loss) * self.rel < self.best):
        self.counter = 0.0  # Reset the counter if a new best
        if store is not None and loss.item(0) < self.best.item(0):
            self.optimal = store()
        self.best = np.minimum(loss, self.best)
    else:
        # Increase the counter if no improvement
        self.counter += np.mean(self.current <= loss)
    self.current = loss
    if verbose:
        print(
            f"CheckConvergence: patience={self.patience-self.counter:g}/{self.patience:g}   iter={self.iter}/{self.max_iter}"
        )
    return self.counter >= self.patience or self.iter >= self.max_iter

LBFGS(loss_fn, variables, max_iter=500, max_eval=None, line_search_fn='strong_wolfe', time_limit=None, increase_tolerance=False, verbose=False, **kwargs)

Optimise a function using LBFGS.

Parameters:

Name Type Description Default
loss_fn Callable[[], Tensor]

Function that returns a value to be minimised.

required
variables List[Tensor]

List of variables to optimise (must have requires_grad=True).

required
max_iter int

Maximum number of LBFGS iterations. Defaults to 500.

500
max_eval Optional[int]

Maximum number of function evaluations. Defaults to 1.25 * max_iter.

None
line_search_fn Optional[str]

Line search method (None or "strong_wolfe"). Defaults to "strong_wolfe".

'strong_wolfe'
time_limit Optional[float]

Optional time limit for the optimisation (in seconds). Defaults to None.

None
increase_tolerance bool

Increase the tolerances for convergence checking. Defaults to False.

False
verbose bool

Print status messages. Defaults to False.

False

Other Parameters:

Name Type Description
**kwargs Any

Arguments forwarded to torch.optim.LBFGS.

Returns:

Type Description
LBFGS

The LBFGS optimiser.

Source code in slisemap/utils.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def LBFGS(
    loss_fn: Callable[[], torch.Tensor],
    variables: List[torch.Tensor],
    max_iter: int = 500,
    max_eval: Optional[int] = None,
    line_search_fn: Optional[str] = "strong_wolfe",
    time_limit: Optional[float] = None,
    increase_tolerance: bool = False,
    verbose: bool = False,
    **kwargs: Any,
) -> torch.optim.LBFGS:
    """Optimise a function using [LBFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS).

    Args:
        loss_fn: Function that returns a value to be minimised.
        variables: List of variables to optimise (must have `requires_grad=True`).
        max_iter: Maximum number of LBFGS iterations. Defaults to 500.
        max_eval: Maximum number of function evaluations. Defaults to `1.25 * max_iter`.
        line_search_fn: Line search method (None or "strong_wolfe"). Defaults to "strong_wolfe".
        time_limit: Optional time limit for the optimisation (in seconds). Defaults to None.
        increase_tolerance: Increase the tolerances for convergence checking. Defaults to False.
        verbose: Print status messages. Defaults to False.

    Keyword Args:
        **kwargs: Arguments forwarded to [`torch.optim.LBFGS`](https://pytorch.org/docs/stable/generated/torch.optim.LBFGS.html).

    Returns:
        The LBFGS optimiser.
    """
    if increase_tolerance:
        kwargs["tolerance_grad"] = 100 * kwargs.get("tolerance_grad", 1e-7)
        kwargs["tolerance_change"] = 100 * kwargs.get("tolerance_change", 1e-9)
    optimiser = torch.optim.LBFGS(
        variables,
        max_iter=max_iter if time_limit is None else 20,
        max_eval=max_eval,
        line_search_fn=line_search_fn,
        **kwargs,
    )

    def closure() -> torch.Tensor:
        optimiser.zero_grad()
        loss = loss_fn()
        loss.backward()
        return loss

    if time_limit is None:
        loss = optimiser.step(closure)
    else:
        start = timer()
        prev_evals = 0
        for _ in range((max_iter - 1) // 20 + 1):
            loss = optimiser.step(closure)
            if not torch.all(torch.isfinite(loss)).cpu().detach().item():
                break
            if timer() - start > time_limit:
                if verbose:
                    print("LBFGS: Time limit exceeded!")
                break
            tot_evals = optimiser.state_dict()["state"][0]["func_evals"]
            if prev_evals + 1 == tot_evals:
                break  # LBFGS has converged if it returns after one evaluation
            prev_evals = tot_evals
            if max_eval is not None:
                if tot_evals >= max_eval:
                    break  # Number of evaluations exceeded max_eval
                optimiser.param_groups[0]["max_eval"] -= tot_evals
            # The number of steps is limited by ceiling(max_iter/20) with 20 iterations per step

    if verbose:
        iters = optimiser.state_dict()["state"][0]["n_iter"]
        evals = optimiser.state_dict()["state"][0]["func_evals"]
        loss = loss.mean().cpu().detach().item()
        if not np.isfinite(loss):
            print("LBFGS: Loss is not finite {}!")
        elif iters >= max_iter:
            print("LBFGS: Maximum number of iterations exceeded!")
        elif max_eval is not None and evals >= max_eval:
            print("LBFGS: Maximum number of evaluations exceeded!")
        else:
            print(f"LBFGS: Converged in {iters} iterations")

    return optimiser

PCA_rotation(X, components=-1, center=True, full=True, niter=10)

Calculate the rotation matrix from PCA.

If the PCA fails (e.g. if original matrix is not full rank) then this shows a warning instead of throwing an error (returns a dummy rotation).

Parameters:

Name Type Description Default
X Tensor

The original matrix.

required
components int

The maximum number of components in the embedding. Defaults to min(*X.shape).

-1
center bool

Center the matrix before calculating the PCA.

True
full bool

Use a full SVD for the PCA (slower). Defaults to True.

True
niter int

The number of iterations when a randomised approach is used. Defaults to 10.

10

Returns:

Type Description
Tensor

Rotation matrix that turns the original matrix into the embedded space.

Source code in slisemap/utils.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def PCA_rotation(
    X: torch.Tensor,
    components: int = -1,
    center: bool = True,
    full: bool = True,
    niter: int = 10,
) -> torch.Tensor:
    """Calculate the rotation matrix from PCA.

    If the PCA fails (e.g. if original matrix is not full rank) then this shows a warning instead of throwing an error (returns a dummy rotation).

    Args:
        X: The original matrix.
        components: The maximum number of components in the embedding. Defaults to `min(*X.shape)`.
        center: Center the matrix before calculating the PCA.
        full: Use a full SVD for the PCA (slower). Defaults to True.
        niter: The number of iterations when a randomised approach is used. Defaults to 10.

    Returns:
        Rotation matrix that turns the original matrix into the embedded space.
    """
    try:
        components = min(*X.shape, components) if components > 0 else min(*X.shape)
        if full:
            if center:
                X = X - X.mean(dim=(-2,), keepdim=True)
            return torch.linalg.svd(X, full_matrices=False)[2].T[:, :components]
        else:
            return torch.pca_lowrank(X, components, center=center, niter=niter)[2]
    except Exception:
        _warn("Could not perform PCA", PCA_rotation)
        z = torch.zeros((X.shape[1], components), dtype=X.dtype, device=X.device)
        z.fill_diagonal_(1.0, True)
        return z

global_model(X, Y, local_model, local_loss, coefficients=None, lasso=0.0, ridge=0.0)

Find coefficients for a global model.

Parameters:

Name Type Description Default
X Tensor

Data matrix.

required
Y Tensor

Target matrix.

required
local_model Callable[[Tensor, Tensor], Tensor]

Prediction function for the model.

required
local_loss Callable[[Tensor, Tensor, Tensor], Tensor]

Loss function for the model.

required
coefficients Optional[int]

Number of coefficients. Defaults to X.shape[1].

None
lasso float

Lasso-regularisation coefficient for B ($\lambda_{lasso} * ||B||_1$). Defaults to 0.0.

0.0
ridge float

Ridge-regularisation coefficient for B ($\lambda_{ridge} * ||B||_2$). Defaults to 0.0.

0.0

Returns:

Type Description
Tensor

Global model coefficients.

Source code in slisemap/utils.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
def global_model(
    X: torch.Tensor,
    Y: torch.Tensor,
    local_model: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    local_loss: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
    coefficients: Optional[int] = None,
    lasso: float = 0.0,
    ridge: float = 0.0,
) -> torch.Tensor:
    r"""Find coefficients for a global model.

    Args:
        X: Data matrix.
        Y: Target matrix.
        local_model: Prediction function for the model.
        local_loss: Loss function for the model.
        coefficients: Number of coefficients. Defaults to X.shape[1].
        lasso: Lasso-regularisation coefficient for B ($\lambda_{lasso} * ||B||_1$). Defaults to 0.0.
        ridge: Ridge-regularisation coefficient for B ($\lambda_{ridge} * ||B||_2$). Defaults to 0.0.

    Returns:
        Global model coefficients.
    """
    shape = (1, X.shape[1] * Y.shape[1] if coefficients is None else coefficients)
    B = torch.zeros(shape, dtype=X.dtype, device=X.device).requires_grad_(True)

    def loss() -> torch.Tensor:
        loss = local_loss(local_model(X, B), Y).mean()
        if lasso > 0:
            loss += lasso * torch.sum(B.abs())
        if ridge > 0:
            loss += ridge * torch.sum(B**2)
        return loss

    LBFGS(loss, [B])
    return B.detach()

dict_array(dict)

Turn a dictionary of various values to a dictionary of numpy arrays with equal length inplace.

Parameters:

Name Type Description Default
dict Dict[str, Any]

Dictionary.

required

Returns:

Type Description
Dict[str, ndarray]

The same dictionary where the values are numpy arrays with equal length.

Source code in slisemap/utils.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def dict_array(dict: Dict[str, Any]) -> Dict[str, np.ndarray]:
    """Turn a dictionary of various values to a dictionary of numpy arrays with equal length inplace.

    Args:
        dict: Dictionary.

    Returns:
        The same dictionary where the values are numpy arrays with equal length.
    """
    n = 1
    for k, v in dict.items():
        v = np.asarray(v).ravel()
        dict[k] = v
        n = max(n, len(v))
    for k, v in dict.items():
        if len(v) == 1:
            dict[k] = np.repeat(v, n)
        elif len(v) != n:
            _warn(f"Uneven lengths in dictionary ({k}: {len(v)} != {n})", dict_array)
    return dict

dict_append(df, d)

Append a dictionary of values to a dictionary of numpy arrays (see dict_array) inplace.

Parameters:

Name Type Description Default
df Dict[str, ndarray]

Dictionary of numpy arrays.

required
d Dict[str, Any]

Dictionary to append.

required

Returns:

Type Description
Dict[str, ndarray]

The same dictionary as df with the values from d appended.

Source code in slisemap/utils.py
420
421
422
423
424
425
426
427
428
429
430
431
432
433
def dict_append(df: Dict[str, np.ndarray], d: Dict[str, Any]) -> Dict[str, np.ndarray]:
    """Append a dictionary of values to a dictionary of numpy arrays (see `dict_array`) inplace.

    Args:
        df: Dictionary of numpy arrays.
        d: Dictionary to append.

    Returns:
        The same dictionary as `df` with the values from `d` appended.
    """
    d = dict_array(d)
    for k in df:
        df[k] = np.concatenate((df[k], d[k]), 0)
    return df

dict_concat(dicts)

Combine multiple dictionaries into one by concatenating the values.

Calls dict_array to pre-process the dictionaries.

Parameters:

Name Type Description Default
dicts Union[Sequence[Dict[str, Any]], Iterator[Dict[str, Any]]]

Sequence or Generator with dictionaries (all must have the same keys).

required

Returns:

Type Description
Dict[str, ndarray]

Combined dictionary.

Source code in slisemap/utils.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def dict_concat(
    dicts: Union[Sequence[Dict[str, Any]], Iterator[Dict[str, Any]]],
) -> Dict[str, np.ndarray]:
    """Combine multiple dictionaries into one by concatenating the values.

    Calls `dict_array` to pre-process the dictionaries.

    Args:
        dicts: Sequence or Generator with dictionaries (all must have the same keys).

    Returns:
        Combined dictionary.
    """
    if isinstance(dicts, Sequence):
        dicts = (d for d in dicts)
    df = dict_array(next(dicts))
    for d in dicts:
        dict_append(df, d)
    return df

ToTensor = Union[float, np.ndarray, torch.Tensor, 'pandas.DataFrame', Dict[str, Sequence[float]], Sequence[float]] module-attribute

Type annotations for objects that can be turned into a torch.Tensor with the to_tensor function.

to_tensor(input, **tensorargs)

Convert the input into a torch.Tensor (via numpy.ndarray if necessary).

This function wrapps torch.as_tensor (and numpy.asarray) and tries to extract row and column names. This function can handle arbitrary objects (such as pandas.DataFrame) if they implement .to_numpy() and, optionally, .index and .columns.

Parameters:

Name Type Description Default
input ToTensor

input data

required

Keyword Args: **tensorargs: additional arguments to torch.as_tensor

Returns:

Name Type Description
output Tensor

output tensor

rows Optional[Sequence[object]]

row names or None

columns Optional[Sequence[object]]

column names or None

Source code in slisemap/utils.py
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
def to_tensor(
    input: ToTensor, **tensorargs: object
) -> Tuple[torch.Tensor, Optional[Sequence[object]], Optional[Sequence[object]]]:
    """Convert the input into a `torch.Tensor` (via `numpy.ndarray` if necessary).

    This function wrapps `torch.as_tensor` (and `numpy.asarray`) and tries to extract row and column names.
    This function can handle arbitrary objects (such as `pandas.DataFrame`) if they implement `.to_numpy()` and, optionally, `.index` and `.columns`.

    Args:
        input: input data
    Keyword Args:
        **tensorargs: additional arguments to `torch.as_tensor`

    Returns:
        output: output tensor
        rows: row names or `None`
        columns: column names or `None`
    """
    if isinstance(input, dict):
        output = torch.as_tensor(np.asarray(tuple(input.values())).T, **tensorargs)
        return output, None, list(input.keys())
    elif isinstance(input, (np.ndarray, torch.Tensor)):
        return (torch.as_tensor(input, **tensorargs), None, None)
    else:
        # Check if X is similar to a Pandas DataFrame
        try:
            output = torch.as_tensor(input.to_numpy(), **tensorargs)
        except (AttributeError, TypeError):
            try:
                output = torch.as_tensor(input.numpy(), **tensorargs)
            except (AttributeError, TypeError):
                try:
                    output = torch.as_tensor(input, **tensorargs)
                except (TypeError, RuntimeError):
                    output = torch.as_tensor(np.asarray(input), **tensorargs)
        try:
            columns = input.columns if len(input.columns) == output.shape[1] else None
        except (AttributeError, TypeError):
            columns = None
        try:
            rows = input.index if len(input.index) == output.shape[0] else None
        except (AttributeError, TypeError):
            rows = None
        return output, rows, columns

Metadata

Bases: dict

Metadata for Slisemap objects.

Primarily row names, column names, and scaling information about the matrices (these are used when plotting). But other arbitrary information can also be stored in this dictionary (The main Slisemap class has predefined "slots").

Source code in slisemap/utils.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
class Metadata(dict):
    """Metadata for Slisemap objects.

    Primarily row names, column names, and scaling information about the matrices (these are used when plotting).
    But other arbitrary information can also be stored in this dictionary (The main Slisemap class has predefined "slots").
    """

    def __init__(self, root: "Slisemap", **kwargs: Any) -> None:  # noqa: F821
        """Create a Metadata dictionary."""
        super().__init__(**kwargs)
        self.root = root

    def set_rows(self, *rows: Optional[Sequence[object]]) -> None:
        """Set the row names with checks to avoid saving ranges.

        Args:
            *rows: row names
        """
        for row in rows:
            if row is not None:
                try:
                    # Check if row is `range(0, self.root.n, 1)`-like (duck typing)
                    if row.start == 0 and row.step == 1 and row.stop == self.root.n:
                        continue
                except AttributeError:
                    pass
                _assert(
                    len(row) == self.root.n,
                    f"Wrong number of row names {len(row)} != {self.root.n}",
                    Metadata.set_rows,
                )
                if all(i == j for i, j in enumerate(row)):
                    continue
                self["rows"] = list(row)
                break

    def set_variables(
        self,
        variables: Optional[Sequence[Any]] = None,
        add_intercept: Optional[bool] = None,
    ) -> None:
        """Set the variable names with checks.

        Args:
            variables: variable names
            add_intercept: add "Intercept" to the variable names. Defaults to `self.root.intercept`,
        """
        if add_intercept is None:
            add_intercept = self.root.intercept
        if variables is not None:
            variables = list(variables)
            if add_intercept:
                variables.append("Intercept")
            _assert(
                len(variables) == self.root.m,
                f"Wrong number of variables {len(variables)} != {self.root.m} ({variables})",
                Metadata.set_variables,
            )
            self["variables"] = variables

    def set_targets(self, targets: Union[None, str, Sequence[Any]] = None) -> None:
        """Set the target names with checks.

        Args:
            targets: target names
        """
        if targets is not None:
            targets = [targets] if isinstance(targets, str) else list(targets)
            _assert(
                len(targets) == self.root.o,
                f"Wrong number of targets {len(targets)} != {self.root.o}",
                Metadata.set_targets,
            )
            self["targets"] = targets

    def set_coefficients(self, coefficients: Optional[Sequence[Any]] = None) -> None:
        """Set the coefficient names with checks.

        Args:
            coefficients: coefficient names
        """
        if coefficients is not None:
            _assert(
                len(coefficients) == self.root.q,
                f"Wrong number of targets {len(coefficients)} != {self.root.q}",
                Metadata.set_coefficients,
            )
            self["coefficients"] = list(coefficients)

    def set_dimensions(self, dimensions: Optional[Sequence[Any]] = None) -> None:
        """Set the dimension names with checks.

        Args:
            dimensions: dimension names
        """
        if dimensions is not None:
            _assert(
                len(dimensions) == self.root.d,
                f"Wrong number of targets {len(dimensions)} != {self.root.d}",
                Metadata.set_dimensions,
            )
            self["dimensions"] = list(dimensions)

    def get_coefficients(self, fallback: bool = True) -> Optional[List[str]]:
        """Get a list of coefficient names.

        Args:
            fallback: If metadata for coefficients is missing, return a new list instead of None. Defaults to True.

        Returns:
            list of coefficient names
        """
        if "coefficients" in self:
            return self["coefficients"]
        if "variables" in self:
            if self.root.m == self.root.q:
                return self["variables"]
            if "targets" in self and self.root.m * self.root.o >= self.root.q:
                return [
                    f"{t}: {v}" for t in self["targets"] for v in self["variables"]
                ][: self.root.q]
        if fallback:
            return [f"B_{i}" for i in range(self.root.q)]
        else:
            return None

    def get_targets(self, fallback: bool = True) -> Optional[List[str]]:
        """Get a list of target names.

        Args:
            fallback: If metadata for targets is missing, return a new list instead of None. Defaults to True.

        Returns:
            list of target names
        """
        if "targets" in self:
            return self["targets"]
        elif fallback:
            return [f"Y_{i}" for i in range(self.root.o)] if self.root.o > 1 else ["Y"]
        else:
            return None

    def get_variables(
        self, intercept: bool = True, fallback: bool = True
    ) -> Optional[List[str]]:
        """Get a list of variable names.

        Args:
            intercept: include the intercept in the list. Defaults to True.
            fallback: If metadata for variables is missing, return a new list instead of None. Defaults to True.


        Returns:
            list of variable names
        """
        if "variables" in self:
            if self.root.intercept and not intercept:
                return self["variables"][:-1]
            else:
                return self["variables"]
        elif fallback:
            if self.root.intercept:
                if not intercept:
                    return [f"X_{i}" for i in range(self.root.m - 1)]
                else:
                    return [f"X_{i}" for i in range(self.root.m - 1)] + ["X_Intercept"]
            else:
                return [f"X_{i}" for i in range(self.root.m)]
        else:
            return None

    def get_dimensions(
        self, fallback: bool = True, long: bool = False
    ) -> Optional[List[str]]:
        """Get a list of dimension names.

        Args:
            fallback: If metadata for dimensions is missing, return a new list instead of None. Defaults to True.
            long: Use "SLISEMAP 1",... as fallback instead of "Z_0",...

        Returns:
            list of dimension names
        """
        if "dimensions" in self:
            return self["dimensions"]
        elif fallback:
            if long:
                cls = "Slisemap" if self.root is None else type(self.root).__name__
                return [f"{cls} {i+1}" for i in range(self.root.d)]
            else:
                return [f"Z_{i}" for i in range(self.root.d)]
        else:
            return None

    def get_rows(self, fallback: bool = True) -> Optional[Sequence[Any]]:
        """Get a list of row names.

        Args:
            fallback: If metadata for rows is missing, return a range instead of None. Defaults to True.

        Returns:
            list (or range) of row names
        """
        if "rows" in self:
            return self["rows"]
        elif fallback:
            return range(self.root.n)
        else:
            return None

    def set_scale_X(
        self,
        center: Union[None, torch.Tensor, np.ndarray, Sequence[float]] = None,
        scale: Union[None, torch.Tensor, np.ndarray, Sequence[float]] = None,
    ) -> None:
        """Set scaling information with checks.

        Use if `X` has been scaled before being input to Slisemap.
        Assuming the scaling can be converted to the form `X = (X_unscaled - center) / scale`.
        This allows some plots to (temporarily) revert the scaling (for more intuitive units).

        Args:
            center: The constant offset of `X`. Defaults to None.
            scale: The scaling factor of `X`. Defaults to None.
        """
        if center is not None:
            center = tonp(center).ravel()
            assert center.size == self.root.m - self.root.intercept
            self["X_center"] = center
        if scale is not None:
            scale = tonp(scale).ravel()
            assert scale.size == self.root.m - self.root.intercept
            self["X_scale"] = scale

    def set_scale_Y(
        self,
        center: Union[None, torch.Tensor, np.ndarray, Sequence[float]] = None,
        scale: Union[None, torch.Tensor, np.ndarray, Sequence[float]] = None,
    ) -> None:
        """Set scaling information with checks.

        Use if `Y` has been scaled before being input to Slisemap.
        Assuming the scaling can be converted to the form `Y = (Y_unscaled - center) / scale`.
        This allows some plots to (temporarily) revert the scaling (for more intuitive units).

        Args:
            center: The constant offset of `Y`. Defaults to None.
            scale: The scaling factor of `Y`. Defaults to None.
        """
        if center is not None:
            center = tonp(center).ravel()
            assert center.size == self.root.o
            self["Y_center"] = center
        if scale is not None:
            scale = tonp(scale).ravel()
            assert scale.size == self.root.o
            self["Y_scale"] = scale

    def unscale_X(self, X: Optional[np.ndarray] = None) -> np.ndarray:
        """Unscale X if the scaling information has been given (see `set_scale_X`).

        Args:
            X: The data matrix X (or `self.root.get_X(intercept=False)` if None).

        Returns:
            Possibly scaled X.
        """
        if X is None:
            X = self.root.get_X(intercept=False)
        if "X_scale" in self:
            X = X * self["X_scale"][None, :]
        if "X_center" in self:
            X = X + self["X_center"][None, :]
        return X

    def unscale_Y(self, Y: Optional[np.ndarray] = None) -> np.ndarray:
        """Unscale Y if the scaling information has been given (see `set_scale_Y`).

        Args:
            Y: The response matrix Y (or `self.root.get_Y()` if None).

        Returns:
            Possibly scaled Y.
        """
        if Y is None:
            Y = self.root.get_Y()
        if "Y_scale" in self:
            Y = Y * self["Y_scale"][None, :]
        if "Y_center" in self:
            Y = Y + self["Y_center"][None, :]
        return Y

__init__(root, **kwargs)

Create a Metadata dictionary.

Source code in slisemap/utils.py
521
522
523
524
def __init__(self, root: "Slisemap", **kwargs: Any) -> None:  # noqa: F821
    """Create a Metadata dictionary."""
    super().__init__(**kwargs)
    self.root = root

set_rows(*rows)

Set the row names with checks to avoid saving ranges.

Parameters:

Name Type Description Default
*rows Optional[Sequence[object]]

row names

()
Source code in slisemap/utils.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
def set_rows(self, *rows: Optional[Sequence[object]]) -> None:
    """Set the row names with checks to avoid saving ranges.

    Args:
        *rows: row names
    """
    for row in rows:
        if row is not None:
            try:
                # Check if row is `range(0, self.root.n, 1)`-like (duck typing)
                if row.start == 0 and row.step == 1 and row.stop == self.root.n:
                    continue
            except AttributeError:
                pass
            _assert(
                len(row) == self.root.n,
                f"Wrong number of row names {len(row)} != {self.root.n}",
                Metadata.set_rows,
            )
            if all(i == j for i, j in enumerate(row)):
                continue
            self["rows"] = list(row)
            break

set_variables(variables=None, add_intercept=None)

Set the variable names with checks.

Parameters:

Name Type Description Default
variables Optional[Sequence[Any]]

variable names

None
add_intercept Optional[bool]

add "Intercept" to the variable names. Defaults to self.root.intercept,

None
Source code in slisemap/utils.py
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
def set_variables(
    self,
    variables: Optional[Sequence[Any]] = None,
    add_intercept: Optional[bool] = None,
) -> None:
    """Set the variable names with checks.

    Args:
        variables: variable names
        add_intercept: add "Intercept" to the variable names. Defaults to `self.root.intercept`,
    """
    if add_intercept is None:
        add_intercept = self.root.intercept
    if variables is not None:
        variables = list(variables)
        if add_intercept:
            variables.append("Intercept")
        _assert(
            len(variables) == self.root.m,
            f"Wrong number of variables {len(variables)} != {self.root.m} ({variables})",
            Metadata.set_variables,
        )
        self["variables"] = variables

set_targets(targets=None)

Set the target names with checks.

Parameters:

Name Type Description Default
targets Union[None, str, Sequence[Any]]

target names

None
Source code in slisemap/utils.py
574
575
576
577
578
579
580
581
582
583
584
585
586
587
def set_targets(self, targets: Union[None, str, Sequence[Any]] = None) -> None:
    """Set the target names with checks.

    Args:
        targets: target names
    """
    if targets is not None:
        targets = [targets] if isinstance(targets, str) else list(targets)
        _assert(
            len(targets) == self.root.o,
            f"Wrong number of targets {len(targets)} != {self.root.o}",
            Metadata.set_targets,
        )
        self["targets"] = targets

set_coefficients(coefficients=None)

Set the coefficient names with checks.

Parameters:

Name Type Description Default
coefficients Optional[Sequence[Any]]

coefficient names

None
Source code in slisemap/utils.py
589
590
591
592
593
594
595
596
597
598
599
600
601
def set_coefficients(self, coefficients: Optional[Sequence[Any]] = None) -> None:
    """Set the coefficient names with checks.

    Args:
        coefficients: coefficient names
    """
    if coefficients is not None:
        _assert(
            len(coefficients) == self.root.q,
            f"Wrong number of targets {len(coefficients)} != {self.root.q}",
            Metadata.set_coefficients,
        )
        self["coefficients"] = list(coefficients)

set_dimensions(dimensions=None)

Set the dimension names with checks.

Parameters:

Name Type Description Default
dimensions Optional[Sequence[Any]]

dimension names

None
Source code in slisemap/utils.py
603
604
605
606
607
608
609
610
611
612
613
614
615
def set_dimensions(self, dimensions: Optional[Sequence[Any]] = None) -> None:
    """Set the dimension names with checks.

    Args:
        dimensions: dimension names
    """
    if dimensions is not None:
        _assert(
            len(dimensions) == self.root.d,
            f"Wrong number of targets {len(dimensions)} != {self.root.d}",
            Metadata.set_dimensions,
        )
        self["dimensions"] = list(dimensions)

get_coefficients(fallback=True)

Get a list of coefficient names.

Parameters:

Name Type Description Default
fallback bool

If metadata for coefficients is missing, return a new list instead of None. Defaults to True.

True

Returns:

Type Description
Optional[List[str]]

list of coefficient names

Source code in slisemap/utils.py
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
def get_coefficients(self, fallback: bool = True) -> Optional[List[str]]:
    """Get a list of coefficient names.

    Args:
        fallback: If metadata for coefficients is missing, return a new list instead of None. Defaults to True.

    Returns:
        list of coefficient names
    """
    if "coefficients" in self:
        return self["coefficients"]
    if "variables" in self:
        if self.root.m == self.root.q:
            return self["variables"]
        if "targets" in self and self.root.m * self.root.o >= self.root.q:
            return [
                f"{t}: {v}" for t in self["targets"] for v in self["variables"]
            ][: self.root.q]
    if fallback:
        return [f"B_{i}" for i in range(self.root.q)]
    else:
        return None

get_targets(fallback=True)

Get a list of target names.

Parameters:

Name Type Description Default
fallback bool

If metadata for targets is missing, return a new list instead of None. Defaults to True.

True

Returns:

Type Description
Optional[List[str]]

list of target names

Source code in slisemap/utils.py
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
def get_targets(self, fallback: bool = True) -> Optional[List[str]]:
    """Get a list of target names.

    Args:
        fallback: If metadata for targets is missing, return a new list instead of None. Defaults to True.

    Returns:
        list of target names
    """
    if "targets" in self:
        return self["targets"]
    elif fallback:
        return [f"Y_{i}" for i in range(self.root.o)] if self.root.o > 1 else ["Y"]
    else:
        return None

get_variables(intercept=True, fallback=True)

Get a list of variable names.

Parameters:

Name Type Description Default
intercept bool

include the intercept in the list. Defaults to True.

True
fallback bool

If metadata for variables is missing, return a new list instead of None. Defaults to True.

True

Returns:

Type Description
Optional[List[str]]

list of variable names

Source code in slisemap/utils.py
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def get_variables(
    self, intercept: bool = True, fallback: bool = True
) -> Optional[List[str]]:
    """Get a list of variable names.

    Args:
        intercept: include the intercept in the list. Defaults to True.
        fallback: If metadata for variables is missing, return a new list instead of None. Defaults to True.


    Returns:
        list of variable names
    """
    if "variables" in self:
        if self.root.intercept and not intercept:
            return self["variables"][:-1]
        else:
            return self["variables"]
    elif fallback:
        if self.root.intercept:
            if not intercept:
                return [f"X_{i}" for i in range(self.root.m - 1)]
            else:
                return [f"X_{i}" for i in range(self.root.m - 1)] + ["X_Intercept"]
        else:
            return [f"X_{i}" for i in range(self.root.m)]
    else:
        return None

get_dimensions(fallback=True, long=False)

Get a list of dimension names.

Parameters:

Name Type Description Default
fallback bool

If metadata for dimensions is missing, return a new list instead of None. Defaults to True.

True
long bool

Use "SLISEMAP 1",... as fallback instead of "Z_0",...

False

Returns:

Type Description
Optional[List[str]]

list of dimension names

Source code in slisemap/utils.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
def get_dimensions(
    self, fallback: bool = True, long: bool = False
) -> Optional[List[str]]:
    """Get a list of dimension names.

    Args:
        fallback: If metadata for dimensions is missing, return a new list instead of None. Defaults to True.
        long: Use "SLISEMAP 1",... as fallback instead of "Z_0",...

    Returns:
        list of dimension names
    """
    if "dimensions" in self:
        return self["dimensions"]
    elif fallback:
        if long:
            cls = "Slisemap" if self.root is None else type(self.root).__name__
            return [f"{cls} {i+1}" for i in range(self.root.d)]
        else:
            return [f"Z_{i}" for i in range(self.root.d)]
    else:
        return None

get_rows(fallback=True)

Get a list of row names.

Parameters:

Name Type Description Default
fallback bool

If metadata for rows is missing, return a range instead of None. Defaults to True.

True

Returns:

Type Description
Optional[Sequence[Any]]

list (or range) of row names

Source code in slisemap/utils.py
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
def get_rows(self, fallback: bool = True) -> Optional[Sequence[Any]]:
    """Get a list of row names.

    Args:
        fallback: If metadata for rows is missing, return a range instead of None. Defaults to True.

    Returns:
        list (or range) of row names
    """
    if "rows" in self:
        return self["rows"]
    elif fallback:
        return range(self.root.n)
    else:
        return None

set_scale_X(center=None, scale=None)

Set scaling information with checks.

Use if X has been scaled before being input to Slisemap. Assuming the scaling can be converted to the form X = (X_unscaled - center) / scale. This allows some plots to (temporarily) revert the scaling (for more intuitive units).

Parameters:

Name Type Description Default
center Union[None, Tensor, ndarray, Sequence[float]]

The constant offset of X. Defaults to None.

None
scale Union[None, Tensor, ndarray, Sequence[float]]

The scaling factor of X. Defaults to None.

None
Source code in slisemap/utils.py
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
def set_scale_X(
    self,
    center: Union[None, torch.Tensor, np.ndarray, Sequence[float]] = None,
    scale: Union[None, torch.Tensor, np.ndarray, Sequence[float]] = None,
) -> None:
    """Set scaling information with checks.

    Use if `X` has been scaled before being input to Slisemap.
    Assuming the scaling can be converted to the form `X = (X_unscaled - center) / scale`.
    This allows some plots to (temporarily) revert the scaling (for more intuitive units).

    Args:
        center: The constant offset of `X`. Defaults to None.
        scale: The scaling factor of `X`. Defaults to None.
    """
    if center is not None:
        center = tonp(center).ravel()
        assert center.size == self.root.m - self.root.intercept
        self["X_center"] = center
    if scale is not None:
        scale = tonp(scale).ravel()
        assert scale.size == self.root.m - self.root.intercept
        self["X_scale"] = scale

set_scale_Y(center=None, scale=None)

Set scaling information with checks.

Use if Y has been scaled before being input to Slisemap. Assuming the scaling can be converted to the form Y = (Y_unscaled - center) / scale. This allows some plots to (temporarily) revert the scaling (for more intuitive units).

Parameters:

Name Type Description Default
center Union[None, Tensor, ndarray, Sequence[float]]

The constant offset of Y. Defaults to None.

None
scale Union[None, Tensor, ndarray, Sequence[float]]

The scaling factor of Y. Defaults to None.

None
Source code in slisemap/utils.py
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
def set_scale_Y(
    self,
    center: Union[None, torch.Tensor, np.ndarray, Sequence[float]] = None,
    scale: Union[None, torch.Tensor, np.ndarray, Sequence[float]] = None,
) -> None:
    """Set scaling information with checks.

    Use if `Y` has been scaled before being input to Slisemap.
    Assuming the scaling can be converted to the form `Y = (Y_unscaled - center) / scale`.
    This allows some plots to (temporarily) revert the scaling (for more intuitive units).

    Args:
        center: The constant offset of `Y`. Defaults to None.
        scale: The scaling factor of `Y`. Defaults to None.
    """
    if center is not None:
        center = tonp(center).ravel()
        assert center.size == self.root.o
        self["Y_center"] = center
    if scale is not None:
        scale = tonp(scale).ravel()
        assert scale.size == self.root.o
        self["Y_scale"] = scale

unscale_X(X=None)

Unscale X if the scaling information has been given (see set_scale_X).

Parameters:

Name Type Description Default
X Optional[ndarray]

The data matrix X (or self.root.get_X(intercept=False) if None).

None

Returns:

Type Description
ndarray

Possibly scaled X.

Source code in slisemap/utils.py
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
def unscale_X(self, X: Optional[np.ndarray] = None) -> np.ndarray:
    """Unscale X if the scaling information has been given (see `set_scale_X`).

    Args:
        X: The data matrix X (or `self.root.get_X(intercept=False)` if None).

    Returns:
        Possibly scaled X.
    """
    if X is None:
        X = self.root.get_X(intercept=False)
    if "X_scale" in self:
        X = X * self["X_scale"][None, :]
    if "X_center" in self:
        X = X + self["X_center"][None, :]
    return X

unscale_Y(Y=None)

Unscale Y if the scaling information has been given (see set_scale_Y).

Parameters:

Name Type Description Default
Y Optional[ndarray]

The response matrix Y (or self.root.get_Y() if None).

None

Returns:

Type Description
ndarray

Possibly scaled Y.

Source code in slisemap/utils.py
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
def unscale_Y(self, Y: Optional[np.ndarray] = None) -> np.ndarray:
    """Unscale Y if the scaling information has been given (see `set_scale_Y`).

    Args:
        Y: The response matrix Y (or `self.root.get_Y()` if None).

    Returns:
        Possibly scaled Y.
    """
    if Y is None:
        Y = self.root.get_Y()
    if "Y_scale" in self:
        Y = Y * self["Y_scale"][None, :]
    if "Y_center" in self:
        Y = Y + self["Y_center"][None, :]
    return Y

make_grid(num=50, d=2, hex=True)

Create a circular grid of points with radius 1.0.

Parameters:

Name Type Description Default
num int

The approximate number of points. Defaults to 50.

50
d int

The number of dimensions. Defaults to 2.

2
hex bool

If d == 2 produce a hexagonal grid instead of a rectangular grid. Defaults to True.

True

Returns:

Type Description
ndarray

A matrix of coordinates [num, d].

Source code in slisemap/utils.py
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
def make_grid(num: int = 50, d: int = 2, hex: bool = True) -> np.ndarray:
    """Create a circular grid of points with radius 1.0.

    Args:
        num: The approximate number of points. Defaults to 50.
        d: The number of dimensions. Defaults to 2.
        hex: If ``d == 2`` produce a hexagonal grid instead of a rectangular grid. Defaults to True.

    Returns:
        A matrix of coordinates `[num, d]`.
    """
    _assert(d > 0, "The number of dimensions must be positive", make_grid)
    if d == 1:
        return np.linspace(-1, 1, num)[:, None]
    elif d == 2 and hex:
        return make_hex_grid(num)
    else:
        nball_frac = np.pi ** (d / 2) / np.math.gamma(d / 2 + 1) / 2**d
        if 4**d * nball_frac > num:
            _warn(
                "Too few grid points per dimension. Try reducing the number of dimensions or increase the number of points in the grid.",
                make_grid,
            )
        proto_1d = int(np.ceil((num / nball_frac) ** (1 / d))) // 2 * 2 + 2
        grid_1d = np.linspace(-0.9999, 0.9999, proto_1d)
        grid = np.stack(np.meshgrid(*(grid_1d for _ in range(d))), -1).reshape((-1, d))
        dist = np.sum(grid**2, 1)
        q = np.quantile(dist, num / len(dist)) + np.finfo(dist.dtype).eps ** 0.5
        grid = grid[dist <= q]
        return grid / np.quantile(grid, 0.99)

make_hex_grid(num=52)

Create a circular grid of 2D points with a hexagon pattern and radius 1.0.

Parameters:

Name Type Description Default
num int

The approximate number of points. Defaults to 52.

52

Returns:

Type Description
ndarray

A matrix of coordinates [num, 2].

Source code in slisemap/utils.py
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
def make_hex_grid(num: int = 52) -> np.ndarray:
    """Create a circular grid of 2D points with a hexagon pattern and radius 1.0.

    Args:
        num: The approximate number of points. Defaults to 52.

    Returns:
        A matrix of coordinates `[num, 2]`.
    """
    num_h = int(np.ceil(np.sqrt(num * 4 / np.pi))) // 2 * 2 + 3
    grid_h, height = np.linspace(-0.9999, 0.9999, num_h, retstep=True)
    width = height * 2 / 3 * np.sqrt(3)
    num_w = int(np.ceil(1.0 / width))
    grid_w = np.arange(-num_w, num_w + 1) * width
    grid = np.stack(np.meshgrid(grid_w, grid_h), -1)
    grid[(1 - num_h // 2 % 2) :: 2, :, 0] += width / 2
    grid = grid.reshape((-1, 2))
    best = None
    for origo in (0.0, 0.5 * width):
        if origo != 0.0:
            grid[:, 0] += origo
        dist = np.sum(grid**2, 1)
        q = np.quantile(dist, num / len(dist))
        for epsilon in (-1e-4, 1e-4):
            grid2 = grid[dist <= q + epsilon]
            if best is None or abs(best.shape[0] - num) > abs(grid2.shape[0] - num):
                best = grid2.copy()
    return best / np.quantile(best, 0.99)