Skip to content

Conversation

@jackmcrider
Copy link

This commit introduces several new components to enhance k-means based model attribution and provide a new LRP rule for pooling layers:

  1. New Layers (src/zennit/layer.py):

    • PairwiseCentroidDistance: Computes pairwise distances between inputs and a set of centroids.
    • NeuralizedKMeans: A layer representing k-means discriminants as a linear transformation (tensor-matrix product with bias).
    • MinPool1d and MinPool2d: Min-pooling layers for 1D and 2D inputs, implemented by negating inputs/outputs of MaxPool.
  2. New Canonizer (src/zennit/canonizers.py):

    • KMeansCanonizer: Replaces PairwiseCentroidDistance layers (with power=2) with a sequence of NeuralizedKMeans, MinPool1d, and torch.nn.Flatten. This neuralization enables the application of LRP rules to k-means like clustering outputs.
  3. New LRP Rule (src/zennit/rules.py):

    • TakesMost: An LRP rule designed for max-pooling layers (and by extension, min-pooling layers after neuralization).
    • Distributes relevance based on a softmax of input contributions within each pooling window, providing a "soft" alternative to winner-takes-all.
    • Implements window-wise maximum subtraction for robust numerical stability during the softmax calculation.
    • Dynamically adapts internal convolution, pooling, and transposed convolution operations to match the input dimensionality (1D/2D).
    • Utilizes common, refactored parameters for unpooling operations.
    • Respects the original module's hyperparameters.
  4. New Tests (tests/):

    • test_kmeans_canonizer in test_canonizers.py: Verifies the correctness of the KMeansCanonizer, ensuring module replacement, functional equivalence of cluster assignments, and proper restoration.

Replaces #197
Closes #198

This commit introduces several new components to enhance k-means based
model attribution and provide a new LRP rule for pooling layers:

1. New Layers (`src/zennit/layer.py`):
   - `PairwiseCentroidDistance`: Computes pairwise distances between
     inputs and a set of centroids.
   - `NeuralizedKMeans`: A layer representing k-means discriminants
     as a linear transformation (tensor-matrix product with bias).
   - `MinPool1d` and `MinPool2d`: Min-pooling layers for 1D and 2D
     inputs, implemented by negating inputs/outputs of MaxPool.

2. New Canonizer (`src/zennit/canonizers.py`):
   - `KMeansCanonizer`: Replaces `PairwiseCentroidDistance` layers
     (with `power=2`) with a sequence of `NeuralizedKMeans`,
     `MinPool1d`, and `torch.nn.Flatten`. This neuralization
     facilitates the application of LRP rules to k-means like
     clustering outputs.

3. New LRP Rule (`src/zennit/rules.py`):
   - `TakesMost`: An LRP rule designed for max-pooling layers (and
     by extension, min-pooling layers after neuralization).
   - Distributes relevance based on a softmax of input contributions
     within each pooling window, providing a "soft" alternative to
     winner-takes-all.
   - Implements window-wise maximum subtraction for robust numerical
     stability during the softmax calculation.
   - Dynamically adapts internal convolution, pooling, and transposed
     convolution operations to match the input dimensionality (1D/2D).
   - Utilizes common, refactored parameters for unpooling operations.
   - Respects the original module's hyperparameters.

4. New Tests (`tests/`):
   - `test_kmeans_canonizer` in `test_canonizers.py`: Verifies the
     correctness of the `KMeansCanonizer`, ensuring module replacement,
     functional equivalence of cluster assignments, and proper restoration.
@jackmcrider
Copy link
Author

Some tests fail, but I think these are mostly lint errors coming from the main branch.

@chr5tphr
Copy link
Owner

Hey Jacob,

Thanks a lot for finalizing the PR! I will get rid of the tests on Python 3.7, so that's fine.
The pylint errors are also fine, I do not think the number of positional arguments is too bad.
The flake8 is valid, you can just add a whitespace where reported.

Could you go through the places that are missing in the coverage report? Some might be unrelated to KMeans and TakesMost, so you can ignore those. It would be good to have those places at least called in some test.

 Name                       Stmts   Miss Branch BrPart  Cover   Missing
----------------------------------------------------------------------
src/zennit/canonizers.py     144      2     34      1    98%   381-382
src/zennit/composites.py     128      0     34      1    99%   159->161
src/zennit/layer.py           27      1      0      0    96%   157
src/zennit/rules.py          138     43     14      0    69%   455-460, 464, 468, 471, 475-482, 486-499, 503-507, 510-564
----------------------------------------------------------------------
TOTAL                       1123     46    312      2    96%

@chr5tphr
Copy link
Owner

chr5tphr commented Jul 9, 2025

Just a heads-up @jackmcrider : I am planning to go over this again, fix potential issues, and merge once #215 is merged. Ideally, I will be done by today.

@jackmcrider
Copy link
Author

jackmcrider commented Jul 9, 2025

I did not see your reply until now, sorry. You mean more tests, right? I will provide more, but not today. Thank you for considering this PR!

@chr5tphr
Copy link
Owner

chr5tphr commented Jul 9, 2025

Thanks! I might also look into it, as I am pushing Zennit this week for a new release. I have some other things still on the list, but would be happy to get this PR into the release.

Some building infrastructure changed, but the building and testing should still work with the same commands. Be sure to rebase.

@chr5tphr chr5tphr force-pushed the main branch 4 times, most recently from 3fbdb43 to 6204e31 Compare July 31, 2025 14:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Neuralized K-Means: make k-means amenable to neural network explanations

3 participants