Skip to content

API Reference

The architecture of the package can be seen on the UML diagram: UML diagram of package structure

HyperSpectral Image

HSI

A dataclass for hyperspectral image data, including the image, wavelengths, and binary mask.

Attributes:

Name Type Description
image Tensor

The hyperspectral image data as a PyTorch tensor.

wavelengths Tensor

The wavelengths present in the image.

orientation tuple[str, str, str]

The orientation of the image data.

device device

The device to be used for inference.

binary_mask Tensor

A binary mask used to cover unimportant parts of the image.

Source code in src/meteors/hsi.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
class HSI(BaseModel):
    """A dataclass for hyperspectral image data, including the image, wavelengths, and binary mask.

    Attributes:
        image (torch.Tensor): The hyperspectral image data as a PyTorch tensor.
        wavelengths (torch.Tensor): The wavelengths present in the image.
        orientation (tuple[str, str, str]): The orientation of the image data.
        device (torch.device): The device to be used for inference.
        binary_mask (torch.Tensor): A binary mask used to cover unimportant parts of the image.
    """

    image: Annotated[  # Should always be a first field
        torch.Tensor,
        PlainValidator(ensure_image_tensor),
        Field(description="Hyperspectral image. Converted to torch tensor."),
    ]
    wavelengths: Annotated[
        torch.Tensor,
        PlainValidator(ensure_wavelengths_tensor),
        Field(description="Wavelengths present in the image. Defaults to None."),
    ]
    orientation: Annotated[
        tuple[str, str, str],
        PlainValidator(validate_orientation),
        Field(
            description=(
                'Orientation of the image - sequence of three one-letter strings in any order: "C", "H", "W" '
                'meaning respectively channels, height and width of the image. Defaults to ("C", "H", "W").'
            ),
        ),
    ] = ("C", "H", "W")
    device: Annotated[
        torch.device,
        PlainValidator(resolve_inference_device_hsi),
        Field(
            validate_default=True,
            exclude=True,
            description="Device to be used for inference. If None, the device of the input image will be used. Defaults to None.",
        ),
    ] = None
    binary_mask: Annotated[
        torch.Tensor,
        PlainValidator(process_and_validate_binary_mask),
        Field(
            validate_default=True,
            description=(
                "Binary mask used to cover not important parts of the base image, masked parts have values equals to 0. "
                "Converted to torch tensor. Defaults to None."
            ),
        ),
    ] = None

    @property
    def spectral_axis(self) -> int:
        """Returns the index of the spectral (wavelength) axis based on the current data orientation.

        In hyperspectral imaging, the spectral axis represents the dimension along which
        different spectral bands or wavelengths are arranged. This property dynamically
        determines the index of this axis based on the current orientation of the data.

        Returns:
            int: The index of the spectral axis in the current data structure.
                - 0 for 'CHW' or 'CWH' orientations (Channel/Wavelength first)
                - 2 for 'HWC' or 'WHC' orientations (Channel/Wavelength last)
                - 1 for 'HCW' or 'WCH' orientations (Channel/Wavelength in the middle)

        Note:
            The orientation is typically represented as a string where:
            - 'C' represents the spectral/wavelength dimension
            - 'H' represents the height (rows) of the image
            - 'W' represents the width (columns) of the image

        Examples:
            >>> hsi_image = HSI()
            >>> hsi_image.orientation = "CHW"
            >>> hsi_image.spectral_axis
            0
            >>> hsi_image.orientation = "HWC"
            >>> hsi_image.spectral_axis
            2
        """
        return get_channel_axis(self.orientation)

    @property
    def spatial_binary_mask(self) -> torch.Tensor:
        """Returns a 2D spatial representation of the binary mask.

        This property extracts a single 2D slice from the 3D binary mask, assuming that
        the mask is identical across all spectral bands. It handles different data
        orientations by first ensuring the spectral dimension is the last dimension
        before extracting the 2D spatial mask.

        Returns:
            torch.Tensor: A 2D tensor representing the spatial binary mask.
                The shape will be (H, W) where H is height and W is width of the image.

        Note:
            - This assumes that the binary mask is consistent across all spectral bands.
            - The returned mask is always 2D, regardless of the original data orientation.

        Examples:
            >>> # If self.binary_mask has shape (100, 100, 5) with spectral_axis=2:
            >>> hsi_image = HSI(binary_mask=torch.rand(100, 100, 5), orientation=("H", "W", "C"))
            >>> hsi_image.spatial_binary_mask.shape
            torch.Size([100, 100])
            >>> If self.binary_mask has shape (5, 100, 100) with spectral_axis=0:
            >>> hsi_image = HSI(binary_mask=torch.rand(5, 100, 100), orientation=("C", "H", "W"))
            >>> hsi_image.spatial_binary_mask.shape
            torch.Size([100, 100])
        """
        mask = self.binary_mask if self.binary_mask is not None else torch.ones_like(self.image)
        return mask.select(dim=self.spectral_axis, index=0)

    model_config = ConfigDict(arbitrary_types_allowed=True)

    @model_validator(mode="after")
    def validate_image_data(self) -> Self:
        """Validates the image data by checking the shape of the wavelengths, image, and spectral_axis.

        Returns:
            Self: The instance of the class.
        """
        validate_shapes(self.wavelengths, self.image, self.spectral_axis)
        return self

    def to(self, device: str | torch.device) -> Self:
        """Moves the image and binary mask (if available) to the specified device.

        Args:
            device (str or torch.device): The device to move the image and binary mask to.

        Returns:
            Self: The updated HSI object.

        Examples:
            >>> # Create an HSI object
            >>> hsi_image = HSI(image=torch.rand(10, 10, 10), wavelengths=np.arange(10))
            >>> # Move the image to cpu
            >>> hsi_image = hsi_image.to("cpu")
            >>> hsi_image.device
            device(type='cpu')
            >>> # Move the image to cuda
            >>> hsi_image = hsi_image.to("cuda")
            >>> hsi_image.device
            device(type='cuda', index=0)
        """
        self.image = self.image.to(device)
        self.binary_mask = self.binary_mask.to(device)
        self.device = self.image.device
        return self

    def get_image(self, apply_mask: bool = True) -> torch.Tensor:
        """Returns the hyperspectral image data with optional masking applied.

        Args:
            apply_mask (bool, optional): Whether to apply the binary mask to the image.
                Defaults to True.
        Returns:
            torch.Tensor: The hyperspectral image data.

        Notes:
            - If apply_mask is True, the binary mask will be applied to the image based on the `binary_mask` attribute.

        Examples:
            >>> hsi_image = HSI(image=torch.rand(10, 100, 100), wavelengths=np.linspace(400, 1000, 10))
            >>> image = hsi_image.get_image()
            >>> image.shape
            torch.Size([10, 100, 100])
            >>> image = hsi_image.get_image(apply_mask=False)
            >>> image.shape
            torch.Size([10, 100, 100])
        """
        if apply_mask and self.binary_mask is not None:
            return self.image * self.binary_mask
        return self.image

    def get_rgb_image(
        self,
        apply_mask: bool = True,
        apply_min_cutoff: bool = False,
        output_channel_axis: int | None = None,
        normalize: bool = True,
    ) -> torch.Tensor:
        """Extracts an RGB representation from the hyperspectral image data.

        This method creates a 3-channel RGB image by selecting appropriate bands
        corresponding to red, green, and blue wavelengths from the hyperspectral data.

        Args:
            apply_mask (bool, optional): Whether to apply the binary mask to the image.
                Defaults to True.
            apply_min_cutoff (bool, optional): Whether to apply a minimum intensity
                cutoff to the image. Defaults to False.
            output_channel_axis (int | None, optional): The axis where the RGB channels
                should be placed in the output tensor. If None, uses the current spectral
                axis of the hyperspectral data. Defaults to None.
            normalize (bool, optional): Whether to normalize the band values to the [0, 1] range.
                Defaults to True.

        Returns:
            torch.Tensor: The RGB representation of the hyperspectral image.
                Shape will be either (H, W, 3), (3, H, W), or (H, 3, W) depending on
                the specified output_channel_axis, where H is height and W is width.

        Notes:
            - The RGB bands are extracted using predefined wavelength ranges for R, G, and B.
            - Each band is normalized independently before combining into the RGB image.
            - If apply_mask is True, masked areas will be set to zero in the output.
            - If apply_min_cutoff is True, a minimum intensity threshold is applied to each band.

        Examples:
            >>> hsi_image = HSI(image=torch.rand(10, 100, 100), wavelengths=np.linspace(400, 1000, 10))
            >>> rgb_image = hsi_image.get_rgb_image()
            >>> rgb_image.shape
            torch.Size([100, 100, 3])

            >>> rgb_image = hsi_image.get_rgb_image(output_channel_axis=0)
            >>> rgb_image.shape
            torch.Size([3, 100, 100])

            >>> rgb_image = hsi_image.get_rgb_image(apply_mask=False, apply_min_cutoff=True)
            >>> rgb_image.shape
            torch.Size([100, 100, 3])
        """
        if output_channel_axis is None:
            output_channel_axis = self.spectral_axis

        rgb_img = torch.stack(
            [
                self.extract_band_by_name(
                    band, apply_mask=apply_mask, apply_min_cutoff=apply_min_cutoff, normalize=normalize
                )
                for band in ["R", "G", "B"]
            ],
            dim=self.spectral_axis,
        )

        return (
            rgb_img
            if output_channel_axis == self.spectral_axis
            else torch.moveaxis(rgb_img, self.spectral_axis, output_channel_axis)
        )

    def _extract_central_slice_from_band(
        self,
        band_wavelengths: torch.Tensor,
        apply_mask: bool = True,
        apply_min_cutoff: bool = False,
        normalize: bool = True,
    ) -> torch.Tensor:
        """Extracts and processes the central wavelength band from a given range in the hyperspectral image.

        This method selects the central band from a specified range of wavelengths,
        applies optional processing steps (masking, normalization, and minimum cutoff),
        and returns the resulting 2D image slice.

        Args:
            band_wavelengths (torch.Tensor): The selected wavelengths that define the whole band
                from which the central slice will be extracted.
                All of the passed wavelengths must be present in the image.
            apply_mask (bool, optional): Whether to apply the binary mask to the extracted band.
                Defaults to True.
            apply_min_cutoff (bool, optional): Whether to apply a minimum intensity cutoff.
                If True, sets the minimum non-zero value to zero after normalization.
                Defaults to False.
            normalize (bool, optional): Whether to normalize the band values to [0, 1] range.
                Defaults to True.

        Returns:
            torch.Tensor: A 2D tensor representing the processed central wavelength band.
                Shape will be (H, W), where H is height and W is width of the image.

        Notes:
            - The central wavelength is determined as the middle index of the provided wavelengths list.
            - If normalization is applied, it's done before masking and cutoff operations.
            - The binary mask, if applied, is expected to have the same spatial dimensions as the image.

        Examples:
            >>> hsi_image = HSI(image=torch.rand(13, 100, 100), wavelengths=np.linspace(400, 1000, 13))
            >>> band_wavelengths = torch.tensor([500, 600, 650, 700])
            >>> central_slice = hsi_image._extract_central_slice_from_band(band_wavelengths)
            >>> central_slice.shape
            torch.Size([100, 100])

            >>> # Extract a slice without normalization or masking
            >>> raw_band = hsi_image._extract_central_slice_from_band(band_wavelengths, apply_mask=False, normalize=False)
        """
        # check if all wavelengths from the `band_wavelengths` are present in the image
        if not all(wave in self.wavelengths for wave in band_wavelengths):
            raise ValueError("All of the passed wavelengths must be present in the image")

        # sort the `band_wavelengths` to ensure the central band is selected
        band_wavelengths = torch.sort(band_wavelengths).values

        start_index = np.where(self.wavelengths == band_wavelengths[0])[0][0]
        relative_center_band_index = len(band_wavelengths) // 2
        central_band_index = start_index + relative_center_band_index

        # Ensure the spectral dimension is the last
        image = self.image if self.spectral_axis == 2 else torch.moveaxis(self.image, self.spectral_axis, 2)

        slice = image[..., central_band_index]

        if normalize:
            if apply_min_cutoff:
                slice_min = slice[slice != 0].min()
            else:
                slice_min = slice.min()

            slice_max = slice.max()
            if slice_max > slice_min:  # Avoid division by zero
                slice = (slice - slice_min) / (slice_max - slice_min)

            if apply_min_cutoff:
                slice[slice == slice.min()] = 0  # Set minimum values to zero

        if apply_mask:
            mask = (
                self.binary_mask if self.spectral_axis == 2 else torch.moveaxis(self.binary_mask, self.spectral_axis, 2)
            )
            slice = slice * mask[..., central_band_index]

        return slice

    def extract_band_by_name(
        self,
        band_name: str,
        selection_method: str = "center",
        apply_mask: bool = True,
        apply_min_cutoff: bool = False,
        normalize: bool = True,
    ) -> torch.Tensor:
        """Extracts a single spectral band from the hyperspectral image based on a standardized band name.

        This method uses the spyndex library to map standardized band names to wavelength ranges,
        then extracts the corresponding band from the hyperspectral data.

        Args:
            band_name (str): The standardized name of the band to extract (e.g., "Red", "NIR", "SWIR1").
            selection_method (str, optional): The method to use for selecting the band within the wavelength range.
                Currently, only "center" is supported, which selects the central wavelength.
                Defaults to "center".
            apply_mask (bool, optional): Whether to apply the binary mask to the extracted band.
                Defaults to True.
            apply_min_cutoff (bool, optional): Whether to apply a minimum intensity cutoff after normalization.
                If True, sets the minimum non-zero value to zero. Defaults to False.
            normalize (bool, optional): Whether to normalize the band values to the [0, 1] range.
                Defaults to True.

        Returns:
            torch.Tensor: A 2D tensor representing the extracted and processed spectral band.
                Shape will be (H, W), where H is height and W is width of the image.

        Raises:
            BandSelectionError: If the specified band name is not found in the spyndex library.
            NotImplementedError: If a selection method other than "center" is specified.

        Notes:
            - The spyndex library is used to map band names to wavelength ranges.
            - Currently, only the "center" selection method is implemented, which chooses
            the central wavelength within the specified range.
            - Processing steps are applied in the order: normalization, cutoff, masking.

        Examples:
            >>> hsi_image = HSI(image=torch.rand(200, 100, 100), wavelengths=np.linspace(400, 2500, 200))
            >>> red_band = hsi_image.extract_band_by_name("Red")
            >>> red_band.shape
            torch.Size([100, 100])

            >>> # Extract NIR band without normalization or masking
            >>> nir_band = hsi_image.extract_band_by_name("NIR", apply_mask=False, normalize=False)
        """
        band_info = spyndex.bands.get(band_name)
        if band_info is None:
            raise BandSelectionError(f"Band name '{band_name}' not found in the spyndex library")

        min_wave, max_wave = band_info.min_wavelength, band_info.max_wavelength
        selected_wavelengths = self.wavelengths[(self.wavelengths >= min_wave) & (self.wavelengths <= max_wave)]

        if selection_method == "center":
            return self._extract_central_slice_from_band(
                selected_wavelengths, apply_mask=apply_mask, apply_min_cutoff=apply_min_cutoff, normalize=normalize
            )
        else:
            raise NotImplementedError(
                f"Selection method '{selection_method}' is not supported. Only 'center' is currently available."
            )

    def change_orientation(self, target_orientation: tuple[str, str, str] | list[str] | str, inplace=False) -> Self:
        """Changes the orientation of the hsi data to the target orientation.

        Args:
            target_orientation (tuple[str, str, str], list[str], str): The target orientation for the hsi data.
                This should be a tuple of three one-letter strings in any order: "C", "H", "W".
            inplace (bool, optional): Whether to modify the hsi data in place or return a new object.

        Returns:
            Self: The updated HSI object with the new orientation.

        Raises:
            ValueError: If the target orientation is not a valid tuple of three one-letter strings.
        """
        target_orientation = validate_orientation(target_orientation)

        if inplace:
            hsi = self
        else:
            hsi = self.model_copy()

        if target_orientation == self.orientation:
            return hsi

        permute_dims = [hsi.orientation.index(dim) for dim in target_orientation]

        # permute the image
        hsi.image = hsi.image.permute(permute_dims)

        # permute the binary mask
        if hsi.binary_mask is not None:
            hsi.binary_mask = hsi.binary_mask.permute(permute_dims)

        hsi.orientation = target_orientation

        return hsi

spatial_binary_mask: torch.Tensor property

Returns a 2D spatial representation of the binary mask.

This property extracts a single 2D slice from the 3D binary mask, assuming that the mask is identical across all spectral bands. It handles different data orientations by first ensuring the spectral dimension is the last dimension before extracting the 2D spatial mask.

Returns:

Type Description
Tensor

torch.Tensor: A 2D tensor representing the spatial binary mask. The shape will be (H, W) where H is height and W is width of the image.

Note
  • This assumes that the binary mask is consistent across all spectral bands.
  • The returned mask is always 2D, regardless of the original data orientation.

Examples:

>>> # If self.binary_mask has shape (100, 100, 5) with spectral_axis=2:
>>> hsi_image = HSI(binary_mask=torch.rand(100, 100, 5), orientation=("H", "W", "C"))
>>> hsi_image.spatial_binary_mask.shape
torch.Size([100, 100])
>>> If self.binary_mask has shape (5, 100, 100) with spectral_axis=0:
>>> hsi_image = HSI(binary_mask=torch.rand(5, 100, 100), orientation=("C", "H", "W"))
>>> hsi_image.spatial_binary_mask.shape
torch.Size([100, 100])

spectral_axis: int property

Returns the index of the spectral (wavelength) axis based on the current data orientation.

In hyperspectral imaging, the spectral axis represents the dimension along which different spectral bands or wavelengths are arranged. This property dynamically determines the index of this axis based on the current orientation of the data.

Returns:

Name Type Description
int int

The index of the spectral axis in the current data structure. - 0 for 'CHW' or 'CWH' orientations (Channel/Wavelength first) - 2 for 'HWC' or 'WHC' orientations (Channel/Wavelength last) - 1 for 'HCW' or 'WCH' orientations (Channel/Wavelength in the middle)

Note

The orientation is typically represented as a string where: - 'C' represents the spectral/wavelength dimension - 'H' represents the height (rows) of the image - 'W' represents the width (columns) of the image

Examples:

>>> hsi_image = HSI()
>>> hsi_image.orientation = "CHW"
>>> hsi_image.spectral_axis
0
>>> hsi_image.orientation = "HWC"
>>> hsi_image.spectral_axis
2

change_orientation(target_orientation, inplace=False)

Changes the orientation of the hsi data to the target orientation.

Parameters:

Name Type Description Default
target_orientation (tuple[str, str, str], list[str], str)

The target orientation for the hsi data. This should be a tuple of three one-letter strings in any order: "C", "H", "W".

required
inplace bool

Whether to modify the hsi data in place or return a new object.

False

Returns:

Name Type Description
Self Self

The updated HSI object with the new orientation.

Raises:

Type Description
ValueError

If the target orientation is not a valid tuple of three one-letter strings.

Source code in src/meteors/hsi.py
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
def change_orientation(self, target_orientation: tuple[str, str, str] | list[str] | str, inplace=False) -> Self:
    """Changes the orientation of the hsi data to the target orientation.

    Args:
        target_orientation (tuple[str, str, str], list[str], str): The target orientation for the hsi data.
            This should be a tuple of three one-letter strings in any order: "C", "H", "W".
        inplace (bool, optional): Whether to modify the hsi data in place or return a new object.

    Returns:
        Self: The updated HSI object with the new orientation.

    Raises:
        ValueError: If the target orientation is not a valid tuple of three one-letter strings.
    """
    target_orientation = validate_orientation(target_orientation)

    if inplace:
        hsi = self
    else:
        hsi = self.model_copy()

    if target_orientation == self.orientation:
        return hsi

    permute_dims = [hsi.orientation.index(dim) for dim in target_orientation]

    # permute the image
    hsi.image = hsi.image.permute(permute_dims)

    # permute the binary mask
    if hsi.binary_mask is not None:
        hsi.binary_mask = hsi.binary_mask.permute(permute_dims)

    hsi.orientation = target_orientation

    return hsi

extract_band_by_name(band_name, selection_method='center', apply_mask=True, apply_min_cutoff=False, normalize=True)

Extracts a single spectral band from the hyperspectral image based on a standardized band name.

This method uses the spyndex library to map standardized band names to wavelength ranges, then extracts the corresponding band from the hyperspectral data.

Parameters:

Name Type Description Default
band_name str

The standardized name of the band to extract (e.g., "Red", "NIR", "SWIR1").

required
selection_method str

The method to use for selecting the band within the wavelength range. Currently, only "center" is supported, which selects the central wavelength. Defaults to "center".

'center'
apply_mask bool

Whether to apply the binary mask to the extracted band. Defaults to True.

True
apply_min_cutoff bool

Whether to apply a minimum intensity cutoff after normalization. If True, sets the minimum non-zero value to zero. Defaults to False.

False
normalize bool

Whether to normalize the band values to the [0, 1] range. Defaults to True.

True

Returns:

Type Description
Tensor

torch.Tensor: A 2D tensor representing the extracted and processed spectral band. Shape will be (H, W), where H is height and W is width of the image.

Raises:

Type Description
BandSelectionError

If the specified band name is not found in the spyndex library.

NotImplementedError

If a selection method other than "center" is specified.

Notes
  • The spyndex library is used to map band names to wavelength ranges.
  • Currently, only the "center" selection method is implemented, which chooses the central wavelength within the specified range.
  • Processing steps are applied in the order: normalization, cutoff, masking.

Examples:

>>> hsi_image = HSI(image=torch.rand(200, 100, 100), wavelengths=np.linspace(400, 2500, 200))
>>> red_band = hsi_image.extract_band_by_name("Red")
>>> red_band.shape
torch.Size([100, 100])
>>> # Extract NIR band without normalization or masking
>>> nir_band = hsi_image.extract_band_by_name("NIR", apply_mask=False, normalize=False)
Source code in src/meteors/hsi.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
def extract_band_by_name(
    self,
    band_name: str,
    selection_method: str = "center",
    apply_mask: bool = True,
    apply_min_cutoff: bool = False,
    normalize: bool = True,
) -> torch.Tensor:
    """Extracts a single spectral band from the hyperspectral image based on a standardized band name.

    This method uses the spyndex library to map standardized band names to wavelength ranges,
    then extracts the corresponding band from the hyperspectral data.

    Args:
        band_name (str): The standardized name of the band to extract (e.g., "Red", "NIR", "SWIR1").
        selection_method (str, optional): The method to use for selecting the band within the wavelength range.
            Currently, only "center" is supported, which selects the central wavelength.
            Defaults to "center".
        apply_mask (bool, optional): Whether to apply the binary mask to the extracted band.
            Defaults to True.
        apply_min_cutoff (bool, optional): Whether to apply a minimum intensity cutoff after normalization.
            If True, sets the minimum non-zero value to zero. Defaults to False.
        normalize (bool, optional): Whether to normalize the band values to the [0, 1] range.
            Defaults to True.

    Returns:
        torch.Tensor: A 2D tensor representing the extracted and processed spectral band.
            Shape will be (H, W), where H is height and W is width of the image.

    Raises:
        BandSelectionError: If the specified band name is not found in the spyndex library.
        NotImplementedError: If a selection method other than "center" is specified.

    Notes:
        - The spyndex library is used to map band names to wavelength ranges.
        - Currently, only the "center" selection method is implemented, which chooses
        the central wavelength within the specified range.
        - Processing steps are applied in the order: normalization, cutoff, masking.

    Examples:
        >>> hsi_image = HSI(image=torch.rand(200, 100, 100), wavelengths=np.linspace(400, 2500, 200))
        >>> red_band = hsi_image.extract_band_by_name("Red")
        >>> red_band.shape
        torch.Size([100, 100])

        >>> # Extract NIR band without normalization or masking
        >>> nir_band = hsi_image.extract_band_by_name("NIR", apply_mask=False, normalize=False)
    """
    band_info = spyndex.bands.get(band_name)
    if band_info is None:
        raise BandSelectionError(f"Band name '{band_name}' not found in the spyndex library")

    min_wave, max_wave = band_info.min_wavelength, band_info.max_wavelength
    selected_wavelengths = self.wavelengths[(self.wavelengths >= min_wave) & (self.wavelengths <= max_wave)]

    if selection_method == "center":
        return self._extract_central_slice_from_band(
            selected_wavelengths, apply_mask=apply_mask, apply_min_cutoff=apply_min_cutoff, normalize=normalize
        )
    else:
        raise NotImplementedError(
            f"Selection method '{selection_method}' is not supported. Only 'center' is currently available."
        )

get_image(apply_mask=True)

Returns the hyperspectral image data with optional masking applied.

Parameters:

Name Type Description Default
apply_mask bool

Whether to apply the binary mask to the image. Defaults to True.

True

Returns: torch.Tensor: The hyperspectral image data.

Notes
  • If apply_mask is True, the binary mask will be applied to the image based on the binary_mask attribute.

Examples:

>>> hsi_image = HSI(image=torch.rand(10, 100, 100), wavelengths=np.linspace(400, 1000, 10))
>>> image = hsi_image.get_image()
>>> image.shape
torch.Size([10, 100, 100])
>>> image = hsi_image.get_image(apply_mask=False)
>>> image.shape
torch.Size([10, 100, 100])
Source code in src/meteors/hsi.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
def get_image(self, apply_mask: bool = True) -> torch.Tensor:
    """Returns the hyperspectral image data with optional masking applied.

    Args:
        apply_mask (bool, optional): Whether to apply the binary mask to the image.
            Defaults to True.
    Returns:
        torch.Tensor: The hyperspectral image data.

    Notes:
        - If apply_mask is True, the binary mask will be applied to the image based on the `binary_mask` attribute.

    Examples:
        >>> hsi_image = HSI(image=torch.rand(10, 100, 100), wavelengths=np.linspace(400, 1000, 10))
        >>> image = hsi_image.get_image()
        >>> image.shape
        torch.Size([10, 100, 100])
        >>> image = hsi_image.get_image(apply_mask=False)
        >>> image.shape
        torch.Size([10, 100, 100])
    """
    if apply_mask and self.binary_mask is not None:
        return self.image * self.binary_mask
    return self.image

get_rgb_image(apply_mask=True, apply_min_cutoff=False, output_channel_axis=None, normalize=True)

Extracts an RGB representation from the hyperspectral image data.

This method creates a 3-channel RGB image by selecting appropriate bands corresponding to red, green, and blue wavelengths from the hyperspectral data.

Parameters:

Name Type Description Default
apply_mask bool

Whether to apply the binary mask to the image. Defaults to True.

True
apply_min_cutoff bool

Whether to apply a minimum intensity cutoff to the image. Defaults to False.

False
output_channel_axis int | None

The axis where the RGB channels should be placed in the output tensor. If None, uses the current spectral axis of the hyperspectral data. Defaults to None.

None
normalize bool

Whether to normalize the band values to the [0, 1] range. Defaults to True.

True

Returns:

Type Description
Tensor

torch.Tensor: The RGB representation of the hyperspectral image. Shape will be either (H, W, 3), (3, H, W), or (H, 3, W) depending on the specified output_channel_axis, where H is height and W is width.

Notes
  • The RGB bands are extracted using predefined wavelength ranges for R, G, and B.
  • Each band is normalized independently before combining into the RGB image.
  • If apply_mask is True, masked areas will be set to zero in the output.
  • If apply_min_cutoff is True, a minimum intensity threshold is applied to each band.

Examples:

>>> hsi_image = HSI(image=torch.rand(10, 100, 100), wavelengths=np.linspace(400, 1000, 10))
>>> rgb_image = hsi_image.get_rgb_image()
>>> rgb_image.shape
torch.Size([100, 100, 3])
>>> rgb_image = hsi_image.get_rgb_image(output_channel_axis=0)
>>> rgb_image.shape
torch.Size([3, 100, 100])
>>> rgb_image = hsi_image.get_rgb_image(apply_mask=False, apply_min_cutoff=True)
>>> rgb_image.shape
torch.Size([100, 100, 3])
Source code in src/meteors/hsi.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def get_rgb_image(
    self,
    apply_mask: bool = True,
    apply_min_cutoff: bool = False,
    output_channel_axis: int | None = None,
    normalize: bool = True,
) -> torch.Tensor:
    """Extracts an RGB representation from the hyperspectral image data.

    This method creates a 3-channel RGB image by selecting appropriate bands
    corresponding to red, green, and blue wavelengths from the hyperspectral data.

    Args:
        apply_mask (bool, optional): Whether to apply the binary mask to the image.
            Defaults to True.
        apply_min_cutoff (bool, optional): Whether to apply a minimum intensity
            cutoff to the image. Defaults to False.
        output_channel_axis (int | None, optional): The axis where the RGB channels
            should be placed in the output tensor. If None, uses the current spectral
            axis of the hyperspectral data. Defaults to None.
        normalize (bool, optional): Whether to normalize the band values to the [0, 1] range.
            Defaults to True.

    Returns:
        torch.Tensor: The RGB representation of the hyperspectral image.
            Shape will be either (H, W, 3), (3, H, W), or (H, 3, W) depending on
            the specified output_channel_axis, where H is height and W is width.

    Notes:
        - The RGB bands are extracted using predefined wavelength ranges for R, G, and B.
        - Each band is normalized independently before combining into the RGB image.
        - If apply_mask is True, masked areas will be set to zero in the output.
        - If apply_min_cutoff is True, a minimum intensity threshold is applied to each band.

    Examples:
        >>> hsi_image = HSI(image=torch.rand(10, 100, 100), wavelengths=np.linspace(400, 1000, 10))
        >>> rgb_image = hsi_image.get_rgb_image()
        >>> rgb_image.shape
        torch.Size([100, 100, 3])

        >>> rgb_image = hsi_image.get_rgb_image(output_channel_axis=0)
        >>> rgb_image.shape
        torch.Size([3, 100, 100])

        >>> rgb_image = hsi_image.get_rgb_image(apply_mask=False, apply_min_cutoff=True)
        >>> rgb_image.shape
        torch.Size([100, 100, 3])
    """
    if output_channel_axis is None:
        output_channel_axis = self.spectral_axis

    rgb_img = torch.stack(
        [
            self.extract_band_by_name(
                band, apply_mask=apply_mask, apply_min_cutoff=apply_min_cutoff, normalize=normalize
            )
            for band in ["R", "G", "B"]
        ],
        dim=self.spectral_axis,
    )

    return (
        rgb_img
        if output_channel_axis == self.spectral_axis
        else torch.moveaxis(rgb_img, self.spectral_axis, output_channel_axis)
    )

to(device)

Moves the image and binary mask (if available) to the specified device.

Parameters:

Name Type Description Default
device str or device

The device to move the image and binary mask to.

required

Returns:

Name Type Description
Self Self

The updated HSI object.

Examples:

>>> # Create an HSI object
>>> hsi_image = HSI(image=torch.rand(10, 10, 10), wavelengths=np.arange(10))
>>> # Move the image to cpu
>>> hsi_image = hsi_image.to("cpu")
>>> hsi_image.device
device(type='cpu')
>>> # Move the image to cuda
>>> hsi_image = hsi_image.to("cuda")
>>> hsi_image.device
device(type='cuda', index=0)
Source code in src/meteors/hsi.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def to(self, device: str | torch.device) -> Self:
    """Moves the image and binary mask (if available) to the specified device.

    Args:
        device (str or torch.device): The device to move the image and binary mask to.

    Returns:
        Self: The updated HSI object.

    Examples:
        >>> # Create an HSI object
        >>> hsi_image = HSI(image=torch.rand(10, 10, 10), wavelengths=np.arange(10))
        >>> # Move the image to cpu
        >>> hsi_image = hsi_image.to("cpu")
        >>> hsi_image.device
        device(type='cpu')
        >>> # Move the image to cuda
        >>> hsi_image = hsi_image.to("cuda")
        >>> hsi_image.device
        device(type='cuda', index=0)
    """
    self.image = self.image.to(device)
    self.binary_mask = self.binary_mask.to(device)
    self.device = self.image.device
    return self

Visualizations

Visualizes a Hyperspectral image object on the given axes. It uses either the object from HSI class or a field from the HSIAttributes class.

Parameters:

Name Type Description Default
hsi_or_attributes HSI | HSIAttributes

The hyperspectral image, or the attributes to be visualized.

required
ax Axes | None

The axes on which the image will be plotted. If None, the current axes will be used.

None
use_mask bool

Whether to use the image mask if provided for the visualization.

True

Returns:

Type Description
Axes

matplotlib.figure.Figure | None: If use_pyplot is False, returns the figure and axes objects. If use_pyplot is True, returns None.

Raises:

Type Description
TypeError

If hsi_or_attributes is not an instance of HSI or HSIAttributes.

Source code in src/meteors/visualize/hsi_visualize.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def visualize_hsi(hsi_or_attributes: HSI | HSIAttributes, ax: Axes | None = None, use_mask: bool = True) -> Axes:
    """Visualizes a Hyperspectral image object on the given axes. It uses either the object from HSI class or a field
    from the HSIAttributes class.

    Parameters:
        hsi_or_attributes (HSI | HSIAttributes): The hyperspectral image, or the attributes to be visualized.
        ax (matplotlib.axes.Axes | None): The axes on which the image will be plotted.
            If None, the current axes will be used.
        use_mask (bool): Whether to use the image mask if provided for the visualization.

    Returns:
        matplotlib.figure.Figure | None:
            If use_pyplot is False, returns the figure and axes objects.
            If use_pyplot is True, returns None.

    Raises:
        TypeError: If hsi_or_attributes is not an instance of HSI or HSIAttributes.
    """
    if isinstance(hsi_or_attributes, HSIAttributes):
        hsi = hsi_or_attributes.hsi
    else:
        hsi = hsi_or_attributes

    if not isinstance(hsi, HSI):
        raise TypeError("hsi_or_attributes must be an instance of HSI or HSIAttributes.")

    hsi = hsi.change_orientation("HWC", inplace=False)

    rgb = hsi.get_rgb_image(output_channel_axis=2, apply_mask=use_mask, normalize=True).cpu().numpy()
    ax = ax or plt.gca()
    ax.imshow(rgb)

    return ax

visualize_attributes(image_attributes, ax=None, use_pyplot=False)

Visualizes the attributes of an image on the given axes.

Parameters:

Name Type Description Default
image_attributes HSIAttributes

The image attributes to be visualized.

required
ax Axes | None

The axes to visualize the image on. If None, creates a new figure and axes.

None
use_pyplot bool

If True, uses pyplot to display the image. If False, returns the figure and axes objects. if ax is not None, use_pyplot is ignored.

False

Returns:

Type Description
tuple[Figure, Axes] | Axes | None

matplotlib.figure.Figure | matplotlib.axes.Axes | None: The figure and axes objects. If use_pyplot is False and ax is None, returns the figure and axes objects. If use_pyplot is True and ax is None, returns None, and displays the image using pyplot. if ax is not None, returns the axes object. If all the attributions are zero, returns None.

Raises:

Type Description
ValueError

If the axes have less than 2 rows and 2 columns

ValueError

If the axes object is not a list of axes objects

Source code in src/meteors/visualize/attr_visualize.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def visualize_attributes(
    image_attributes: HSIAttributes, ax: Axes | None = None, use_pyplot: bool = False
) -> tuple[Figure, Axes] | Axes | None:
    """Visualizes the attributes of an image on the given axes.

    Parameters:
        image_attributes (HSIAttributes): The image attributes to be visualized.
        ax (Axes | None): The axes to visualize the image on. If None, creates a new figure and axes.
        use_pyplot (bool): If True, uses pyplot to display the image. If False, returns the figure and axes objects.
            if ax is not None, use_pyplot is ignored.

    Returns:
        matplotlib.figure.Figure | matplotlib.axes.Axes | None: The figure and axes objects.
            If use_pyplot is False and ax is None, returns the figure and axes objects.
            If use_pyplot is True and ax is None, returns None, and displays the image using pyplot.
            if ax is not None, returns the axes object.
            If all the attributions are zero, returns None.

    Raises:
        ValueError: If the axes have less than 2 rows and 2 columns
        ValueError: If the axes object is not a list of axes objects
    """
    if image_attributes.hsi.orientation != ("H", "W", "C"):
        logger.info(
            f"The orientation of the image is not (H, W, C): {image_attributes.hsi.orientation}. "
            f"Changing it to (H, W, C) for visualization."
        )
        rotated_attributes_dataclass = image_attributes.change_orientation("HWC", inplace=False)
    else:
        rotated_attributes_dataclass = image_attributes

    rotated_attributes = rotated_attributes_dataclass.attributes.detach().cpu().numpy()
    if np.all(rotated_attributes == 0):
        warnings.warn("All the attributions are zero. There is nothing to visualize.")
        return None

    used_ax = True
    if ax is None:
        used_ax = False
        fig, ax = plt.subplots(2, 2, figsize=(9, 7))

    if not hasattr(ax, "shape"):
        raise ValueError("Provided ax parameter is only one axes object, but it should be a list of axes objects")
    elif len(ax.shape) != 2 or ax.shape[0] < 2 or ax.shape[1] < 2:
        raise ValueError("The axes should have at least 2 rows and 2 columns.")
    else:
        fig = ax[0, 0].get_figure()

    ax[0, 0].set_title("Attribution Heatmap")
    ax[0, 0].grid(False)
    ax[0, 0].axis("off")

    fig.suptitle(f"HSI Attributes of: {rotated_attributes_dataclass.attribution_method}")

    _ = viz.visualize_image_attr(
        rotated_attributes,
        method="heat_map",
        sign="all",
        plt_fig_axis=(fig, ax[0, 0]),
        show_colorbar=True,
        use_pyplot=False,
    )

    ax[0, 1].set_title("Attribution Module Values")
    ax[0, 1].grid(False)
    ax[0, 1].axis("off")

    # Attributions module values
    _ = viz.visualize_image_attr(
        rotated_attributes,
        method="heat_map",
        sign="absolute_value",
        plt_fig_axis=(fig, ax[0, 1]),
        show_colorbar=True,
        use_pyplot=False,
    )

    attr_all = rotated_attributes.sum(axis=(0, 1))
    ax[1, 0].scatter(rotated_attributes_dataclass.hsi.wavelengths, attr_all, c="r")
    ax[1, 0].set_title("Spectral Attribution")
    ax[1, 0].set_xlabel("Wavelength")
    ax[1, 0].set_ylabel("Attribution")
    ax[1, 0].grid(True)

    attr_abs = np.abs(rotated_attributes).sum(axis=(0, 1))
    ax[1, 1].scatter(rotated_attributes_dataclass.hsi.wavelengths, attr_abs, c="b")
    ax[1, 1].set_title("Spectral Attribution Absolute Values")
    ax[1, 1].set_xlabel("Wavelength")
    ax[1, 1].set_ylabel("Attribution Absolute Value")
    ax[1, 1].grid(True)

    plt.tight_layout()

    if used_ax:
        return ax

    if use_pyplot:
        plt.show()  # pragma: no cover
        return None  # pragma: no cover

    return fig, ax

visualize_spatial_aggregated_attributes(attributes, aggregated_mask, ax=None, use_pyplot=False, aggregate_func=torch.mean)

Visualizes the spatial attributes of an hsi object aggregated by a custom mask.

Parameters:

Name Type Description Default
attributes HSIAttributes

The spatial attributes of the hsi object to visualize.

required
aggregated_mask Tensor | ndarray

The mask used to aggregate the spatial attributes.

required
ax Axes | None

The axes object to plot the visualization on. If None, a new axes will be created.

None
use_pyplot bool

If True, displays the visualization using pyplot. If ax is not None, use_pyplot is ignored. If False, returns the figure and axes objects. Defaults to False.

False
aggregate_func Callable[[Tensor], Tensor]

The aggregation function to be applied. The function should take a tensor as input and return a tensor as output. We recommend using torch functions. Defaults to torch.mean.

mean

Raises:

Type Description
ShapeMismatchError

If the shape of the aggregated mask does not match the shape of the spatial attributes.

Returns:

Type Description
tuple[Figure, Axes] | Axes | None

tuple[Figure, Axes] | Axes | None: If ax is not None, returns the axes object. If use_pyplot is True, returns None. If use_pyplot is False, returns the figure and axes objects.

Source code in src/meteors/visualize/attr_visualize.py
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
def visualize_spatial_aggregated_attributes(
    attributes: HSIAttributes,
    aggregated_mask: torch.Tensor | np.ndarray,
    ax: Axes | None = None,
    use_pyplot: bool = False,
    aggregate_func: Callable[[torch.Tensor], torch.Tensor] = torch.mean,
) -> tuple[Figure, Axes] | Axes | None:
    """Visualizes the spatial attributes of an hsi object aggregated by a custom mask.

    Args:
        attributes (HSIAttributes): The spatial attributes of the hsi object to visualize.
        aggregated_mask (torch.Tensor | np.ndarray): The mask used to aggregate the spatial attributes.
        ax (Axes | None, optional): The axes object to plot the visualization on. If None, a new axes will be created.
        use_pyplot (bool, optional): If True, displays the visualization using pyplot.
            If ax is not None, use_pyplot is ignored.
            If False, returns the figure and axes objects. Defaults to False.
        aggregate_func (Callable[[torch.Tensor], torch.Tensor], optional): The aggregation function to be applied.
            The function should take a tensor as input and return a tensor as output.
            We recommend using torch functions. Defaults to torch.mean.

    Raises:
        ShapeMismatchError: If the shape of the aggregated mask does not match the shape of the spatial attributes.

    Returns:
        tuple[Figure, Axes] | Axes | None: If ax is not None, returns the axes object.
            If use_pyplot is True, returns None. If use_pyplot is False, returns the figure and axes objects.
    """
    if isinstance(aggregated_mask, np.ndarray):
        aggregated_mask = torch.from_numpy(aggregated_mask)

    if aggregated_mask.shape != attributes.hsi.image.shape:
        aggregated_mask = aggregated_mask.expand_as(attributes.attributes)

    new_attrs = aggregate_by_mask(attributes.attributes, aggregated_mask, aggregate_func)

    new_spatial_attributes = HSISpatialAttributes(
        hsi=attributes.hsi,
        attributes=new_attrs,
        mask=aggregated_mask,
        score=attributes.score,
    )

    out = visualize_spatial_attributes(new_spatial_attributes, ax=ax, use_pyplot=False)
    if ax is not None:
        return out

    fig, ax = out  # type: ignore
    fig.suptitle("Spatial Attributes Visualization Aggregated")

    if use_pyplot:
        plt.show()  # pragma: no cover
        return None  # pragma: no cover

    return fig, ax

visualize_spectral_aggregated_attributes(attributes, band_names, band_mask, ax=None, use_pyplot=False, color_palette=None, show_not_included=True, aggregate_func=torch.mean)

Visualizes the spectral attributes of an hsi object aggregated by a custom band mask.

Parameters:

Name Type Description Default
attributes HSIAttributes | list[HSIAttributes]

The spectral attributes of the hsi object to visualize.

required
band_names dict[str | tuple[str, ...], int]

A dictionary mapping band names to their indices.

required
band_mask Tensor | ndarray

The mask used to aggregate the spectral attributes.

required
ax Axes | None

The axes object to plot the visualization on. If None, a new axes will be created.

None
use_pyplot bool

If True, displays the visualization using pyplot. If ax is not None, use_pyplot is ignored. If False, returns the figure and axes objects. Defaults to False.

False
color_palette list[str] | None

The color palette to use for visualizing different spectral bands. If None, a default color palette is used. Defaults to None.

None
show_not_included bool

If True, includes the spectral bands that are not included in the visualization. If False, only includes the spectral bands that are included in the visualization. Defaults to True.

True
aggregate_func Callable[[Tensor], Tensor]

The aggregation function to be applied. The function should take a tensor as input and return a tensor as output. We recommend using torch functions. Defaults to torch.mean.

mean

Raises:

Type Description
ShapeMismatchError

If the shape of the band mask does not match the shape of the spectral attributes.

Returns:

Type Description
tuple[Figure, Axes] | Axes | None

tuple[Figure, Axes] | Axes | None: If ax is not None, returns the axes object. If use_pyplot is True, returns None. If use_pyplot is False, returns the figure and axes objects

Source code in src/meteors/visualize/attr_visualize.py
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
def visualize_spectral_aggregated_attributes(
    attributes: HSIAttributes | list[HSIAttributes],
    band_names: dict[str | tuple[str, ...], int],
    band_mask: torch.Tensor | np.ndarray,
    ax: Axes | None = None,
    use_pyplot: bool = False,
    color_palette: list[str] | None = None,
    show_not_included: bool = True,
    aggregate_func: Callable[[torch.Tensor], torch.Tensor] = torch.mean,
) -> tuple[Figure, Axes] | Axes | None:
    """Visualizes the spectral attributes of an hsi object aggregated by a custom band mask.

    Args:
        attributes (HSIAttributes | list[HSIAttributes]): The spectral attributes of the hsi object to visualize.
        band_names (dict[str | tuple[str, ...], int]): A dictionary mapping band names to their indices.
        band_mask (torch.Tensor | np.ndarray): The mask used to aggregate the spectral attributes.
        ax (Axes | None, optional): The axes object to plot the visualization on. If None, a new axes will be created.
        use_pyplot (bool, optional): If True, displays the visualization using pyplot.
            If ax is not None, use_pyplot is ignored. If False, returns the figure and axes objects. Defaults to False.
        color_palette (list[str] | None, optional): The color palette to use for visualizing different spectral bands.
            If None, a default color palette is used. Defaults to None.
        show_not_included (bool, optional): If True, includes the spectral bands that are not included in the visualization.
            If False, only includes the spectral bands that are included in the visualization. Defaults to True.
        aggregate_func (Callable[[torch.Tensor], torch.Tensor], optional): The aggregation function to be applied.
            The function should take a tensor as input and return a tensor as output.
            We recommend using torch functions. Defaults to torch.mean.

    Raises:
        ShapeMismatchError: If the shape of the band mask does not match the shape of the spectral attributes.

    Returns:
        tuple[Figure, Axes] | Axes | None: If ax is not None, returns the axes object.
            If use_pyplot is True, returns None. If use_pyplot is False, returns the figure and axes objects
    """
    attributes_example = attributes if isinstance(attributes, HSIAttributes) else attributes[0]
    if isinstance(band_mask, np.ndarray):
        band_mask = torch.from_numpy(band_mask)

    if band_mask.shape != attributes_example.hsi.image.shape:
        band_mask = expand_spectral_mask(attributes_example.hsi, band_mask, repeat_dimensions=True)

    band_names = align_band_names_with_mask(band_names, band_mask)

    new_attrs = aggregate_by_mask(attributes_example.attributes, band_mask, aggregate_func)

    new_spectral_attributes: HSISpectralAttributes | list[HSISpectralAttributes]
    if isinstance(attributes, HSIAttributes):
        new_spectral_attributes = HSISpectralAttributes(
            hsi=attributes.hsi,
            attributes=new_attrs,
            mask=band_mask,
            band_names=band_names,
            score=attributes.score,
        )
    else:
        new_spectral_attributes = [
            HSISpectralAttributes(
                hsi=attr.hsi,
                attributes=new_attrs,
                mask=band_mask,
                band_names=band_names,
                score=attr.score,
            )
            for attr in attributes
        ]

    out = visualize_spectral_attributes(
        new_spectral_attributes,
        ax=ax,
        use_pyplot=False,
        color_palette=color_palette,
        show_not_included=show_not_included,
    )  # type: ignore
    if ax is not None:
        return out

    if use_pyplot:
        plt.show()  # pragma: no cover
        return None  # pragma: no cover

    return out

visualize_aggregated_attributes(attributes, mask, band_names=None, ax=None, use_pyplot=False, color_palette=None, show_not_included=True, aggregate_func=torch.mean)

Visualizes the aggregated attributes of an hsi object.

Parameters:

Name Type Description Default
attributes HSIAttributes | list[HSIAttributes]

The attributes of the hsi object to visualize.

required
mask Tensor | ndarray

The mask used to aggregate the attributes.

required
band_names dict[str | tuple[str, ...], int] | None

A dictionary mapping band names to their indices. If None, the visualization will be spatially aggregated. Defaults to None.

None
ax Axes | None

The axes object to plot the visualization on. If None, a new axes will be created.

None
use_pyplot bool

If True, displays the visualization using pyplot. If ax is not None, use_pyplot is ignored. If False, returns the figure and axes objects. Defaults to False.

False
color_palette list[str] | None

The color palette to use for visualizing different spectral bands. If None, a default color palette is used. Defaults to None.

None
show_not_included bool

If True, includes the spectral bands that are not included in the visualization. If False, only includes the spectral bands that are included in the visualization. Defaults to True.

True
aggregate_func Callable[[Tensor], Tensor]

The aggregation function to be applied. The function should take a tensor as input and return a tensor as output. We recommend using torch functions. Defaults to torch.mean.

mean

Raises:

Type Description
ValueError

If the shape of the mask does not match the shape of the attributes.

AssertionError

If band_names is None and attributes is a list of HSIAttributes objects.

Returns:

Type Description
tuple[Figure, Axes] | Axes | None

tuple[Figure, Axes] | Axes | None: If ax is not None, returns the axes object. If use_pyplot is True, returns None. If use_pyplot is False, returns the figure and axes objects.

Source code in src/meteors/visualize/attr_visualize.py
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
def visualize_aggregated_attributes(
    attributes: HSIAttributes | list[HSIAttributes],
    mask: torch.Tensor | np.ndarray,
    band_names: dict[str | tuple[str, ...], int] | None = None,
    ax: Axes | None = None,
    use_pyplot: bool = False,
    color_palette: list[str] | None = None,
    show_not_included: bool = True,
    aggregate_func: Callable[[torch.Tensor], torch.Tensor] = torch.mean,
) -> tuple[Figure, Axes] | Axes | None:
    """Visualizes the aggregated attributes of an hsi object.

    Args:
        attributes (HSIAttributes | list[HSIAttributes]): The attributes of the hsi object to visualize.
        mask (torch.Tensor | np.ndarray): The mask used to aggregate the attributes.
        band_names (dict[str | tuple[str, ...], int] | None, optional): A dictionary mapping band names to their indices.
            If None, the visualization will be spatially aggregated. Defaults to None.
        ax (Axes | None, optional): The axes object to plot the visualization on. If None, a new axes will be created.
        use_pyplot (bool, optional): If True, displays the visualization using pyplot.
            If ax is not None, use_pyplot is ignored. If False, returns the figure and axes objects. Defaults to False.
        color_palette (list[str] | None, optional): The color palette to use for visualizing different spectral bands.
            If None, a default color palette is used. Defaults to None.
        show_not_included (bool, optional): If True, includes the spectral bands that are not included in the visualization.
            If False, only includes the spectral bands that are included in the visualization. Defaults to True.
        aggregate_func (Callable[[torch.Tensor], torch.Tensor], optional): The aggregation function to be applied.
            The function should take a tensor as input and return a tensor as output.
            We recommend using torch functions. Defaults to torch.mean.

    Raises:
        ValueError: If the shape of the mask does not match the shape of the attributes.
        AssertionError: If band_names is None and attributes is a list of HSIAttributes objects.

    Returns:
        tuple[Figure, Axes] | Axes | None: If ax is not None, returns the axes object.
            If use_pyplot is True, returns None. If use_pyplot is False, returns the figure and axes objects.
    """
    agg = False if isinstance(attributes, HSIAttributes) else True
    if band_names is None:
        logger.info("Band names not provided. Using Spatial Analysis.")
        assert not agg, "In Spatial Analysis, attributes must be a single HSIAttributes object."
        return visualize_spatial_aggregated_attributes(attributes, mask, ax, use_pyplot, aggregate_func)  # type: ignore
    else:
        logger.info("Band names provided. Using Spectral Analysis.")
        return visualize_spectral_aggregated_attributes(
            attributes, band_names, mask, ax, use_pyplot, color_palette, show_not_included, aggregate_func
        )

visualize_spectral_attributes_by_waveband(spectral_attributes, ax, color_palette=None, show_not_included=True, show_legend=True)

Visualizes spectral attributes by waveband.

Parameters:

Name Type Description Default
spectral_attributes HSISpectralAttributes | list[HSISpectralAttributes]

The spectral attributes to visualize.

required
ax Axes | None

The matplotlib axes to plot the visualization on. If None, a new axes will be created.

required
color_palette list[str] | None

The color palette to use for plotting. If None, a default color palette will be used.

None
show_not_included bool

Whether to show the "not_included" band in the visualization. Default is True.

True
show_legend bool

Whether to show the legend in the visualization.

True

Returns:

Name Type Description
Axes Axes

The matplotlib axes object containing the visualization.

Raises: TypeError: If the spectral attributes are not an HSISpectralAttributes object or a list of HSISpectralAttributes objects.

Source code in src/meteors/visualize/attr_visualize.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def visualize_spectral_attributes_by_waveband(
    spectral_attributes: HSISpectralAttributes | list[HSISpectralAttributes],
    ax: Axes | None,
    color_palette: list[str] | None = None,
    show_not_included: bool = True,
    show_legend: bool = True,
) -> Axes:
    """Visualizes spectral attributes by waveband.

    Args:
        spectral_attributes (HSISpectralAttributes | list[HSISpectralAttributes]):
            The spectral attributes to visualize.
        ax (Axes | None): The matplotlib axes to plot the visualization on.
            If None, a new axes will be created.
        color_palette (list[str] | None): The color palette to use for plotting.
            If None, a default color palette will be used.
        show_not_included (bool): Whether to show the "not_included" band in the visualization.
            Default is True.
        show_legend (bool): Whether to show the legend in the visualization.

    Returns:
        Axes: The matplotlib axes object containing the visualization.
    Raises:
        TypeError: If the spectral attributes are not an HSISpectralAttributes object or a list of HSISpectralAttributes objects.
    """
    if isinstance(spectral_attributes, HSISpectralAttributes):
        spectral_attributes = [spectral_attributes]
    if not (
        isinstance(spectral_attributes, list)
        and all(isinstance(attr, HSISpectralAttributes) for attr in spectral_attributes)
    ):
        raise TypeError(
            "spectral_attributes parameter must be an HSISpectralAttributes object or a list of HSISpectralAttributes objects."
        )

    aggregate_results = False if len(spectral_attributes) == 1 else True
    band_names = dict(spectral_attributes[0].band_names)
    wavelengths = spectral_attributes[0].hsi.wavelengths
    validate_consistent_band_and_wavelengths(band_names, wavelengths, spectral_attributes)

    ax = setup_visualization(ax, "Attributions by Waveband", "Wavelength (nm)", "Correlation with Output")

    if not show_not_included and band_names.get("not_included") is not None:
        band_names.pop("not_included")

    band_names = _merge_band_names_segments(band_names)  # type: ignore

    if color_palette is None:
        color_palette = sns.color_palette("hsv", len(band_names.keys()))

    band_mask = spectral_attributes[0].band_mask.cpu()
    attribution_map = torch.stack([attr.flattened_attributes.cpu() for attr in spectral_attributes])

    for idx, (band_name, segment_id) in enumerate(band_names.items()):
        current_wavelengths = wavelengths[band_mask == segment_id]
        current_attribution_map = attribution_map[:, band_mask == segment_id]

        current_mean = current_attribution_map.numpy().mean(axis=0)
        if aggregate_results:
            lolims = current_attribution_map.numpy().min(axis=0)
            uplims = current_attribution_map.numpy().max(axis=0)

            ax.errorbar(
                current_wavelengths.numpy(),
                current_mean,
                yerr=[current_mean - lolims, uplims - current_mean],
                label=band_name,
                color=color_palette[idx],
                linestyle="--",
                marker="o",
                markersize=5,
            )
        else:
            ax.scatter(
                current_wavelengths.numpy(),
                current_mean,
                label=band_name,
                color=color_palette[idx],
            )

    if show_legend:
        ax.legend(title="SuperBand")

    return ax

visualize_spectral_attributes_by_magnitude(spectral_attributes, ax, color_palette=None, annotate_bars=True, show_not_included=True)

Visualizes the spectral attributes by magnitude.

Parameters:

Name Type Description Default
spectral_attributes HSISpectralAttributes | list[HSISpectralAttributes]

The spectral attributes to visualize.

required
ax Axes | None

The matplotlib Axes object to plot the visualization on. If None, a new Axes object will be created.

required
color_palette list[str] | None

The color palette to use for the visualization. If None, a default color palette will be used.

None
annotate_bars bool

Whether to annotate the bars with their magnitudes. Defaults to True.

True
show_not_included bool

Whether to show the 'not_included' band in the visualization. Defaults to True.

True

Returns:

Name Type Description
Axes Axes

The matplotlib Axes object containing the visualization.

Raises: TypeError: If the spectral attributes are not an HSISpectralAttributes object or a list of HSISpectralAttributes objects.

Source code in src/meteors/visualize/attr_visualize.py
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
def visualize_spectral_attributes_by_magnitude(
    spectral_attributes: HSISpectralAttributes | list[HSISpectralAttributes],
    ax: Axes | None,
    color_palette: list[str] | None = None,
    annotate_bars: bool = True,
    show_not_included: bool = True,
) -> Axes:
    """Visualizes the spectral attributes by magnitude.

    Args:
        spectral_attributes (HSISpectralAttributes | list[HSISpectralAttributes]):
            The spectral attributes to visualize.
        ax (Axes | None): The matplotlib Axes object to plot the visualization on.
            If None, a new Axes object will be created.
        color_palette (list[str] | None): The color palette to use for the visualization.
            If None, a default color palette will be used.
        annotate_bars (bool): Whether to annotate the bars with their magnitudes.
            Defaults to True.
        show_not_included (bool): Whether to show the 'not_included' band in the visualization.
            Defaults to True.

    Returns:
        Axes: The matplotlib Axes object containing the visualization.
    Raises:
        TypeError: If the spectral attributes are not an HSISpectralAttributes object or a list of HSISpectralAttributes objects.
    """
    if isinstance(spectral_attributes, HSISpectralAttributes):
        spectral_attributes = [spectral_attributes]
    if not (
        isinstance(spectral_attributes, list)
        and all(isinstance(attr, HSISpectralAttributes) for attr in spectral_attributes)
    ):
        raise TypeError(
            "spectral_attributes parameter must be an HSISpectralAttributes object or a list of HSISpectralAttributes objects."
        )

    aggregate_results = False if len(spectral_attributes) == 1 else True
    band_names = dict(spectral_attributes[0].band_names)
    wavelengths = spectral_attributes[0].hsi.wavelengths
    validate_consistent_band_and_wavelengths(band_names, wavelengths, spectral_attributes)

    ax = setup_visualization(ax, "Attributions by Magnitude", "Group", "Average Attribution Magnitude")
    ax.tick_params(axis="x", rotation=45)

    band_names = _merge_band_names_segments(band_names)  # type: ignore
    labels = list(band_names.keys())

    if not show_not_included and band_names.get("not_included") is not None:
        band_names.pop("not_included")
        labels = list(band_names.keys())

    if color_palette is None:
        color_palette = sns.color_palette("hsv", len(band_names.keys()))

    band_mask = spectral_attributes[0].band_mask.cpu()
    attribution_map = torch.stack([attr.flattened_attributes.cpu() for attr in spectral_attributes])
    avg_magnitudes = calculate_average_magnitudes(band_names, band_mask, attribution_map)

    if aggregate_results:
        boxplot = ax.boxplot(avg_magnitudes, labels=labels, patch_artist=True)
        for patch, color in zip(boxplot["boxes"], color_palette):
            patch.set_facecolor(color)

    else:
        bars = ax.bar(labels, avg_magnitudes, color=color_palette)
        if annotate_bars:
            for bar in bars:
                height = bar.get_height()
                ax.annotate(
                    f"{height:.2f}",
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha="center",
                    va="bottom",
                )
    return ax

visualize_spectral_attributes(spectral_attributes, ax=None, use_pyplot=False, color_palette=None, show_not_included=True)

Visualizes the spectral attributes of an hsi object or a list of hsi objects.

Parameters:

Name Type Description Default
spectral_attributes HSISpectralAttributes | list[HSISpectralAttributes]

The spectral attributes of the image object to visualize.

required
ax Axes | None

The axes object to plot the visualization on. If None, a new axes will be created.

None
use_pyplot bool

If ax is not None, use_pyplot is ignored. If True, displays the visualization using pyplot. If False, returns the figure and axes objects. Defaults to False.

False
color_palette list[str] | None

The color palette to use for visualizing different spectral bands. If None, a default color palette is used. Defaults to None.

None
show_not_included bool

If True, includes the spectral bands that are not included in the visualization. If False, only includes the spectral bands that are included in the visualization. Defaults to True.

True

Returns:

Type Description
tuple[Figure, Axes] | Axes | None

tuple[matplotlib.figure.Figure, matplotlib.axes.Axes] | matplotlib.axes.Axes | None: If ax is not None, returns the axes object. If use_pyplot is True, returns None. If use_pyplot is False, returns the figure and axes objects.

Raises:

Type Description
ValueError

If ax is provided as a single axes object and not a list of axes objects.

ValueError

If agg is True and the axes have less than 3 rows or 3 columns.

ValueError

If agg is False and the axes have less than 2 rows or 2 columns.

Source code in src/meteors/visualize/attr_visualize.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def visualize_spectral_attributes(
    spectral_attributes: HSISpectralAttributes | list[HSISpectralAttributes],
    ax: Axes | None = None,
    use_pyplot: bool = False,
    color_palette: list[str] | None = None,
    show_not_included: bool = True,
) -> tuple[Figure, Axes] | Axes | None:
    """Visualizes the spectral attributes of an hsi object or a list of hsi objects.

    Args:
        spectral_attributes (HSISpectralAttributes | list[HSISpectralAttributes]):
            The spectral attributes of the image object to visualize.
        ax (Axes | None, optional):
            The axes object to plot the visualization on. If None, a new axes will be created.
        use_pyplot (bool, optional):
            If ax is not None, use_pyplot is ignored.
            If True, displays the visualization using pyplot.
            If False, returns the figure and axes objects. Defaults to False.
        color_palette (list[str] | None, optional):
            The color palette to use for visualizing different spectral bands.
            If None, a default color palette is used.
            Defaults to None.
        show_not_included (bool, optional):
            If True, includes the spectral bands that are not included in the visualization.
            If False, only includes the spectral bands that are included in the visualization.
            Defaults to True.

    Returns:
        tuple[matplotlib.figure.Figure, matplotlib.axes.Axes] | matplotlib.axes.Axes | None:
            If ax is not None, returns the axes object.
            If use_pyplot is True, returns None.
            If use_pyplot is False, returns the figure and axes objects.

    Raises:
        ValueError: If ax is provided as a single axes object and not a list of axes objects.
        ValueError: If agg is True and the axes have less than 3 rows or 3 columns.
        ValueError: If agg is False and the axes have less than 2 rows or 2 columns.
    """
    agg = True if isinstance(spectral_attributes, list) else False
    band_names = spectral_attributes[0].band_names if agg else spectral_attributes.band_names  # type: ignore

    color_palette = color_palette or sns.color_palette("hsv", len(band_names.keys()))

    use_ax = True
    if ax is None:
        use_ax = False
        fig, ax = plt.subplots(1, 3 if agg else 2, figsize=(15, 5))
        fig.suptitle("Spectral Attributes Visualization")

    if not hasattr(ax, "shape"):
        raise ValueError("Provided as is one axes object, but it should be a list of axes objects")
    if agg and (len(ax.shape) != 1 or ax.shape[0] < 3):
        raise ValueError("The axes should have at least 3 rows or 3 columns if agg is True")
    if not agg and (len(ax.shape) != 1 or ax.shape[0] < 2):
        raise ValueError("The axes should have at least 2 rows or 2 columns if agg is False")

    visualize_spectral_attributes_by_waveband(
        spectral_attributes,
        ax[0],
        color_palette=color_palette,
        show_not_included=show_not_included,
        show_legend=False,
    )

    visualize_spectral_attributes_by_magnitude(
        spectral_attributes,
        ax[1],
        color_palette=color_palette,
        show_not_included=show_not_included,
    )

    if agg:
        scores = [attr.score for attr in spectral_attributes]  # type: ignore
        mean_score = sum(scores) / len(scores)  # type: ignore
        ax[2].hist(scores, bins=50, color="steelblue", alpha=0.7)
        ax[2].axvline(mean_score, color="darkred", linestyle="dashed")

        ax[2].set_title("Distribution of Score Values")
        ax[2].set_xlabel("Score")
        ax[2].set_ylabel("Frequency")

    if use_ax:
        return ax

    if use_pyplot:
        plt.show()  # pragma: no cover
        return None  # pragma: no cover

    return fig, ax

visualize_spatial_attributes(spatial_attributes, ax=None, use_pyplot=False)

Visualizes the spatial attributes of an hsi using Lime attribution.

Parameters:

Name Type Description Default
spatial_attributes HSISpatialAttributes

The spatial attributes of the image object to visualize.

required
ax Axes | None

The axes object to plot the visualization on. If None, a new axes will be created.

None
use_pyplot bool

Whether to use pyplot for visualization. Defaults to False.

False

Returns:

Type Description
tuple[Figure, Axes] | Axes | None

tuple[matplotlib.figure.Figure, matplotlib.axes.Axes] | matplotlib.axes.Axes | None: If ax is not None, returns the axes object. If use_pyplot is True, returns None. If use_pyplot is False, returns the figure and axes objects.

Raises:

Type Description
ValueError

If the axes have less 3 rows or 3 columns

ValueError

If the axes object is not a list of axes objects

Source code in src/meteors/visualize/attr_visualize.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def visualize_spatial_attributes(
    spatial_attributes: HSISpatialAttributes, ax: Axes | None = None, use_pyplot: bool = False
) -> tuple[Figure, Axes] | Axes | None:
    """Visualizes the spatial attributes of an hsi using Lime attribution.

    Args:
        spatial_attributes (HSISpatialAttributes):
            The spatial attributes of the image object to visualize.
        ax (Axes | None, optional):
            The axes object to plot the visualization on. If None, a new axes will be created.
        use_pyplot (bool, optional):
            Whether to use pyplot for visualization. Defaults to False.

    Returns:
        tuple[matplotlib.figure.Figure, matplotlib.axes.Axes] | matplotlib.axes.Axes | None:
            If ax is not None, returns the axes object.
            If use_pyplot is True, returns None.
            If use_pyplot is False, returns the figure and axes objects.

    Raises:
        ValueError: If the axes have less 3 rows or 3 columns
        ValueError: If the axes object is not a list of axes objects
    """
    mask_enabled = spatial_attributes.segmentation_mask is not None
    use_ax = True
    if ax is None:
        use_ax = False
        fig, ax = plt.subplots(1, 3 if mask_enabled else 2, figsize=(15, 5))
        fig.suptitle("Spatial Attributes Visualization")

    if not hasattr(ax, "shape"):
        raise ValueError("Provided as is one axes object, but it should be a list of axes objects")
    elif len(ax.shape) != 1 or ax.shape[0] < 3:
        raise ValueError("The axes should have at least 3 rows or 3 columns")
    else:
        fig = ax[0].get_figure()

    spatial_attributes = spatial_attributes.change_orientation("HWC", inplace=False)

    if mask_enabled:
        mask = spatial_attributes.segmentation_mask.cpu()

        group_names = mask.unique().tolist()
        colors = sns.color_palette("hsv", len(group_names))
        color_map = dict(zip(group_names, colors))

        for unique in group_names:
            segment_indices = torch.argwhere(mask == unique)

            y_center, x_center = segment_indices.numpy().mean(axis=0).astype(int)
            ax[1].text(x_center, y_center, str(unique), color=color_map[unique], fontsize=8, ha="center", va="center")
            ax[2].text(x_center, y_center, str(unique), color=color_map[unique], fontsize=8, ha="center", va="center")

        ax[2].imshow(mask.numpy() / mask.max(), cmap="gray")
        ax[2].set_title("Mask")
        ax[2].grid(False)
        ax[2].axis("off")

    ax[0].imshow(spatial_attributes.hsi.get_rgb_image(output_channel_axis=2).cpu())
    ax[0].set_title("Original image")
    ax[0].grid(False)
    ax[0].axis("off")

    attrs = spatial_attributes.attributes.cpu().numpy()
    if np.all(attrs == 0):
        logger.warning("All spatial attributes are zero.")
        cmap = LinearSegmentedColormap.from_list("RdWhGn", ["red", "white", "green"])
        heat_map = ax[1].imshow(attrs.sum(axis=-1), cmap=cmap, vmin=-1, vmax=1)

        axis_separator = make_axes_locatable(ax[1])
        colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.1)
        fig.colorbar(heat_map, orientation="horizontal", cax=colorbar_axis)
    else:
        viz.visualize_image_attr(
            attrs,
            method="heat_map",
            sign="all",
            plt_fig_axis=(fig, ax[1]),
            show_colorbar=True,
            use_pyplot=False,
        )
    ax[1].set_title("Attribution Map")
    ax[1].axis("off")

    if use_ax:
        return ax

    if use_pyplot:
        plt.show()  # pragma: no cover
        return None  # pragma: no cover
    else:
        return fig, ax

Attribution Methods

HSIAttributes

Bases: BaseModel

Represents an object that contains Hyperspectral image attributes and explanations.

Attributes:

Name Type Description
hsi HSI

Hyperspectral image object for which the explanations were created.

attributes Tensor

Attributions (explanations) for the hsi.

score float

The score provided by the interpretable model. Can be None if method don't provide one.

device device

Device to be used for inference. If None, the device of the input hsi will be used. Defaults to None.

attribution_method str | None

The method used to generate the explanation. Defaults to None.

Source code in src/meteors/attr/attributes.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
class HSIAttributes(BaseModel):
    """Represents an object that contains Hyperspectral image attributes and explanations.

    Attributes:
        hsi (HSI): Hyperspectral image object for which the explanations were created.
        attributes (torch.Tensor): Attributions (explanations) for the hsi.
        score (float): The score provided by the interpretable model. Can be None if method don't provide one.
        device (torch.device): Device to be used for inference. If None, the device of the input hsi will be used.
            Defaults to None.
        attribution_method (str | None): The method used to generate the explanation. Defaults to None.
    """

    hsi: Annotated[
        HSI,
        Field(
            description="Hyperspectral image object for which the explanations were created.",
        ),
    ]
    attributes: Annotated[
        torch.Tensor,
        BeforeValidator(validate_and_convert_attributes),
        Field(
            description="Attributions (explanations) for the hsi.",
        ),
    ]
    attribution_method: Annotated[
        str | None,
        AfterValidator(validate_attribution_method),
        Field(
            description="The method used to generate the explanation.",
        ),
    ] = None
    score: Annotated[
        float | None,
        Field(
            validate_default=True,
            description="The score provided by the interpretable model. Can be None if method don't provide one.",
        ),
    ] = None
    mask: Annotated[
        torch.Tensor | None,
        BeforeValidator(validate_and_convert_mask),
        Field(
            description="`superpixel` or `superband` mask used for the explanation.",
        ),
    ] = None
    device: Annotated[
        torch.device,
        BeforeValidator(resolve_inference_device_attributes),
        Field(
            validate_default=True,
            exclude=True,
            description=(
                "Device to be used for inference. If None, the device of the input hsi will be used. "
                "Defaults to None."
            ),
        ),
    ] = None

    @property
    def flattened_attributes(self) -> torch.Tensor:
        """Returns a flattened tensor of attributes.

        This method should be implemented in the subclass.

        Returns:
            torch.Tensor: A flattened tensor of attributes.
        """
        raise NotImplementedError("The `flattened_attributes` property must be implemented in the subclass")

    model_config = ConfigDict(arbitrary_types_allowed=True)

    @property
    def orientation(self) -> tuple[str, str, str]:
        """Returns the orientation of the hsi.

        Returns:
            tuple[str, str, str]: The orientation of the hsi corresponding to the attributes.
        """
        return self.hsi.orientation

    def _validate_hsi_attributions_and_mask(self) -> None:
        """Validates the hsi attributions and performs necessary operations to ensure compatibility with the device.

        Raises:
            ValueError: If the shapes of the attributes and hsi tensors do not match.
        """
        validate_shapes(self.attributes, self.hsi)

        self.attributes = self.attributes.to(self.device)
        if self.device != self.hsi.device:
            self.hsi.to(self.device)

        if self.mask is not None:
            validate_shapes(self.mask, self.hsi)
            self.mask = self.mask.to(self.device)

    @model_validator(mode="after")
    def validate_hsi_attributions(self) -> Self:
        """Validates the hsi attributions.

        This method performs validation on the hsi attributions to ensure they are correct.

        Returns:
            Self: The current instance of the class.
        """
        self._validate_hsi_attributions_and_mask()
        return self

    def to(self, device: str | torch.device) -> Self:
        """Move the hsi and attributes tensors to the specified device.

        Args:
            device (str or torch.device): The device to move the tensors to.

        Returns:
            Self: The modified object with tensors moved to the specified device.

        Examples:
            >>> attrs = HSIAttributes(hsi, attributes, score=0.5)
            >>> attrs.to("cpu")
            >>> attrs.hsi.device
            device(type='cpu')
            >>> attrs.attributes.device
            device(type='cpu')
            >>> attrs.to("cuda")
            >>> attrs.hsi.device
            device(type='cuda')
            >>> attrs.attributes.device
            device(type='cuda')
        """
        self.hsi = self.hsi.to(device)
        self.attributes = self.attributes.to(device)
        self.device = self.hsi.device
        return self

    def change_orientation(self, target_orientation: tuple[str, str, str] | list[str] | str, inplace=False) -> Self:
        """Changes the orientation of the image data along with the attributions to the target orientation.

        Args:
            target_orientation (tuple[str, str, str] | list[str] | str): The target orientation for the attribution data.
                This should be a tuple of three one-letter strings in any order: "C", "H", "W".
            inplace (bool, optional): Whether to modify the data in place or return a new object.

        Returns:
            Self: The updated Image object with the new orientation.

        Raises:
            OrientationError: If the target orientation is not a valid tuple of three one-letter strings.
        """
        current_orientation = self.orientation
        hsi = self.hsi.change_orientation(target_orientation, inplace=inplace)
        if inplace:
            attrs = self
        else:
            attrs = self.model_copy()
            attrs.hsi = hsi

        # now change the orientation of the attributes
        if current_orientation == target_orientation:
            return attrs

        permute_dims = [current_orientation.index(dim) for dim in target_orientation]

        attrs.attributes = attrs.attributes.permute(permute_dims)

        if attrs.mask is not None:
            attrs.mask = attrs.mask.permute(permute_dims)
        return attrs

flattened_attributes: torch.Tensor property

Returns a flattened tensor of attributes.

This method should be implemented in the subclass.

Returns:

Type Description
Tensor

torch.Tensor: A flattened tensor of attributes.

orientation: tuple[str, str, str] property

Returns the orientation of the hsi.

Returns:

Type Description
tuple[str, str, str]

tuple[str, str, str]: The orientation of the hsi corresponding to the attributes.

change_orientation(target_orientation, inplace=False)

Changes the orientation of the image data along with the attributions to the target orientation.

Parameters:

Name Type Description Default
target_orientation tuple[str, str, str] | list[str] | str

The target orientation for the attribution data. This should be a tuple of three one-letter strings in any order: "C", "H", "W".

required
inplace bool

Whether to modify the data in place or return a new object.

False

Returns:

Name Type Description
Self Self

The updated Image object with the new orientation.

Raises:

Type Description
OrientationError

If the target orientation is not a valid tuple of three one-letter strings.

Source code in src/meteors/attr/attributes.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
def change_orientation(self, target_orientation: tuple[str, str, str] | list[str] | str, inplace=False) -> Self:
    """Changes the orientation of the image data along with the attributions to the target orientation.

    Args:
        target_orientation (tuple[str, str, str] | list[str] | str): The target orientation for the attribution data.
            This should be a tuple of three one-letter strings in any order: "C", "H", "W".
        inplace (bool, optional): Whether to modify the data in place or return a new object.

    Returns:
        Self: The updated Image object with the new orientation.

    Raises:
        OrientationError: If the target orientation is not a valid tuple of three one-letter strings.
    """
    current_orientation = self.orientation
    hsi = self.hsi.change_orientation(target_orientation, inplace=inplace)
    if inplace:
        attrs = self
    else:
        attrs = self.model_copy()
        attrs.hsi = hsi

    # now change the orientation of the attributes
    if current_orientation == target_orientation:
        return attrs

    permute_dims = [current_orientation.index(dim) for dim in target_orientation]

    attrs.attributes = attrs.attributes.permute(permute_dims)

    if attrs.mask is not None:
        attrs.mask = attrs.mask.permute(permute_dims)
    return attrs

to(device)

Move the hsi and attributes tensors to the specified device.

Parameters:

Name Type Description Default
device str or device

The device to move the tensors to.

required

Returns:

Name Type Description
Self Self

The modified object with tensors moved to the specified device.

Examples:

>>> attrs = HSIAttributes(hsi, attributes, score=0.5)
>>> attrs.to("cpu")
>>> attrs.hsi.device
device(type='cpu')
>>> attrs.attributes.device
device(type='cpu')
>>> attrs.to("cuda")
>>> attrs.hsi.device
device(type='cuda')
>>> attrs.attributes.device
device(type='cuda')
Source code in src/meteors/attr/attributes.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
def to(self, device: str | torch.device) -> Self:
    """Move the hsi and attributes tensors to the specified device.

    Args:
        device (str or torch.device): The device to move the tensors to.

    Returns:
        Self: The modified object with tensors moved to the specified device.

    Examples:
        >>> attrs = HSIAttributes(hsi, attributes, score=0.5)
        >>> attrs.to("cpu")
        >>> attrs.hsi.device
        device(type='cpu')
        >>> attrs.attributes.device
        device(type='cpu')
        >>> attrs.to("cuda")
        >>> attrs.hsi.device
        device(type='cuda')
        >>> attrs.attributes.device
        device(type='cuda')
    """
    self.hsi = self.hsi.to(device)
    self.attributes = self.attributes.to(device)
    self.device = self.hsi.device
    return self

HSISpatialAttributes

Bases: HSIAttributes

Represents spatial attributes of an hsi used for explanation.

Attributes:

Name Type Description
hsi HSI

Hyperspectral image object for which the explanations were created.

attributes Tensor

Attributions (explanations) for the hsi.

score float

The score provided by the interpretable model. Can be None if method don't provide one.

device device

Device to be used for inference. If None, the device of the input hsi will be used. Defaults to None.

attribution_method str | None

The method used to generate the explanation. Defaults to None.

segmentation_mask Tensor

Spatial (Segmentation) mask used for the explanation.

flattened_attributes Tensor

Spatial 2D attribution map.

Source code in src/meteors/attr/attributes.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
class HSISpatialAttributes(HSIAttributes):
    """Represents spatial attributes of an hsi used for explanation.

    Attributes:
        hsi (HSI): Hyperspectral image object for which the explanations were created.
        attributes (torch.Tensor): Attributions (explanations) for the hsi.
        score (float): The score provided by the interpretable model. Can be None if method don't provide one.
        device (torch.device): Device to be used for inference. If None, the device of the input hsi will be used.
            Defaults to None.
        attribution_method (str | None): The method used to generate the explanation. Defaults to None.
        segmentation_mask (torch.Tensor): Spatial (Segmentation) mask used for the explanation.
        flattened_attributes (torch.Tensor): Spatial 2D attribution map.
    """

    @property
    def segmentation_mask(self) -> torch.Tensor:
        """Returns the 2D spatial segmentation mask that has the same size as the hsi image.

        Returns:
            torch.Tensor: The segmentation mask tensor.

        Raises:
            HSIAttributesError: If the segmentation mask is not provided in the attributes object.
        """
        if self.mask is None:
            raise HSIAttributesError("Segmentation mask is not provided in the attributes object")
        return self.mask.select(dim=self.hsi.spectral_axis, index=0)

    @property
    def flattened_attributes(self) -> torch.Tensor:
        """Returns a flattened tensor of attributes, with removed repeated dimensions.

        In the case of spatial attributes, the flattened attributes are 2D spatial attributes of shape (rows, columns) and the spectral dimension is removed.

        Examples:
            >>> segmentation_mask = torch.zeros((3, 2, 2))
            >>> attrs = HSISpatialAttributes(hsi, attributes, score=0.5, segmentation_mask=segmentation_mask)
            >>> attrs.flattened_attributes
                tensor([[0., 0.],
                        [0., 0.]])

        Returns:
            torch.Tensor: A flattened tensor of attributes.
        """
        return self.attributes.select(dim=self.hsi.spectral_axis, index=0)

    def _validate_hsi_attributions_and_mask(self) -> None:
        """Validates the hsi attributions and performs necessary operations to ensure compatibility with the device.

        Raises:
            HSIAttributesError: If the segmentation mask is not provided in the attributes object.
        """
        super()._validate_hsi_attributions_and_mask()
        if self.mask is None:
            raise HSIAttributesError("Segmentation mask is not provided in the attributes object")

flattened_attributes: torch.Tensor property

Returns a flattened tensor of attributes, with removed repeated dimensions.

In the case of spatial attributes, the flattened attributes are 2D spatial attributes of shape (rows, columns) and the spectral dimension is removed.

Examples:

>>> segmentation_mask = torch.zeros((3, 2, 2))
>>> attrs = HSISpatialAttributes(hsi, attributes, score=0.5, segmentation_mask=segmentation_mask)
>>> attrs.flattened_attributes
    tensor([[0., 0.],
            [0., 0.]])

Returns:

Type Description
Tensor

torch.Tensor: A flattened tensor of attributes.

segmentation_mask: torch.Tensor property

Returns the 2D spatial segmentation mask that has the same size as the hsi image.

Returns:

Type Description
Tensor

torch.Tensor: The segmentation mask tensor.

Raises:

Type Description
HSIAttributesError

If the segmentation mask is not provided in the attributes object.

HSISpectralAttributes

Bases: HSIAttributes

Represents an hsi with spectral attributes used for explanation.

Attributes:

Name Type Description
hsi HSI

Hyperspectral hsi object for which the explanations were created.

attributes Tensor

Attributions (explanations) for the hsi.

score float

R^2 score of interpretable model used for the explanation.

device device

Device to be used for inference. If None, the device of the input hsi will be used. Defaults to None.

attribution_method str | None

The method used to generate the explanation. Defaults to None.

band_mask Tensor

Band mask used for the explanation.

band_names dict[str | tuple[str, ...], int]

Dictionary that translates the band names into the band segment ids.

flattened_attributes Tensor

Spectral 1D attribution map.

Source code in src/meteors/attr/attributes.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
class HSISpectralAttributes(HSIAttributes):
    """Represents an hsi with spectral attributes used for explanation.

    Attributes:
        hsi (HSI): Hyperspectral hsi object for which the explanations were created.
        attributes (torch.Tensor): Attributions (explanations) for the hsi.
        score (float): R^2 score of interpretable model used for the explanation.
        device (torch.device): Device to be used for inference. If None, the device of the input hsi will be used.
            Defaults to None.
        attribution_method (str | None): The method used to generate the explanation. Defaults to None.
        band_mask (torch.Tensor): Band mask used for the explanation.
        band_names (dict[str | tuple[str, ...], int]): Dictionary that translates the band names into the band segment ids.
        flattened_attributes (torch.Tensor): Spectral 1D attribution map.
    """

    band_names: Annotated[
        dict[str | tuple[str, ...], int],
        Field(
            description="Dictionary that translates the band names into the band segment ids.",
        ),
    ]

    @property
    def band_mask(self) -> torch.Tensor:
        """Returns a 1D band mask - a band mask with removed repeated dimensions (num_bands, ),
        where num_bands is the number of bands in the hsi image.

        The method selects the appropriate dimensions from the `band_mask` tensor
        based on the `axis_to_select` and returns a flattened version of the selected
        tensor.

        Returns:
            torch.Tensor: The flattened band mask tensor.

        Examples:
            >>> band_names = {"R": 0, "G": 1, "B": 2}
            >>> attrs = HSISpectralAttributes(hsi, attributes, score=0.5, mask=band_mask)
            >>> attrs.flattened_band_mask
            torch.tensor([0, 1, 2])
        """
        if self.mask is None:
            raise ValueError("Band mask is not provided")
        axis_to_select = [i for i in range(self.hsi.image.ndim) if i != self.hsi.spectral_axis]
        return self.mask.select(dim=axis_to_select[0], index=0).select(dim=axis_to_select[1] - 1, index=0)

    @property
    def flattened_attributes(self) -> torch.Tensor:
        """Returns a flattened tensor of attributes with removed repeated dimensions.

        In the case of spectral attributes, the flattened attributes are 1D tensor of shape (num_bands, ), where num_bands is the number of bands in the hsi image.

        Returns:
            torch.Tensor: A flattened tensor of attributes.
        """
        axis = [i for i in range(self.attributes.ndim) if i != self.hsi.spectral_axis]
        return self.attributes.select(dim=axis[0], index=0).select(dim=axis[1] - 1, index=0)

    def _validate_hsi_attributions_and_mask(self) -> None:
        """Validates the hsi attributions and performs necessary operations to ensure compatibility with the device.

        Raises:
            HSIAttributesError: If the band mask is not provided in the attributes object
        """
        super()._validate_hsi_attributions_and_mask()
        if self.mask is None:
            raise HSIAttributesError("Band mask is not provided in the attributes object")

        self.band_names = align_band_names_with_mask(self.band_names, self.mask)

band_mask: torch.Tensor property

Returns a 1D band mask - a band mask with removed repeated dimensions (num_bands, ), where num_bands is the number of bands in the hsi image.

The method selects the appropriate dimensions from the band_mask tensor based on the axis_to_select and returns a flattened version of the selected tensor.

Returns:

Type Description
Tensor

torch.Tensor: The flattened band mask tensor.

Examples:

>>> band_names = {"R": 0, "G": 1, "B": 2}
>>> attrs = HSISpectralAttributes(hsi, attributes, score=0.5, mask=band_mask)
>>> attrs.flattened_band_mask
torch.tensor([0, 1, 2])

flattened_attributes: torch.Tensor property

Returns a flattened tensor of attributes with removed repeated dimensions.

In the case of spectral attributes, the flattened attributes are 1D tensor of shape (num_bands, ), where num_bands is the number of bands in the hsi image.

Returns:

Type Description
Tensor

torch.Tensor: A flattened tensor of attributes.

Lime

Bases: Explainer

Lime class is a subclass of Explainer and represents the Lime explainer. Lime is an interpretable model-agnostic explanation method that explains the predictions of a black-box model by approximating it with a simpler interpretable model. The Lime method is based on the captum implementation and is an implementation of an idea coming from the original paper on Lime, where more details about this method can be found.

Parameters:

Name Type Description Default
explainable_model ExplainableModel

The explainable model to be explained.

required
interpretable_model InterpretableModel

The interpretable model used to approximate the black-box model. Defaults to SkLearnLasso with alpha parameter set to 0.08.

SkLearnLasso(alpha=0.08)
similarity_func Callable[[Tensor], Tensor] | None

The similarity function used by Lime. Defaults to None.

None
perturb_func Callable[[Tensor], Tensor] | None

The perturbation function used by Lime. Defaults to None.

None
Source code in src/meteors/attr/lime.py
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
class Lime(Explainer):
    """Lime class is a subclass of Explainer and represents the Lime explainer. Lime is an interpretable model-agnostic
    explanation method that explains the predictions of a black-box model by approximating it with a simpler
    interpretable model. The Lime method is based on the [`captum` implementation](https://captum.ai/api/lime.html)
    and is an implementation of an idea coming from the [original paper on Lime](https://arxiv.org/abs/1602.04938),
    where more details about this method can be found.

    Args:
        explainable_model (ExplainableModel): The explainable model to be explained.
        interpretable_model (InterpretableModel): The interpretable model used to approximate the black-box model.
            Defaults to `SkLearnLasso` with alpha parameter set to 0.08.
        similarity_func (Callable[[torch.Tensor], torch.Tensor] | None, optional): The similarity function used by Lime.
            Defaults to None.
        perturb_func (Callable[[torch.Tensor], torch.Tensor] | None, optional): The perturbation function used by Lime.
            Defaults to None.
    """

    def __init__(
        self,
        explainable_model: ExplainableModel,
        interpretable_model: InterpretableModel = SkLearnLasso(alpha=0.08),
        similarity_func: Callable[[torch.Tensor], torch.Tensor] | None = None,
        perturb_func: Callable[[torch.Tensor], torch.Tensor] | None = None,
    ):
        super().__init__(explainable_model)
        self.interpretable_model = interpretable_model
        self._attribution_method: LimeBase = self._construct_lime(
            self.explainable_model.forward_func, interpretable_model, similarity_func, perturb_func
        )

    @staticmethod
    def _construct_lime(
        forward_func: Callable[[torch.Tensor], torch.Tensor],
        interpretable_model: InterpretableModel,
        similarity_func: Callable | None,
        perturb_func: Callable[[torch.Tensor], torch.Tensor] | None,
    ) -> LimeBase:
        """Constructs the LimeBase object.

        Args:
            forward_func (Callable[[torch.Tensor], torch.Tensor]): The forward function of the explainable model.
            interpretable_model (InterpretableModel): The interpretable model used to approximate the black-box model.
            similarity_func (Callable | None): The similarity function used by Lime.
            perturb_func (Callable[[torch.Tensor], torch.Tensor] | None): The perturbation function used by Lime.

        Returns:
            LimeBase: The constructed LimeBase object.
        """
        return LimeBase(
            forward_func=forward_func,
            interpretable_model=interpretable_model,
            similarity_func=similarity_func,
            perturb_func=perturb_func,
        )

    @staticmethod
    def get_segmentation_mask(
        hsi: HSI,
        segmentation_method: Literal["patch", "slic"] = "slic",
        **segmentation_method_params: Any,
    ) -> torch.Tensor:
        """Generates a segmentation mask for the given hsi using the specified segmentation method.

        Args:
            hsi (HSI): The input hyperspectral image for which the segmentation mask needs to be generated.
            segmentation_method (Literal["patch", "slic"], optional): The segmentation method to be used.
                Defaults to "slic".
            **segmentation_method_params (Any): Additional parameters specific to the chosen segmentation method.

        Returns:
            torch.Tensor: The segmentation mask as a tensor.

        Raises:
            TypeError: If the input hsi is not an instance of the HSI class.
            ValueError: If an unsupported segmentation method is specified.

        Examples:
            >>> hsi = meteors.HSI(image=torch.ones((3, 240, 240)), wavelengths=[462.08, 465.27, 468.47])
            >>> segmentation_mask = mt_lime.Lime.get_segmentation_mask(hsi, segmentation_method="slic")
            >>> segmentation_mask.shape
            torch.Size([1, 240, 240])
            >>> segmentation_mask = meteors.attr.Lime.get_segmentation_mask(hsi, segmentation_method="patch", patch_size=2)
            >>> segmentation_mask.shape
            torch.Size([1, 240, 240])
            >>> segmentation_mask[0, :2, :2]
            torch.tensor([[1, 1],
                          [1, 1]])
            >>> segmentation_mask[0, 2:4, :2]
            torch.tensor([[2, 2],
                          [2, 2]])
        """
        if not isinstance(hsi, HSI):
            raise TypeError("hsi should be an instance of HSI class")

        try:
            if segmentation_method == "slic":
                return Lime._get_slic_segmentation_mask(hsi, **segmentation_method_params)
            elif segmentation_method == "patch":
                return Lime._get_patch_segmentation_mask(hsi, **segmentation_method_params)
            else:
                raise ValueError(f"Unsupported segmentation method: {segmentation_method}")
        except Exception as e:
            raise MaskCreationError(f"Error creating segmentation mask using method {segmentation_method}: {e}")

    @staticmethod
    def get_band_mask(
        hsi: HSI,
        band_names: None | list[str | list[str]] | dict[tuple[str, ...] | str, int] = None,
        band_indices: None | dict[str | tuple[str, ...], ListOfWavelengthsIndices] = None,
        band_wavelengths: None | dict[str | tuple[str, ...], ListOfWavelengths] = None,
        device: str | torch.device | None = None,
        repeat_dimensions: bool = False,
    ) -> tuple[torch.Tensor, dict[tuple[str, ...] | str, int]]:
        """Generates a band mask based on the provided hsi and band information.

        Remember you need to provide either band_names, band_indices, or band_wavelengths to create the band mask.
        If you provide more than one, the band mask will be created using only one using the following priority:
        band_names > band_wavelengths > band_indices.

        Args:
            hsi (HSI): The input hyperspectral image.
            band_names (None | list[str | list[str]] | dict[tuple[str, ...] | str, int], optional):
                The names of the spectral bands to include in the mask. Defaults to None.
            band_indices (None | dict[str | tuple[str, ...], list[tuple[int, int]] | tuple[int, int] | list[int]], optional):
                The indices or ranges of indices of the spectral bands to include in the mask. Defaults to None.
            band_wavelengths (None | dict[str | tuple[str, ...], list[tuple[float, float]] | tuple[float, float], list[float], float], optional):
                The wavelengths or ranges of wavelengths of the spectral bands to include in the mask. Defaults to None.
            device (str | torch.device | None, optional):
                The device to use for computation. Defaults to None.
            repeat_dimensions (bool, optional):
                Whether to repeat the dimensions of the mask to match the input hsi shape. Defaults to False.

        Returns:
            tuple[torch.Tensor, dict[tuple[str, ...] | str, int]]: A tuple containing the band mask tensor and a dictionary
            mapping band names to segment IDs.

        Raises:
            TypeError: If the input hsi is not an instance of the HSI class.
            ValueError: If no band names, indices, or wavelengths are provided.

        Examples:
            >>> hsi = mt.HSI(image=torch.ones((len(wavelengths), 10, 10)), wavelengths=wavelengths)
            >>> band_names = ["R", "G"]
            >>> band_mask, dict_labels_to_segment_ids = mt_lime.Lime.get_band_mask(hsi, band_names=band_names)
            >>> dict_labels_to_segment_ids
            {"R": 1, "G": 2}
            >>> band_indices = {"RGB": [0, 1, 2]}
            >>> band_mask, dict_labels_to_segment_ids = mt_lime.Lime.get_band_mask(hsi, band_indices=band_indices)
            >>> dict_labels_to_segment_ids
            {"RGB": 1}
            >>> band_wavelengths = {"RGB": [(462.08, 465.27), (465.27, 468.47), (468.47, 471.68)]}
            >>> band_mask, dict_labels_to_segment_ids = mt_lime.Lime.get_band_mask(hsi, band_wavelengths=band_wavelengths)
            >>> dict_labels_to_segment_ids
            {"RGB": 1}
        """
        if not isinstance(hsi, HSI):
            raise TypeError("hsi should be an instance of HSI class")

        try:
            if not (band_names is not None or band_indices is not None or band_wavelengths is not None):
                raise ValueError("No band names, indices, or wavelengths are provided.")

            # validate types
            dict_labels_to_segment_ids = None
            if band_names is not None:
                logger.debug("Getting band mask from band names of spectral bands")
                if band_wavelengths is not None or band_indices is not None:
                    ignored_params = [
                        param
                        for param in ["band_wavelengths", "band_indices"]
                        if param in locals() and locals()[param] is not None
                    ]
                    ignored_params_str = " and ".join(ignored_params)
                    logger.info(
                        f"Only the band names will be used to create the band mask. The additional parameters {ignored_params_str} will be ignored."
                    )
                try:
                    validate_band_names(band_names)
                    band_groups, dict_labels_to_segment_ids = Lime._get_band_wavelengths_indices_from_band_names(
                        hsi.wavelengths, band_names
                    )
                except Exception as e:
                    raise BandSelectionError(f"Incorrect band names provided: {e}") from e
            elif band_wavelengths is not None:
                logger.debug("Getting band mask from band groups given by ranges of wavelengths")
                if band_indices is not None:
                    logger.info(
                        "Only the band wavelengths will be used to create the band mask. The band_indices will be ignored."
                    )
                validate_band_format(band_wavelengths, variable_name="band_wavelengths")
                try:
                    band_groups = Lime._get_band_indices_from_band_wavelengths(
                        hsi.wavelengths,
                        band_wavelengths,
                    )
                except Exception as e:
                    raise ValueError(
                        f"Incorrect band ranges wavelengths provided, please check if provided wavelengths are correct: {e}"
                    ) from e
            elif band_indices is not None:
                logger.debug("Getting band mask from band groups given by ranges of indices")
                validate_band_format(band_indices, variable_name="band_indices")
                try:
                    band_groups = Lime._get_band_indices_from_input_band_indices(hsi.wavelengths, band_indices)
                except Exception as e:
                    raise ValueError(
                        f"Incorrect band ranges indices provided, please check if provided indices are correct: {e}"
                    ) from e

            return Lime._create_tensor_band_mask(
                hsi,
                band_groups,
                dict_labels_to_segment_ids=dict_labels_to_segment_ids,
                device=device,
                repeat_dimensions=repeat_dimensions,
                return_dict_labels_to_segment_ids=True,
            )
        except Exception as e:
            raise MaskCreationError(f"Error creating band mask: {e}") from e

    @staticmethod
    def _make_band_names_indexable(segment_name: list[str] | tuple[str, ...] | str) -> tuple[str, ...] | str:
        """Converts a list of strings into a tuple of strings if necessary to make it indexable.

        Args:
            segment_name (list[str] | tuple[str, ...] | str): The segment name to be converted.

        Returns:
            tuple[str, ...] | str: The converted segment name.

        Raises:
            TypeError: If the segment_name is not of type list or string.
        """
        if (
            isinstance(segment_name, tuple) and all(isinstance(subitem, str) for subitem in segment_name)
        ) or isinstance(segment_name, str):
            return segment_name
        elif isinstance(segment_name, list) and all(isinstance(subitem, str) for subitem in segment_name):
            return tuple(segment_name)
        raise TypeError(f"Incorrect segment {segment_name} type. Should be either a list or string")

    @staticmethod
    # @lru_cache(maxsize=32) Can't use with lists as they are not hashable
    def _extract_bands_from_spyndex(segment_name: list[str] | tuple[str, ...] | str) -> tuple[str, ...] | str:
        """Extracts bands from the given segment name.

        Args:
            segment_name (list[str] | tuple[str, ...] | str): The name of the segment.
                Users may pass either band names or indices names, as in the spyndex library.

        Returns:
            tuple[str, ...] | str: A tuple of band names if multiple bands are extracted,
                or a single band name if only one band is extracted.

        Raises:
            BandSelectionError: If the provided band name is invalid.
                The band name must be either in `spyndex.indices` or `spyndex.bands`.
        """
        if isinstance(segment_name, str):
            segment_name = (segment_name,)
        elif isinstance(segment_name, list):
            segment_name = tuple(segment_name)

        band_names_segment: list[str] = []
        for band_name in segment_name:
            if band_name in spyndex.indices:
                band_names_segment += list(spyndex.indices[band_name].bands)
            elif band_name in spyndex.bands:
                band_names_segment.append(band_name)
            else:
                raise BandSelectionError(
                    f"Invalid band name {band_name}, band name must be either in `spyndex.indices` or `spyndex.bands`"
                )

        return tuple(set(band_names_segment)) if len(band_names_segment) > 1 else band_names_segment[0]

    @staticmethod
    def _get_indices_from_wavelength_indices_range(
        wavelengths: torch.Tensor, ranges: list[tuple[int, int]] | tuple[int, int]
    ) -> list[int]:
        """Converts wavelength indices ranges to list indices.

        Args:
            wavelengths (torch.Tensor): The tensor containing the wavelengths.
            ranges (list[tuple[int, int]] | tuple[int, int]): The wavelength indices ranges.

        Returns:
            list[int]: The indices of bands corresponding to the wavelength indices ranges.
        """
        validated_ranges_list = validate_segment_format(ranges)
        validated_ranges_list = adjust_and_validate_segment_ranges(wavelengths, validated_ranges_list)

        return list(
            set(
                chain.from_iterable(
                    [list(range(int(validated_range[0]), int(validated_range[1]))) for validated_range in ranges]  # type: ignore
                )
            )
        )

    @staticmethod
    def _get_band_wavelengths_indices_from_band_names(
        wavelengths: torch.Tensor,
        band_names: list[str | list[str]] | dict[tuple[str, ...] | str, int],
    ) -> tuple[dict[tuple[str, ...] | str, list[int]], dict[tuple[str, ...] | str, int]]:
        """Extracts band wavelengths indices from the given band names.

        This function takes a list or dictionary of band names or segments and extracts the list of wavelengths indices
        associated with each segment. It returns a tuple containing a dictionary with mapping segment labels into
        wavelength indices and a dictionary mapping segment labels into segment ids.

        Args:
            wavelengths (torch.Tensor): The tensor containing the wavelengths.
            band_names (list[str | list[str]] | dict[tuple[str, ...] | str, int]):
                A list or dictionary with band names or segments.

        Returns:
            tuple[dict[tuple[str, ...] | str, list[int]], dict[tuple[str, ...] | str, int]]:
                A tuple containing the dictionary with mapping segment labels into wavelength indices and the mapping
                from segment labels into segment ids.

        Raises:
            TypeError: If the band names are not in the correct format.
        """
        if isinstance(band_names, str):
            band_names = [band_names]
        if isinstance(band_names, list):
            logger.debug("band_names is a list of segments, creating a dictionary of segments")
            band_names_hashed = [Lime._make_band_names_indexable(segment) for segment in band_names]
            dict_labels_to_segment_ids = {segment: idx + 1 for idx, segment in enumerate(band_names_hashed)}
            segments_list = band_names_hashed
        elif isinstance(band_names, dict):
            dict_labels_to_segment_ids = band_names.copy()
            segments_list = tuple(band_names.keys())  # type: ignore
        else:
            raise TypeError("Incorrect band_names type. It should be a dict or a list")
        segments_list_after_mapping = [Lime._extract_bands_from_spyndex(segment) for segment in segments_list]
        band_indices: dict[tuple[str, ...] | str, list[int]] = {}
        for original_segment, segment in zip(segments_list, segments_list_after_mapping):
            segment_indices_ranges: list[tuple[int, int]] = []
            if isinstance(segment, str):
                segment = (segment,)
            for band_name in segment:
                min_wavelength = spyndex.bands[band_name].min_wavelength
                max_wavelength = spyndex.bands[band_name].max_wavelength

                if min_wavelength > wavelengths.max() or max_wavelength < wavelengths.min():
                    logger.debug(
                        f"Band {band_name} is not present in the given wavelengths. "
                        f"Band ranges from {min_wavelength} nm to {max_wavelength} nm and the HSI wavelengths "
                        f"range from {wavelengths.min():.2f} nm to {wavelengths.max():.2f} nm. The given band will be skipped"
                    )
                else:
                    segment_indices_ranges += Lime._convert_wavelengths_to_indices(
                        wavelengths,
                        (spyndex.bands[band_name].min_wavelength, spyndex.bands[band_name].max_wavelength),
                    )

            segment_list = Lime._get_indices_from_wavelength_indices_range(wavelengths, segment_indices_ranges)
            band_indices[original_segment] = segment_list
        return band_indices, dict_labels_to_segment_ids

    @staticmethod
    def _convert_wavelengths_to_indices(
        wavelengths: torch.Tensor, ranges: list[tuple[float, float]] | tuple[float, float]
    ) -> list[tuple[int, int]]:
        """Converts wavelength ranges to index ranges.

        Args:
            wavelengths (torch.Tensor): The tensor containing the wavelengths.
            ranges (list[tuple[float, float]] | tuple[float, float]): The wavelength ranges.

        Returns:
            list[tuple[int, int]]: The index ranges corresponding to the wavelength ranges.
        """
        indices = []
        if isinstance(ranges, tuple):
            ranges = [ranges]

        for start, end in ranges:
            start_idx = torch.searchsorted(wavelengths, start, side="left")
            end_idx = torch.searchsorted(wavelengths, end, side="right")
            indices.append((start_idx.item(), end_idx.item()))
        return indices

    @staticmethod
    def _get_band_indices_from_band_wavelengths(
        wavelengths: torch.Tensor,
        band_wavelengths: dict[str | tuple[str, ...], ListOfWavelengths],
    ) -> dict[str | tuple[str, ...], list[int]]:
        """Converts the ranges or list of wavelengths into indices.

        Args:
            wavelengths (torch.Tensor): The tensor containing the wavelengths.
            band_wavelengths (dict): A dictionary mapping segment labels to wavelength list or ranges.

        Returns:
            dict: A dictionary mapping segment labels to index ranges.

        Raises:
            TypeError: If band_wavelengths is not a dictionary.
        """
        if not isinstance(band_wavelengths, dict):
            raise TypeError("band_wavelengths should be a dictionary")

        band_indices: dict[str | tuple[str, ...], list[int]] = {}
        for segment_label, segment in band_wavelengths.items():
            try:
                dtype = torch_dtype_to_python_dtype(wavelengths.dtype)
                if isinstance(segment, (float, int)):
                    segment = [dtype(segment)]  # type: ignore
                if isinstance(segment, list) and all(isinstance(x, (float, int)) for x in segment):
                    segment_dtype = change_dtype_of_list(segment, dtype)
                    indices = Lime._convert_wavelengths_list_to_indices(wavelengths, segment_dtype)  # type: ignore
                else:
                    if isinstance(segment, list):
                        segment_dtype = [
                            tuple(change_dtype_of_list(list(ranges), dtype))  # type: ignore
                            for ranges in segment
                        ]
                    else:
                        segment_dtype = tuple(change_dtype_of_list(segment, dtype))

                    valid_segment_range = validate_segment_format(segment_dtype, dtype)
                    range_indices = Lime._convert_wavelengths_to_indices(wavelengths, valid_segment_range)  # type: ignore
                    valid_indices_format = validate_segment_format(range_indices)
                    valid_range_indices = adjust_and_validate_segment_ranges(wavelengths, valid_indices_format)
                    indices = Lime._get_indices_from_wavelength_indices_range(wavelengths, valid_range_indices)
            except Exception as e:
                raise ValueError(f"Problem with segment {segment_label}: {e}") from e

            band_indices[segment_label] = indices

        return band_indices

    @staticmethod
    def _convert_wavelengths_list_to_indices(wavelengths: torch.Tensor, ranges: list[float]) -> list[int]:
        """Converts a list of wavelengths into indices.

        Args:
            wavelengths (torch.Tensor): The tensor containing the wavelengths.
            ranges (list[float]): The list of wavelengths.

        Returns:
            list[int]: The indices corresponding to the wavelengths.
        """
        indices = []
        for wavelength in ranges:
            index = (wavelengths == wavelength).nonzero(as_tuple=False)
            number_of_elements = torch.numel(index)
            if number_of_elements == 1:
                indices.append(index.item())
            elif number_of_elements == 0:
                raise ValueError(f"Couldn't find wavelength of value {wavelength} in list of wavelength")
            else:
                raise ValueError(f"Wavelength of value {wavelength} was present more than once in list of wavelength")
        return indices

    @staticmethod
    def _get_band_indices_from_input_band_indices(
        wavelengths: torch.Tensor,
        input_band_indices: dict[str | tuple[str, ...], ListOfWavelengthsIndices],
    ) -> dict[str | tuple[str, ...], list[int]]:
        """Get band indices from band list or ranges indices.

        Args:
            wavelengths (torch.Tensor): The tensor containing the wavelengths.
            band_indices (dict[str | tuple[str, ...], ListOfWavelengthsIndices]):
                A dictionary mapping segment labels to a list of wavelength indices.

        Returns:
            dict[str | tuple[str, ...], list[int]]: A dictionary mapping segment labels to a list of band indices.

        Raises:
            TypeError: If `band_indices` is not a dictionary.
        """
        if not isinstance(input_band_indices, dict):
            raise TypeError("band_indices should be a dictionary")

        band_indices: dict[str | tuple[str, ...], list[int]] = {}
        for segment_label, indices in input_band_indices.items():
            try:
                if isinstance(indices, int):
                    indices = [indices]  # type: ignore
                if isinstance(indices, list) and all(isinstance(x, int) for x in indices):
                    indices: list[int] = indices  # type: ignore
                else:
                    valid_indices_format = validate_segment_format(indices)  # type: ignore
                    valid_range_indices = adjust_and_validate_segment_ranges(wavelengths, valid_indices_format)
                    indices = Lime._get_indices_from_wavelength_indices_range(wavelengths, valid_range_indices)  # type: ignore

                band_indices[segment_label] = indices  # type: ignore
            except Exception as e:
                raise ValueError(f"Problem with segment {segment_label}") from e

        return band_indices

    @staticmethod
    def _check_overlapping_segments(dict_labels_to_indices: dict[str | tuple[str, ...], list[int]]) -> None:
        """Check for overlapping segments.

        Args:
            dict_labels_to_indices (dict[str | tuple[str, ...], list[int]]):
                A dictionary mapping segment labels to indices.

        Returns:
            None
        """
        overlapping_segments: list[tuple[str | tuple[str, ...], str | tuple[str, ...]]] = []
        labels = list(dict_labels_to_indices.keys())

        for i, segment_label in enumerate(labels):
            for second_label in labels[i + 1 :]:
                indices = dict_labels_to_indices[segment_label]
                second_indices = dict_labels_to_indices[second_label]

                if set(indices) & set(second_indices):
                    overlapping_segments.append((segment_label, second_label))

        for label_first, label_second in overlapping_segments:
            label_first_str = label_first if isinstance(label_first, str) else "/".join(label_first)
            label_second_str = label_second if isinstance(label_second, str) else "/".join(label_second)

            logger.warning(
                f"Segments {label_first_str} and {label_second_str} are overlapping,"
                " overlapping wavelengths will be assigned to only one"
            )

    @staticmethod
    def _validate_and_create_dict_labels_to_segment_ids(
        dict_labels_to_segment_ids: dict[str | tuple[str, ...], int] | None,
        segment_labels: list[str | tuple[str, ...]],
    ) -> dict[str | tuple[str, ...], int]:
        """Validates and creates a dictionary mapping segment labels to segment IDs.

        Args:
            dict_labels_to_segment_ids (dict[str | tuple[str, ...], int] | None):
                The existing mapping from segment labels to segment IDs, or None if it doesn't exist.
            segment_labels (list[str | tuple[str, ...]]): The list of segment labels.

        Returns:
            dict[str | tuple[str, ...], int]: A tuple containing the validated dictionary mapping segment
            labels to segment IDs and a boolean flag indicating whether the segment labels are hashed.

        Raises:
            ValueError: If the length of `dict_labels_to_segment_ids` doesn't match the length of `segment_labels`.
            ValueError: If a segment label is not present in `dict_labels_to_segment_ids`.
            ValueError: If there are non-unique segment IDs in `dict_labels_to_segment_ids`.
        """
        if dict_labels_to_segment_ids is None:
            logger.debug("Creating mapping from segment labels into ids")
            return {segment: idx + 1 for idx, segment in enumerate(segment_labels)}

        logger.debug("Using existing mapping from segment labels into segment ids")

        if len(dict_labels_to_segment_ids) != len(segment_labels):
            raise ValueError(
                (
                    f"Incorrect dict_labels_to_segment_ids - length mismatch. Expected: "
                    f"{len(segment_labels)}, Actual: {len(dict_labels_to_segment_ids)}"
                )
            )

        unique_segment_ids = set(dict_labels_to_segment_ids.values())
        if len(unique_segment_ids) != len(segment_labels):
            raise ValueError("Non unique segment ids in the dict_labels_to_segment_ids")

        logger.debug("Passed mapping is correct")
        return dict_labels_to_segment_ids

    @staticmethod
    def _create_single_dim_band_mask(
        hsi: HSI,
        dict_labels_to_indices: dict[str | tuple[str, ...], list[int]],
        dict_labels_to_segment_ids: dict[str | tuple[str, ...], int],
        device: torch.device,
    ) -> torch.Tensor:
        """Create a one-dimensional band mask based on the given image, labels, and segment IDs.

        Args:
            hsi (HSI): The input hsi.
            dict_labels_to_indices (dict[str | tuple[str, ...], list[int]]):
                A dictionary mapping labels or label tuples to lists of indices.
            dict_labels_to_segment_ids (dict[str | tuple[str, ...], int]):
                A dictionary mapping labels or label tuples to segment IDs.
            device (torch.device): The device to use for the tensor.

        Returns:
            torch.Tensor: The one-dimensional band mask tensor.

        Raises:
            ValueError: If the indices for a segment are out of bounds for the one-dimensional band mask.
        """
        band_mask_single_dim = torch.zeros(len(hsi.wavelengths), dtype=torch.int64, device=device)

        segment_labels = list(dict_labels_to_segment_ids.keys())

        for segment_label in segment_labels[::-1]:
            segment_indices = dict_labels_to_indices[segment_label]
            segment_id = dict_labels_to_segment_ids[segment_label]
            are_indices_valid = all(0 <= idx < band_mask_single_dim.shape[0] for idx in segment_indices)
            if not are_indices_valid:
                raise ValueError(
                    (
                        f"Indices for segment {segment_label} are out of bounds for the one-dimensional band mask"
                        f"of shape {band_mask_single_dim.shape}"
                    )
                )
            band_mask_single_dim[segment_indices] = segment_id

        return band_mask_single_dim

    @staticmethod
    def _create_tensor_band_mask(
        hsi: HSI,
        dict_labels_to_indices: dict[str | tuple[str, ...], list[int]],
        dict_labels_to_segment_ids: dict[str | tuple[str, ...], int] | None = None,
        device: str | torch.device | None = None,
        repeat_dimensions: bool = False,
        return_dict_labels_to_segment_ids: bool = True,
    ) -> torch.Tensor | tuple[torch.Tensor, dict[tuple[str, ...] | str, int]]:
        """Create a tensor band mask from dictionaries. The band mask is created based on the given hsi, labels, and
        segment IDs. The band mask is a tensor with the same shape as the input hsi and contains segment IDs, where each
        segment is represented by a unique ID. The band mask will be used to attribute the hsi using the LIME method.

        Args:
            hsi (HSI): The input hsi.
            dict_labels_to_indices (dict[str | tuple[str, ...], list[int]]): A dictionary mapping labels to indices.
            dict_labels_to_segment_ids (dict[str | tuple[str, ...], int] | None, optional):
                A dictionary mapping labels to segment IDs. Defaults to None.
            device (str | torch.device | None, optional): The device to use. Defaults to None.
            repeat_dimensions (bool, optional): Whether to repeat dimensions. Defaults to False.
            return_dict_labels_to_segment_ids (bool, optional):
                Whether to return the dictionary mapping labels to segment IDs. Defaults to True.

        Returns:
            torch.Tensor | tuple[torch.Tensor, dict[tuple[str, ...] | str, int]]:
                The tensor band mask or a tuple containing the tensor band mask
                and the dictionary mapping labels to segment IDs.
        """
        if device is None:
            device = hsi.device
        segment_labels = list(dict_labels_to_indices.keys())

        logger.debug(f"Creating a band mask on the device {device} using {len(segment_labels)} segments")

        # Check for overlapping segments
        Lime._check_overlapping_segments(dict_labels_to_indices)

        # Create or validate dict_labels_to_segment_ids
        dict_labels_to_segment_ids = Lime._validate_and_create_dict_labels_to_segment_ids(
            dict_labels_to_segment_ids, segment_labels
        )

        # Create single-dimensional band mask
        band_mask_single_dim = Lime._create_single_dim_band_mask(
            hsi, dict_labels_to_indices, dict_labels_to_segment_ids, device
        )

        # Expand band mask to match image dimensions
        band_mask = expand_spectral_mask(hsi, band_mask_single_dim, repeat_dimensions)

        if return_dict_labels_to_segment_ids:
            return band_mask, dict_labels_to_segment_ids
        return band_mask

    def attribute(  # type: ignore
        self,
        hsi: list[HSI] | HSI,
        target: list[int] | int | None = None,
        attribution_type: Literal["spatial", "spectral"] | None = None,
        additional_forward_args: Any = None,
        **kwargs: Any,
    ) -> HSISpatialAttributes | HSISpectralAttributes | list[HSISpatialAttributes] | list[HSISpectralAttributes]:
        """A wrapper function to attribute the image using the LIME method. It executes either the
        `get_spatial_attributes` or `get_spectral_attributes` method based on the provided `attribution_type`. For more
        detailed description of the methods, please refer to the respective method documentation.

        Args:
            hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
                If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
                The output will be a list of HSISpatialAttributes or HSISpectralAttributes objects.
            target (list[int] | int | None, optional): target class index for computing the attributions. If None,
                methods assume that the output has only one class. If the output has multiple classes, the target index
                must be provided. For multiple input images, a list of target indices can be provided, one for each
                image or single target value will be used for all images. Defaults to None.
            attribution_type (Literal["spatial", "spectral"] | None, optional): The type of attribution to be computed.
                User can compute spatial or spectral attributions with the LIME method. If None, the method will
                throw an error. Defaults to None.
            additional_forward_args (Any, optional): If the forward function requires additional arguments other than
                the inputs for which attributions should not be computed, this argument can be provided.
                It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors or any arbitrary python types.
                These arguments are provided to forward_func in order following the arguments in inputs.
                Note that attributions are not computed with respect to these arguments. Default: None
            kwargs (Any): Additional keyword arguments for the LIME method.

        Returns:
            HSISpectralAttributes | HSISpatialAttributes | list[HSISpectralAttributes | HSISpatialAttributes]:
                The computed attributions Spectral or Spatial for the input hyperspectral image(s).
                if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

        Raises:
            RuntimeError: If the Lime object is not initialized or is not an instance of LimeBase.
            ValueError: If number of HSI images is not equal to the number of masks provided.

        Examples:
            >>> simple_model = lambda x: torch.rand((x.shape[0], 2))
            >>> hsi = mt.HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
            >>> segmentation_mask = torch.randint(1, 4, (1, 240, 240))
            >>> lime = meteors.attr.Lime(
                    explainable_model=ExplainableModel(simple_model, "regression"), interpretable_model=SkLearnLasso(alpha=0.1)
                )
            >>> spatial_attribution = lime.attribute(hsi, segmentation_mask=segmentation_mask, target=0, attribution_type="spatial")
            >>> spatial_attribution.hsi
            HSI(shape=(4, 240, 240), dtype=torch.float32)
            >>> band_mask = torch.randint(1, 4, (4, 1, 1)).repeat(1, 240, 240)
            >>> band_names = ["R", "G", "B"]
            >>> spectral_attribution = lime.attribute(
            ...     hsi, band_mask=band_mask, band_names=band_names, target=0, attribution_type="spectral"
            ... )
            >>> spectral_attribution.hsi
            HSI(shape=(4, 240, 240), dtype=torch.float32)
        """
        if attribution_type == "spatial":
            return self.get_spatial_attributes(
                hsi, target=target, additional_forward_args=additional_forward_args, **kwargs
            )
        elif attribution_type == "spectral":
            return self.get_spectral_attributes(
                hsi, target=target, additional_forward_args=additional_forward_args, **kwargs
            )
        raise ValueError(f"Unsupported attribution type: {attribution_type}. Use 'spatial' or 'spectral'")

    def get_spatial_attributes(
        self,
        hsi: list[HSI] | HSI,
        segmentation_mask: np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor] | None = None,
        target: list[int] | int | None = None,
        n_samples: int = 10,
        perturbations_per_eval: int = 4,
        verbose: bool = False,
        segmentation_method: Literal["slic", "patch"] = "slic",
        additional_forward_args: Any = None,
        **segmentation_method_params: Any,
    ) -> list[HSISpatialAttributes] | HSISpatialAttributes:
        """
        Get spatial attributes of an hsi image using the LIME method. Based on the provided hsi and segmentation mask
        LIME method attributes the `superpixels` provided by the segmentation mask. Please refer to the original paper
        `https://arxiv.org/abs/1602.04938` for more details or to Christoph Molnar's book
        `https://christophm.github.io/interpretable-ml-book/lime.html`.

        This function attributes the hyperspectral image using the LIME (Local Interpretable Model-Agnostic Explanations)
        method for spatial data. It returns an `HSISpatialAttributes` object that contains the hyperspectral image,,
        the attributions, the segmentation mask, and the score of the interpretable model used for the explanation.

        Args:
            hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
                If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
                The output will be a list of HSISpatialAttributes objects.
            segmentation_mask (np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor] | None, optional):
                A segmentation mask according to which the attribution should be performed.
                The segmentation mask should have a 2D or 3D shape, which can be broadcastable to the shape of the
                input image. The only dimension on which the image and the mask shapes can differ is the spectral
                dimension, marked with letter `C` in the `image.orientation` parameter. If None, a new segmentation mask
                is created using the `segmentation_method`. Additional parameters for the segmentation method may be
                passed as kwargs. If multiple HSI images are provided, a list of segmentation masks can be provided,
                one for each image. If list is not provided method will assume that the same segmentation mask is used
                    for all images. Defaults to None.
            target (list[int] | int | None, optional): target class index for computing the attributions. If None,
                methods assume that the output has only one class. If the output has multiple classes, the target index
                must be provided. For multiple input images, a list of target indices can be provided, one for each
                image or single target value will be used for all images. Defaults to None.
            n_samples (int, optional): The number of samples to generate/analyze in LIME. The more the better but slower.
                Defaults to 10.
            perturbations_per_eval (int, optional): The number of perturbations to evaluate at once
                (Simply the inner batch size). Defaults to 4.
            verbose (bool, optional): Whether to show the progress bar. Defaults to False.
            segmentation_method (Literal["slic", "patch"], optional):
                Segmentation method used only if `segmentation_mask` is None. Defaults to "slic".
            additional_forward_args (Any, optional): If the forward function requires additional arguments other than
                the inputs for which attributions should not be computed, this argument can be provided.
                It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors or any arbitrary python types.
                These arguments are provided to forward_func in order following the arguments in inputs.
                Note that attributions are not computed with respect to these arguments. Default: None
            **segmentation_method_params (Any): Additional parameters for the segmentation method.

        Returns:
            HSISpatialAttributes | list[HSISpatialAttributes]: An object containing the image, the attributions,
                the segmentation mask, and the score of the interpretable model used for the explanation.

        Raises:
            RuntimeError: If the Lime object is not initialized or is not an instance of LimeBase.
            MaskCreationError: If there is an error creating the segmentation mask.
            ValueError: If the number of segmentation masks is not equal to the number of HSI images provided.
            HSIAttributesError: If there is an error during creating spatial attribution.

        Examples:
            >>> simple_model = lambda x: torch.rand((x.shape[0], 2))
            >>> hsi = mt.HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
            >>> segmentation_mask = torch.randint(1, 4, (1, 240, 240))
            >>> lime = meteors.attr.Lime(
                    explainable_model=ExplainableModel(simple_model, "regression"), interpretable_model=SkLearnLasso(alpha=0.1)
                )
            >>> spatial_attribution = lime.get_spatial_attributes(hsi, segmentation_mask=segmentation_mask, target=0)
            >>> spatial_attribution.hsi
            HSI(shape=(4, 240, 240), dtype=torch.float32)
            >>> spatial_attribution.attributes.shape
            torch.Size([4, 240, 240])
            >>> spatial_attribution.segmentation_mask.shape
            torch.Size([1, 240, 240])
            >>> spatial_attribution.score
            1.0
        """
        if self._attribution_method is None or not isinstance(self._attribution_method, LimeBase):
            raise RuntimeError("Lime object not initialized")  # pragma: no cover

        if isinstance(hsi, HSI):
            hsi = [hsi]

        if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
            raise TypeError("All of the input hyperspectral images must be of type HSI")

        if segmentation_mask is None:
            segmentation_mask = self.get_segmentation_mask(hsi[0], segmentation_method, **segmentation_method_params)

            logger.warning(
                "Segmentation mask is created based on the first HSI image provided, this approach may not be optimal as "
                "the same segmentation mask may not be the best suitable for all images",
            )

        if isinstance(segmentation_mask, tuple):
            segmentation_mask = tuple(segmentation_mask)
        elif not isinstance(segmentation_mask, list):
            segmentation_mask = [segmentation_mask] * len(hsi)

        if len(hsi) != len(segmentation_mask):
            raise ValueError(
                f"Number of segmentation masks should be equal to the number of HSI images provided, provided {len(segmentation_mask)}"
            )

        segmentation_mask = [
            ensure_torch_tensor(mask, f"Segmentation mask number {idx+1} should be None, numpy array, or torch tensor")
            for idx, mask in enumerate(segmentation_mask)
        ]
        segmentation_mask = [
            mask.unsqueeze(0).moveaxis(0, hsi_img.spectral_axis) if mask.ndim != hsi_img.image.ndim else mask
            for hsi_img, mask in zip(hsi, segmentation_mask)
        ]
        segmentation_mask = [
            validate_mask_shape("segmentation", hsi_img, mask) for hsi_img, mask in zip(hsi, segmentation_mask)
        ]

        hsi_input = torch.stack([hsi_img.get_image() for hsi_img in hsi], dim=0)
        segmentation_mask = torch.stack(segmentation_mask, dim=0)

        assert segmentation_mask.shape == hsi_input.shape

        segmentation_mask = segmentation_mask.to(self.device)
        hsi_input = hsi_input.to(self.device)

        lime_attributes, score = self._attribution_method.attribute(
            inputs=hsi_input,
            target=target,
            feature_mask=segmentation_mask,
            n_samples=n_samples,
            perturbations_per_eval=perturbations_per_eval,
            additional_forward_args=additional_forward_args,
            show_progress=verbose,
            return_input_shape=True,
        )

        try:
            spatial_attribution = [
                HSISpatialAttributes(
                    hsi=hsi_img,
                    attributes=lime_attr,
                    mask=segmentation_mask[idx].expand_as(hsi_img.image),
                    score=score.item(),
                    attribution_method="Lime",
                )
                for idx, (hsi_img, lime_attr) in enumerate(zip(hsi, lime_attributes))
            ]
        except Exception as e:
            raise HSIAttributesError(f"Error during creating spatial attribution {e}") from e

        return spatial_attribution[0] if len(spatial_attribution) == 1 else spatial_attribution

    def get_spectral_attributes(
        self,
        hsi: list[HSI] | HSI,
        band_mask: np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor] | None = None,
        target: list[int] | int | None = None,
        n_samples: int = 10,
        perturbations_per_eval: int = 4,
        verbose: bool = False,
        additional_forward_args: Any = None,
        band_names: list[str | list[str]] | dict[tuple[str, ...] | str, int] | None = None,
    ) -> HSISpectralAttributes | list[HSISpectralAttributes]:
        """
        Attributes the hsi image using LIME method for spectral data. Based on the provided hsi and band mask, the LIME
        method attributes the hsi based on `superbands` (clustered bands) provided by the band mask.
        Please refer to the original paper `https://arxiv.org/abs/1602.04938` for more details or to
        Christoph Molnar's book `https://christophm.github.io/interpretable-ml-book/lime.html`.

        The function returns a HSISpectralAttributes object that contains the image, the attributions, the band mask,
        the band names, and the score of the interpretable model used for the explanation.

        Args:
            hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
                If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
                The output will be a list of HSISpatialAttributes objects.
            band_mask (np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor] | None, optional): Band mask that
                is used for the spectral attribution. The band mask should have a 1D or 3D shape, which can be
                broadcastable to the shape of the input image. The only dimensions on which the image and the mask shapes
                can differ is the height and width dimensions, marked with letters `H` and `W` in the `image.orientation`
                parameter. If equals to None, the band mask is created within the function. If multiple HSI images are
                provided, a list of band masks can be provided, one for each image. If list is not provided method will
                assume that the same band mask is used for all images. Defaults to None.
            target (list[int] | int | None, optional): target class index for computing the attributions. If None,
                methods assume that the output has only one class. If the output has multiple classes, the target index
                must be provided. For multiple input images, a list of target indices can be provided, one for each
                image or single target value will be used for all images. Defaults to None.
            n_samples (int, optional): The number of samples to generate/analyze in LIME. The more the better but slower.
                Defaults to 10.
            perturbations_per_eval (int, optional): The number of perturbations to evaluate at once
                (Simply the inner batch size). Defaults to 4.
            verbose (bool, optional): Whether to show the progress bar. Defaults to False.
            segmentation_method (Literal["slic", "patch"], optional):
                Segmentation method used only if `segmentation_mask` is None. Defaults to "slic".
            additional_forward_args (Any, optional): If the forward function requires additional arguments other than
                the inputs for which attributions should not be computed, this argument can be provided.
                It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors or any arbitrary python types.
                These arguments are provided to forward_func in order following the arguments in inputs.
                Note that attributions are not computed with respect to these arguments. Default: None
            band_names (list[str] | dict[str | tuple[str, ...], int] | None, optional): Band names. Defaults to None.

        Returns:
            HSISpectralAttributes | list[HSISpectralAttributes]: An object containing the image, the attributions,
                the band mask, the band names, and the score of the interpretable model used for the explanation.

        Raises:
            RuntimeError: If the Lime object is not initialized or is not an instance of LimeBase.
            MaskCreationError: If there is an error creating the band mask.
            ValueError: If the number of band masks is not equal to the number of HSI images provided.
            HSIAttributesError: If there is an error during creating spectral attribution.

        Examples:
            >>> simple_model = lambda x: torch.rand((x.shape[0], 2))
            >>> hsi = mt.HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
            >>> band_mask = torch.randint(1, 4, (4, 1, 1)).repeat(1, 240, 240)
            >>> band_names = ["R", "G", "B"]
            >>> lime = meteors.attr.Lime(
                    explainable_model=ExplainableModel(simple_model, "regression"), interpretable_model=SkLearnLasso(alpha=0.1)
                )
            >>> spectral_attribution = lime.get_spectral_attributes(hsi, band_mask=band_mask, band_names=band_names, target=0)
            >>> spectral_attribution.hsi
            HSI(shape=(4, 240, 240), dtype=torch.float32)
            >>> spectral_attribution.attributes.shape
            torch.Size([4, 240, 240])
            >>> spectral_attribution.band_mask.shape
            torch.Size([4, 240, 240])
            >>> spectral_attribution.band_names
            ["R", "G", "B"]
            >>> spectral_attribution.score
            1.0
        """

        if self._attribution_method is None or not isinstance(self._attribution_method, LimeBase):
            raise RuntimeError("Lime object not initialized")  # pragma: no cover

        if isinstance(hsi, HSI):
            hsi = [hsi]

        if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
            raise TypeError("All of the input hyperspectral images must be of type HSI")

        if band_mask is None:
            created_bands = [self.get_band_mask(hsi_img, band_names) for hsi_img in hsi]
            band_mask, band_name_list = zip(*created_bands)
            band_names = band_name_list[0]

        if isinstance(band_mask, tuple):
            band_mask = list(band_mask)
        elif not isinstance(band_mask, list):
            band_mask = [band_mask]

        if len(hsi) != len(band_mask):
            if len(band_mask) == 1:
                band_mask = band_mask * len(hsi)
                logger.debug("Reusing the same band mask for all images")
            else:
                raise ValueError(
                    f"Number of band masks should be equal to the number of HSI images provided, provided {len(band_mask)}"
                )

        band_mask = [
            ensure_torch_tensor(mask, f"Band mask number {idx+1} should be None, numpy array, or torch tensor")
            for idx, mask in enumerate(band_mask)
        ]
        band_mask = [
            mask.unsqueeze(-1).unsqueeze(-1).moveaxis(0, hsi_img.spectral_axis)
            if mask.ndim != hsi_img.image.ndim
            else mask
            for hsi_img, mask in zip(hsi, band_mask)
        ]
        band_mask = [validate_mask_shape("band", hsi_img, mask) for hsi_img, mask in zip(hsi, band_mask)]

        hsi_input = torch.stack([hsi_img.get_image() for hsi_img in hsi], dim=0)
        band_mask = torch.stack(band_mask, dim=0)

        if band_names is None:
            band_names = {str(segment): idx for idx, segment in enumerate(torch.unique(band_mask))}
        else:
            logger.debug(
                "Band names are provided and will be used. In the future, there should be an option to validate them."
            )

        assert hsi_input.shape == band_mask.shape

        hsi_input = hsi_input.to(self.device)
        band_mask = band_mask.to(self.device)

        lime_attributes, score = self._attribution_method.attribute(
            inputs=hsi_input,
            target=target,
            feature_mask=band_mask,
            n_samples=n_samples,
            perturbations_per_eval=perturbations_per_eval,
            additional_forward_args=additional_forward_args,
            show_progress=verbose,
            return_input_shape=True,
        )

        try:
            spectral_attribution = [
                HSISpectralAttributes(
                    hsi=hsi_img,
                    attributes=lime_attr,
                    mask=band_mask[idx].expand_as(hsi_img.image),
                    band_names=band_names,
                    score=score.item(),
                    attribution_method="Lime",
                )
                for idx, (hsi_img, lime_attr) in enumerate(zip(hsi, lime_attributes))
            ]
        except Exception as e:
            raise HSIAttributesError(f"Error during creating spectral attribution {e}") from e

        return spectral_attribution[0] if len(spectral_attribution) == 1 else spectral_attribution

    @staticmethod
    def _get_slic_segmentation_mask(
        hsi: HSI, num_interpret_features: int = 10, *args: Any, **kwargs: Any
    ) -> torch.Tensor:
        """Creates a segmentation mask using the SLIC method.

        Args:
            hsi (HSI): An HSI object for which the segmentation mask is created.
            num_interpret_features (int, optional): Number of segments. Defaults to 10.
            *args: Additional positional arguments to be passed to the SLIC method.
            **kwargs: Additional keyword arguments to be passed to the SLIC method.

        Returns:
            torch.Tensor: An output segmentation mask.
        """
        segmentation_mask = slic(
            hsi.get_image().cpu().detach().numpy(),
            n_segments=num_interpret_features,
            mask=hsi.spatial_binary_mask.cpu().detach().numpy(),
            channel_axis=hsi.spectral_axis,
            *args,
            **kwargs,
        )

        if segmentation_mask.min() == 1:
            segmentation_mask -= 1

        segmentation_mask = torch.from_numpy(segmentation_mask)
        segmentation_mask = segmentation_mask.unsqueeze(dim=hsi.spectral_axis)

        return segmentation_mask

    @staticmethod
    def _get_patch_segmentation_mask(hsi: HSI, patch_size: int | float = 10, *args: Any, **kwargs: Any) -> torch.Tensor:
        """
        Creates a segmentation mask using the patch method - creates small squares of the same size
            and assigns a unique value to each square.

        Args:
            hsi (HSI): An HSI object for which the segmentation mask is created.
            patch_size (int, optional): Size of the patch, the hsi size should be divisible by this value.
                Defaults to 10.

        Returns:
            torch.Tensor: An output segmentation mask.
        """
        if patch_size < 1 or not isinstance(patch_size, (int, float)):
            raise ValueError("Invalid patch_size. patch_size must be a positive integer")

        if hsi.image.shape[1] % patch_size != 0 or hsi.image.shape[2] % patch_size != 0:
            raise ValueError("Invalid patch_size. patch_size must be a factor of both width and height of the hsi")

        height, width = hsi.image.shape[1], hsi.image.shape[2]

        idx_mask = torch.arange(height // patch_size * width // patch_size, device=hsi.device).reshape(
            height // patch_size, width // patch_size
        )
        idx_mask += 1
        segmentation_mask = torch.repeat_interleave(idx_mask, patch_size, dim=0)
        segmentation_mask = torch.repeat_interleave(segmentation_mask, patch_size, dim=1)
        segmentation_mask = segmentation_mask * hsi.spatial_binary_mask
        # segmentation_mask = torch.repeat_interleave(
        # torch.unsqueeze(segmentation_mask, dim=hsi.spectral_axis),
        # repeats=hsi.image.shape[hsi.spectral_axis], dim=hsi.spectral_axis)
        segmentation_mask = segmentation_mask.unsqueeze(dim=hsi.spectral_axis)

        mask_idx = np.unique(segmentation_mask).tolist()
        for idx, mask_val in enumerate(mask_idx):
            segmentation_mask[segmentation_mask == mask_val] = idx

        return segmentation_mask

attribute(hsi, target=None, attribution_type=None, additional_forward_args=None, **kwargs)

A wrapper function to attribute the image using the LIME method. It executes either the get_spatial_attributes or get_spectral_attributes method based on the provided attribution_type. For more detailed description of the methods, please refer to the respective method documentation.

Parameters:

Name Type Description Default
hsi list[HSI] | HSI

Input hyperspectral image(s) for which the attributions are to be computed. If a list of HSI objects is provided, the attributions are computed for each HSI object in the list. The output will be a list of HSISpatialAttributes or HSISpectralAttributes objects.

required
target list[int] | int | None

target class index for computing the attributions. If None, methods assume that the output has only one class. If the output has multiple classes, the target index must be provided. For multiple input images, a list of target indices can be provided, one for each image or single target value will be used for all images. Defaults to None.

None
attribution_type Literal['spatial', 'spectral'] | None

The type of attribution to be computed. User can compute spatial or spectral attributions with the LIME method. If None, the method will throw an error. Defaults to None.

None
additional_forward_args Any

If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

None
kwargs Any

Additional keyword arguments for the LIME method.

{}

Returns:

Type Description
HSISpatialAttributes | HSISpectralAttributes | list[HSISpatialAttributes] | list[HSISpectralAttributes]

HSISpectralAttributes | HSISpatialAttributes | list[HSISpectralAttributes | HSISpatialAttributes]: The computed attributions Spectral or Spatial for the input hyperspectral image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

Raises:

Type Description
RuntimeError

If the Lime object is not initialized or is not an instance of LimeBase.

ValueError

If number of HSI images is not equal to the number of masks provided.

Examples:

>>> simple_model = lambda x: torch.rand((x.shape[0], 2))
>>> hsi = mt.HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
>>> segmentation_mask = torch.randint(1, 4, (1, 240, 240))
>>> lime = meteors.attr.Lime(
        explainable_model=ExplainableModel(simple_model, "regression"), interpretable_model=SkLearnLasso(alpha=0.1)
    )
>>> spatial_attribution = lime.attribute(hsi, segmentation_mask=segmentation_mask, target=0, attribution_type="spatial")
>>> spatial_attribution.hsi
HSI(shape=(4, 240, 240), dtype=torch.float32)
>>> band_mask = torch.randint(1, 4, (4, 1, 1)).repeat(1, 240, 240)
>>> band_names = ["R", "G", "B"]
>>> spectral_attribution = lime.attribute(
...     hsi, band_mask=band_mask, band_names=band_names, target=0, attribution_type="spectral"
... )
>>> spectral_attribution.hsi
HSI(shape=(4, 240, 240), dtype=torch.float32)
Source code in src/meteors/attr/lime.py
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
def attribute(  # type: ignore
    self,
    hsi: list[HSI] | HSI,
    target: list[int] | int | None = None,
    attribution_type: Literal["spatial", "spectral"] | None = None,
    additional_forward_args: Any = None,
    **kwargs: Any,
) -> HSISpatialAttributes | HSISpectralAttributes | list[HSISpatialAttributes] | list[HSISpectralAttributes]:
    """A wrapper function to attribute the image using the LIME method. It executes either the
    `get_spatial_attributes` or `get_spectral_attributes` method based on the provided `attribution_type`. For more
    detailed description of the methods, please refer to the respective method documentation.

    Args:
        hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
            If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
            The output will be a list of HSISpatialAttributes or HSISpectralAttributes objects.
        target (list[int] | int | None, optional): target class index for computing the attributions. If None,
            methods assume that the output has only one class. If the output has multiple classes, the target index
            must be provided. For multiple input images, a list of target indices can be provided, one for each
            image or single target value will be used for all images. Defaults to None.
        attribution_type (Literal["spatial", "spectral"] | None, optional): The type of attribution to be computed.
            User can compute spatial or spectral attributions with the LIME method. If None, the method will
            throw an error. Defaults to None.
        additional_forward_args (Any, optional): If the forward function requires additional arguments other than
            the inputs for which attributions should not be computed, this argument can be provided.
            It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
            containing multiple additional arguments including tensors or any arbitrary python types.
            These arguments are provided to forward_func in order following the arguments in inputs.
            Note that attributions are not computed with respect to these arguments. Default: None
        kwargs (Any): Additional keyword arguments for the LIME method.

    Returns:
        HSISpectralAttributes | HSISpatialAttributes | list[HSISpectralAttributes | HSISpatialAttributes]:
            The computed attributions Spectral or Spatial for the input hyperspectral image(s).
            if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

    Raises:
        RuntimeError: If the Lime object is not initialized or is not an instance of LimeBase.
        ValueError: If number of HSI images is not equal to the number of masks provided.

    Examples:
        >>> simple_model = lambda x: torch.rand((x.shape[0], 2))
        >>> hsi = mt.HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
        >>> segmentation_mask = torch.randint(1, 4, (1, 240, 240))
        >>> lime = meteors.attr.Lime(
                explainable_model=ExplainableModel(simple_model, "regression"), interpretable_model=SkLearnLasso(alpha=0.1)
            )
        >>> spatial_attribution = lime.attribute(hsi, segmentation_mask=segmentation_mask, target=0, attribution_type="spatial")
        >>> spatial_attribution.hsi
        HSI(shape=(4, 240, 240), dtype=torch.float32)
        >>> band_mask = torch.randint(1, 4, (4, 1, 1)).repeat(1, 240, 240)
        >>> band_names = ["R", "G", "B"]
        >>> spectral_attribution = lime.attribute(
        ...     hsi, band_mask=band_mask, band_names=band_names, target=0, attribution_type="spectral"
        ... )
        >>> spectral_attribution.hsi
        HSI(shape=(4, 240, 240), dtype=torch.float32)
    """
    if attribution_type == "spatial":
        return self.get_spatial_attributes(
            hsi, target=target, additional_forward_args=additional_forward_args, **kwargs
        )
    elif attribution_type == "spectral":
        return self.get_spectral_attributes(
            hsi, target=target, additional_forward_args=additional_forward_args, **kwargs
        )
    raise ValueError(f"Unsupported attribution type: {attribution_type}. Use 'spatial' or 'spectral'")

get_band_mask(hsi, band_names=None, band_indices=None, band_wavelengths=None, device=None, repeat_dimensions=False) staticmethod

Generates a band mask based on the provided hsi and band information.

Remember you need to provide either band_names, band_indices, or band_wavelengths to create the band mask. If you provide more than one, the band mask will be created using only one using the following priority: band_names > band_wavelengths > band_indices.

Parameters:

Name Type Description Default
hsi HSI

The input hyperspectral image.

required
band_names None | list[str | list[str]] | dict[tuple[str, ...] | str, int]

The names of the spectral bands to include in the mask. Defaults to None.

None
band_indices None | dict[str | tuple[str, ...], list[tuple[int, int]] | tuple[int, int] | list[int]]

The indices or ranges of indices of the spectral bands to include in the mask. Defaults to None.

None
band_wavelengths None | dict[str | tuple[str, ...], list[tuple[float, float]] | tuple[float, float], list[float], float]

The wavelengths or ranges of wavelengths of the spectral bands to include in the mask. Defaults to None.

None
device str | device | None

The device to use for computation. Defaults to None.

None
repeat_dimensions bool

Whether to repeat the dimensions of the mask to match the input hsi shape. Defaults to False.

False

Returns:

Type Description
Tensor

tuple[torch.Tensor, dict[tuple[str, ...] | str, int]]: A tuple containing the band mask tensor and a dictionary

dict[tuple[str, ...] | str, int]

mapping band names to segment IDs.

Raises:

Type Description
TypeError

If the input hsi is not an instance of the HSI class.

ValueError

If no band names, indices, or wavelengths are provided.

Examples:

>>> hsi = mt.HSI(image=torch.ones((len(wavelengths), 10, 10)), wavelengths=wavelengths)
>>> band_names = ["R", "G"]
>>> band_mask, dict_labels_to_segment_ids = mt_lime.Lime.get_band_mask(hsi, band_names=band_names)
>>> dict_labels_to_segment_ids
{"R": 1, "G": 2}
>>> band_indices = {"RGB": [0, 1, 2]}
>>> band_mask, dict_labels_to_segment_ids = mt_lime.Lime.get_band_mask(hsi, band_indices=band_indices)
>>> dict_labels_to_segment_ids
{"RGB": 1}
>>> band_wavelengths = {"RGB": [(462.08, 465.27), (465.27, 468.47), (468.47, 471.68)]}
>>> band_mask, dict_labels_to_segment_ids = mt_lime.Lime.get_band_mask(hsi, band_wavelengths=band_wavelengths)
>>> dict_labels_to_segment_ids
{"RGB": 1}
Source code in src/meteors/attr/lime.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
@staticmethod
def get_band_mask(
    hsi: HSI,
    band_names: None | list[str | list[str]] | dict[tuple[str, ...] | str, int] = None,
    band_indices: None | dict[str | tuple[str, ...], ListOfWavelengthsIndices] = None,
    band_wavelengths: None | dict[str | tuple[str, ...], ListOfWavelengths] = None,
    device: str | torch.device | None = None,
    repeat_dimensions: bool = False,
) -> tuple[torch.Tensor, dict[tuple[str, ...] | str, int]]:
    """Generates a band mask based on the provided hsi and band information.

    Remember you need to provide either band_names, band_indices, or band_wavelengths to create the band mask.
    If you provide more than one, the band mask will be created using only one using the following priority:
    band_names > band_wavelengths > band_indices.

    Args:
        hsi (HSI): The input hyperspectral image.
        band_names (None | list[str | list[str]] | dict[tuple[str, ...] | str, int], optional):
            The names of the spectral bands to include in the mask. Defaults to None.
        band_indices (None | dict[str | tuple[str, ...], list[tuple[int, int]] | tuple[int, int] | list[int]], optional):
            The indices or ranges of indices of the spectral bands to include in the mask. Defaults to None.
        band_wavelengths (None | dict[str | tuple[str, ...], list[tuple[float, float]] | tuple[float, float], list[float], float], optional):
            The wavelengths or ranges of wavelengths of the spectral bands to include in the mask. Defaults to None.
        device (str | torch.device | None, optional):
            The device to use for computation. Defaults to None.
        repeat_dimensions (bool, optional):
            Whether to repeat the dimensions of the mask to match the input hsi shape. Defaults to False.

    Returns:
        tuple[torch.Tensor, dict[tuple[str, ...] | str, int]]: A tuple containing the band mask tensor and a dictionary
        mapping band names to segment IDs.

    Raises:
        TypeError: If the input hsi is not an instance of the HSI class.
        ValueError: If no band names, indices, or wavelengths are provided.

    Examples:
        >>> hsi = mt.HSI(image=torch.ones((len(wavelengths), 10, 10)), wavelengths=wavelengths)
        >>> band_names = ["R", "G"]
        >>> band_mask, dict_labels_to_segment_ids = mt_lime.Lime.get_band_mask(hsi, band_names=band_names)
        >>> dict_labels_to_segment_ids
        {"R": 1, "G": 2}
        >>> band_indices = {"RGB": [0, 1, 2]}
        >>> band_mask, dict_labels_to_segment_ids = mt_lime.Lime.get_band_mask(hsi, band_indices=band_indices)
        >>> dict_labels_to_segment_ids
        {"RGB": 1}
        >>> band_wavelengths = {"RGB": [(462.08, 465.27), (465.27, 468.47), (468.47, 471.68)]}
        >>> band_mask, dict_labels_to_segment_ids = mt_lime.Lime.get_band_mask(hsi, band_wavelengths=band_wavelengths)
        >>> dict_labels_to_segment_ids
        {"RGB": 1}
    """
    if not isinstance(hsi, HSI):
        raise TypeError("hsi should be an instance of HSI class")

    try:
        if not (band_names is not None or band_indices is not None or band_wavelengths is not None):
            raise ValueError("No band names, indices, or wavelengths are provided.")

        # validate types
        dict_labels_to_segment_ids = None
        if band_names is not None:
            logger.debug("Getting band mask from band names of spectral bands")
            if band_wavelengths is not None or band_indices is not None:
                ignored_params = [
                    param
                    for param in ["band_wavelengths", "band_indices"]
                    if param in locals() and locals()[param] is not None
                ]
                ignored_params_str = " and ".join(ignored_params)
                logger.info(
                    f"Only the band names will be used to create the band mask. The additional parameters {ignored_params_str} will be ignored."
                )
            try:
                validate_band_names(band_names)
                band_groups, dict_labels_to_segment_ids = Lime._get_band_wavelengths_indices_from_band_names(
                    hsi.wavelengths, band_names
                )
            except Exception as e:
                raise BandSelectionError(f"Incorrect band names provided: {e}") from e
        elif band_wavelengths is not None:
            logger.debug("Getting band mask from band groups given by ranges of wavelengths")
            if band_indices is not None:
                logger.info(
                    "Only the band wavelengths will be used to create the band mask. The band_indices will be ignored."
                )
            validate_band_format(band_wavelengths, variable_name="band_wavelengths")
            try:
                band_groups = Lime._get_band_indices_from_band_wavelengths(
                    hsi.wavelengths,
                    band_wavelengths,
                )
            except Exception as e:
                raise ValueError(
                    f"Incorrect band ranges wavelengths provided, please check if provided wavelengths are correct: {e}"
                ) from e
        elif band_indices is not None:
            logger.debug("Getting band mask from band groups given by ranges of indices")
            validate_band_format(band_indices, variable_name="band_indices")
            try:
                band_groups = Lime._get_band_indices_from_input_band_indices(hsi.wavelengths, band_indices)
            except Exception as e:
                raise ValueError(
                    f"Incorrect band ranges indices provided, please check if provided indices are correct: {e}"
                ) from e

        return Lime._create_tensor_band_mask(
            hsi,
            band_groups,
            dict_labels_to_segment_ids=dict_labels_to_segment_ids,
            device=device,
            repeat_dimensions=repeat_dimensions,
            return_dict_labels_to_segment_ids=True,
        )
    except Exception as e:
        raise MaskCreationError(f"Error creating band mask: {e}") from e

get_segmentation_mask(hsi, segmentation_method='slic', **segmentation_method_params) staticmethod

Generates a segmentation mask for the given hsi using the specified segmentation method.

Parameters:

Name Type Description Default
hsi HSI

The input hyperspectral image for which the segmentation mask needs to be generated.

required
segmentation_method Literal['patch', 'slic']

The segmentation method to be used. Defaults to "slic".

'slic'
**segmentation_method_params Any

Additional parameters specific to the chosen segmentation method.

{}

Returns:

Type Description
Tensor

torch.Tensor: The segmentation mask as a tensor.

Raises:

Type Description
TypeError

If the input hsi is not an instance of the HSI class.

ValueError

If an unsupported segmentation method is specified.

Examples:

>>> hsi = meteors.HSI(image=torch.ones((3, 240, 240)), wavelengths=[462.08, 465.27, 468.47])
>>> segmentation_mask = mt_lime.Lime.get_segmentation_mask(hsi, segmentation_method="slic")
>>> segmentation_mask.shape
torch.Size([1, 240, 240])
>>> segmentation_mask = meteors.attr.Lime.get_segmentation_mask(hsi, segmentation_method="patch", patch_size=2)
>>> segmentation_mask.shape
torch.Size([1, 240, 240])
>>> segmentation_mask[0, :2, :2]
torch.tensor([[1, 1],
              [1, 1]])
>>> segmentation_mask[0, 2:4, :2]
torch.tensor([[2, 2],
              [2, 2]])
Source code in src/meteors/attr/lime.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
@staticmethod
def get_segmentation_mask(
    hsi: HSI,
    segmentation_method: Literal["patch", "slic"] = "slic",
    **segmentation_method_params: Any,
) -> torch.Tensor:
    """Generates a segmentation mask for the given hsi using the specified segmentation method.

    Args:
        hsi (HSI): The input hyperspectral image for which the segmentation mask needs to be generated.
        segmentation_method (Literal["patch", "slic"], optional): The segmentation method to be used.
            Defaults to "slic".
        **segmentation_method_params (Any): Additional parameters specific to the chosen segmentation method.

    Returns:
        torch.Tensor: The segmentation mask as a tensor.

    Raises:
        TypeError: If the input hsi is not an instance of the HSI class.
        ValueError: If an unsupported segmentation method is specified.

    Examples:
        >>> hsi = meteors.HSI(image=torch.ones((3, 240, 240)), wavelengths=[462.08, 465.27, 468.47])
        >>> segmentation_mask = mt_lime.Lime.get_segmentation_mask(hsi, segmentation_method="slic")
        >>> segmentation_mask.shape
        torch.Size([1, 240, 240])
        >>> segmentation_mask = meteors.attr.Lime.get_segmentation_mask(hsi, segmentation_method="patch", patch_size=2)
        >>> segmentation_mask.shape
        torch.Size([1, 240, 240])
        >>> segmentation_mask[0, :2, :2]
        torch.tensor([[1, 1],
                      [1, 1]])
        >>> segmentation_mask[0, 2:4, :2]
        torch.tensor([[2, 2],
                      [2, 2]])
    """
    if not isinstance(hsi, HSI):
        raise TypeError("hsi should be an instance of HSI class")

    try:
        if segmentation_method == "slic":
            return Lime._get_slic_segmentation_mask(hsi, **segmentation_method_params)
        elif segmentation_method == "patch":
            return Lime._get_patch_segmentation_mask(hsi, **segmentation_method_params)
        else:
            raise ValueError(f"Unsupported segmentation method: {segmentation_method}")
    except Exception as e:
        raise MaskCreationError(f"Error creating segmentation mask using method {segmentation_method}: {e}")

get_spatial_attributes(hsi, segmentation_mask=None, target=None, n_samples=10, perturbations_per_eval=4, verbose=False, segmentation_method='slic', additional_forward_args=None, **segmentation_method_params)

Get spatial attributes of an hsi image using the LIME method. Based on the provided hsi and segmentation mask LIME method attributes the superpixels provided by the segmentation mask. Please refer to the original paper https://arxiv.org/abs/1602.04938 for more details or to Christoph Molnar's book https://christophm.github.io/interpretable-ml-book/lime.html.

This function attributes the hyperspectral image using the LIME (Local Interpretable Model-Agnostic Explanations) method for spatial data. It returns an HSISpatialAttributes object that contains the hyperspectral image,, the attributions, the segmentation mask, and the score of the interpretable model used for the explanation.

Parameters:

Name Type Description Default
hsi list[HSI] | HSI

Input hyperspectral image(s) for which the attributions are to be computed. If a list of HSI objects is provided, the attributions are computed for each HSI object in the list. The output will be a list of HSISpatialAttributes objects.

required
segmentation_mask ndarray | Tensor | list[ndarray | Tensor] | None

A segmentation mask according to which the attribution should be performed. The segmentation mask should have a 2D or 3D shape, which can be broadcastable to the shape of the input image. The only dimension on which the image and the mask shapes can differ is the spectral dimension, marked with letter C in the image.orientation parameter. If None, a new segmentation mask is created using the segmentation_method. Additional parameters for the segmentation method may be passed as kwargs. If multiple HSI images are provided, a list of segmentation masks can be provided, one for each image. If list is not provided method will assume that the same segmentation mask is used for all images. Defaults to None.

None
target list[int] | int | None

target class index for computing the attributions. If None, methods assume that the output has only one class. If the output has multiple classes, the target index must be provided. For multiple input images, a list of target indices can be provided, one for each image or single target value will be used for all images. Defaults to None.

None
n_samples int

The number of samples to generate/analyze in LIME. The more the better but slower. Defaults to 10.

10
perturbations_per_eval int

The number of perturbations to evaluate at once (Simply the inner batch size). Defaults to 4.

4
verbose bool

Whether to show the progress bar. Defaults to False.

False
segmentation_method Literal['slic', 'patch']

Segmentation method used only if segmentation_mask is None. Defaults to "slic".

'slic'
additional_forward_args Any

If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

None
**segmentation_method_params Any

Additional parameters for the segmentation method.

{}

Returns:

Type Description
list[HSISpatialAttributes] | HSISpatialAttributes

HSISpatialAttributes | list[HSISpatialAttributes]: An object containing the image, the attributions, the segmentation mask, and the score of the interpretable model used for the explanation.

Raises:

Type Description
RuntimeError

If the Lime object is not initialized or is not an instance of LimeBase.

MaskCreationError

If there is an error creating the segmentation mask.

ValueError

If the number of segmentation masks is not equal to the number of HSI images provided.

HSIAttributesError

If there is an error during creating spatial attribution.

Examples:

>>> simple_model = lambda x: torch.rand((x.shape[0], 2))
>>> hsi = mt.HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
>>> segmentation_mask = torch.randint(1, 4, (1, 240, 240))
>>> lime = meteors.attr.Lime(
        explainable_model=ExplainableModel(simple_model, "regression"), interpretable_model=SkLearnLasso(alpha=0.1)
    )
>>> spatial_attribution = lime.get_spatial_attributes(hsi, segmentation_mask=segmentation_mask, target=0)
>>> spatial_attribution.hsi
HSI(shape=(4, 240, 240), dtype=torch.float32)
>>> spatial_attribution.attributes.shape
torch.Size([4, 240, 240])
>>> spatial_attribution.segmentation_mask.shape
torch.Size([1, 240, 240])
>>> spatial_attribution.score
1.0
Source code in src/meteors/attr/lime.py
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
def get_spatial_attributes(
    self,
    hsi: list[HSI] | HSI,
    segmentation_mask: np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor] | None = None,
    target: list[int] | int | None = None,
    n_samples: int = 10,
    perturbations_per_eval: int = 4,
    verbose: bool = False,
    segmentation_method: Literal["slic", "patch"] = "slic",
    additional_forward_args: Any = None,
    **segmentation_method_params: Any,
) -> list[HSISpatialAttributes] | HSISpatialAttributes:
    """
    Get spatial attributes of an hsi image using the LIME method. Based on the provided hsi and segmentation mask
    LIME method attributes the `superpixels` provided by the segmentation mask. Please refer to the original paper
    `https://arxiv.org/abs/1602.04938` for more details or to Christoph Molnar's book
    `https://christophm.github.io/interpretable-ml-book/lime.html`.

    This function attributes the hyperspectral image using the LIME (Local Interpretable Model-Agnostic Explanations)
    method for spatial data. It returns an `HSISpatialAttributes` object that contains the hyperspectral image,,
    the attributions, the segmentation mask, and the score of the interpretable model used for the explanation.

    Args:
        hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
            If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
            The output will be a list of HSISpatialAttributes objects.
        segmentation_mask (np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor] | None, optional):
            A segmentation mask according to which the attribution should be performed.
            The segmentation mask should have a 2D or 3D shape, which can be broadcastable to the shape of the
            input image. The only dimension on which the image and the mask shapes can differ is the spectral
            dimension, marked with letter `C` in the `image.orientation` parameter. If None, a new segmentation mask
            is created using the `segmentation_method`. Additional parameters for the segmentation method may be
            passed as kwargs. If multiple HSI images are provided, a list of segmentation masks can be provided,
            one for each image. If list is not provided method will assume that the same segmentation mask is used
                for all images. Defaults to None.
        target (list[int] | int | None, optional): target class index for computing the attributions. If None,
            methods assume that the output has only one class. If the output has multiple classes, the target index
            must be provided. For multiple input images, a list of target indices can be provided, one for each
            image or single target value will be used for all images. Defaults to None.
        n_samples (int, optional): The number of samples to generate/analyze in LIME. The more the better but slower.
            Defaults to 10.
        perturbations_per_eval (int, optional): The number of perturbations to evaluate at once
            (Simply the inner batch size). Defaults to 4.
        verbose (bool, optional): Whether to show the progress bar. Defaults to False.
        segmentation_method (Literal["slic", "patch"], optional):
            Segmentation method used only if `segmentation_mask` is None. Defaults to "slic".
        additional_forward_args (Any, optional): If the forward function requires additional arguments other than
            the inputs for which attributions should not be computed, this argument can be provided.
            It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
            containing multiple additional arguments including tensors or any arbitrary python types.
            These arguments are provided to forward_func in order following the arguments in inputs.
            Note that attributions are not computed with respect to these arguments. Default: None
        **segmentation_method_params (Any): Additional parameters for the segmentation method.

    Returns:
        HSISpatialAttributes | list[HSISpatialAttributes]: An object containing the image, the attributions,
            the segmentation mask, and the score of the interpretable model used for the explanation.

    Raises:
        RuntimeError: If the Lime object is not initialized or is not an instance of LimeBase.
        MaskCreationError: If there is an error creating the segmentation mask.
        ValueError: If the number of segmentation masks is not equal to the number of HSI images provided.
        HSIAttributesError: If there is an error during creating spatial attribution.

    Examples:
        >>> simple_model = lambda x: torch.rand((x.shape[0], 2))
        >>> hsi = mt.HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
        >>> segmentation_mask = torch.randint(1, 4, (1, 240, 240))
        >>> lime = meteors.attr.Lime(
                explainable_model=ExplainableModel(simple_model, "regression"), interpretable_model=SkLearnLasso(alpha=0.1)
            )
        >>> spatial_attribution = lime.get_spatial_attributes(hsi, segmentation_mask=segmentation_mask, target=0)
        >>> spatial_attribution.hsi
        HSI(shape=(4, 240, 240), dtype=torch.float32)
        >>> spatial_attribution.attributes.shape
        torch.Size([4, 240, 240])
        >>> spatial_attribution.segmentation_mask.shape
        torch.Size([1, 240, 240])
        >>> spatial_attribution.score
        1.0
    """
    if self._attribution_method is None or not isinstance(self._attribution_method, LimeBase):
        raise RuntimeError("Lime object not initialized")  # pragma: no cover

    if isinstance(hsi, HSI):
        hsi = [hsi]

    if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
        raise TypeError("All of the input hyperspectral images must be of type HSI")

    if segmentation_mask is None:
        segmentation_mask = self.get_segmentation_mask(hsi[0], segmentation_method, **segmentation_method_params)

        logger.warning(
            "Segmentation mask is created based on the first HSI image provided, this approach may not be optimal as "
            "the same segmentation mask may not be the best suitable for all images",
        )

    if isinstance(segmentation_mask, tuple):
        segmentation_mask = tuple(segmentation_mask)
    elif not isinstance(segmentation_mask, list):
        segmentation_mask = [segmentation_mask] * len(hsi)

    if len(hsi) != len(segmentation_mask):
        raise ValueError(
            f"Number of segmentation masks should be equal to the number of HSI images provided, provided {len(segmentation_mask)}"
        )

    segmentation_mask = [
        ensure_torch_tensor(mask, f"Segmentation mask number {idx+1} should be None, numpy array, or torch tensor")
        for idx, mask in enumerate(segmentation_mask)
    ]
    segmentation_mask = [
        mask.unsqueeze(0).moveaxis(0, hsi_img.spectral_axis) if mask.ndim != hsi_img.image.ndim else mask
        for hsi_img, mask in zip(hsi, segmentation_mask)
    ]
    segmentation_mask = [
        validate_mask_shape("segmentation", hsi_img, mask) for hsi_img, mask in zip(hsi, segmentation_mask)
    ]

    hsi_input = torch.stack([hsi_img.get_image() for hsi_img in hsi], dim=0)
    segmentation_mask = torch.stack(segmentation_mask, dim=0)

    assert segmentation_mask.shape == hsi_input.shape

    segmentation_mask = segmentation_mask.to(self.device)
    hsi_input = hsi_input.to(self.device)

    lime_attributes, score = self._attribution_method.attribute(
        inputs=hsi_input,
        target=target,
        feature_mask=segmentation_mask,
        n_samples=n_samples,
        perturbations_per_eval=perturbations_per_eval,
        additional_forward_args=additional_forward_args,
        show_progress=verbose,
        return_input_shape=True,
    )

    try:
        spatial_attribution = [
            HSISpatialAttributes(
                hsi=hsi_img,
                attributes=lime_attr,
                mask=segmentation_mask[idx].expand_as(hsi_img.image),
                score=score.item(),
                attribution_method="Lime",
            )
            for idx, (hsi_img, lime_attr) in enumerate(zip(hsi, lime_attributes))
        ]
    except Exception as e:
        raise HSIAttributesError(f"Error during creating spatial attribution {e}") from e

    return spatial_attribution[0] if len(spatial_attribution) == 1 else spatial_attribution

get_spectral_attributes(hsi, band_mask=None, target=None, n_samples=10, perturbations_per_eval=4, verbose=False, additional_forward_args=None, band_names=None)

Attributes the hsi image using LIME method for spectral data. Based on the provided hsi and band mask, the LIME method attributes the hsi based on superbands (clustered bands) provided by the band mask. Please refer to the original paper https://arxiv.org/abs/1602.04938 for more details or to Christoph Molnar's book https://christophm.github.io/interpretable-ml-book/lime.html.

The function returns a HSISpectralAttributes object that contains the image, the attributions, the band mask, the band names, and the score of the interpretable model used for the explanation.

Parameters:

Name Type Description Default
hsi list[HSI] | HSI

Input hyperspectral image(s) for which the attributions are to be computed. If a list of HSI objects is provided, the attributions are computed for each HSI object in the list. The output will be a list of HSISpatialAttributes objects.

required
band_mask ndarray | Tensor | list[ndarray | Tensor] | None

Band mask that is used for the spectral attribution. The band mask should have a 1D or 3D shape, which can be broadcastable to the shape of the input image. The only dimensions on which the image and the mask shapes can differ is the height and width dimensions, marked with letters H and W in the image.orientation parameter. If equals to None, the band mask is created within the function. If multiple HSI images are provided, a list of band masks can be provided, one for each image. If list is not provided method will assume that the same band mask is used for all images. Defaults to None.

None
target list[int] | int | None

target class index for computing the attributions. If None, methods assume that the output has only one class. If the output has multiple classes, the target index must be provided. For multiple input images, a list of target indices can be provided, one for each image or single target value will be used for all images. Defaults to None.

None
n_samples int

The number of samples to generate/analyze in LIME. The more the better but slower. Defaults to 10.

10
perturbations_per_eval int

The number of perturbations to evaluate at once (Simply the inner batch size). Defaults to 4.

4
verbose bool

Whether to show the progress bar. Defaults to False.

False
segmentation_method Literal['slic', 'patch']

Segmentation method used only if segmentation_mask is None. Defaults to "slic".

required
additional_forward_args Any

If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

None
band_names list[str] | dict[str | tuple[str, ...], int] | None

Band names. Defaults to None.

None

Returns:

Type Description
HSISpectralAttributes | list[HSISpectralAttributes]

HSISpectralAttributes | list[HSISpectralAttributes]: An object containing the image, the attributions, the band mask, the band names, and the score of the interpretable model used for the explanation.

Raises:

Type Description
RuntimeError

If the Lime object is not initialized or is not an instance of LimeBase.

MaskCreationError

If there is an error creating the band mask.

ValueError

If the number of band masks is not equal to the number of HSI images provided.

HSIAttributesError

If there is an error during creating spectral attribution.

Examples:

>>> simple_model = lambda x: torch.rand((x.shape[0], 2))
>>> hsi = mt.HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
>>> band_mask = torch.randint(1, 4, (4, 1, 1)).repeat(1, 240, 240)
>>> band_names = ["R", "G", "B"]
>>> lime = meteors.attr.Lime(
        explainable_model=ExplainableModel(simple_model, "regression"), interpretable_model=SkLearnLasso(alpha=0.1)
    )
>>> spectral_attribution = lime.get_spectral_attributes(hsi, band_mask=band_mask, band_names=band_names, target=0)
>>> spectral_attribution.hsi
HSI(shape=(4, 240, 240), dtype=torch.float32)
>>> spectral_attribution.attributes.shape
torch.Size([4, 240, 240])
>>> spectral_attribution.band_mask.shape
torch.Size([4, 240, 240])
>>> spectral_attribution.band_names
["R", "G", "B"]
>>> spectral_attribution.score
1.0
Source code in src/meteors/attr/lime.py
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
def get_spectral_attributes(
    self,
    hsi: list[HSI] | HSI,
    band_mask: np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor] | None = None,
    target: list[int] | int | None = None,
    n_samples: int = 10,
    perturbations_per_eval: int = 4,
    verbose: bool = False,
    additional_forward_args: Any = None,
    band_names: list[str | list[str]] | dict[tuple[str, ...] | str, int] | None = None,
) -> HSISpectralAttributes | list[HSISpectralAttributes]:
    """
    Attributes the hsi image using LIME method for spectral data. Based on the provided hsi and band mask, the LIME
    method attributes the hsi based on `superbands` (clustered bands) provided by the band mask.
    Please refer to the original paper `https://arxiv.org/abs/1602.04938` for more details or to
    Christoph Molnar's book `https://christophm.github.io/interpretable-ml-book/lime.html`.

    The function returns a HSISpectralAttributes object that contains the image, the attributions, the band mask,
    the band names, and the score of the interpretable model used for the explanation.

    Args:
        hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
            If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
            The output will be a list of HSISpatialAttributes objects.
        band_mask (np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor] | None, optional): Band mask that
            is used for the spectral attribution. The band mask should have a 1D or 3D shape, which can be
            broadcastable to the shape of the input image. The only dimensions on which the image and the mask shapes
            can differ is the height and width dimensions, marked with letters `H` and `W` in the `image.orientation`
            parameter. If equals to None, the band mask is created within the function. If multiple HSI images are
            provided, a list of band masks can be provided, one for each image. If list is not provided method will
            assume that the same band mask is used for all images. Defaults to None.
        target (list[int] | int | None, optional): target class index for computing the attributions. If None,
            methods assume that the output has only one class. If the output has multiple classes, the target index
            must be provided. For multiple input images, a list of target indices can be provided, one for each
            image or single target value will be used for all images. Defaults to None.
        n_samples (int, optional): The number of samples to generate/analyze in LIME. The more the better but slower.
            Defaults to 10.
        perturbations_per_eval (int, optional): The number of perturbations to evaluate at once
            (Simply the inner batch size). Defaults to 4.
        verbose (bool, optional): Whether to show the progress bar. Defaults to False.
        segmentation_method (Literal["slic", "patch"], optional):
            Segmentation method used only if `segmentation_mask` is None. Defaults to "slic".
        additional_forward_args (Any, optional): If the forward function requires additional arguments other than
            the inputs for which attributions should not be computed, this argument can be provided.
            It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
            containing multiple additional arguments including tensors or any arbitrary python types.
            These arguments are provided to forward_func in order following the arguments in inputs.
            Note that attributions are not computed with respect to these arguments. Default: None
        band_names (list[str] | dict[str | tuple[str, ...], int] | None, optional): Band names. Defaults to None.

    Returns:
        HSISpectralAttributes | list[HSISpectralAttributes]: An object containing the image, the attributions,
            the band mask, the band names, and the score of the interpretable model used for the explanation.

    Raises:
        RuntimeError: If the Lime object is not initialized or is not an instance of LimeBase.
        MaskCreationError: If there is an error creating the band mask.
        ValueError: If the number of band masks is not equal to the number of HSI images provided.
        HSIAttributesError: If there is an error during creating spectral attribution.

    Examples:
        >>> simple_model = lambda x: torch.rand((x.shape[0], 2))
        >>> hsi = mt.HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
        >>> band_mask = torch.randint(1, 4, (4, 1, 1)).repeat(1, 240, 240)
        >>> band_names = ["R", "G", "B"]
        >>> lime = meteors.attr.Lime(
                explainable_model=ExplainableModel(simple_model, "regression"), interpretable_model=SkLearnLasso(alpha=0.1)
            )
        >>> spectral_attribution = lime.get_spectral_attributes(hsi, band_mask=band_mask, band_names=band_names, target=0)
        >>> spectral_attribution.hsi
        HSI(shape=(4, 240, 240), dtype=torch.float32)
        >>> spectral_attribution.attributes.shape
        torch.Size([4, 240, 240])
        >>> spectral_attribution.band_mask.shape
        torch.Size([4, 240, 240])
        >>> spectral_attribution.band_names
        ["R", "G", "B"]
        >>> spectral_attribution.score
        1.0
    """

    if self._attribution_method is None or not isinstance(self._attribution_method, LimeBase):
        raise RuntimeError("Lime object not initialized")  # pragma: no cover

    if isinstance(hsi, HSI):
        hsi = [hsi]

    if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
        raise TypeError("All of the input hyperspectral images must be of type HSI")

    if band_mask is None:
        created_bands = [self.get_band_mask(hsi_img, band_names) for hsi_img in hsi]
        band_mask, band_name_list = zip(*created_bands)
        band_names = band_name_list[0]

    if isinstance(band_mask, tuple):
        band_mask = list(band_mask)
    elif not isinstance(band_mask, list):
        band_mask = [band_mask]

    if len(hsi) != len(band_mask):
        if len(band_mask) == 1:
            band_mask = band_mask * len(hsi)
            logger.debug("Reusing the same band mask for all images")
        else:
            raise ValueError(
                f"Number of band masks should be equal to the number of HSI images provided, provided {len(band_mask)}"
            )

    band_mask = [
        ensure_torch_tensor(mask, f"Band mask number {idx+1} should be None, numpy array, or torch tensor")
        for idx, mask in enumerate(band_mask)
    ]
    band_mask = [
        mask.unsqueeze(-1).unsqueeze(-1).moveaxis(0, hsi_img.spectral_axis)
        if mask.ndim != hsi_img.image.ndim
        else mask
        for hsi_img, mask in zip(hsi, band_mask)
    ]
    band_mask = [validate_mask_shape("band", hsi_img, mask) for hsi_img, mask in zip(hsi, band_mask)]

    hsi_input = torch.stack([hsi_img.get_image() for hsi_img in hsi], dim=0)
    band_mask = torch.stack(band_mask, dim=0)

    if band_names is None:
        band_names = {str(segment): idx for idx, segment in enumerate(torch.unique(band_mask))}
    else:
        logger.debug(
            "Band names are provided and will be used. In the future, there should be an option to validate them."
        )

    assert hsi_input.shape == band_mask.shape

    hsi_input = hsi_input.to(self.device)
    band_mask = band_mask.to(self.device)

    lime_attributes, score = self._attribution_method.attribute(
        inputs=hsi_input,
        target=target,
        feature_mask=band_mask,
        n_samples=n_samples,
        perturbations_per_eval=perturbations_per_eval,
        additional_forward_args=additional_forward_args,
        show_progress=verbose,
        return_input_shape=True,
    )

    try:
        spectral_attribution = [
            HSISpectralAttributes(
                hsi=hsi_img,
                attributes=lime_attr,
                mask=band_mask[idx].expand_as(hsi_img.image),
                band_names=band_names,
                score=score.item(),
                attribution_method="Lime",
            )
            for idx, (hsi_img, lime_attr) in enumerate(zip(hsi, lime_attributes))
        ]
    except Exception as e:
        raise HSIAttributesError(f"Error during creating spectral attribution {e}") from e

    return spectral_attribution[0] if len(spectral_attribution) == 1 else spectral_attribution

Lime Base

The Lime Base class was adapted from the Captum Lime implementation. This adaptation builds upon the original work, extending and customizing it for specific use cases within this project. To see the original implementation, please refer to the Captum repository.

IntegratedGradients

Bases: Explainer

IntegratedGradients explainer class for generating attributions using the Integrated Gradients method. The Integrated Gradients method is based on the captum implementation and is an implementation of an idea coming from the original paper on Integrated Gradients, where more details about this method can be found.

Attributes:

Name Type Description
_attribution_method IntegratedGradients

The Integrated Gradients method from the captum library.

multiply_by_inputs

Indicates whether to factor model inputs’ multiplier in the final attribution scores. In the literature this is also known as local vs global attribution. If inputs’ multiplier isn’t factored in, then that type of attribution method is also called local attribution. If it is, then that type of attribution method is called global. More detailed can be found in this paper. In case of integrated gradients, if multiply_by_inputs is set to True, final sensitivity scores are being multiplied by (inputs - baselines).

Parameters:

Name Type Description Default
explainable_model ExplainableModel | Explainer

The explainable model to be explained.

required
Source code in src/meteors/attr/integrated_gradients.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class IntegratedGradients(Explainer):
    """
    IntegratedGradients explainer class for generating attributions using the Integrated Gradients method.
    The Integrated Gradients method is based on the [`captum` implementation](https://captum.ai/api/integrated_gradients.html)
    and is an implementation of an idea coming from the [original paper on Integrated Gradients](https://arxiv.org/pdf/1703.01365),
    where more details about this method can be found.

    Attributes:
        _attribution_method (CaptumIntegratedGradients): The Integrated Gradients method from the `captum` library.
        multiply_by_inputs: Indicates whether to factor model inputs’ multiplier in the final attribution scores.
            In the literature this is also known as local vs global attribution. If inputs’ multiplier isn’t factored
            in, then that type of attribution method is also called local attribution. If it is, then that type of
            attribution method is called global. More detailed can be found in this [paper](https://arxiv.org/abs/1711.06104).
            In case of integrated gradients, if multiply_by_inputs is set to True,
            final sensitivity scores are being multiplied by (inputs - baselines).

    Args:
        explainable_model (ExplainableModel | Explainer): The explainable model to be explained.
    """

    def __init__(self, explainable_model: ExplainableModel, multiply_by_inputs: bool = True):
        super().__init__(explainable_model)
        self.multiply_by_inputs = multiply_by_inputs

        self._attribution_method = CaptumIntegratedGradients(
            explainable_model.forward_func, multiply_by_inputs=self.multiply_by_inputs
        )

    def attribute(
        self,
        hsi: list[HSI] | HSI,
        baseline: int | float | torch.Tensor | list[int | float | torch.Tensor] = None,
        target: list[int] | int | None = None,
        additional_forward_args: Any = None,
        n_steps: int = 50,
        method: Literal[
            "riemann_right", "riemann_left", "riemann_middle", "riemann_trapezoid", "gausslegendre"
        ] = "gausslegendre",
        return_convergence_delta: bool = False,
    ) -> HSIAttributes | list[HSIAttributes]:
        """
        Method for generating attributions using the Integrated Gradients method.

        Args:
            hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
                If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
                The output will be a list of HSIAttributes objects.
            baseline (int | float | torch.Tensor | list[int | float | torch.Tensor, optional): Baselines define the
                starting point from which integral is computed and can be provided as:
                    - integer or float representing a constant value used as the baseline for all input pixels.
                    - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                        if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                        the input tensor for each HSI object.
                    - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline
                        for each input pixel. If the input is a list of HSI objects, the baseline can be a list of
                        tensors with the same shape as the input tensor for each HSI object. Defaults to None.
            target (list[int] | int | None, optional): target class index for computing the attributions. If None,
                methods assume that the output has only one class. If the output has multiple classes, the target index
                must be provided. For multiple input images, a list of target indices can be provided, one for each
                image or single target value will be used for all images. Defaults to None.
            additional_forward_args (Any, optional): If the forward function requires additional arguments other than
                the inputs for which attributions should not be computed, this argument can be provided.
                It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors or any arbitrary python types.
                These arguments are provided to forward_func in order following the arguments in inputs.
                Note that attributions are not computed with respect to these arguments. Default: None
            n_steps (int, optional): The number of steps to approximate the integral. Default: 50.
            method (Literal["riemann_right", "riemann_left", "riemann_middle", "riemann_trapezoid", "gausslegendre"],
                optional): Method for approximating the integral, one of riemann_right, riemann_left, riemann_middle,
                riemann_trapezoid or gausslegendre. Default: gausslegendre if no method is provided.
            return_convergence_delta (bool, optional): Indicates whether to return convergence delta or not.
                If return_convergence_delta is set to True convergence delta will be returned in a tuple following
                attributions. Default: False

        Returns:
            HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
                if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

        Raises:
            RuntimeError: If the explainer is not initialized.
            HSIAttributesError: If an error occurs during the generation of the attributions.


        Examples:
            >>> integrated_gradients = IntegratedGradients(explainable_model)
            >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
            >>> attributions = integrated_gradients.attribute(hsi, method="riemann_right", baseline=0.0)
            >>> attributions, approximation_error = integrated_gradients.attribute(hsi, return_convergence_delta=True)
            >>> approximation_error
            0.5
            >>> attributions = integrated_gradients.attribute([hsi, hsi])
            >>> len(attributions)
            2
        """
        if self._attribution_method is None:
            raise RuntimeError("IntegratedGradients explainer is not initialized, INITIALIZATION ERROR")

        if not isinstance(hsi, list):
            hsi = [hsi]

        if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
            raise TypeError("All of the input hyperspectral images must be of type HSI")

        if not isinstance(baseline, list):
            baseline = [baseline] * len(hsi)

        baseline = torch.stack(
            [
                validate_and_transform_baseline(base, hsi_image).to(hsi_image.device)
                for hsi_image, base in zip(hsi, baseline)
            ],
            dim=0,
        )
        input_tensor = torch.stack(
            [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
        )

        ig_attributions = self._attribution_method.attribute(
            input_tensor,
            baselines=baseline,
            target=target,
            n_steps=n_steps,
            additional_forward_args=additional_forward_args,
            method=method,
            return_convergence_delta=return_convergence_delta,
        )

        if return_convergence_delta:
            attributions, approximation_error = ig_attributions
        else:
            attributions, approximation_error = ig_attributions, [None] * len(hsi)

        try:
            attributes = [
                HSIAttributes(hsi=hsi_image, attributes=attribution, score=error, attribution_method=self.get_name())
                for hsi_image, attribution, error in zip(hsi, attributions, approximation_error)
            ]
        except Exception as e:
            raise HSIAttributesError(f"Error while creating HSIAttributes: {e}") from e

        return attributes[0] if len(attributes) == 1 else attributes

attribute(hsi, baseline=None, target=None, additional_forward_args=None, n_steps=50, method='gausslegendre', return_convergence_delta=False)

Method for generating attributions using the Integrated Gradients method.

Parameters:

Name Type Description Default
hsi list[HSI] | HSI

Input hyperspectral image(s) for which the attributions are to be computed. If a list of HSI objects is provided, the attributions are computed for each HSI object in the list. The output will be a list of HSIAttributes objects.

required
baseline int | float | torch.Tensor | list[int | float | torch.Tensor

Baselines define the starting point from which integral is computed and can be provided as: - integer or float representing a constant value used as the baseline for all input pixels. - tensor with the same shape as the input tensor, providing a baseline for each input pixel. if the input is a list of HSI objects, the baseline can be a tensor with the same shape as the input tensor for each HSI object. - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline for each input pixel. If the input is a list of HSI objects, the baseline can be a list of tensors with the same shape as the input tensor for each HSI object. Defaults to None.

None
target list[int] | int | None

target class index for computing the attributions. If None, methods assume that the output has only one class. If the output has multiple classes, the target index must be provided. For multiple input images, a list of target indices can be provided, one for each image or single target value will be used for all images. Defaults to None.

None
additional_forward_args Any

If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

None
n_steps int

The number of steps to approximate the integral. Default: 50.

50
return_convergence_delta bool

Indicates whether to return convergence delta or not. If return_convergence_delta is set to True convergence delta will be returned in a tuple following attributions. Default: False

False

Returns:

Type Description
HSIAttributes | list[HSIAttributes]

HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

Raises:

Type Description
RuntimeError

If the explainer is not initialized.

HSIAttributesError

If an error occurs during the generation of the attributions.

Examples:

>>> integrated_gradients = IntegratedGradients(explainable_model)
>>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
>>> attributions = integrated_gradients.attribute(hsi, method="riemann_right", baseline=0.0)
>>> attributions, approximation_error = integrated_gradients.attribute(hsi, return_convergence_delta=True)
>>> approximation_error
0.5
>>> attributions = integrated_gradients.attribute([hsi, hsi])
>>> len(attributions)
2
Source code in src/meteors/attr/integrated_gradients.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def attribute(
    self,
    hsi: list[HSI] | HSI,
    baseline: int | float | torch.Tensor | list[int | float | torch.Tensor] = None,
    target: list[int] | int | None = None,
    additional_forward_args: Any = None,
    n_steps: int = 50,
    method: Literal[
        "riemann_right", "riemann_left", "riemann_middle", "riemann_trapezoid", "gausslegendre"
    ] = "gausslegendre",
    return_convergence_delta: bool = False,
) -> HSIAttributes | list[HSIAttributes]:
    """
    Method for generating attributions using the Integrated Gradients method.

    Args:
        hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
            If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
            The output will be a list of HSIAttributes objects.
        baseline (int | float | torch.Tensor | list[int | float | torch.Tensor, optional): Baselines define the
            starting point from which integral is computed and can be provided as:
                - integer or float representing a constant value used as the baseline for all input pixels.
                - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                    if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                    the input tensor for each HSI object.
                - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline
                    for each input pixel. If the input is a list of HSI objects, the baseline can be a list of
                    tensors with the same shape as the input tensor for each HSI object. Defaults to None.
        target (list[int] | int | None, optional): target class index for computing the attributions. If None,
            methods assume that the output has only one class. If the output has multiple classes, the target index
            must be provided. For multiple input images, a list of target indices can be provided, one for each
            image or single target value will be used for all images. Defaults to None.
        additional_forward_args (Any, optional): If the forward function requires additional arguments other than
            the inputs for which attributions should not be computed, this argument can be provided.
            It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
            containing multiple additional arguments including tensors or any arbitrary python types.
            These arguments are provided to forward_func in order following the arguments in inputs.
            Note that attributions are not computed with respect to these arguments. Default: None
        n_steps (int, optional): The number of steps to approximate the integral. Default: 50.
        method (Literal["riemann_right", "riemann_left", "riemann_middle", "riemann_trapezoid", "gausslegendre"],
            optional): Method for approximating the integral, one of riemann_right, riemann_left, riemann_middle,
            riemann_trapezoid or gausslegendre. Default: gausslegendre if no method is provided.
        return_convergence_delta (bool, optional): Indicates whether to return convergence delta or not.
            If return_convergence_delta is set to True convergence delta will be returned in a tuple following
            attributions. Default: False

    Returns:
        HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
            if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

    Raises:
        RuntimeError: If the explainer is not initialized.
        HSIAttributesError: If an error occurs during the generation of the attributions.


    Examples:
        >>> integrated_gradients = IntegratedGradients(explainable_model)
        >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
        >>> attributions = integrated_gradients.attribute(hsi, method="riemann_right", baseline=0.0)
        >>> attributions, approximation_error = integrated_gradients.attribute(hsi, return_convergence_delta=True)
        >>> approximation_error
        0.5
        >>> attributions = integrated_gradients.attribute([hsi, hsi])
        >>> len(attributions)
        2
    """
    if self._attribution_method is None:
        raise RuntimeError("IntegratedGradients explainer is not initialized, INITIALIZATION ERROR")

    if not isinstance(hsi, list):
        hsi = [hsi]

    if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
        raise TypeError("All of the input hyperspectral images must be of type HSI")

    if not isinstance(baseline, list):
        baseline = [baseline] * len(hsi)

    baseline = torch.stack(
        [
            validate_and_transform_baseline(base, hsi_image).to(hsi_image.device)
            for hsi_image, base in zip(hsi, baseline)
        ],
        dim=0,
    )
    input_tensor = torch.stack(
        [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
    )

    ig_attributions = self._attribution_method.attribute(
        input_tensor,
        baselines=baseline,
        target=target,
        n_steps=n_steps,
        additional_forward_args=additional_forward_args,
        method=method,
        return_convergence_delta=return_convergence_delta,
    )

    if return_convergence_delta:
        attributions, approximation_error = ig_attributions
    else:
        attributions, approximation_error = ig_attributions, [None] * len(hsi)

    try:
        attributes = [
            HSIAttributes(hsi=hsi_image, attributes=attribution, score=error, attribution_method=self.get_name())
            for hsi_image, attribution, error in zip(hsi, attributions, approximation_error)
        ]
    except Exception as e:
        raise HSIAttributesError(f"Error while creating HSIAttributes: {e}") from e

    return attributes[0] if len(attributes) == 1 else attributes

InputXGradient

Bases: Explainer

Initializes the InputXGradient explainer. The InputXGradients method is a straightforward approach to computing attribution. It simply multiplies the input image with the gradient with respect to the input. This method is based on the captum implementation

Attributes:

Name Type Description
_attribution_method CaptumIntegratedGradients

The InputXGradient method from the captum library.

Parameters:

Name Type Description Default
explainable_model ExplainableModel | Explainer

The explainable model to be explained.

required
Source code in src/meteors/attr/input_x_gradients.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class InputXGradient(Explainer):
    """
    Initializes the InputXGradient explainer. The InputXGradients method is a straightforward approach to
    computing attribution. It simply multiplies the input image with the gradient with respect to the input.
    This method is based on the [`captum` implementation](https://captum.ai/api/input_x_gradient.html)

    Attributes:
        _attribution_method (CaptumIntegratedGradients): The InputXGradient method from the `captum` library.

    Args:
        explainable_model (ExplainableModel | Explainer): The explainable model to be explained.
    """

    def __init__(self, explainable_model: ExplainableModel):
        super().__init__(explainable_model)

        self._attribution_method = CaptumInputXGradient(explainable_model.forward_func)

    def attribute(
        self,
        hsi: list[HSI] | HSI,
        target: list[int] | int | None = None,
        additional_forward_args: Any = None,
    ) -> HSIAttributes | list[HSIAttributes]:
        """
        Method for generating attributions using the InputXGradient method.

        Args:
            hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
                If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
                The output will be a list of HSIAttributes objects.
            target (list[int] | int | None, optional): target class index for computing the attributions. If None,
                methods assume that the output has only one class. If the output has multiple classes, the target index
                must be provided. For multiple input images, a list of target indices can be provided, one for each
                image or single target value will be used for all images. Defaults to None.
            additional_forward_args (Any, optional): If the forward function requires additional arguments other than
                the inputs for which attributions should not be computed, this argument can be provided.
                It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors or any arbitrary python types.
                These arguments are provided to forward_func in order following the arguments in inputs.
                Note that attributions are not computed with respect to these arguments. Default: None

        Returns:
            HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
                if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

        Raises:
            RuntimeError: If the explainer is not initialized.
            HSIAttributesError: If an error occurs during the generation of the attributions.

        Examples:
            >>> input_x_gradient = InputXGradient(explainable_model)
            >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
            >>> attributions = input_x_gradient.attribute(hsi)
            >>> attributions = input_x_gradient.attribute([hsi, hsi])
            >>> len(attributions)
            2
        """
        if self._attribution_method is None:
            raise RuntimeError("InputXGradient explainer is not initialized, INITIALIZATION ERROR")

        if not isinstance(hsi, list):
            hsi = [hsi]

        if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
            raise TypeError("All of the input hyperspectral images must be of type HSI")

        input_tensor = torch.stack(
            [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
        )

        gradient_attribution = self._attribution_method.attribute(
            input_tensor, target=target, additional_forward_args=additional_forward_args
        )

        try:
            attributes = [
                HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
                for hsi_image, attribution in zip(hsi, gradient_attribution)
            ]
        except Exception as e:
            raise HSIAttributesError(f"Error in generating InputXGradient attributions: {e}") from e

        return attributes[0] if len(attributes) == 1 else attributes

attribute(hsi, target=None, additional_forward_args=None)

Method for generating attributions using the InputXGradient method.

Parameters:

Name Type Description Default
hsi list[HSI] | HSI

Input hyperspectral image(s) for which the attributions are to be computed. If a list of HSI objects is provided, the attributions are computed for each HSI object in the list. The output will be a list of HSIAttributes objects.

required
target list[int] | int | None

target class index for computing the attributions. If None, methods assume that the output has only one class. If the output has multiple classes, the target index must be provided. For multiple input images, a list of target indices can be provided, one for each image or single target value will be used for all images. Defaults to None.

None
additional_forward_args Any

If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

None

Returns:

Type Description
HSIAttributes | list[HSIAttributes]

HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

Raises:

Type Description
RuntimeError

If the explainer is not initialized.

HSIAttributesError

If an error occurs during the generation of the attributions.

Examples:

>>> input_x_gradient = InputXGradient(explainable_model)
>>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
>>> attributions = input_x_gradient.attribute(hsi)
>>> attributions = input_x_gradient.attribute([hsi, hsi])
>>> len(attributions)
2
Source code in src/meteors/attr/input_x_gradients.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def attribute(
    self,
    hsi: list[HSI] | HSI,
    target: list[int] | int | None = None,
    additional_forward_args: Any = None,
) -> HSIAttributes | list[HSIAttributes]:
    """
    Method for generating attributions using the InputXGradient method.

    Args:
        hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
            If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
            The output will be a list of HSIAttributes objects.
        target (list[int] | int | None, optional): target class index for computing the attributions. If None,
            methods assume that the output has only one class. If the output has multiple classes, the target index
            must be provided. For multiple input images, a list of target indices can be provided, one for each
            image or single target value will be used for all images. Defaults to None.
        additional_forward_args (Any, optional): If the forward function requires additional arguments other than
            the inputs for which attributions should not be computed, this argument can be provided.
            It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
            containing multiple additional arguments including tensors or any arbitrary python types.
            These arguments are provided to forward_func in order following the arguments in inputs.
            Note that attributions are not computed with respect to these arguments. Default: None

    Returns:
        HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
            if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

    Raises:
        RuntimeError: If the explainer is not initialized.
        HSIAttributesError: If an error occurs during the generation of the attributions.

    Examples:
        >>> input_x_gradient = InputXGradient(explainable_model)
        >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
        >>> attributions = input_x_gradient.attribute(hsi)
        >>> attributions = input_x_gradient.attribute([hsi, hsi])
        >>> len(attributions)
        2
    """
    if self._attribution_method is None:
        raise RuntimeError("InputXGradient explainer is not initialized, INITIALIZATION ERROR")

    if not isinstance(hsi, list):
        hsi = [hsi]

    if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
        raise TypeError("All of the input hyperspectral images must be of type HSI")

    input_tensor = torch.stack(
        [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
    )

    gradient_attribution = self._attribution_method.attribute(
        input_tensor, target=target, additional_forward_args=additional_forward_args
    )

    try:
        attributes = [
            HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
            for hsi_image, attribution in zip(hsi, gradient_attribution)
        ]
    except Exception as e:
        raise HSIAttributesError(f"Error in generating InputXGradient attributions: {e}") from e

    return attributes[0] if len(attributes) == 1 else attributes

Occlusion

Bases: Explainer

Occlusion explainer class for generating attributions using the Occlusion method. This attribution method perturbs the input by replacing the contiguous rectangular region with a given baseline and computing the difference in output. In our case, features are located in multiple regions, and attribution from different hyper-rectangles is averaged. The implementation of this method is also based on the captum repository. More details about this approach can be found in the original paper

Attributes:

Name Type Description
_attribution_method Occlusion

The Occlusion method from the captum library.

Parameters:

Name Type Description Default
explainable_model ExplainableModel | Explainer

The explainable model to be explained.

required
postprocessing_segmentation_output Callable[[Tensor], Tensor] | None

A segmentation postprocessing function for segmentation problem type. This is required for segmentation problem type as attribution methods needs to have 1d output. Defaults to None, which means that the attribution method is not used.

required
Source code in src/meteors/attr/occlusion.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
class Occlusion(Explainer):
    """
    Occlusion explainer class for generating attributions using the Occlusion method.
    This attribution method perturbs the input by replacing the contiguous rectangular region
    with a given baseline and computing the difference in output.
    In our case, features are located in multiple regions, and attribution from different hyper-rectangles is averaged.
    The implementation of this method is also based on the [`captum` repository](https://captum.ai/api/occlusion.html).
    More details about this approach can be found in the [original paper](https://arxiv.org/abs/1311.2901)

    Attributes:
        _attribution_method (CaptumOcclusion): The Occlusion method from the `captum` library.

    Args:
        explainable_model (ExplainableModel | Explainer): The explainable model to be explained.
        postprocessing_segmentation_output (Callable[[torch.Tensor], torch.Tensor] | None):
            A segmentation postprocessing function for segmentation problem type. This is required for segmentation
            problem type as attribution methods needs to have 1d output. Defaults to None, which means that the
            attribution method is not used.
    """

    def __init__(self, explainable_model: ExplainableModel):
        super().__init__(explainable_model)

        self._attribution_method = CaptumOcclusion(explainable_model.forward_func)

    @staticmethod
    def _create_segmentation_mask(
        input_shape: tuple[int, int, int], sliding_window_shapes: tuple[int, int, int], strides: tuple[int, int, int]
    ) -> torch.Tensor:
        """
        Create a binary segmentation mask based on sliding windows.

        Args:
            input_shape (Tuple[int, int, int]): Shape of the input tensor (e.g., (H, W, C))
            sliding_window_shapes (Tuple[int, int, int]): Shape of the sliding window (e.g., (h, w, c))
            strides (Tuple[int, int, int]): Strides for the sliding window (e.g., (s_h, s_w, s_c))

        Returns:
            torch.Tensor: Binary mask tensor with ones where windows are placed
        """
        # Initialize empty mask
        mask = torch.zeros(input_shape, dtype=torch.int32)

        # Calculate number of windows in each dimension
        windows = []
        for dim_size, window_size, stride in zip(input_shape, sliding_window_shapes, strides):
            if stride == 0:
                raise ValueError("Stride cannot be zero.")
            n_windows = dim_size // stride if (dim_size - window_size) % stride == 0 else dim_size // stride + 1
            # 1 + (dim_size - window_size) // stride
            windows.append(n_windows)

        # Generate all possible indices using itertools.product
        for i, indices in enumerate(itertools.product(*[range(w) for w in windows])):
            # Calculate start position for each dimension
            starts = [idx * stride for idx, stride in zip(indices, strides)]

            # Calculate end position for each dimension
            ends = [start + window for start, window in zip(starts, sliding_window_shapes)]

            # Create slice objects for each dimension
            slices = tuple(slice(start, end) for start, end in zip(starts, ends))

            # Mark window position in mask
            mask[slices] = i + 1

        return mask

    def attribute(
        self,
        hsi: list[HSI] | HSI,
        target: list[int] | int | None = None,
        sliding_window_shapes: int | tuple[int, int, int] = (1, 1, 1),
        strides: int | tuple[int, int, int] = (1, 1, 1),
        baseline: int | float | torch.Tensor | list[int | float | torch.Tensor] = None,
        additional_forward_args: Any = None,
        perturbations_per_eval: int = 1,
        show_progress: bool = False,
    ) -> HSIAttributes | list[HSIAttributes]:
        """
        Method for generating attributions using the Occlusion method.

        Args:
            hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
                If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
                The output will be a list of HSIAttributes objects.
            target (list[int] | int | None, optional): target class index for computing the attributions. If None,
                methods assume that the output has only one class. If the output has multiple classes, the target index
                must be provided. For multiple input images, a list of target indices can be provided, one for each
                image or single target value will be used for all images. Defaults to None.
            sliding_window_shapes (int | tuple[int, int, int]):
                The shape of the sliding window. If an integer is provided, it will be used for all dimensions.
                Defaults to (1, 1, 1).
            strides (int | tuple[int, int, int], optional): The stride of the sliding window. Defaults to (1, 1, 1).
                Simply put, the stride is the number of pixels by which the sliding window is moved in each dimension.
            baseline (int | float | torch.Tensor | list[int | float | torch.Tensor], optional): Baselines define
                reference value which replaces each feature when occluded is computed and can be provided as:
                    - integer or float representing a constant value used as the baseline for all input pixels.
                    - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                        if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                        the input tensor for each HSI object.
                    - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline
                        for each input pixel. If the input is a list of HSI objects, the baseline can be a list of
                        tensors with the same shape as the input tensor for each HSI object. Defaults to None.
            additional_forward_args (Any, optional): If the forward function requires additional arguments other than
                the inputs for which attributions should not be computed, this argument can be provided.
                It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors or any arbitrary python types.
                These arguments are provided to forward_func in order following the arguments in inputs.
                Note that attributions are not computed with respect to these arguments. Default: None
            perturbations_per_eval (int, optional): Allows multiple occlusions to be included in one batch
                (one call to forward_fn). By default, perturbations_per_eval is 1, so each occlusion is processed
                individually. Each forward pass will contain a maximum of perturbations_per_eval * #examples samples.
                For DataParallel models, each batch is split among the available devices, so evaluations on each
                available device contain at most (perturbations_per_eval * #examples) / num_devices samples. When
                working with multiple examples, the number of perturbations per evaluation should be set to at least
                the number of examples. Defaults to 1.
            show_progress (bool, optional): If True, displays a progress bar. Defaults to False.

        Returns:
            HSIAttributes: The computed attributions for the input hyperspectral image(s). if a list of HSI objects
                is provided, the attributions are computed for each HSI object in the list.

        Raises:
            RuntimeError: If the explainer is not initialized.
            ValueError: If the sliding window shapes or strides are not a tuple of three integers.
            HSIAttributesError: If an error occurs during the generation of the attributions.

        Example:
            >>> occlusion = Occlusion(explainable_model)
            >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
            >>> attributions = occlusion.attribute(hsi, baseline=0, sliding_window_shapes=(4, 3, 3), strides=(1, 1, 1))
            >>> attributions = occlusion.attribute([hsi, hsi], baseline=0, sliding_window_shapes=(4, 3, 3), strides=(1, 2, 2))
            >>> len(attributions)
            2
        """
        if self._attribution_method is None:
            raise RuntimeError("Occlusion explainer is not initialized, INITIALIZATION ERROR")

        if not isinstance(hsi, list):
            hsi = [hsi]

        if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
            raise TypeError("All of the input hyperspectral images must be of type HSI")

        if not isinstance(baseline, list):
            baseline = [baseline] * len(hsi)

        baseline = torch.stack(
            [
                validate_and_transform_baseline(base, hsi_image).to(hsi_image.device)
                for hsi_image, base in zip(hsi, baseline)
            ],
            dim=0,
        )
        input_tensor = torch.stack(
            [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
        )

        if isinstance(sliding_window_shapes, int):
            sliding_window_shapes = (sliding_window_shapes, sliding_window_shapes, sliding_window_shapes)
        if isinstance(strides, int):
            strides = (strides, strides, strides)

        if len(strides) != 3:
            raise ValueError("Strides must be a tuple of three integers")
        if len(sliding_window_shapes) != 3:
            raise ValueError("Sliding window shapes must be a tuple of three integers")

        assert len(sliding_window_shapes) == len(strides) == 3
        occlusion_attributions = self._attribution_method.attribute(
            input_tensor,
            sliding_window_shapes=sliding_window_shapes,
            strides=strides,
            target=target,
            baselines=baseline,
            additional_forward_args=additional_forward_args,
            perturbations_per_eval=min(perturbations_per_eval, len(hsi)),
            show_progress=show_progress,
        )

        try:
            attributes = [
                HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
                for hsi_image, attribution in zip(hsi, occlusion_attributions)
            ]
        except Exception as e:
            raise HSIAttributesError(f"Error in generating Occlusion attributions: {e}") from e

        return attributes[0] if len(attributes) == 1 else attributes

    def get_spatial_attributes(
        self,
        hsi: list[HSI] | HSI,
        target: list[int] | int | None = None,
        sliding_window_shapes: int | tuple[int, int] = (1, 1),
        strides: int | tuple[int, int] = 1,
        baseline: int | float | torch.Tensor | list[int | float | torch.Tensor] = None,
        additional_forward_args: Any = None,
        perturbations_per_eval: int = 1,
        show_progress: bool = False,
    ) -> HSISpatialAttributes | list[HSISpatialAttributes]:
        """Compute spatial attributions for the input HSI using the Occlusion method. In this case, the sliding window
        is applied to the spatial dimensions only.

        Args:
            hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
                If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
                The output will be a list of HSIAttributes objects.
            target (list[int] | int | None, optional): target class index for computing the attributions. If None,
                methods assume that the output has only one class. If the output has multiple classes, the target index
                must be provided. For multiple input images, a list of target indices can be provided, one for each
                image or single target value will be used for all images. Defaults to None.
            sliding_window_shapes (int | tuple[int, int]): The shape of the sliding window for spatial dimensions.
                If an integer is provided, it will be used for both spatial dimensions. Defaults to (1, 1).
            strides (int | tuple[int, int], optional): The stride of the sliding window for spatial dimensions.
                Defaults to 1. Simply put, the stride is the number of pixels by which the sliding window is moved
                in each spatial dimension.
            baseline (int | float | torch.Tensor | list[int | float | torch.Tensor], optional): Baselines define
                reference value which replaces each feature when occluded is computed and can be provided as:
                    - integer or float representing a constant value used as the baseline for all input pixels.
                    - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                        if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                        the input tensor for each HSI object.
                    - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline
                      for each input pixel. If the input is a list of HSI objects, the baseline can be a list of
                      tensors with the same shape as the input tensor for each HSI object. Defaults to None.
            additional_forward_args (Any, optional): If the forward function requires additional arguments other than
                the inputs for which attributions should not be computed, this argument can be provided.
                It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors or any arbitrary python types.
                These arguments are provided to forward_func in order following the arguments in inputs.
                Note that attributions are not computed with respect to these arguments. Default: None
            perturbations_per_eval (int, optional): Allows multiple occlusions to be included in one batch
                (one call to forward_fn). By default, perturbations_per_eval is 1, so each occlusion is processed
                individually. Each forward pass will contain a maximum of perturbations_per_eval * #examples samples.
                For DataParallel models, each batch is split among the available devices, so evaluations on each
                available device contain at most (perturbations_per_eval * #examples) / num_devices samples. When
                working with multiple examples, the number of perturbations per evaluation should be set to at least
                the number of examples. Defaults to 1.
            show_progress (bool, optional): If True, displays a progress bar. Defaults to False.

        Returns:
            HSISpatialAttributes | list[HSISpatialAttributes]: The computed attributions for the input hyperspectral image(s).
                if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

        Raises:
            RuntimeError: If the explainer is not initialized.
            ValueError: If the sliding window shapes or strides are not a tuple of two integers.
            HSIAttributesError: If an error occurs during the generation of the attributions

        Example:
            >>> occlusion = Occlusion(explainable_model)
            >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
            >>> attributions = occlusion.get_spatial_attributes(hsi, baseline=0, sliding_window_shapes=(3, 3), strides=(1, 1))
            >>> attributions = occlusion.get_spatial_attributes([hsi, hsi], baseline=0, sliding_window_shapes=(3, 3), strides=(2, 2))
            >>> len(attributions)
            2
        """
        if self._attribution_method is None:
            raise RuntimeError("Occlusion explainer is not initialized, INITIALIZATION ERROR")

        if not isinstance(hsi, list):
            hsi = [hsi]

        if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
            raise TypeError("All of the input hyperspectral images must be of type HSI")

        if not isinstance(baseline, list):
            baseline = [baseline] * len(hsi)

        baseline = torch.stack(
            [
                validate_and_transform_baseline(base, hsi_image).to(hsi_image.device)
                for hsi_image, base in zip(hsi, baseline)
            ],
            dim=0,
        )
        input_tensor = torch.stack(
            [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
        )

        if isinstance(sliding_window_shapes, int):
            sliding_window_shapes = (sliding_window_shapes, sliding_window_shapes)
        if isinstance(strides, int):
            strides = (strides, strides)

        if len(strides) != 2:
            raise ValueError("Strides must be a tuple of two integers")
        if len(sliding_window_shapes) != 2:
            raise ValueError("Sliding window shapes must be a tuple of two integers")

        list_sliding_window_shapes = list(sliding_window_shapes)
        list_strides = list(strides)
        if isinstance(hsi, list):
            list_sliding_window_shapes.insert(hsi[0].spectral_axis, hsi[0].image.shape[hsi[0].spectral_axis])
            list_strides.insert(hsi[0].spectral_axis, hsi[0].image.shape[hsi[0].spectral_axis])
        else:
            list_sliding_window_shapes.insert(hsi.spectral_axis, hsi.image.shape[hsi.spectral_axis])
            list_strides.insert(hsi.spectral_axis, hsi.image.shape[hsi.spectral_axis])
        sliding_window_shapes = tuple(list_sliding_window_shapes)  # type: ignore
        strides = tuple(list_strides)  # type: ignore

        assert len(sliding_window_shapes) == len(strides) == 3
        segment_mask = [
            self._create_segmentation_mask(hsi_image.image.shape, sliding_window_shapes, strides) for hsi_image in hsi
        ]

        occlusion_attributions = self._attribution_method.attribute(
            input_tensor,
            sliding_window_shapes=sliding_window_shapes,
            strides=strides,
            target=target,
            baselines=baseline,
            additional_forward_args=additional_forward_args,
            perturbations_per_eval=min(perturbations_per_eval, len(hsi)),
            show_progress=show_progress,
        )

        try:
            spatial_attributes = [
                HSISpatialAttributes(
                    hsi=hsi_image, attributes=attribution, attribution_method=self.get_name(), mask=mask
                )
                for hsi_image, attribution, mask in zip(hsi, occlusion_attributions, segment_mask)
            ]
        except Exception as e:
            raise HSIAttributesError(f"Error in generating Occlusion attributions: {e}") from e

        return spatial_attributes[0] if len(spatial_attributes) == 1 else spatial_attributes

    def get_spectral_attributes(
        self,
        hsi: list[HSI] | HSI,
        target: list[int] | int | None = None,
        sliding_window_shapes: int | tuple[int] = 1,
        strides: int | tuple[int] = 1,
        baseline: int | float | torch.Tensor | list[int | float | torch.Tensor] = None,
        additional_forward_args: Any = None,
        perturbations_per_eval: int = 1,
        show_progress: bool = False,
    ) -> HSISpectralAttributes | list[HSISpectralAttributes]:
        """Compute spectral attributions for the input HSI using the Occlusion method. In this case, the sliding window
        is applied to the spectral dimension only.

        Args:
            hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
                If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
                The output will be a list of HSIAttributes objects.
            target (list[int] | int | None, optional): target class index for computing the attributions. If None,
                methods assume that the output has only one class. If the output has multiple classes, the target index
                must be provided. For multiple input images, a list of target indices can be provided, one for each
                image or single target value will be used for all images. Defaults to None.
            sliding_window_shapes (int | tuple[int]): The size of the sliding window for the spectral dimension.
                Defaults to 1.
            strides (int | tuple[int], optional): The stride of the sliding window for the spectral dimension.
                Defaults to 1. Simply put, the stride is the number of pixels by which the sliding window is moved
                in spectral dimension.
            baseline (int | float | torch.Tensor | list[int | float | torch.Tensor], optional): Baselines define
                reference value which replaces each feature when occluded is computed and can be provided as:
                    - integer or float representing a constant value used as the baseline for all input pixels.
                    - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                        if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                        the input tensor for each HSI object.
                    - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline
                      for each input pixel. If the input is a list of HSI objects, the baseline can be a list of
                      tensors with the same shape as the input tensor for each HSI object. Defaults to None.
            additional_forward_args (Any, optional): If the forward function requires additional arguments other than
                the inputs for which attributions should not be computed, this argument can be provided.
                It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors or any arbitrary python types.
                These arguments are provided to forward_func in order following the arguments in inputs.
                Note that attributions are not computed with respect to these arguments. Default: None
            perturbations_per_eval (int, optional): Allows multiple occlusions to be included in one batch
                (one call to forward_fn). By default, perturbations_per_eval is 1, so each occlusion is processed
                individually. Each forward pass will contain a maximum of perturbations_per_eval * #examples samples.
                For DataParallel models, each batch is split among the available devices, so evaluations on each
                available device contain at most (perturbations_per_eval * #examples) / num_devices samples. When
                working with multiple examples, the number of perturbations per evaluation should be set to at least
                the number of examples. Defaults to 1.
            show_progress (bool, optional): If True, displays a progress bar. Defaults to False.

        Returns:
            HSISpectralAttributes | list[HSISpectralAttributes]: The computed attributions for the input hyperspectral
                image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in
                the list.

        Raises:
            RuntimeError: If the explainer is not initialized.
            ValueError: If the sliding window shapes or strides are not a tuple of a single integer.
            TypeError: If the sliding window shapes or strides are not a single integer.
            HSIAttributesError: If an error occurs during the generation of the attributions

        Example:
            >>> occlusion = Occlusion(explainable_model)
            >>> hsi = HSI(image=torch.ones((10, 240, 240)), wavelengths=torch.arange(10))
            >>> attributions = occlusion.get_spectral_attributes(hsi, baseline=0, sliding_window_shapes=3, strides=1)
            >>> attributions = occlusion.get_spectral_attributes([hsi, hsi], baseline=0, sliding_window_shapes=3, strides=2)
            >>> len(attributions)
            2
        """
        if self._attribution_method is None:
            raise RuntimeError("Occlusion explainer is not initialized, INITIALIZATION ERROR")

        if not isinstance(hsi, list):
            hsi = [hsi]

        if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
            raise TypeError("All of the input hyperspectral images must be of type HSI")

        if not isinstance(baseline, list):
            baseline = [baseline] * len(hsi)

        baseline = torch.stack(
            [
                validate_and_transform_baseline(base, hsi_image).to(hsi_image.device)
                for hsi_image, base in zip(hsi, baseline)
            ],
            dim=0,
        )
        input_tensor = torch.stack(
            [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
        )

        if isinstance(sliding_window_shapes, tuple):
            if len(sliding_window_shapes) != 1:
                raise ValueError("Sliding window shapes must be a single integer or a tuple of a single integer")
            sliding_window_shapes = sliding_window_shapes[0]
        if isinstance(strides, tuple):
            if len(strides) != 1:
                raise ValueError("Strides must be a single integer or a tuple of a single integer")
            strides = strides[0]

        if not isinstance(sliding_window_shapes, int):
            raise TypeError("Sliding window shapes must be a single integer")
        if not isinstance(strides, int):
            raise TypeError("Strides must be a single integer")

        if isinstance(hsi, list):
            full_sliding_window_shapes = list(hsi[0].image.shape)
            full_sliding_window_shapes[hsi[0].spectral_axis] = sliding_window_shapes
            full_strides = list(hsi[0].image.shape)
            full_strides[hsi[0].spectral_axis] = strides
        else:
            full_sliding_window_shapes = list(hsi.image.shape)
            full_sliding_window_shapes[hsi.spectral_axis] = sliding_window_shapes
            full_strides = list(hsi.image.shape)
            full_strides[hsi.spectral_axis] = strides

        sliding_window_shapes = tuple(full_sliding_window_shapes)
        strides = tuple(full_strides)

        assert len(sliding_window_shapes) == len(strides) == 3
        band_mask = [
            self._create_segmentation_mask(hsi_image.image.shape, sliding_window_shapes, strides) for hsi_image in hsi
        ]
        band_names = {str(ui.item()): ui.item() for ui in torch.unique(band_mask[0])}

        occlusion_attributions = self._attribution_method.attribute(
            input_tensor,
            sliding_window_shapes=sliding_window_shapes,
            strides=strides,
            target=target,
            baselines=baseline,
            additional_forward_args=additional_forward_args,
            perturbations_per_eval=min(perturbations_per_eval, len(hsi)),
            show_progress=show_progress,
        )

        try:
            spectral_attributes = [
                HSISpectralAttributes(
                    hsi=hsi_image,
                    attributes=attribution,
                    attribution_method=self.get_name(),
                    mask=mask,
                    band_names=band_names,
                )
                for hsi_image, attribution, mask in zip(hsi, occlusion_attributions, band_mask)
            ]
        except Exception as e:
            raise HSIAttributesError(f"Error in generating Occlusion attributions: {e}") from e

        return spectral_attributes[0] if len(spectral_attributes) == 1 else spectral_attributes

attribute(hsi, target=None, sliding_window_shapes=(1, 1, 1), strides=(1, 1, 1), baseline=None, additional_forward_args=None, perturbations_per_eval=1, show_progress=False)

Method for generating attributions using the Occlusion method.

Parameters:

Name Type Description Default
hsi list[HSI] | HSI

Input hyperspectral image(s) for which the attributions are to be computed. If a list of HSI objects is provided, the attributions are computed for each HSI object in the list. The output will be a list of HSIAttributes objects.

required
target list[int] | int | None

target class index for computing the attributions. If None, methods assume that the output has only one class. If the output has multiple classes, the target index must be provided. For multiple input images, a list of target indices can be provided, one for each image or single target value will be used for all images. Defaults to None.

None
sliding_window_shapes int | tuple[int, int, int]

The shape of the sliding window. If an integer is provided, it will be used for all dimensions. Defaults to (1, 1, 1).

(1, 1, 1)
strides int | tuple[int, int, int]

The stride of the sliding window. Defaults to (1, 1, 1). Simply put, the stride is the number of pixels by which the sliding window is moved in each dimension.

(1, 1, 1)
baseline int | float | Tensor | list[int | float | Tensor]

Baselines define reference value which replaces each feature when occluded is computed and can be provided as: - integer or float representing a constant value used as the baseline for all input pixels. - tensor with the same shape as the input tensor, providing a baseline for each input pixel. if the input is a list of HSI objects, the baseline can be a tensor with the same shape as the input tensor for each HSI object. - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline for each input pixel. If the input is a list of HSI objects, the baseline can be a list of tensors with the same shape as the input tensor for each HSI object. Defaults to None.

None
additional_forward_args Any

If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

None
perturbations_per_eval int

Allows multiple occlusions to be included in one batch (one call to forward_fn). By default, perturbations_per_eval is 1, so each occlusion is processed individually. Each forward pass will contain a maximum of perturbations_per_eval * #examples samples. For DataParallel models, each batch is split among the available devices, so evaluations on each available device contain at most (perturbations_per_eval * #examples) / num_devices samples. When working with multiple examples, the number of perturbations per evaluation should be set to at least the number of examples. Defaults to 1.

1
show_progress bool

If True, displays a progress bar. Defaults to False.

False

Returns:

Name Type Description
HSIAttributes HSIAttributes | list[HSIAttributes]

The computed attributions for the input hyperspectral image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

Raises:

Type Description
RuntimeError

If the explainer is not initialized.

ValueError

If the sliding window shapes or strides are not a tuple of three integers.

HSIAttributesError

If an error occurs during the generation of the attributions.

Example

occlusion = Occlusion(explainable_model) hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68]) attributions = occlusion.attribute(hsi, baseline=0, sliding_window_shapes=(4, 3, 3), strides=(1, 1, 1)) attributions = occlusion.attribute([hsi, hsi], baseline=0, sliding_window_shapes=(4, 3, 3), strides=(1, 2, 2)) len(attributions) 2

Source code in src/meteors/attr/occlusion.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def attribute(
    self,
    hsi: list[HSI] | HSI,
    target: list[int] | int | None = None,
    sliding_window_shapes: int | tuple[int, int, int] = (1, 1, 1),
    strides: int | tuple[int, int, int] = (1, 1, 1),
    baseline: int | float | torch.Tensor | list[int | float | torch.Tensor] = None,
    additional_forward_args: Any = None,
    perturbations_per_eval: int = 1,
    show_progress: bool = False,
) -> HSIAttributes | list[HSIAttributes]:
    """
    Method for generating attributions using the Occlusion method.

    Args:
        hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
            If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
            The output will be a list of HSIAttributes objects.
        target (list[int] | int | None, optional): target class index for computing the attributions. If None,
            methods assume that the output has only one class. If the output has multiple classes, the target index
            must be provided. For multiple input images, a list of target indices can be provided, one for each
            image or single target value will be used for all images. Defaults to None.
        sliding_window_shapes (int | tuple[int, int, int]):
            The shape of the sliding window. If an integer is provided, it will be used for all dimensions.
            Defaults to (1, 1, 1).
        strides (int | tuple[int, int, int], optional): The stride of the sliding window. Defaults to (1, 1, 1).
            Simply put, the stride is the number of pixels by which the sliding window is moved in each dimension.
        baseline (int | float | torch.Tensor | list[int | float | torch.Tensor], optional): Baselines define
            reference value which replaces each feature when occluded is computed and can be provided as:
                - integer or float representing a constant value used as the baseline for all input pixels.
                - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                    if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                    the input tensor for each HSI object.
                - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline
                    for each input pixel. If the input is a list of HSI objects, the baseline can be a list of
                    tensors with the same shape as the input tensor for each HSI object. Defaults to None.
        additional_forward_args (Any, optional): If the forward function requires additional arguments other than
            the inputs for which attributions should not be computed, this argument can be provided.
            It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
            containing multiple additional arguments including tensors or any arbitrary python types.
            These arguments are provided to forward_func in order following the arguments in inputs.
            Note that attributions are not computed with respect to these arguments. Default: None
        perturbations_per_eval (int, optional): Allows multiple occlusions to be included in one batch
            (one call to forward_fn). By default, perturbations_per_eval is 1, so each occlusion is processed
            individually. Each forward pass will contain a maximum of perturbations_per_eval * #examples samples.
            For DataParallel models, each batch is split among the available devices, so evaluations on each
            available device contain at most (perturbations_per_eval * #examples) / num_devices samples. When
            working with multiple examples, the number of perturbations per evaluation should be set to at least
            the number of examples. Defaults to 1.
        show_progress (bool, optional): If True, displays a progress bar. Defaults to False.

    Returns:
        HSIAttributes: The computed attributions for the input hyperspectral image(s). if a list of HSI objects
            is provided, the attributions are computed for each HSI object in the list.

    Raises:
        RuntimeError: If the explainer is not initialized.
        ValueError: If the sliding window shapes or strides are not a tuple of three integers.
        HSIAttributesError: If an error occurs during the generation of the attributions.

    Example:
        >>> occlusion = Occlusion(explainable_model)
        >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
        >>> attributions = occlusion.attribute(hsi, baseline=0, sliding_window_shapes=(4, 3, 3), strides=(1, 1, 1))
        >>> attributions = occlusion.attribute([hsi, hsi], baseline=0, sliding_window_shapes=(4, 3, 3), strides=(1, 2, 2))
        >>> len(attributions)
        2
    """
    if self._attribution_method is None:
        raise RuntimeError("Occlusion explainer is not initialized, INITIALIZATION ERROR")

    if not isinstance(hsi, list):
        hsi = [hsi]

    if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
        raise TypeError("All of the input hyperspectral images must be of type HSI")

    if not isinstance(baseline, list):
        baseline = [baseline] * len(hsi)

    baseline = torch.stack(
        [
            validate_and_transform_baseline(base, hsi_image).to(hsi_image.device)
            for hsi_image, base in zip(hsi, baseline)
        ],
        dim=0,
    )
    input_tensor = torch.stack(
        [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
    )

    if isinstance(sliding_window_shapes, int):
        sliding_window_shapes = (sliding_window_shapes, sliding_window_shapes, sliding_window_shapes)
    if isinstance(strides, int):
        strides = (strides, strides, strides)

    if len(strides) != 3:
        raise ValueError("Strides must be a tuple of three integers")
    if len(sliding_window_shapes) != 3:
        raise ValueError("Sliding window shapes must be a tuple of three integers")

    assert len(sliding_window_shapes) == len(strides) == 3
    occlusion_attributions = self._attribution_method.attribute(
        input_tensor,
        sliding_window_shapes=sliding_window_shapes,
        strides=strides,
        target=target,
        baselines=baseline,
        additional_forward_args=additional_forward_args,
        perturbations_per_eval=min(perturbations_per_eval, len(hsi)),
        show_progress=show_progress,
    )

    try:
        attributes = [
            HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
            for hsi_image, attribution in zip(hsi, occlusion_attributions)
        ]
    except Exception as e:
        raise HSIAttributesError(f"Error in generating Occlusion attributions: {e}") from e

    return attributes[0] if len(attributes) == 1 else attributes

get_spatial_attributes(hsi, target=None, sliding_window_shapes=(1, 1), strides=1, baseline=None, additional_forward_args=None, perturbations_per_eval=1, show_progress=False)

Compute spatial attributions for the input HSI using the Occlusion method. In this case, the sliding window is applied to the spatial dimensions only.

Parameters:

Name Type Description Default
hsi list[HSI] | HSI

Input hyperspectral image(s) for which the attributions are to be computed. If a list of HSI objects is provided, the attributions are computed for each HSI object in the list. The output will be a list of HSIAttributes objects.

required
target list[int] | int | None

target class index for computing the attributions. If None, methods assume that the output has only one class. If the output has multiple classes, the target index must be provided. For multiple input images, a list of target indices can be provided, one for each image or single target value will be used for all images. Defaults to None.

None
sliding_window_shapes int | tuple[int, int]

The shape of the sliding window for spatial dimensions. If an integer is provided, it will be used for both spatial dimensions. Defaults to (1, 1).

(1, 1)
strides int | tuple[int, int]

The stride of the sliding window for spatial dimensions. Defaults to 1. Simply put, the stride is the number of pixels by which the sliding window is moved in each spatial dimension.

1
baseline int | float | Tensor | list[int | float | Tensor]

Baselines define reference value which replaces each feature when occluded is computed and can be provided as: - integer or float representing a constant value used as the baseline for all input pixels. - tensor with the same shape as the input tensor, providing a baseline for each input pixel. if the input is a list of HSI objects, the baseline can be a tensor with the same shape as the input tensor for each HSI object. - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline for each input pixel. If the input is a list of HSI objects, the baseline can be a list of tensors with the same shape as the input tensor for each HSI object. Defaults to None.

None
additional_forward_args Any

If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

None
perturbations_per_eval int

Allows multiple occlusions to be included in one batch (one call to forward_fn). By default, perturbations_per_eval is 1, so each occlusion is processed individually. Each forward pass will contain a maximum of perturbations_per_eval * #examples samples. For DataParallel models, each batch is split among the available devices, so evaluations on each available device contain at most (perturbations_per_eval * #examples) / num_devices samples. When working with multiple examples, the number of perturbations per evaluation should be set to at least the number of examples. Defaults to 1.

1
show_progress bool

If True, displays a progress bar. Defaults to False.

False

Returns:

Type Description
HSISpatialAttributes | list[HSISpatialAttributes]

HSISpatialAttributes | list[HSISpatialAttributes]: The computed attributions for the input hyperspectral image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

Raises:

Type Description
RuntimeError

If the explainer is not initialized.

ValueError

If the sliding window shapes or strides are not a tuple of two integers.

HSIAttributesError

If an error occurs during the generation of the attributions

Example

occlusion = Occlusion(explainable_model) hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68]) attributions = occlusion.get_spatial_attributes(hsi, baseline=0, sliding_window_shapes=(3, 3), strides=(1, 1)) attributions = occlusion.get_spatial_attributes([hsi, hsi], baseline=0, sliding_window_shapes=(3, 3), strides=(2, 2)) len(attributions) 2

Source code in src/meteors/attr/occlusion.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def get_spatial_attributes(
    self,
    hsi: list[HSI] | HSI,
    target: list[int] | int | None = None,
    sliding_window_shapes: int | tuple[int, int] = (1, 1),
    strides: int | tuple[int, int] = 1,
    baseline: int | float | torch.Tensor | list[int | float | torch.Tensor] = None,
    additional_forward_args: Any = None,
    perturbations_per_eval: int = 1,
    show_progress: bool = False,
) -> HSISpatialAttributes | list[HSISpatialAttributes]:
    """Compute spatial attributions for the input HSI using the Occlusion method. In this case, the sliding window
    is applied to the spatial dimensions only.

    Args:
        hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
            If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
            The output will be a list of HSIAttributes objects.
        target (list[int] | int | None, optional): target class index for computing the attributions. If None,
            methods assume that the output has only one class. If the output has multiple classes, the target index
            must be provided. For multiple input images, a list of target indices can be provided, one for each
            image or single target value will be used for all images. Defaults to None.
        sliding_window_shapes (int | tuple[int, int]): The shape of the sliding window for spatial dimensions.
            If an integer is provided, it will be used for both spatial dimensions. Defaults to (1, 1).
        strides (int | tuple[int, int], optional): The stride of the sliding window for spatial dimensions.
            Defaults to 1. Simply put, the stride is the number of pixels by which the sliding window is moved
            in each spatial dimension.
        baseline (int | float | torch.Tensor | list[int | float | torch.Tensor], optional): Baselines define
            reference value which replaces each feature when occluded is computed and can be provided as:
                - integer or float representing a constant value used as the baseline for all input pixels.
                - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                    if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                    the input tensor for each HSI object.
                - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline
                  for each input pixel. If the input is a list of HSI objects, the baseline can be a list of
                  tensors with the same shape as the input tensor for each HSI object. Defaults to None.
        additional_forward_args (Any, optional): If the forward function requires additional arguments other than
            the inputs for which attributions should not be computed, this argument can be provided.
            It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
            containing multiple additional arguments including tensors or any arbitrary python types.
            These arguments are provided to forward_func in order following the arguments in inputs.
            Note that attributions are not computed with respect to these arguments. Default: None
        perturbations_per_eval (int, optional): Allows multiple occlusions to be included in one batch
            (one call to forward_fn). By default, perturbations_per_eval is 1, so each occlusion is processed
            individually. Each forward pass will contain a maximum of perturbations_per_eval * #examples samples.
            For DataParallel models, each batch is split among the available devices, so evaluations on each
            available device contain at most (perturbations_per_eval * #examples) / num_devices samples. When
            working with multiple examples, the number of perturbations per evaluation should be set to at least
            the number of examples. Defaults to 1.
        show_progress (bool, optional): If True, displays a progress bar. Defaults to False.

    Returns:
        HSISpatialAttributes | list[HSISpatialAttributes]: The computed attributions for the input hyperspectral image(s).
            if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

    Raises:
        RuntimeError: If the explainer is not initialized.
        ValueError: If the sliding window shapes or strides are not a tuple of two integers.
        HSIAttributesError: If an error occurs during the generation of the attributions

    Example:
        >>> occlusion = Occlusion(explainable_model)
        >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
        >>> attributions = occlusion.get_spatial_attributes(hsi, baseline=0, sliding_window_shapes=(3, 3), strides=(1, 1))
        >>> attributions = occlusion.get_spatial_attributes([hsi, hsi], baseline=0, sliding_window_shapes=(3, 3), strides=(2, 2))
        >>> len(attributions)
        2
    """
    if self._attribution_method is None:
        raise RuntimeError("Occlusion explainer is not initialized, INITIALIZATION ERROR")

    if not isinstance(hsi, list):
        hsi = [hsi]

    if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
        raise TypeError("All of the input hyperspectral images must be of type HSI")

    if not isinstance(baseline, list):
        baseline = [baseline] * len(hsi)

    baseline = torch.stack(
        [
            validate_and_transform_baseline(base, hsi_image).to(hsi_image.device)
            for hsi_image, base in zip(hsi, baseline)
        ],
        dim=0,
    )
    input_tensor = torch.stack(
        [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
    )

    if isinstance(sliding_window_shapes, int):
        sliding_window_shapes = (sliding_window_shapes, sliding_window_shapes)
    if isinstance(strides, int):
        strides = (strides, strides)

    if len(strides) != 2:
        raise ValueError("Strides must be a tuple of two integers")
    if len(sliding_window_shapes) != 2:
        raise ValueError("Sliding window shapes must be a tuple of two integers")

    list_sliding_window_shapes = list(sliding_window_shapes)
    list_strides = list(strides)
    if isinstance(hsi, list):
        list_sliding_window_shapes.insert(hsi[0].spectral_axis, hsi[0].image.shape[hsi[0].spectral_axis])
        list_strides.insert(hsi[0].spectral_axis, hsi[0].image.shape[hsi[0].spectral_axis])
    else:
        list_sliding_window_shapes.insert(hsi.spectral_axis, hsi.image.shape[hsi.spectral_axis])
        list_strides.insert(hsi.spectral_axis, hsi.image.shape[hsi.spectral_axis])
    sliding_window_shapes = tuple(list_sliding_window_shapes)  # type: ignore
    strides = tuple(list_strides)  # type: ignore

    assert len(sliding_window_shapes) == len(strides) == 3
    segment_mask = [
        self._create_segmentation_mask(hsi_image.image.shape, sliding_window_shapes, strides) for hsi_image in hsi
    ]

    occlusion_attributions = self._attribution_method.attribute(
        input_tensor,
        sliding_window_shapes=sliding_window_shapes,
        strides=strides,
        target=target,
        baselines=baseline,
        additional_forward_args=additional_forward_args,
        perturbations_per_eval=min(perturbations_per_eval, len(hsi)),
        show_progress=show_progress,
    )

    try:
        spatial_attributes = [
            HSISpatialAttributes(
                hsi=hsi_image, attributes=attribution, attribution_method=self.get_name(), mask=mask
            )
            for hsi_image, attribution, mask in zip(hsi, occlusion_attributions, segment_mask)
        ]
    except Exception as e:
        raise HSIAttributesError(f"Error in generating Occlusion attributions: {e}") from e

    return spatial_attributes[0] if len(spatial_attributes) == 1 else spatial_attributes

get_spectral_attributes(hsi, target=None, sliding_window_shapes=1, strides=1, baseline=None, additional_forward_args=None, perturbations_per_eval=1, show_progress=False)

Compute spectral attributions for the input HSI using the Occlusion method. In this case, the sliding window is applied to the spectral dimension only.

Parameters:

Name Type Description Default
hsi list[HSI] | HSI

Input hyperspectral image(s) for which the attributions are to be computed. If a list of HSI objects is provided, the attributions are computed for each HSI object in the list. The output will be a list of HSIAttributes objects.

required
target list[int] | int | None

target class index for computing the attributions. If None, methods assume that the output has only one class. If the output has multiple classes, the target index must be provided. For multiple input images, a list of target indices can be provided, one for each image or single target value will be used for all images. Defaults to None.

None
sliding_window_shapes int | tuple[int]

The size of the sliding window for the spectral dimension. Defaults to 1.

1
strides int | tuple[int]

The stride of the sliding window for the spectral dimension. Defaults to 1. Simply put, the stride is the number of pixels by which the sliding window is moved in spectral dimension.

1
baseline int | float | Tensor | list[int | float | Tensor]

Baselines define reference value which replaces each feature when occluded is computed and can be provided as: - integer or float representing a constant value used as the baseline for all input pixels. - tensor with the same shape as the input tensor, providing a baseline for each input pixel. if the input is a list of HSI objects, the baseline can be a tensor with the same shape as the input tensor for each HSI object. - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline for each input pixel. If the input is a list of HSI objects, the baseline can be a list of tensors with the same shape as the input tensor for each HSI object. Defaults to None.

None
additional_forward_args Any

If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

None
perturbations_per_eval int

Allows multiple occlusions to be included in one batch (one call to forward_fn). By default, perturbations_per_eval is 1, so each occlusion is processed individually. Each forward pass will contain a maximum of perturbations_per_eval * #examples samples. For DataParallel models, each batch is split among the available devices, so evaluations on each available device contain at most (perturbations_per_eval * #examples) / num_devices samples. When working with multiple examples, the number of perturbations per evaluation should be set to at least the number of examples. Defaults to 1.

1
show_progress bool

If True, displays a progress bar. Defaults to False.

False

Returns:

Type Description
HSISpectralAttributes | list[HSISpectralAttributes]

HSISpectralAttributes | list[HSISpectralAttributes]: The computed attributions for the input hyperspectral image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

Raises:

Type Description
RuntimeError

If the explainer is not initialized.

ValueError

If the sliding window shapes or strides are not a tuple of a single integer.

TypeError

If the sliding window shapes or strides are not a single integer.

HSIAttributesError

If an error occurs during the generation of the attributions

Example

occlusion = Occlusion(explainable_model) hsi = HSI(image=torch.ones((10, 240, 240)), wavelengths=torch.arange(10)) attributions = occlusion.get_spectral_attributes(hsi, baseline=0, sliding_window_shapes=3, strides=1) attributions = occlusion.get_spectral_attributes([hsi, hsi], baseline=0, sliding_window_shapes=3, strides=2) len(attributions) 2

Source code in src/meteors/attr/occlusion.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
def get_spectral_attributes(
    self,
    hsi: list[HSI] | HSI,
    target: list[int] | int | None = None,
    sliding_window_shapes: int | tuple[int] = 1,
    strides: int | tuple[int] = 1,
    baseline: int | float | torch.Tensor | list[int | float | torch.Tensor] = None,
    additional_forward_args: Any = None,
    perturbations_per_eval: int = 1,
    show_progress: bool = False,
) -> HSISpectralAttributes | list[HSISpectralAttributes]:
    """Compute spectral attributions for the input HSI using the Occlusion method. In this case, the sliding window
    is applied to the spectral dimension only.

    Args:
        hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
            If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
            The output will be a list of HSIAttributes objects.
        target (list[int] | int | None, optional): target class index for computing the attributions. If None,
            methods assume that the output has only one class. If the output has multiple classes, the target index
            must be provided. For multiple input images, a list of target indices can be provided, one for each
            image or single target value will be used for all images. Defaults to None.
        sliding_window_shapes (int | tuple[int]): The size of the sliding window for the spectral dimension.
            Defaults to 1.
        strides (int | tuple[int], optional): The stride of the sliding window for the spectral dimension.
            Defaults to 1. Simply put, the stride is the number of pixels by which the sliding window is moved
            in spectral dimension.
        baseline (int | float | torch.Tensor | list[int | float | torch.Tensor], optional): Baselines define
            reference value which replaces each feature when occluded is computed and can be provided as:
                - integer or float representing a constant value used as the baseline for all input pixels.
                - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                    if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                    the input tensor for each HSI object.
                - list of integers, floats or tensors with the same shape as the input tensor, providing a baseline
                  for each input pixel. If the input is a list of HSI objects, the baseline can be a list of
                  tensors with the same shape as the input tensor for each HSI object. Defaults to None.
        additional_forward_args (Any, optional): If the forward function requires additional arguments other than
            the inputs for which attributions should not be computed, this argument can be provided.
            It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
            containing multiple additional arguments including tensors or any arbitrary python types.
            These arguments are provided to forward_func in order following the arguments in inputs.
            Note that attributions are not computed with respect to these arguments. Default: None
        perturbations_per_eval (int, optional): Allows multiple occlusions to be included in one batch
            (one call to forward_fn). By default, perturbations_per_eval is 1, so each occlusion is processed
            individually. Each forward pass will contain a maximum of perturbations_per_eval * #examples samples.
            For DataParallel models, each batch is split among the available devices, so evaluations on each
            available device contain at most (perturbations_per_eval * #examples) / num_devices samples. When
            working with multiple examples, the number of perturbations per evaluation should be set to at least
            the number of examples. Defaults to 1.
        show_progress (bool, optional): If True, displays a progress bar. Defaults to False.

    Returns:
        HSISpectralAttributes | list[HSISpectralAttributes]: The computed attributions for the input hyperspectral
            image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in
            the list.

    Raises:
        RuntimeError: If the explainer is not initialized.
        ValueError: If the sliding window shapes or strides are not a tuple of a single integer.
        TypeError: If the sliding window shapes or strides are not a single integer.
        HSIAttributesError: If an error occurs during the generation of the attributions

    Example:
        >>> occlusion = Occlusion(explainable_model)
        >>> hsi = HSI(image=torch.ones((10, 240, 240)), wavelengths=torch.arange(10))
        >>> attributions = occlusion.get_spectral_attributes(hsi, baseline=0, sliding_window_shapes=3, strides=1)
        >>> attributions = occlusion.get_spectral_attributes([hsi, hsi], baseline=0, sliding_window_shapes=3, strides=2)
        >>> len(attributions)
        2
    """
    if self._attribution_method is None:
        raise RuntimeError("Occlusion explainer is not initialized, INITIALIZATION ERROR")

    if not isinstance(hsi, list):
        hsi = [hsi]

    if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
        raise TypeError("All of the input hyperspectral images must be of type HSI")

    if not isinstance(baseline, list):
        baseline = [baseline] * len(hsi)

    baseline = torch.stack(
        [
            validate_and_transform_baseline(base, hsi_image).to(hsi_image.device)
            for hsi_image, base in zip(hsi, baseline)
        ],
        dim=0,
    )
    input_tensor = torch.stack(
        [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
    )

    if isinstance(sliding_window_shapes, tuple):
        if len(sliding_window_shapes) != 1:
            raise ValueError("Sliding window shapes must be a single integer or a tuple of a single integer")
        sliding_window_shapes = sliding_window_shapes[0]
    if isinstance(strides, tuple):
        if len(strides) != 1:
            raise ValueError("Strides must be a single integer or a tuple of a single integer")
        strides = strides[0]

    if not isinstance(sliding_window_shapes, int):
        raise TypeError("Sliding window shapes must be a single integer")
    if not isinstance(strides, int):
        raise TypeError("Strides must be a single integer")

    if isinstance(hsi, list):
        full_sliding_window_shapes = list(hsi[0].image.shape)
        full_sliding_window_shapes[hsi[0].spectral_axis] = sliding_window_shapes
        full_strides = list(hsi[0].image.shape)
        full_strides[hsi[0].spectral_axis] = strides
    else:
        full_sliding_window_shapes = list(hsi.image.shape)
        full_sliding_window_shapes[hsi.spectral_axis] = sliding_window_shapes
        full_strides = list(hsi.image.shape)
        full_strides[hsi.spectral_axis] = strides

    sliding_window_shapes = tuple(full_sliding_window_shapes)
    strides = tuple(full_strides)

    assert len(sliding_window_shapes) == len(strides) == 3
    band_mask = [
        self._create_segmentation_mask(hsi_image.image.shape, sliding_window_shapes, strides) for hsi_image in hsi
    ]
    band_names = {str(ui.item()): ui.item() for ui in torch.unique(band_mask[0])}

    occlusion_attributions = self._attribution_method.attribute(
        input_tensor,
        sliding_window_shapes=sliding_window_shapes,
        strides=strides,
        target=target,
        baselines=baseline,
        additional_forward_args=additional_forward_args,
        perturbations_per_eval=min(perturbations_per_eval, len(hsi)),
        show_progress=show_progress,
    )

    try:
        spectral_attributes = [
            HSISpectralAttributes(
                hsi=hsi_image,
                attributes=attribution,
                attribution_method=self.get_name(),
                mask=mask,
                band_names=band_names,
            )
            for hsi_image, attribution, mask in zip(hsi, occlusion_attributions, band_mask)
        ]
    except Exception as e:
        raise HSIAttributesError(f"Error in generating Occlusion attributions: {e}") from e

    return spectral_attributes[0] if len(spectral_attributes) == 1 else spectral_attributes

Saliency

Bases: Explainer

Saliency explainer class for generating attributions using the Saliency method. This baseline method for computing input attribution calculates gradients with respect to inputs. It also has an option to return the absolute value of the gradients, which is the default behaviour. Implementation of this method is based on the captum repository

Attributes:

Name Type Description
_attribution_method Saliency

The Saliency method from the captum library.

Parameters:

Name Type Description Default
explainable_model ExplainableModel | Explainer

The explainable model to be explained.

required
Source code in src/meteors/attr/saliency.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class Saliency(Explainer):
    """
    Saliency explainer class for generating attributions using the Saliency method.
    This baseline method for computing input attribution calculates gradients with respect to inputs.
    It also has an option to return the absolute value of the gradients, which is the default behaviour.
    Implementation of this method is based on the [`captum` repository](https://captum.ai/api/saliency.html)

    Attributes:
        _attribution_method (CaptumSaliency): The Saliency method from the `captum` library.

    Args:
        explainable_model (ExplainableModel | Explainer): The explainable model to be explained.
    """

    def __init__(self, explainable_model: ExplainableModel):
        super().__init__(explainable_model)

        self._attribution_method = CaptumSaliency(explainable_model.forward_func)

    def attribute(
        self,
        hsi: list[HSI] | HSI,
        target: list[int] | int | None = None,
        abs: bool = True,
        additional_forward_args: Any = None,
    ) -> HSIAttributes | list[HSIAttributes]:
        """
        Method for generating attributions using the Saliency method.

        Args:
            hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
                If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
                The output will be a list of HSIAttributes objects.
            target (list[int] | int | None, optional): target class index for computing the attributions. If None,
                methods assume that the output has only one class. If the output has multiple classes, the target index
                must be provided. For multiple input images, a list of target indices can be provided, one for each
                image or single target value will be used for all images. Defaults to None.
            abs (bool, optional): Returns absolute value of gradients if set to True,
                otherwise returns the (signed) gradients if False. Default: True
            additional_forward_args (Any, optional): If the forward function requires additional arguments other than
                the inputs for which attributions should not be computed, this argument can be provided.
                It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors or any arbitrary python types.
                These arguments are provided to forward_func in order following the arguments in inputs.
                Note that attributions are not computed with respect to these arguments. Default: None

        Returns:
            HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
                if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

        Raises:
            RuntimeError: If the explainer is not initialized.
            HSIAttributesError: If an error occurs during the generation of the attributions

        Examples:
            >>> saliency = Saliency(explainable_model)
            >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
            >>> attributions = saliency.attribute(hsi)
            >>> attributions = saliency.attribute([hsi, hsi])
            >>> len(attributions)
            2
        """
        if self._attribution_method is None:
            raise RuntimeError("Saliency explainer is not initialized, INITIALIZATION ERROR")

        if not isinstance(hsi, list):
            hsi = [hsi]

        if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
            raise TypeError("All of the input hyperspectral images must be of type HSI")

        input_tensor = torch.stack(
            [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
        )

        saliency_attributions = self._attribution_method.attribute(
            input_tensor, target=target, abs=abs, additional_forward_args=additional_forward_args
        )

        try:
            attributes = [
                HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
                for hsi_image, attribution in zip(hsi, saliency_attributions)
            ]
        except Exception as e:
            raise HSIAttributesError(f"Error in generating Saliency attributions: {e}") from e

        return attributes[0] if len(attributes) == 1 else attributes

attribute(hsi, target=None, abs=True, additional_forward_args=None)

Method for generating attributions using the Saliency method.

Parameters:

Name Type Description Default
hsi list[HSI] | HSI

Input hyperspectral image(s) for which the attributions are to be computed. If a list of HSI objects is provided, the attributions are computed for each HSI object in the list. The output will be a list of HSIAttributes objects.

required
target list[int] | int | None

target class index for computing the attributions. If None, methods assume that the output has only one class. If the output has multiple classes, the target index must be provided. For multiple input images, a list of target indices can be provided, one for each image or single target value will be used for all images. Defaults to None.

None
abs bool

Returns absolute value of gradients if set to True, otherwise returns the (signed) gradients if False. Default: True

True
additional_forward_args Any

If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

None

Returns:

Type Description
HSIAttributes | list[HSIAttributes]

HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

Raises:

Type Description
RuntimeError

If the explainer is not initialized.

HSIAttributesError

If an error occurs during the generation of the attributions

Examples:

>>> saliency = Saliency(explainable_model)
>>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
>>> attributions = saliency.attribute(hsi)
>>> attributions = saliency.attribute([hsi, hsi])
>>> len(attributions)
2
Source code in src/meteors/attr/saliency.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def attribute(
    self,
    hsi: list[HSI] | HSI,
    target: list[int] | int | None = None,
    abs: bool = True,
    additional_forward_args: Any = None,
) -> HSIAttributes | list[HSIAttributes]:
    """
    Method for generating attributions using the Saliency method.

    Args:
        hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
            If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
            The output will be a list of HSIAttributes objects.
        target (list[int] | int | None, optional): target class index for computing the attributions. If None,
            methods assume that the output has only one class. If the output has multiple classes, the target index
            must be provided. For multiple input images, a list of target indices can be provided, one for each
            image or single target value will be used for all images. Defaults to None.
        abs (bool, optional): Returns absolute value of gradients if set to True,
            otherwise returns the (signed) gradients if False. Default: True
        additional_forward_args (Any, optional): If the forward function requires additional arguments other than
            the inputs for which attributions should not be computed, this argument can be provided.
            It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
            containing multiple additional arguments including tensors or any arbitrary python types.
            These arguments are provided to forward_func in order following the arguments in inputs.
            Note that attributions are not computed with respect to these arguments. Default: None

    Returns:
        HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
            if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

    Raises:
        RuntimeError: If the explainer is not initialized.
        HSIAttributesError: If an error occurs during the generation of the attributions

    Examples:
        >>> saliency = Saliency(explainable_model)
        >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
        >>> attributions = saliency.attribute(hsi)
        >>> attributions = saliency.attribute([hsi, hsi])
        >>> len(attributions)
        2
    """
    if self._attribution_method is None:
        raise RuntimeError("Saliency explainer is not initialized, INITIALIZATION ERROR")

    if not isinstance(hsi, list):
        hsi = [hsi]

    if not all(isinstance(hsi_image, HSI) for hsi_image in hsi):
        raise TypeError("All of the input hyperspectral images must be of type HSI")

    input_tensor = torch.stack(
        [hsi_image.get_image().requires_grad_(True).to(hsi_image.device) for hsi_image in hsi], dim=0
    )

    saliency_attributions = self._attribution_method.attribute(
        input_tensor, target=target, abs=abs, additional_forward_args=additional_forward_args
    )

    try:
        attributes = [
            HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
            for hsi_image, attribution in zip(hsi, saliency_attributions)
        ]
    except Exception as e:
        raise HSIAttributesError(f"Error in generating Saliency attributions: {e}") from e

    return attributes[0] if len(attributes) == 1 else attributes

NoiseTunnel

Bases: BaseNoiseTunnel

Noise Tunnel is a method that is used to explain the model's predictions by adding noise to the input tensor. The noise is added to the input tensor, and the model's output is computed. The process is repeated multiple times to obtain a distribution of the model's output. The final attribution is computed as the mean of the outputs. For more information about the method, see captum documentation.

Parameters:

Name Type Description Default
chained_explainer

The explainable method that will be used to compute the attributions.

required

Raises:

Type Description
RuntimeError

If the callable object is not an instance of the Explainer class

Source code in src/meteors/attr/noise_tunnel.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
class NoiseTunnel(BaseNoiseTunnel):
    """Noise Tunnel is a method that is used to explain the model's predictions by adding noise to the input tensor.
    The noise is added to the input tensor, and the model's output is computed. The process is repeated multiple times
    to obtain a distribution of the model's output. The final attribution is computed as the mean of the outputs.
    For more information about the method, see [`captum` documentation](https://captum.ai/api/noise_tunnel.html).

    Arguments:
        chained_explainer: The explainable method that will be used to compute the attributions.

    Raises:
        RuntimeError: If the callable object is not an instance of the Explainer class
    """

    @staticmethod
    def perturb_input(
        input: torch.Tensor,
        n_samples: int = 1,
        perturbation_axis: None | tuple[int | slice] = None,
        stdevs: float = 1,
        **kwargs: Any,
    ) -> torch.Tensor:
        """
        The default perturbation function used in the noise tunnel with small enhancement for hyperspectral images.
        It randomly adds noise to the input tensor from a normal distribution with a given standard deviation.
        The noise is added to the selected bands (channels) of the input tensor.
        The bands to be perturbed are selected based on the `perturbation_axis` parameter.
        By default all bands are perturbed, which is equivalent to the standard noise tunnel method.

        Args:
            input (torch.Tensor): An input tensor to be perturbed. It should have the shape (C, H, W).
            n_samples (int): A number of samples to be drawn - number of perturbed inputs to be generated.
            perturbation_axis (None | tuple[int | slice]): The indices of the bands to be perturbed.
                If set to None, all bands are perturbed. Defaults to None.
            stdevs (float): The standard deviation of gaussian noise with zero mean that is added to each input
                in the batch. Defaults to 1.0.

        Returns:
            torch.Tensor: A perturbed tensor, which contains `n_samples` perturbed inputs.
        """
        if n_samples < 1:
            raise ValueError("Number of perturbated samples to be generated must be greater than 0")

        # the perturbation
        perturbed_input = input.clone().unsqueeze(0)
        # repeat the perturbed_input on the first dimension n_samples times
        perturbed_input = perturbed_input.repeat_interleave(n_samples, dim=0)

        # the perturbation shape
        if perturbation_axis is None:
            perturbation_shape = perturbed_input.shape
        else:
            perturbation_axis = (slice(None),) + perturbation_axis  # type: ignore
            perturbation_shape = perturbed_input[perturbation_axis].shape

        # the noise
        noise = torch.normal(0, stdevs, size=perturbation_shape).to(input.device)

        # add the noise to the perturbed_input
        if perturbation_axis is None:
            perturbed_input += noise
        else:
            perturbed_input[perturbation_axis] += noise

        perturbed_input.requires_grad_(True)

        return perturbed_input

    def attribute(
        self,
        hsi: list[HSI] | HSI,
        target: list[int] | int | None = None,
        additional_forward_args: Any = None,
        n_samples: int = 5,
        steps_per_batch: int = 1,
        perturbation_axis: None | tuple[int | slice] = None,
        stdevs: float | tuple[float, ...] = 1.0,
        method: Literal["smoothgrad", "smoothgrad_sq", "vargrad"] = "smoothgrad",
    ) -> HSIAttributes | list[HSIAttributes]:
        """
        Method for generating attributions using the Noise Tunnel method.

        Args:
            hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
                If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
                The output will be a list of HSIAttributes objects.
            baseline (int | float | torch.Tensor, optional): Baselines define reference value which replaces each
                feature when occluded is computed and can be provided as:
                    - integer or float representing a constant value used as the baseline for all input pixels.
                    - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                        if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                        the input tensor for each HSI object.
            target (list[int] | int | None, optional): target class index for computing the attributions. If None,
                methods assume that the output has only one class. If the output has multiple classes, the target index
                must be provided. For multiple input images, a list of target indices can be provided, one for each
                image or single target value will be used for all images. Defaults to None.
            additional_forward_args (Any, optional): If the forward function requires additional arguments other than
                the inputs for which attributions should not be computed, this argument can be provided.
                It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors or any arbitrary python types.
                These arguments are provided to forward_func in order following the arguments in inputs.
                Note that attributions are not computed with respect to these arguments. Default: None
            n_samples (int, optional): The number of randomly generated examples per sample in the input batch.
                Random examples are generated by adding gaussian random noise to each sample.
                Default: 5 if nt_samples is not provided.
            steps_per_batch (int, optional): The number of the n_samples that will be processed together.
                With the help of this parameter we can avoid out of memory situation and reduce the number of randomly
                generated examples per sample in each batch. Default: None if steps_per_batch is not provided.
                In this case all nt_samples will be processed together.
            perturbation_axis (None | tuple[int | slice], optional): The indices of the input image to be perturbed.
                If set to None, all bands are perturbed, which corresponds to a traditional noise tunnel method.
                Defaults to None.
            stdevs (float | tuple[float, ...], optional): The standard deviation of gaussian noise with zero mean that
                is added to each input in the batch. If stdevs is a single float value then that same value is used
                for all inputs. If stdevs is a tuple, then the length of the tuple must match the number of inputs as
                each value in the tuple is used for the corresponding input. Default: 1.0
            method (Literal["smoothgrad", "smoothgrad_sq", "vargrad"], optional): Smoothing type of the attributions.
                smoothgrad, smoothgrad_sq or vargrad Default: smoothgrad if type is not provided.

        Returns:
            HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
                if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

        Raises:
            HSIAttributesError: If an error occurs during the generation of the attributions.

        Examples:
            >>> noise_tunnel = NoiseTunnel(explainable_model)
            >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
            >>> attributions = noise_tunnel.attribute(hsi)
            >>> attributions = noise_tunnel.attribute([hsi, hsi])
            >>> len(attributions)
            2
        """
        if isinstance(stdevs, list):
            stdevs = tuple(stdevs)

        if not isinstance(hsi, list):
            hsi = [hsi]

        if not all([isinstance(input, HSI) for input in hsi]):
            raise TypeError("All inputs must be HSI objects")

        if isinstance(stdevs, tuple):
            if len(stdevs) != len(hsi):
                raise ValueError(
                    "The number of stdevs must match the number of input images, number of stdevs:"
                    f"{len(stdevs)}, number of input images: {len(hsi)}"
                )
        else:
            stdevs = tuple([stdevs] * len(hsi))

        if not isinstance(target, list):
            target = [target] * len(hsi)  # type: ignore

        nt_attributes = torch.empty((n_samples, len(hsi)) + hsi[0].image.shape, device=hsi[0].device)

        for batch in range(0, len(hsi)):
            input = hsi[batch]
            targeted = target[batch]
            stdev = stdevs[batch]
            perturbed_input = self.perturb_input(input.image, n_samples, perturbation_axis, stdev)
            nt_attributes[:, batch] = self._forward_loop(
                perturbed_input, input, targeted, additional_forward_args, n_samples, steps_per_batch
            )

        nt_attributes = self._aggregate_attributions(nt_attributes, method)

        try:
            attributes = [
                HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
                for hsi_image, attribution in zip(hsi, nt_attributes)
            ]
        except Exception as e:
            raise HSIAttributesError(f"Error in generating NoiseTunnel attributions: {e}") from e

        return attributes[0] if len(attributes) == 1 else attributes

attribute(hsi, target=None, additional_forward_args=None, n_samples=5, steps_per_batch=1, perturbation_axis=None, stdevs=1.0, method='smoothgrad')

Method for generating attributions using the Noise Tunnel method.

Parameters:

Name Type Description Default
hsi list[HSI] | HSI

Input hyperspectral image(s) for which the attributions are to be computed. If a list of HSI objects is provided, the attributions are computed for each HSI object in the list. The output will be a list of HSIAttributes objects.

required
baseline int | float | Tensor

Baselines define reference value which replaces each feature when occluded is computed and can be provided as: - integer or float representing a constant value used as the baseline for all input pixels. - tensor with the same shape as the input tensor, providing a baseline for each input pixel. if the input is a list of HSI objects, the baseline can be a tensor with the same shape as the input tensor for each HSI object.

required
target list[int] | int | None

target class index for computing the attributions. If None, methods assume that the output has only one class. If the output has multiple classes, the target index must be provided. For multiple input images, a list of target indices can be provided, one for each image or single target value will be used for all images. Defaults to None.

None
additional_forward_args Any

If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

None
n_samples int

The number of randomly generated examples per sample in the input batch. Random examples are generated by adding gaussian random noise to each sample. Default: 5 if nt_samples is not provided.

5
steps_per_batch int

The number of the n_samples that will be processed together. With the help of this parameter we can avoid out of memory situation and reduce the number of randomly generated examples per sample in each batch. Default: None if steps_per_batch is not provided. In this case all nt_samples will be processed together.

1
perturbation_axis None | tuple[int | slice]

The indices of the input image to be perturbed. If set to None, all bands are perturbed, which corresponds to a traditional noise tunnel method. Defaults to None.

None
stdevs float | tuple[float, ...]

The standard deviation of gaussian noise with zero mean that is added to each input in the batch. If stdevs is a single float value then that same value is used for all inputs. If stdevs is a tuple, then the length of the tuple must match the number of inputs as each value in the tuple is used for the corresponding input. Default: 1.0

1.0
method Literal['smoothgrad', 'smoothgrad_sq', 'vargrad']

Smoothing type of the attributions. smoothgrad, smoothgrad_sq or vargrad Default: smoothgrad if type is not provided.

'smoothgrad'

Returns:

Type Description
HSIAttributes | list[HSIAttributes]

HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

Raises:

Type Description
HSIAttributesError

If an error occurs during the generation of the attributions.

Examples:

>>> noise_tunnel = NoiseTunnel(explainable_model)
>>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
>>> attributions = noise_tunnel.attribute(hsi)
>>> attributions = noise_tunnel.attribute([hsi, hsi])
>>> len(attributions)
2
Source code in src/meteors/attr/noise_tunnel.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def attribute(
    self,
    hsi: list[HSI] | HSI,
    target: list[int] | int | None = None,
    additional_forward_args: Any = None,
    n_samples: int = 5,
    steps_per_batch: int = 1,
    perturbation_axis: None | tuple[int | slice] = None,
    stdevs: float | tuple[float, ...] = 1.0,
    method: Literal["smoothgrad", "smoothgrad_sq", "vargrad"] = "smoothgrad",
) -> HSIAttributes | list[HSIAttributes]:
    """
    Method for generating attributions using the Noise Tunnel method.

    Args:
        hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
            If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
            The output will be a list of HSIAttributes objects.
        baseline (int | float | torch.Tensor, optional): Baselines define reference value which replaces each
            feature when occluded is computed and can be provided as:
                - integer or float representing a constant value used as the baseline for all input pixels.
                - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                    if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                    the input tensor for each HSI object.
        target (list[int] | int | None, optional): target class index for computing the attributions. If None,
            methods assume that the output has only one class. If the output has multiple classes, the target index
            must be provided. For multiple input images, a list of target indices can be provided, one for each
            image or single target value will be used for all images. Defaults to None.
        additional_forward_args (Any, optional): If the forward function requires additional arguments other than
            the inputs for which attributions should not be computed, this argument can be provided.
            It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
            containing multiple additional arguments including tensors or any arbitrary python types.
            These arguments are provided to forward_func in order following the arguments in inputs.
            Note that attributions are not computed with respect to these arguments. Default: None
        n_samples (int, optional): The number of randomly generated examples per sample in the input batch.
            Random examples are generated by adding gaussian random noise to each sample.
            Default: 5 if nt_samples is not provided.
        steps_per_batch (int, optional): The number of the n_samples that will be processed together.
            With the help of this parameter we can avoid out of memory situation and reduce the number of randomly
            generated examples per sample in each batch. Default: None if steps_per_batch is not provided.
            In this case all nt_samples will be processed together.
        perturbation_axis (None | tuple[int | slice], optional): The indices of the input image to be perturbed.
            If set to None, all bands are perturbed, which corresponds to a traditional noise tunnel method.
            Defaults to None.
        stdevs (float | tuple[float, ...], optional): The standard deviation of gaussian noise with zero mean that
            is added to each input in the batch. If stdevs is a single float value then that same value is used
            for all inputs. If stdevs is a tuple, then the length of the tuple must match the number of inputs as
            each value in the tuple is used for the corresponding input. Default: 1.0
        method (Literal["smoothgrad", "smoothgrad_sq", "vargrad"], optional): Smoothing type of the attributions.
            smoothgrad, smoothgrad_sq or vargrad Default: smoothgrad if type is not provided.

    Returns:
        HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
            if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

    Raises:
        HSIAttributesError: If an error occurs during the generation of the attributions.

    Examples:
        >>> noise_tunnel = NoiseTunnel(explainable_model)
        >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
        >>> attributions = noise_tunnel.attribute(hsi)
        >>> attributions = noise_tunnel.attribute([hsi, hsi])
        >>> len(attributions)
        2
    """
    if isinstance(stdevs, list):
        stdevs = tuple(stdevs)

    if not isinstance(hsi, list):
        hsi = [hsi]

    if not all([isinstance(input, HSI) for input in hsi]):
        raise TypeError("All inputs must be HSI objects")

    if isinstance(stdevs, tuple):
        if len(stdevs) != len(hsi):
            raise ValueError(
                "The number of stdevs must match the number of input images, number of stdevs:"
                f"{len(stdevs)}, number of input images: {len(hsi)}"
            )
    else:
        stdevs = tuple([stdevs] * len(hsi))

    if not isinstance(target, list):
        target = [target] * len(hsi)  # type: ignore

    nt_attributes = torch.empty((n_samples, len(hsi)) + hsi[0].image.shape, device=hsi[0].device)

    for batch in range(0, len(hsi)):
        input = hsi[batch]
        targeted = target[batch]
        stdev = stdevs[batch]
        perturbed_input = self.perturb_input(input.image, n_samples, perturbation_axis, stdev)
        nt_attributes[:, batch] = self._forward_loop(
            perturbed_input, input, targeted, additional_forward_args, n_samples, steps_per_batch
        )

    nt_attributes = self._aggregate_attributions(nt_attributes, method)

    try:
        attributes = [
            HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
            for hsi_image, attribution in zip(hsi, nt_attributes)
        ]
    except Exception as e:
        raise HSIAttributesError(f"Error in generating NoiseTunnel attributions: {e}") from e

    return attributes[0] if len(attributes) == 1 else attributes

perturb_input(input, n_samples=1, perturbation_axis=None, stdevs=1, **kwargs) staticmethod

The default perturbation function used in the noise tunnel with small enhancement for hyperspectral images. It randomly adds noise to the input tensor from a normal distribution with a given standard deviation. The noise is added to the selected bands (channels) of the input tensor. The bands to be perturbed are selected based on the perturbation_axis parameter. By default all bands are perturbed, which is equivalent to the standard noise tunnel method.

Parameters:

Name Type Description Default
input Tensor

An input tensor to be perturbed. It should have the shape (C, H, W).

required
n_samples int

A number of samples to be drawn - number of perturbed inputs to be generated.

1
perturbation_axis None | tuple[int | slice]

The indices of the bands to be perturbed. If set to None, all bands are perturbed. Defaults to None.

None
stdevs float

The standard deviation of gaussian noise with zero mean that is added to each input in the batch. Defaults to 1.0.

1

Returns:

Type Description
Tensor

torch.Tensor: A perturbed tensor, which contains n_samples perturbed inputs.

Source code in src/meteors/attr/noise_tunnel.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@staticmethod
def perturb_input(
    input: torch.Tensor,
    n_samples: int = 1,
    perturbation_axis: None | tuple[int | slice] = None,
    stdevs: float = 1,
    **kwargs: Any,
) -> torch.Tensor:
    """
    The default perturbation function used in the noise tunnel with small enhancement for hyperspectral images.
    It randomly adds noise to the input tensor from a normal distribution with a given standard deviation.
    The noise is added to the selected bands (channels) of the input tensor.
    The bands to be perturbed are selected based on the `perturbation_axis` parameter.
    By default all bands are perturbed, which is equivalent to the standard noise tunnel method.

    Args:
        input (torch.Tensor): An input tensor to be perturbed. It should have the shape (C, H, W).
        n_samples (int): A number of samples to be drawn - number of perturbed inputs to be generated.
        perturbation_axis (None | tuple[int | slice]): The indices of the bands to be perturbed.
            If set to None, all bands are perturbed. Defaults to None.
        stdevs (float): The standard deviation of gaussian noise with zero mean that is added to each input
            in the batch. Defaults to 1.0.

    Returns:
        torch.Tensor: A perturbed tensor, which contains `n_samples` perturbed inputs.
    """
    if n_samples < 1:
        raise ValueError("Number of perturbated samples to be generated must be greater than 0")

    # the perturbation
    perturbed_input = input.clone().unsqueeze(0)
    # repeat the perturbed_input on the first dimension n_samples times
    perturbed_input = perturbed_input.repeat_interleave(n_samples, dim=0)

    # the perturbation shape
    if perturbation_axis is None:
        perturbation_shape = perturbed_input.shape
    else:
        perturbation_axis = (slice(None),) + perturbation_axis  # type: ignore
        perturbation_shape = perturbed_input[perturbation_axis].shape

    # the noise
    noise = torch.normal(0, stdevs, size=perturbation_shape).to(input.device)

    # add the noise to the perturbed_input
    if perturbation_axis is None:
        perturbed_input += noise
    else:
        perturbed_input[perturbation_axis] += noise

    perturbed_input.requires_grad_(True)

    return perturbed_input

HyperNoiseTunnel

Bases: BaseNoiseTunnel

Hyper Noise Tunnel is our novel method, designed specifically to explain hyperspectral satellite images. It is inspired by the behaviour of the classical Noise Tunnel (Smooth Grad) method, but instead of sampling noise into the original image, it randomly masks some of the bands with the baseline. In the process, the created noised samples are close to the distribution of the original image yet differ enough to smoothen the produced attribution map.

Parameters:

Name Type Description Default
chained_explainer

The explainable method that will be used to compute the attributions.

required

Raises:

Type Description
RuntimeError

If the callable object is not an instance of the Explainer class

Source code in src/meteors/attr/noise_tunnel.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
class HyperNoiseTunnel(BaseNoiseTunnel):
    """Hyper Noise Tunnel is our novel method, designed specifically to explain hyperspectral satellite images. It is
    inspired by the behaviour of the classical Noise Tunnel (Smooth Grad) method, but instead of sampling noise into the
    original image, it randomly masks some of the bands with the baseline. In the process, the created _noised_ samples
    are close to the distribution of the original image yet differ enough to smoothen the produced attribution map.

    Arguments:
        chained_explainer: The explainable method that will be used to compute the attributions.

    Raises:
        RuntimeError: If the callable object is not an instance of the Explainer class
    """

    @staticmethod
    def perturb_input(
        input: torch.Tensor,
        baseline: torch.Tensor | None = None,
        n_samples: int = 1,
        perturbation_prob: float = 0.5,
        num_perturbed_bands: int | None = None,
        **kwargs: Any,
    ) -> torch.Tensor:
        """The perturbation function used in the hyper noise tunnel. It randomly selects a subset of the input bands
        that will be masked out and replaced with the baseline. The parameters `num_perturbed_bands` and
        `perturbation_prob` control the number of bands that will be perturbed (masked). If `num_perturbed_bands` is
        set, it will be used as the number of bands to perturb, which will be randomly selected. Otherwise, the number
        of bands will be drawn from a binomial distribution with `perturbation_prob` as the probability of success.

        Args:
            input (torch.Tensor): An input tensor to be perturbed. It should have the shape (C, H, W).
            baseline (torch.Tensor | None, optional): A tensor that will be used to replace the perturbed bands.
            n_samples (int): A number of samples to be drawn - number of perturbed inputs to be generated.
            perturbation_prob (float, optional): A probability that each band will be perturbed intependently.
                Defaults to 0.5.
            num_perturbed_bands (int | None, optional): A number of perturbed bands in the whole image.
                If set to None, the bands are perturbed with probability `perturbation_prob` each. Defaults to None.

        Returns:
            torch.Tensor: A perturbed tensor, which contains `n_samples` perturbed inputs.
        """
        # validate the baseline against the input
        if baseline is None:
            raise ValueError("Baseline must be provided for the HyperNoiseTunnel method")

        if baseline.shape != input.shape:
            raise ShapeMismatchError(f"Baseline shape {baseline.shape} does not match input shape {input.shape}")

        if n_samples < 1:
            raise ValueError("Number of perturbated samples to be generated must be greater than 0")

        if perturbation_prob < 0 or perturbation_prob > 1:
            raise ValueError("Perturbation probability must be in the range [0, 1]")

        # the perturbation
        perturbed_input = input.clone().unsqueeze(0)
        # repeat the perturbed_input on the first dimension n_samples times
        perturbed_input = perturbed_input.repeat_interleave(n_samples, dim=0)

        n_samples_x_channels_shape = (
            n_samples,
            input.shape[0],
        )  # shape of the tensor containing the perturbed channels for each sample

        channels_to_be_perturbed: torch.Tensor = torch.zeros(n_samples_x_channels_shape, device=input.device).bool()

        if num_perturbed_bands is None:
            channel_perturbation_probabilities = (
                torch.ones(n_samples_x_channels_shape, device=input.device) * perturbation_prob
            )
            channels_to_be_perturbed = torch.bernoulli(channel_perturbation_probabilities).bool()

        else:
            if num_perturbed_bands < 0 or num_perturbed_bands > input.shape[0]:
                raise ValueError(
                    f"Cannot perturb {num_perturbed_bands} bands in the input with {input.shape[0]} channels. The number of perturbed bands must be in the range [0, {input.shape[0]}]"
                )

            channels_to_be_perturbed = torch_random_choice(
                input.shape[0], num_perturbed_bands, n_samples, device=input.device
            )

        # now having chosen the perturbed channels, we can replace them with the baseline
        reshaped_baseline = baseline.unsqueeze(0).repeat_interleave(n_samples, dim=0)
        perturbed_input[channels_to_be_perturbed] = reshaped_baseline[channels_to_be_perturbed]

        perturbed_input.requires_grad_(True)

        return perturbed_input

    def attribute(
        self,
        hsi: list[HSI] | HSI,
        baseline: int | float | torch.Tensor | list[int | float | torch.Tensor] | None = None,
        target: list[int] | int | None = None,
        additional_forward_args: Any = None,
        n_samples: int = 5,
        steps_per_batch: int = 1,
        perturbation_prob: float = 0.5,
        num_perturbed_bands: int | None = None,
        method: Literal["smoothgrad", "smoothgrad_sq", "vargrad"] = "smoothgrad",
    ) -> HSIAttributes | list[HSIAttributes]:
        """
        Method for generating attributions using the Hyper Noise Tunnel method.

        Args:
            hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
                If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
                The output will be a list of HSIAttributes objects.
            baseline (int | float | torch.Tensor | list[int | float | torch.Tensor], optional): Baselines define reference value which
                replaces each feature when occluded is computed and can be provided as:
                    - integer or float representing a constant value used as the baseline for all input pixels.
                    - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                        if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                        the input tensor for each HSI object or a list of tensors with the same length as the input list.
            target (list[int] | int | None, optional): target class index for computing the attributions. If None,
                methods assume that the output has only one class. If the output has multiple classes, the target index
                must be provided. For multiple input images, a list of target indices can be provided, one for each
                image or single target value will be used for all images. Defaults to None.
            additional_forward_args (Any, optional): If the forward function requires additional arguments other than
                the inputs for which attributions should not be computed, this argument can be provided.
                It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors or any arbitrary python types.
                These arguments are provided to forward_func in order following the arguments in inputs.
                Note that attributions are not computed with respect to these arguments. Default: None
            n_samples (int, optional):The number of randomly generated examples per sample in the input batch.
                Random examples are generated by adding gaussian random noise to each sample.
                Default: 5 if nt_samples is not provided.
            steps_per_batch (int, optional): The number of the n_samples that will be processed together.
                With the help of this parameter we can avoid out of memory situation and reduce the number of randomly
                generated examples per sample in each batch. Default: None if steps_per_batch is not provided.
                In this case all nt_samples will be processed together.
            perturbation_prob (float, optional): The probability that each band will be perturbed independently.
                Defaults to 0.5.
            num_perturbed_bands (int | None, optional): The number of perturbed bands in the whole image.
                The bands to be perturbed are selected randomly with no replacement.
                If set to None, the bands are perturbed with probability `perturbation_prob` each. Defaults to None.
            method (Literal["smoothgrad", "smoothgrad_sq", "vargrad"], optional): Smoothing type of the attributions.
                smoothgrad, smoothgrad_sq or vargrad Default: smoothgrad if type is not provided.

        Returns:
            HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
                if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

        Raises:
            HSIAttributesError: If an error occurs during the generation of the attributions.

        Examples:
            >>> hyper_noise_tunnel = HyperNoiseTunnel(explainable_model)
            >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
            >>> attributions = hyper_noise_tunnel.attribute(hsi)
            >>> attributions = hyper_noise_tunnel.attribute([hsi, hsi])
            >>> len(attributions)
            2
        """
        change_orientation = []
        original_orientation = []

        if not isinstance(hsi, list):
            hsi = [hsi]

        if not all([isinstance(input, HSI) for input in hsi]):
            raise TypeError("All inputs must be HSI objects")

        for i in range(len(hsi)):
            if hsi[i].orientation != ("C", "H", "W"):
                change_orientation.append(True)
                original_orientation.append(hsi[i].orientation)
                hsi[i] = hsi[i].change_orientation("CHW")
            else:
                change_orientation.append(False)

        if not isinstance(baseline, list):
            baseline = [baseline] * len(hsi)
        elif len(baseline) != len(hsi):
            raise ValueError("The number of baseline must match the number of input images")

        baseline = [validate_and_transform_baseline(base, hsi_image) for base, hsi_image in zip(baseline, hsi)]

        if not isinstance(target, list):
            target = [target] * len(hsi)  # type: ignore

        hnt_attributes = torch.empty((n_samples, len(hsi)) + hsi[0].image.shape, device=hsi[0].device)
        for batch in range(0, len(hsi)):
            input = hsi[batch]
            targeted = target[batch]
            base = baseline[batch]
            perturbed_input = self.perturb_input(input.image, base, n_samples, perturbation_prob, num_perturbed_bands)

            hnt_attributes[:, batch] = self._forward_loop(
                perturbed_input, input, targeted, additional_forward_args, n_samples, steps_per_batch
            )

        hnt_attributes = self._aggregate_attributions(hnt_attributes, method)

        try:
            attributes = [
                HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
                for hsi_image, attribution in zip(hsi, hnt_attributes)
            ]
        except Exception as e:
            raise HSIAttributesError(f"Error in generating HyperNoiseTunnel attributions: {e}") from e

        for i in range(len(change_orientation)):
            if change_orientation[i]:
                attributes[i].hsi = attributes[i].hsi.change_orientation(original_orientation[i])

        return attributes[0] if len(attributes) == 1 else attributes

attribute(hsi, baseline=None, target=None, additional_forward_args=None, n_samples=5, steps_per_batch=1, perturbation_prob=0.5, num_perturbed_bands=None, method='smoothgrad')

Method for generating attributions using the Hyper Noise Tunnel method.

Parameters:

Name Type Description Default
hsi list[HSI] | HSI

Input hyperspectral image(s) for which the attributions are to be computed. If a list of HSI objects is provided, the attributions are computed for each HSI object in the list. The output will be a list of HSIAttributes objects.

required
baseline int | float | Tensor | list[int | float | Tensor]

Baselines define reference value which replaces each feature when occluded is computed and can be provided as: - integer or float representing a constant value used as the baseline for all input pixels. - tensor with the same shape as the input tensor, providing a baseline for each input pixel. if the input is a list of HSI objects, the baseline can be a tensor with the same shape as the input tensor for each HSI object or a list of tensors with the same length as the input list.

None
target list[int] | int | None

target class index for computing the attributions. If None, methods assume that the output has only one class. If the output has multiple classes, the target index must be provided. For multiple input images, a list of target indices can be provided, one for each image or single target value will be used for all images. Defaults to None.

None
additional_forward_args Any

If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

None
n_samples int

The number of randomly generated examples per sample in the input batch. Random examples are generated by adding gaussian random noise to each sample. Default: 5 if nt_samples is not provided.

5
steps_per_batch int

The number of the n_samples that will be processed together. With the help of this parameter we can avoid out of memory situation and reduce the number of randomly generated examples per sample in each batch. Default: None if steps_per_batch is not provided. In this case all nt_samples will be processed together.

1
perturbation_prob float

The probability that each band will be perturbed independently. Defaults to 0.5.

0.5
num_perturbed_bands int | None

The number of perturbed bands in the whole image. The bands to be perturbed are selected randomly with no replacement. If set to None, the bands are perturbed with probability perturbation_prob each. Defaults to None.

None
method Literal['smoothgrad', 'smoothgrad_sq', 'vargrad']

Smoothing type of the attributions. smoothgrad, smoothgrad_sq or vargrad Default: smoothgrad if type is not provided.

'smoothgrad'

Returns:

Type Description
HSIAttributes | list[HSIAttributes]

HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

Raises:

Type Description
HSIAttributesError

If an error occurs during the generation of the attributions.

Examples:

>>> hyper_noise_tunnel = HyperNoiseTunnel(explainable_model)
>>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
>>> attributions = hyper_noise_tunnel.attribute(hsi)
>>> attributions = hyper_noise_tunnel.attribute([hsi, hsi])
>>> len(attributions)
2
Source code in src/meteors/attr/noise_tunnel.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
def attribute(
    self,
    hsi: list[HSI] | HSI,
    baseline: int | float | torch.Tensor | list[int | float | torch.Tensor] | None = None,
    target: list[int] | int | None = None,
    additional_forward_args: Any = None,
    n_samples: int = 5,
    steps_per_batch: int = 1,
    perturbation_prob: float = 0.5,
    num_perturbed_bands: int | None = None,
    method: Literal["smoothgrad", "smoothgrad_sq", "vargrad"] = "smoothgrad",
) -> HSIAttributes | list[HSIAttributes]:
    """
    Method for generating attributions using the Hyper Noise Tunnel method.

    Args:
        hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
            If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
            The output will be a list of HSIAttributes objects.
        baseline (int | float | torch.Tensor | list[int | float | torch.Tensor], optional): Baselines define reference value which
            replaces each feature when occluded is computed and can be provided as:
                - integer or float representing a constant value used as the baseline for all input pixels.
                - tensor with the same shape as the input tensor, providing a baseline for each input pixel.
                    if the input is a list of HSI objects, the baseline can be a tensor with the same shape as
                    the input tensor for each HSI object or a list of tensors with the same length as the input list.
        target (list[int] | int | None, optional): target class index for computing the attributions. If None,
            methods assume that the output has only one class. If the output has multiple classes, the target index
            must be provided. For multiple input images, a list of target indices can be provided, one for each
            image or single target value will be used for all images. Defaults to None.
        additional_forward_args (Any, optional): If the forward function requires additional arguments other than
            the inputs for which attributions should not be computed, this argument can be provided.
            It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple
            containing multiple additional arguments including tensors or any arbitrary python types.
            These arguments are provided to forward_func in order following the arguments in inputs.
            Note that attributions are not computed with respect to these arguments. Default: None
        n_samples (int, optional):The number of randomly generated examples per sample in the input batch.
            Random examples are generated by adding gaussian random noise to each sample.
            Default: 5 if nt_samples is not provided.
        steps_per_batch (int, optional): The number of the n_samples that will be processed together.
            With the help of this parameter we can avoid out of memory situation and reduce the number of randomly
            generated examples per sample in each batch. Default: None if steps_per_batch is not provided.
            In this case all nt_samples will be processed together.
        perturbation_prob (float, optional): The probability that each band will be perturbed independently.
            Defaults to 0.5.
        num_perturbed_bands (int | None, optional): The number of perturbed bands in the whole image.
            The bands to be perturbed are selected randomly with no replacement.
            If set to None, the bands are perturbed with probability `perturbation_prob` each. Defaults to None.
        method (Literal["smoothgrad", "smoothgrad_sq", "vargrad"], optional): Smoothing type of the attributions.
            smoothgrad, smoothgrad_sq or vargrad Default: smoothgrad if type is not provided.

    Returns:
        HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
            if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.

    Raises:
        HSIAttributesError: If an error occurs during the generation of the attributions.

    Examples:
        >>> hyper_noise_tunnel = HyperNoiseTunnel(explainable_model)
        >>> hsi = HSI(image=torch.ones((4, 240, 240)), wavelengths=[462.08, 465.27, 468.47, 471.68])
        >>> attributions = hyper_noise_tunnel.attribute(hsi)
        >>> attributions = hyper_noise_tunnel.attribute([hsi, hsi])
        >>> len(attributions)
        2
    """
    change_orientation = []
    original_orientation = []

    if not isinstance(hsi, list):
        hsi = [hsi]

    if not all([isinstance(input, HSI) for input in hsi]):
        raise TypeError("All inputs must be HSI objects")

    for i in range(len(hsi)):
        if hsi[i].orientation != ("C", "H", "W"):
            change_orientation.append(True)
            original_orientation.append(hsi[i].orientation)
            hsi[i] = hsi[i].change_orientation("CHW")
        else:
            change_orientation.append(False)

    if not isinstance(baseline, list):
        baseline = [baseline] * len(hsi)
    elif len(baseline) != len(hsi):
        raise ValueError("The number of baseline must match the number of input images")

    baseline = [validate_and_transform_baseline(base, hsi_image) for base, hsi_image in zip(baseline, hsi)]

    if not isinstance(target, list):
        target = [target] * len(hsi)  # type: ignore

    hnt_attributes = torch.empty((n_samples, len(hsi)) + hsi[0].image.shape, device=hsi[0].device)
    for batch in range(0, len(hsi)):
        input = hsi[batch]
        targeted = target[batch]
        base = baseline[batch]
        perturbed_input = self.perturb_input(input.image, base, n_samples, perturbation_prob, num_perturbed_bands)

        hnt_attributes[:, batch] = self._forward_loop(
            perturbed_input, input, targeted, additional_forward_args, n_samples, steps_per_batch
        )

    hnt_attributes = self._aggregate_attributions(hnt_attributes, method)

    try:
        attributes = [
            HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
            for hsi_image, attribution in zip(hsi, hnt_attributes)
        ]
    except Exception as e:
        raise HSIAttributesError(f"Error in generating HyperNoiseTunnel attributions: {e}") from e

    for i in range(len(change_orientation)):
        if change_orientation[i]:
            attributes[i].hsi = attributes[i].hsi.change_orientation(original_orientation[i])

    return attributes[0] if len(attributes) == 1 else attributes

perturb_input(input, baseline=None, n_samples=1, perturbation_prob=0.5, num_perturbed_bands=None, **kwargs) staticmethod

The perturbation function used in the hyper noise tunnel. It randomly selects a subset of the input bands that will be masked out and replaced with the baseline. The parameters num_perturbed_bands and perturbation_prob control the number of bands that will be perturbed (masked). If num_perturbed_bands is set, it will be used as the number of bands to perturb, which will be randomly selected. Otherwise, the number of bands will be drawn from a binomial distribution with perturbation_prob as the probability of success.

Parameters:

Name Type Description Default
input Tensor

An input tensor to be perturbed. It should have the shape (C, H, W).

required
baseline Tensor | None

A tensor that will be used to replace the perturbed bands.

None
n_samples int

A number of samples to be drawn - number of perturbed inputs to be generated.

1
perturbation_prob float

A probability that each band will be perturbed intependently. Defaults to 0.5.

0.5
num_perturbed_bands int | None

A number of perturbed bands in the whole image. If set to None, the bands are perturbed with probability perturbation_prob each. Defaults to None.

None

Returns:

Type Description
Tensor

torch.Tensor: A perturbed tensor, which contains n_samples perturbed inputs.

Source code in src/meteors/attr/noise_tunnel.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
@staticmethod
def perturb_input(
    input: torch.Tensor,
    baseline: torch.Tensor | None = None,
    n_samples: int = 1,
    perturbation_prob: float = 0.5,
    num_perturbed_bands: int | None = None,
    **kwargs: Any,
) -> torch.Tensor:
    """The perturbation function used in the hyper noise tunnel. It randomly selects a subset of the input bands
    that will be masked out and replaced with the baseline. The parameters `num_perturbed_bands` and
    `perturbation_prob` control the number of bands that will be perturbed (masked). If `num_perturbed_bands` is
    set, it will be used as the number of bands to perturb, which will be randomly selected. Otherwise, the number
    of bands will be drawn from a binomial distribution with `perturbation_prob` as the probability of success.

    Args:
        input (torch.Tensor): An input tensor to be perturbed. It should have the shape (C, H, W).
        baseline (torch.Tensor | None, optional): A tensor that will be used to replace the perturbed bands.
        n_samples (int): A number of samples to be drawn - number of perturbed inputs to be generated.
        perturbation_prob (float, optional): A probability that each band will be perturbed intependently.
            Defaults to 0.5.
        num_perturbed_bands (int | None, optional): A number of perturbed bands in the whole image.
            If set to None, the bands are perturbed with probability `perturbation_prob` each. Defaults to None.

    Returns:
        torch.Tensor: A perturbed tensor, which contains `n_samples` perturbed inputs.
    """
    # validate the baseline against the input
    if baseline is None:
        raise ValueError("Baseline must be provided for the HyperNoiseTunnel method")

    if baseline.shape != input.shape:
        raise ShapeMismatchError(f"Baseline shape {baseline.shape} does not match input shape {input.shape}")

    if n_samples < 1:
        raise ValueError("Number of perturbated samples to be generated must be greater than 0")

    if perturbation_prob < 0 or perturbation_prob > 1:
        raise ValueError("Perturbation probability must be in the range [0, 1]")

    # the perturbation
    perturbed_input = input.clone().unsqueeze(0)
    # repeat the perturbed_input on the first dimension n_samples times
    perturbed_input = perturbed_input.repeat_interleave(n_samples, dim=0)

    n_samples_x_channels_shape = (
        n_samples,
        input.shape[0],
    )  # shape of the tensor containing the perturbed channels for each sample

    channels_to_be_perturbed: torch.Tensor = torch.zeros(n_samples_x_channels_shape, device=input.device).bool()

    if num_perturbed_bands is None:
        channel_perturbation_probabilities = (
            torch.ones(n_samples_x_channels_shape, device=input.device) * perturbation_prob
        )
        channels_to_be_perturbed = torch.bernoulli(channel_perturbation_probabilities).bool()

    else:
        if num_perturbed_bands < 0 or num_perturbed_bands > input.shape[0]:
            raise ValueError(
                f"Cannot perturb {num_perturbed_bands} bands in the input with {input.shape[0]} channels. The number of perturbed bands must be in the range [0, {input.shape[0]}]"
            )

        channels_to_be_perturbed = torch_random_choice(
            input.shape[0], num_perturbed_bands, n_samples, device=input.device
        )

    # now having chosen the perturbed channels, we can replace them with the baseline
    reshaped_baseline = baseline.unsqueeze(0).repeat_interleave(n_samples, dim=0)
    perturbed_input[channels_to_be_perturbed] = reshaped_baseline[channels_to_be_perturbed]

    perturbed_input.requires_grad_(True)

    return perturbed_input