• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 10097470274

25 Jul 2024 04:00PM UTC coverage: 83.904% (-0.1%) from 84.035%
10097470274

push

github

web-flow
Merge pull request #475 from WenjieDu/dev

Add attention map visualization func

5 of 27 new or added lines in 3 files covered. (18.52%)

3 existing lines in 2 files now uncovered.

10582 of 12612 relevant lines covered (83.9%)

5.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pypots/utils/visual/attention_map.py
1
"""
2
Utilities for attention map visualization.
3
"""
4

5
# Created by Anshuman Swain <aswai@seas.upenn.edu> and Wenjie Du <wenjay.du@gmail.com>
6
# License: BSD-3-Clause
7

NEW
8
import matplotlib.pyplot as plt
×
NEW
9
import numpy as np
×
NEW
10
from numpy.typing import ArrayLike
×
11

NEW
12
try:
×
NEW
13
    import seaborn as sns
×
NEW
14
except Exception:
×
NEW
15
    pass
×
16

17

NEW
18
def plot_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale=None):
×
19
    """Visualize the map of attention weights from Transformer-based models.
20

21
    Parameters
22
    ---------------
23
    timeSteps: 1D array-like object, preferable list of strings
24
        A vector containing the time steps of the input.
25
        The time steps will be converted to a list of strings if they are not already.
26

27
    attention: 2D array-like object
28
        A 2D matrix representing the attention weights
29

30
    fontscale: float/int
31
        Sets the scale for fonts in the Seaborn heatmap (applied to sns.set_theme(font_scale = _)
32

33

34
    Return
35
    ---------------
36
    ax: Matplotlib axes object
37

38
    """
39

NEW
40
    if not all(isinstance(ele, str) for ele in timeSteps):
×
NEW
41
        timeSteps = [str(step) for step in timeSteps]
×
42

NEW
43
    if fontscale is not None:
×
NEW
44
        sns.set_theme(font_scale=fontscale)
×
45

NEW
46
    fig, ax = plt.subplots()
×
NEW
47
    ax.tick_params(left=True, bottom=True, labelsize=10)
×
NEW
48
    ax.set_xticks(ax.get_xticks()[::2])
×
NEW
49
    ax.set_yticks(ax.get_yticks()[::2])
×
50

NEW
51
    assert attention.ndim == 2, "The attention matrix is not two-dimensional"
×
NEW
52
    sns.heatmap(
×
53
        attention,
54
        ax=ax,
55
        xticklabels=timeSteps,
56
        yticklabels=timeSteps,
57
        linewidths=0,
58
        cbar=True,
59
    )
NEW
60
    cb = ax.collections[0].colorbar
×
NEW
61
    cb.ax.tick_params(labelsize=10)
×
62

NEW
63
    return fig
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc