序論
Transformerアーキテクチャの登場により、自然言語処理分野は革命的な進歩を遂げました。その核心であるSelf-Attention機構は、従来のRNNやCNNでは困難だった長距離依存関係の学習を可能にしました。しかし、この機構の内部動作は「ブラックボックス」として扱われることが多く、実際にモデルが何を学習しているかを理解することは困難でした。
アテンション機構の可視化は、モデルの解釈可能性(Interpretability)を向上させる重要な手法です。本記事では、元Google Brainでの研究経験と現在のAIスタートアップでの実装経験を基に、Transformerアテンション機構の可視化手法を包括的に解説します。単なる可視化ツールの使用方法ではなく、数学的背景から実装詳細、そして実際のビジネス応用まで、実践的な知識を提供します。
第1章:Transformerアテンション機構の数学的基盤
1.1 Self-Attentionの数学的定義
Self-Attention機構は、以下の数式で定義されます:
Attention(Q, K, V) = softmax(QK^T / √d_k)V
ここで、QはQuery行列、KはKey行列、VはValue行列、d_kはキーベクトルの次元数です。この式は一見シンプルですが、その背後には深い数学的意味があります。
Query、Key、Valueの生成は、入力ベクトルXに対して以下のように行われます:
Q = XW_Q
K = XW_K
V = XW_V
W_Q、W_K、W_Vは学習可能なパラメータ行列です。これらの行列は、入力情報を異なる表現空間に射影する役割を果たします。
1.2 Multi-Head Attentionの数学的構造
Multi-Head Attentionは、異なる表現部分空間での注意機構を並列実行します:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W_O
各ヘッドは以下のように計算されます:
head_i = Attention(QW_Q^i, KW_K^i, VW_V^i)
この並列化により、モデルは異なる種類の関係性を同時に学習できます。例えば、あるヘッドは構文的関係を、別のヘッドは意味的関係を捉えることが実験的に確認されています。
1.3 位置エンコーディングとの相互作用
Transformerでは位置情報を明示的に与える必要があります。正弦波位置エンコーディングは以下の式で定義されます:
def positional_encoding(position, d_model):
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# 偶数インデックスにsin適用
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# 奇数インデックスにcos適用
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
return angle_rads
この位置エンコーディングとアテンション機構の相互作用により、モデルは相対的な位置関係を学習します。
第2章:アテンション可視化の理論的基盤
2.1 アテンション重みの解釈学的意味
アテンション重みα_ijは、トークンiがトークンjにどの程度「注意」を向けているかを表します。これは確率分布として正規化されており、以下の性質を満たします:
Σ_j α_ij = 1 (各行の和が1)
0 ≤ α_ij ≤ 1 (非負性)
しかし、この重みの解釈には注意が必要です。高いアテンション重みが必ずしも高い重要度を意味するわけではありません。Jain & Wallace (2019)の研究では、アテンション重みと勾配ベースの重要度スコアの間に必ずしも強い相関がないことが示されています。
2.2 層間でのアテンションパターンの進化
Transformerの各層では、異なるレベルの抽象化が行われます。Rogers et al. (2020)の分析によると:
- 下位層(1-3層): 構文的関係(品詞、構文木の隣接関係)
- 中位層(4-8層): 意味的関係(共参照、含意関係)
- 上位層(9-12層): タスク固有の抽象的関係
この層次的な表現学習は、可視化において重要な示唆を与えます。
2.3 Head間の機能分化
Voita et al. (2019)の研究では、BERTの各ヘッドが特定の言語学的機能に特化していることが示されました:
ヘッド番号 | 主要機能 | 検出される関係 |
---|---|---|
Head 8-10 | 構文解析 | 依存関係、修飾関係 |
Head 5-7 | 共参照解決 | 代名詞-先行詞関係 |
Head 1-3 | 位置関係 | 隣接トークン、局所的関係 |
第3章:実装レベルでの可視化技術
3.1 基本的なアテンション可視化の実装
最も基本的なアテンション可視化は、アテンション重み行列のヒートマップ表示です:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertModel, BertTokenizer
class AttentionVisualizer:
def __init__(self, model_name='bert-base-uncased'):
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertModel.from_pretrained(model_name,
output_attentions=True)
self.model.eval()
def extract_attention(self, text):
"""テキストからアテンション重みを抽出"""
inputs = self.tokenizer(text, return_tensors='pt',
padding=True, truncation=True)
with torch.no_grad():
outputs = self.model(**inputs)
# attentions: (layer, batch, head, seq_len, seq_len)
attentions = outputs.attentions
tokens = self.tokenizer.convert_ids_to_tokens(
inputs['input_ids'][0])
return attentions, tokens
def visualize_head(self, attentions, tokens, layer, head):
"""特定の層・ヘッドのアテンションを可視化"""
attention = attentions[layer][0, head].cpu().numpy()
plt.figure(figsize=(12, 10))
sns.heatmap(attention,
xticklabels=tokens,
yticklabels=tokens,
cmap='Blues',
cbar=True,
square=True)
plt.title(f'Layer {layer}, Head {head} Attention Pattern')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
3.2 高度な可視化技術:アテンションロールアウト
Abnar & Zuidema (2020)によって提案されたアテンションロールアウトは、複数層にわたるアテンションの流れを追跡する手法です:
class AttentionRollout:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def compute_rollout(self, attentions, discard_ratio=0.9):
"""アテンションロールアウトを計算"""
result = torch.eye(attentions[0].size(-1))
for attention in attentions:
# 各ヘッドの平均を取得
attention_heads_fused = attention.mean(dim=1)
# 最小値を除去(ノイズ削減)
flat = attention_heads_fused.view(-1)
_, indices = flat.topk(int(flat.size(-1) * discard_ratio))
flat[indices] = 0
# 残差接続を考慮
I = torch.eye(attention_heads_fused.size(-1))
a = (attention_heads_fused + I) / 2
a = a / a.sum(dim=-1, keepdim=True)
result = torch.matmul(a, result)
return result
def visualize_rollout(self, text, start_layer=0, end_layer=None):
"""ロールアウト結果を可視化"""
inputs = self.tokenizer(text, return_tensors='pt')
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
attentions = outputs.attentions
if end_layer is None:
end_layer = len(attentions)
selected_attentions = attentions[start_layer:end_layer]
rollout = self.compute_rollout(selected_attentions)
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
plt.figure(figsize=(12, 8))
sns.heatmap(rollout.numpy(),
xticklabels=tokens,
yticklabels=tokens,
cmap='Reds')
plt.title(f'Attention Rollout (Layers {start_layer}-{end_layer})')
plt.show()
3.3 動的アテンション可視化
リアルタイムでのアテンション変化を可視化する動的手法:
import matplotlib.animation as animation
class DynamicAttentionVisualizer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def create_animation(self, text, output_file='attention_animation.gif'):
"""層ごとのアテンション変化をアニメーション化"""
inputs = self.tokenizer(text, return_tensors='pt')
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
attentions = outputs.attentions
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
fig, ax = plt.subplots(figsize=(10, 8))
def animate(layer):
ax.clear()
# 全ヘッドの平均を計算
attention_avg = attentions[layer].mean(dim=1)[0].cpu().numpy()
im = ax.imshow(attention_avg, cmap='Blues', aspect='auto')
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45)
ax.set_yticklabels(tokens)
ax.set_title(f'Layer {layer} - Average Attention')
return [im]
anim = animation.FuncAnimation(fig, animate,
frames=len(attentions),
interval=1000,
blit=False)
anim.save(output_file, writer='pillow')
plt.show()
第4章:専門的可視化手法とその実装
4.1 アテンショングラフによるネットワーク解析
アテンション重みをグラフとして扱い、ネットワーク解析手法を適用する手法:
import networkx as nx
from collections import defaultdict
class AttentionGraphAnalyzer:
def __init__(self, threshold=0.1):
self.threshold = threshold
def create_attention_graph(self, attention_weights, tokens):
"""アテンション重みからグラフを生成"""
G = nx.DiGraph()
# ノード追加
for i, token in enumerate(tokens):
G.add_node(i, label=token)
# エッジ追加(閾値以上の重みのみ)
for i in range(len(tokens)):
for j in range(len(tokens)):
weight = attention_weights[i, j]
if weight > self.threshold and i != j:
G.add_edge(i, j, weight=weight)
return G
def analyze_centrality(self, G):
"""中心性指標の計算"""
metrics = {
'degree_centrality': nx.degree_centrality(G),
'betweenness_centrality': nx.betweenness_centrality(G),
'eigenvector_centrality': nx.eigenvector_centrality(G),
'pagerank': nx.pagerank(G)
}
return metrics
def visualize_graph(self, G, tokens, pos_type='spring'):
"""グラフの可視化"""
plt.figure(figsize=(15, 10))
if pos_type == 'spring':
pos = nx.spring_layout(G, k=3, iterations=50)
elif pos_type == 'circular':
pos = nx.circular_layout(G)
else:
pos = nx.random_layout(G)
# エッジの重みに基づいて太さを調整
edges = G.edges()
weights = [G[u][v]['weight'] for u, v in edges]
nx.draw_networkx_nodes(G, pos, node_size=1000,
node_color='lightblue', alpha=0.7)
nx.draw_networkx_edges(G, pos, width=[w*10 for w in weights],
alpha=0.6, edge_color='gray')
# ラベル
labels = {i: tokens[i] for i in range(len(tokens))}
nx.draw_networkx_labels(G, pos, labels, font_size=10)
plt.title('Attention Graph Network')
plt.axis('off')
plt.tight_layout()
plt.show()
4.2 次元削減による高次元アテンション空間の可視化
t-SNEやUMAPを用いたアテンション空間の低次元可視化:
from sklearn.manifold import TSNE
import umap
class HighDimAttentionVisualizer:
def __init__(self):
pass
def extract_attention_features(self, attentions, layer_range=None):
"""アテンション特徴量の抽出"""
if layer_range is None:
layer_range = range(len(attentions))
features = []
for layer_idx in layer_range:
attention = attentions[layer_idx][0] # バッチサイズ1を仮定
# 各ヘッドを特徴量として使用
for head_idx in range(attention.shape[0]):
head_attention = attention[head_idx].cpu().numpy()
# 上三角行列のみを使用(対称性を考慮)
triu_indices = np.triu_indices(head_attention.shape[0], k=1)
features.append(head_attention[triu_indices])
return np.array(features)
def visualize_tsne(self, attentions, tokens, perplexity=30):
"""t-SNEによる可視化"""
features = self.extract_attention_features(attentions)
tsne = TSNE(n_components=2, perplexity=perplexity,
random_state=42, n_iter=1000)
embedded = tsne.fit_transform(features)
plt.figure(figsize=(12, 8))
scatter = plt.scatter(embedded[:, 0], embedded[:, 1],
c=range(len(embedded)), cmap='tab10')
# 各点にラベルを追加
for i, (x, y) in enumerate(embedded):
layer = i // 12 # 12ヘッド/層を仮定
head = i % 12
plt.annotate(f'L{layer}H{head}', (x, y),
xytext=(5, 5), textcoords='offset points',
fontsize=8, alpha=0.7)
plt.title('t-SNE Visualization of Attention Patterns')
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.colorbar(scatter)
plt.show()
def visualize_umap(self, attentions, tokens, n_neighbors=15):
"""UMAPによる可視化"""
features = self.extract_attention_features(attentions)
reducer = umap.UMAP(n_neighbors=n_neighbors,
min_dist=0.3, random_state=42)
embedded = reducer.fit_transform(features)
plt.figure(figsize=(12, 8))
scatter = plt.scatter(embedded[:, 0], embedded[:, 1],
c=range(len(embedded)), cmap='viridis')
plt.title('UMAP Visualization of Attention Patterns')
plt.xlabel('UMAP 1')
plt.ylabel('UMAP 2')
plt.colorbar(scatter)
plt.show()
4.3 統計的アテンション分析
アテンション分布の統計的特性を分析する手法:
import scipy.stats as stats
class AttentionStatisticalAnalyzer:
def __init__(self):
pass
def compute_attention_statistics(self, attentions):
"""アテンション統計量の計算"""
stats_by_layer = []
for layer_idx, attention in enumerate(attentions):
layer_stats = {}
attention_np = attention[0].cpu().numpy() # (heads, seq, seq)
# 各ヘッドの統計量
for head_idx in range(attention_np.shape[0]):
head_attention = attention_np[head_idx]
# エントロピー計算
entropy = -np.sum(head_attention * np.log(head_attention + 1e-10),
axis=-1).mean()
# 集中度(ジニ係数)
gini = self.gini_coefficient(head_attention)
# スパース性
sparsity = np.sum(head_attention < 0.01) / head_attention.size
layer_stats[f'head_{head_idx}'] = {
'entropy': entropy,
'gini': gini,
'sparsity': sparsity,
'max_attention': head_attention.max(),
'mean_attention': head_attention.mean()
}
stats_by_layer.append(layer_stats)
return stats_by_layer
def gini_coefficient(self, attention_matrix):
"""ジニ係数の計算"""
flattened = attention_matrix.flatten()
flattened = np.sort(flattened)
n = len(flattened)
cumulative = np.cumsum(flattened)
return (n + 1 - 2 * np.sum(cumulative) / cumulative[-1]) / n
def plot_statistics_distribution(self, stats_by_layer, metric='entropy'):
"""統計量の分布を可視化"""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# レイヤー別統計量
layer_means = []
layer_stds = []
for layer_idx, layer_stats in enumerate(stats_by_layer):
values = [head_stats[metric] for head_stats in layer_stats.values()]
layer_means.append(np.mean(values))
layer_stds.append(np.std(values))
# レイヤー別平均値の推移
axes[0, 0].plot(layer_means, marker='o')
axes[0, 0].fill_between(range(len(layer_means)),
np.array(layer_means) - np.array(layer_stds),
np.array(layer_means) + np.array(layer_stds),
alpha=0.3)
axes[0, 0].set_title(f'{metric.capitalize()} by Layer')
axes[0, 0].set_xlabel('Layer')
axes[0, 0].set_ylabel(metric.capitalize())
# ヒストグラム
all_values = []
for layer_stats in stats_by_layer:
all_values.extend([head_stats[metric] for head_stats in layer_stats.values()])
axes[0, 1].hist(all_values, bins=30, alpha=0.7)
axes[0, 1].set_title(f'{metric.capitalize()} Distribution')
axes[0, 1].set_xlabel(metric.capitalize())
axes[0, 1].set_ylabel('Frequency')
# ヘッドごとの箱ひげ図
head_data = [[] for _ in range(12)] # 12ヘッドを仮定
for layer_stats in stats_by_layer:
for head_idx in range(12):
if f'head_{head_idx}' in layer_stats:
head_data[head_idx].append(layer_stats[f'head_{head_idx}'][metric])
axes[1, 0].boxplot(head_data)
axes[1, 0].set_title(f'{metric.capitalize()} by Head')
axes[1, 0].set_xlabel('Head')
axes[1, 0].set_ylabel(metric.capitalize())
# レイヤー×ヘッドヒートマップ
heatmap_data = np.zeros((len(stats_by_layer), 12))
for layer_idx, layer_stats in enumerate(stats_by_layer):
for head_idx in range(12):
if f'head_{head_idx}' in layer_stats:
heatmap_data[layer_idx, head_idx] = layer_stats[f'head_{head_idx}'][metric]
im = axes[1, 1].imshow(heatmap_data, cmap='viridis', aspect='auto')
axes[1, 1].set_title(f'{metric.capitalize()} Heatmap')
axes[1, 1].set_xlabel('Head')
axes[1, 1].set_ylabel('Layer')
plt.colorbar(im, ax=axes[1, 1])
plt.tight_layout()
plt.show()
第5章:実用的アプリケーションと事例分析
5.1 感情分析タスクでのアテンション分析
実際の感情分析タスクにおけるアテンションパターンの分析事例:
class SentimentAttentionAnalyzer:
def __init__(self, model_path):
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(
model_path, output_attentions=True)
self.model.eval()
def analyze_sentiment_attention(self, texts, labels):
"""感情ラベル別のアテンションパターン分析"""
attention_patterns = {'positive': [], 'negative': [], 'neutral': []}
for text, label in zip(texts, labels):
inputs = self.tokenizer(text, return_tensors='pt',
truncation=True, padding=True)
with torch.no_grad():
outputs = self.model(**inputs)
attentions = outputs.attentions
# 最終層の平均アテンションを使用
final_attention = attentions[-1].mean(dim=1)[0].cpu().numpy()
label_name = ['negative', 'neutral', 'positive'][label]
attention_patterns[label_name].append(final_attention)
return attention_patterns
def compare_attention_patterns(self, attention_patterns):
"""感情クラス間のアテンションパターン比較"""
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
for idx, (sentiment, patterns) in enumerate(attention_patterns.items()):
if patterns:
avg_pattern = np.mean(patterns, axis=0)
im = axes[idx].imshow(avg_pattern, cmap='Blues')
axes[idx].set_title(f'{sentiment.capitalize()} Sentiment')
axes[idx].set_xlabel('Key Position')
axes[idx].set_ylabel('Query Position')
plt.colorbar(im, ax=axes[idx])
plt.tight_layout()
plt.show()
def identify_important_tokens(self, text, top_k=5):
"""重要なトークンの特定"""
inputs = self.tokenizer(text, return_tensors='pt')
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
with torch.no_grad():
outputs = self.model(**inputs)
attentions = outputs.attentions
# 全層・全ヘッドの平均アテンション
avg_attention = torch.stack(attentions).mean(dim=(0, 2))[0]
# 各トークンが受ける総アテンション量
token_importance = avg_attention.sum(dim=0).cpu().numpy()
# 上位k個のトークンを選択
top_indices = np.argsort(token_importance)[-top_k:][::-1]
important_tokens = [(tokens[i], token_importance[i]) for i in top_indices]
return important_tokens
5.2 質問応答システムでのアテンション分析
質問応答タスクにおけるアテンション機構の役割分析:
class QAAttentionAnalyzer:
def __init__(self, model_name='bert-large-uncased-whole-word-masking-finetuned-squad'):
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertForQuestionAnswering.from_pretrained(
model_name, output_attentions=True)
self.model.eval()
def analyze_qa_attention(self, question, context):
"""質問応答タスクのアテンション分析"""
inputs = self.tokenizer.encode_plus(
question, context,
return_tensors='pt',
max_length=512,
truncation=True,
padding=True
)
with torch.no_grad():
outputs = self.model(**inputs)
attentions = outputs.attentions
start_logits = outputs.start_logits
end_logits = outputs.end_logits
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# [SEP]トークンの位置を特定
sep_indices = [i for i, token in enumerate(tokens) if token == '[SEP]']
question_end = sep_indices[0] if sep_indices else len(tokens)
return {
'attentions': attentions,
'tokens': tokens,
'question_end': question_end,
'start_logits': start_logits,
'end_logits': end_logits
}
def visualize_question_context_attention(self, analysis_result):
"""質問-文脈間のアテンション可視化"""
attentions = analysis_result['attentions']
tokens = analysis_result['tokens']
question_end = analysis_result['question_end']
# 最終層の平均アテンション
final_attention = attentions[-1].mean(dim=1)[0].cpu().numpy()
# 質問部分から文脈部分への注意
q_to_c_attention = final_attention[:question_end, question_end:]
plt.figure(figsize=(15, 8))
# 質問トークン
question_tokens = tokens[1:question_end] # [CLS]を除く
# 文脈トークン(最初の50トークンのみ表示)
context_tokens = tokens[question_end+1:question_end+51]
if len(context_tokens) > 0:
attention_subset = q_to_c_attention[:len(question_tokens), :len(context_tokens)]
sns.heatmap(attention_subset,
xticklabels=context_tokens,
yticklabels=question_tokens,
cmap='Reds',
cbar=True)
plt.title('Question-to-Context Attention')
plt.xlabel('Context Tokens')
plt.ylabel('Question Tokens')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
def identify_answer_relevant_attention(self, analysis_result):
"""答えに関連するアテンションパターンの特定"""
start_logits = analysis_result['start_logits']
end_logits = analysis_result['end_logits']
tokens = analysis_result['tokens']
# 答えの予測位置
start_pos = torch.argmax(start_logits)
end_pos = torch.argmax(end_logits)
predicted_answer = tokens[start_pos:end_pos+1]
# 答え位置に対する各層のアテンション
attentions = analysis_result['attentions']
answer_attention_by_layer = []
for layer_idx, attention in enumerate(attentions):
layer_attention = attention[0].mean(dim=0).cpu().numpy()
# 答え位置への平均アテンション
answer_attention = layer_attention[:, start_pos:end_pos+1].mean(axis=1)
answer_attention_by_layer.append(answer_attention)
return {
'predicted_answer': predicted_answer,
'start_pos': start_pos.item(),
'end_pos': end_pos.item(),
'attention_by_layer': answer_attention_by_layer
}
5.3 機械翻訳でのアテンション分析
機械翻訳タスクにおけるソース-ターゲット間のアテンション分析:
class TranslationAttentionAnalyzer:
def __init__(self, model_name='Helsinki-NLP/opus-mt-en-de'):
from transformers import MarianMTModel, MarianTokenizer
self.tokenizer = MarianTokenizer.from_pretrained(model_name)
self.model = MarianMTModel.from_pretrained(model_name,
output_attentions=True)
self.model.eval()
def analyze_translation_attention(self, source_text):
"""翻訳時のアテンション分析"""
inputs = self.tokenizer(source_text, return_tensors='pt')
with torch.no_grad():
outputs = self.model.generate(
**inputs,
output_attentions=True,
return_dict_in_generate=True,
max_length=100
)
generated_ids = outputs.sequences
attentions = outputs.encoder_attentions
source_tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
target_tokens = self.tokenizer.convert_ids_to_tokens(generated_ids[0])
return {
'source_tokens': source_tokens,
'target_tokens': target_tokens,
'encoder_attentions': attentions,
'generated_text': self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
}
def visualize_alignment(self, analysis_result):
"""ソース-ターゲット単語アライメントの可視化"""
source_tokens = analysis_result['source_tokens']
target_tokens = analysis_result['target_tokens']
attentions = analysis_result['encoder_attentions']
if attentions:
# エンコーダーの最終層アテンション
final_attention = attentions[-1].mean(dim=1)[0].cpu().numpy()
plt.figure(figsize=(12, 8))
sns.heatmap(final_attention,
xticklabels=source_tokens,
yticklabels=source_tokens,
cmap='Blues',
cbar=True)
plt.title('Source-to-Source Attention in Translation')
plt.xlabel('Source Tokens (Key)')
plt.ylabel('Source Tokens (Query)')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
第6章:高度な解釈手法とモデル診断
6.1 アテンション勾配分析
アテンション重みと勾配を組み合わせた重要度分析:
class AttentionGradientAnalyzer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def compute_attention_gradients(self, text, target_class=None):
"""アテンション重みの勾配を計算"""
inputs = self.tokenizer(text, return_tensors='pt')
inputs.requires_grad_()
# フォワードパス
outputs = self.model(**inputs, output_attentions=True)
attentions = outputs.attentions
if target_class is None:
target_class = torch.argmax(outputs.logits, dim=-1)
# 目的関数:予測確率
target_prob = torch.nn.functional.softmax(outputs.logits, dim=-1)[0, target_class]
# 勾配計算
attention_gradients = []
for attention in attentions:
attention.retain_grad()
target_prob.backward(retain_graph=True)
for attention in attentions:
if attention.grad is not None:
attention_gradients.append(attention.grad)
else:
attention_gradients.append(torch.zeros_like(attention))
return attention_gradients, attentions
def compute_integrated_attention_gradients(self, text, baseline="", steps=50):
"""統合勾配によるアテンション重要度分析"""
def interpolate_inputs(baseline_inputs, target_inputs, alpha):
return {
key: baseline_inputs[key] + alpha * (target_inputs[key] - baseline_inputs[key])
for key in baseline_inputs.keys()
}
baseline_inputs = self.tokenizer(baseline, return_tensors='pt')
target_inputs = self.tokenizer(text, return_tensors='pt')
# パス積分
integrated_grads = []
for step in range(steps):
alpha = step / (steps - 1)
interpolated_inputs = interpolate_inputs(baseline_inputs, target_inputs, alpha)
gradients, _ = self.compute_attention_gradients(
self.tokenizer.decode(interpolated_inputs['input_ids'][0]))
if step == 0:
integrated_grads = [torch.zeros_like(grad) for grad in gradients]
for i, grad in enumerate(gradients):
integrated_grads[i] += grad / steps
return integrated_grads
6.2 アテンション摂動分析
アテンション重みを摂動させることによる重要度分析:
class AttentionPerturbationAnalyzer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def perturb_attention(self, attention_weights, perturbation_ratio=0.1):
"""アテンション重みに摂動を加える"""
noise = torch.randn_like(attention_weights) * perturbation_ratio
perturbed_attention = attention_weights + noise
# ソフトマックス正規化を再適用
perturbed_attention = torch.nn.functional.softmax(
perturbed_attention / torch.sqrt(torch.tensor(attention_weights.size(-1), dtype=torch.float)),
dim=-1
)
return perturbed_attention
def analyze_attention_robustness(self, text, num_perturbations=100):
"""アテンション摂動に対するロバスト性分析"""
inputs = self.tokenizer(text, return_tensors='pt')
with torch.no_grad():
original_output = self.model(**inputs)
original_probs = torch.nn.functional.softmax(original_output.logits, dim=-1)
# 摂動実験
perturbation_results = []
for _ in range(num_perturbations):
# モデルのアテンション重みを一時的に変更する必要があるため、
# ここでは概念的な実装を示す
perturbed_probs = self._run_with_perturbed_attention(inputs)
# KLダイバージェンス計算
kl_div = torch.nn.functional.kl_div(
torch.log(perturbed_probs), original_probs, reduction='batchmean'
)
perturbation_results.append(kl_div.item())
return {
'mean_kl_divergence': np.mean(perturbation_results),
'std_kl_divergence': np.std(perturbation_results),
'robustness_score': 1 / (1 + np.mean(perturbation_results))
}
def _run_with_perturbed_attention(self, inputs):
"""摂動されたアテンションでの推論(概念的実装)"""
# 実際の実装では、モデルの内部処理をフックして
# アテンション重みを動的に変更する必要がある
pass
6.3 アテンション一貫性分析
同じ入力に対するアテンションパターンの一貫性評価:
class AttentionConsistencyAnalyzer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def analyze_consistency_across_runs(self, text, num_runs=50):
"""複数実行間でのアテンション一貫性分析"""
attention_matrices = []
# ドロップアウトを有効にして複数回実行
self.model.train()
for _ in range(num_runs):
inputs = self.tokenizer(text, return_tensors='pt')
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
# 最終層の平均アテンション
final_attention = outputs.attentions[-1].mean(dim=1)[0].cpu().numpy()
attention_matrices.append(final_attention)
self.model.eval()
# 一貫性指標の計算
attention_stack = np.stack(attention_matrices)
# 平均と標準偏差
mean_attention = np.mean(attention_stack, axis=0)
std_attention = np.std(attention_stack, axis=0)
# 変動係数(CV: Coefficient of Variation)
cv = std_attention / (mean_attention + 1e-10)
# ペアワイズ相関
correlations = []
for i in range(num_runs):
for j in range(i+1, num_runs):
corr = np.corrcoef(attention_matrices[i].flatten(),
attention_matrices[j].flatten())[0, 1]
correlations.append(corr)
return {
'mean_attention': mean_attention,
'std_attention': std_attention,
'coefficient_of_variation': cv,
'pairwise_correlations': correlations,
'mean_correlation': np.mean(correlations),
'consistency_score': np.mean(correlations)
}
def visualize_consistency(self, consistency_result):
"""一貫性分析結果の可視化"""
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
# 平均アテンション
im1 = axes[0, 0].imshow(consistency_result['mean_attention'], cmap='Blues')
axes[0, 0].set_title('Mean Attention Pattern')
plt.colorbar(im1, ax=axes[0, 0])
# 標準偏差
im2 = axes[0, 1].imshow(consistency_result['std_attention'], cmap='Reds')
axes[0, 1].set_title('Attention Standard Deviation')
plt.colorbar(im2, ax=axes[0, 1])
# 変動係数
im3 = axes[1, 0].imshow(consistency_result['coefficient_of_variation'], cmap='Oranges')
axes[1, 0].set_title('Coefficient of Variation')
plt.colorbar(im3, ax=axes[1, 0])
# 相関分布
axes[1, 1].hist(consistency_result['pairwise_correlations'], bins=20, alpha=0.7)
axes[1, 1].axvline(consistency_result['mean_correlation'],
color='red', linestyle='--',
label=f'Mean: {consistency_result["mean_correlation"]:.3f}')
axes[1, 1].set_title('Pairwise Correlation Distribution')
axes[1, 1].set_xlabel('Correlation')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].legend()
plt.tight_layout()
plt.show()
第7章:実装における限界とリスク評価
7.1 アテンション可視化の解釈上の限界
アテンション可視化には以下の重要な限界があります:
7.1.1 因果関係の誤解
アテンション重みの高さは必ずしもトークンの重要度を表しません。Serrano & Smith (2019)の研究では、アテンション重みを人工的に操作しても、最終的な予測結果に大きな変化が見られない場合があることが示されています。
def demonstrate_attention_weight_manipulation():
"""アテンション重み操作による予測変化の検証"""
# 概念的な実装例
original_prediction = model_predict(inputs)
# 最も高いアテンション重みを0に設定
modified_attention = attention_weights.clone()
max_indices = torch.argmax(modified_attention, dim=-1, keepdim=True)
modified_attention.scatter_(-1, max_indices, 0)
modified_attention = torch.nn.functional.softmax(modified_attention, dim=-1)
modified_prediction = model_predict_with_attention(inputs, modified_attention)
prediction_change = torch.nn.functional.kl_div(
torch.log(modified_prediction), original_prediction, reduction='batchmean'
)
return prediction_change
7.1.2 複数層にわたる情報統合の複雑性
Transformerでは情報が複数層にわたって段階的に統合されるため、単一層のアテンションパターンのみを見ることは不十分です。
7.1.3 ヘッド間の相互作用
各アテンションヘッドは独立に動作しません。Clark et al. (2019)の研究では、異なるヘッド間での機能的な相互依存関係が存在することが示されています。
7.2 計算コストとスケーラビリティの課題
7.2.1 メモリ使用量の問題
アテンション重みの保存と分析には大量のメモリが必要です:
def estimate_attention_memory_usage(seq_length, num_layers, num_heads, batch_size=1):
"""アテンション保存に必要なメモリ使用量の推定"""
# float32を仮定(4バイト)
bytes_per_attention_matrix = seq_length * seq_length * 4
total_matrices = num_layers * num_heads * batch_size
total_bytes = bytes_per_attention_matrix * total_matrices
# GBに変換
gb_usage = total_bytes / (1024**3)
return {
'total_gb': gb_usage,
'per_layer_gb': gb_usage / num_layers,
'per_head_gb': gb_usage / (num_layers * num_heads)
}
# 例:BERT-large (24層, 16ヘッド, 512トークン)
memory_usage = estimate_attention_memory_usage(512, 24, 16)
print(f"推定メモリ使用量: {memory_usage['total_gb']:.2f} GB")
7.2.2 長いシーケンスでの計算量増加
アテンション機構の計算量はシーケンス長の二乗に比例するため、長いテキストでは実用的でない場合があります。
7.3 バイアスと公平性の問題
7.3.1 学習データのバイアスの可視化
アテンション可視化により、モデルが学習した潜在的なバイアスを発見できる場合があります:
class BiasDetectionAnalyzer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def analyze_gender_bias_in_attention(self, sentences_with_pronouns):
"""代名詞に対するアテンションのジェンダーバイアス分析"""
bias_results = {'he': [], 'she': [], 'they': []}
for sentence in sentences_with_pronouns:
inputs = self.tokenizer(sentence, return_tensors='pt')
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
# 最終層の平均アテンション
final_attention = outputs.attentions[-1].mean(dim=1)[0].cpu().numpy()
# 代名詞の位置を特定
for pronoun in ['he', 'she', 'they']:
if pronoun in sentence.lower():
pronoun_indices = [i for i, token in enumerate(tokens)
if pronoun in token.lower()]
if pronoun_indices:
# 代名詞が受けるアテンション
pronoun_attention = final_attention[:, pronoun_indices[0]]
bias_results[pronoun].append(pronoun_attention.mean())
return bias_results
def compare_bias_patterns(self, bias_results):
"""バイアスパターンの統計的比較"""
from scipy import stats
# 各代名詞グループの統計
for pronoun, attentions in bias_results.items():
if attentions:
print(f"{pronoun}: 平均={np.mean(attentions):.4f}, "
f"標準偏差={np.std(attentions):.4f}")
# t検定による有意差検定
if bias_results['he'] and bias_results['she']:
t_stat, p_value = stats.ttest_ind(bias_results['he'], bias_results['she'])
print(f"he vs she: t統計量={t_stat:.4f}, p値={p_value:.4f}")
7.4 不適切なユースケース
以下のようなケースでは、アテンション可視化の使用は不適切です:
7.4.1 因果推論の代替としての使用
アテンション重みから直接的な因果関係を推論することは危険です。相関関係と因果関係を混同してはいけません。
7.4.2 モデルの完全な説明としての過信
アテンション可視化は、モデルの動作の一側面を示すに過ぎません。完全な説明可能性を提供するものではありません。
7.4.3 高リスクドメインでの単独使用
医療診断や法的判断など、高リスクなドメインでは、アテンション可視化のみに基づく意思決定は避けるべきです。
第8章:将来展望と研究動向
8.1 新興可視化技術
8.1.1 3D空間でのアテンション表現
従来の2次元ヒートマップを超えた、3次元空間でのアテンションネットワーク可視化が研究されています:
import plotly.graph_objects as go
import plotly.express as px
class AttentionFlow3DVisualizer:
def __init__(self):
pass
def create_3d_attention_flow(self, attentions, tokens):
"""3次元アテンションフロー可視化"""
fig = go.Figure()
# 各層を異なるz座標に配置
for layer_idx, attention in enumerate(attentions):
attention_2d = attention[0].mean(dim=0).cpu().numpy()
# 強いアテンション接続のみを表示
threshold = np.percentile(attention_2d, 90)
x_coords, y_coords = np.where(attention_2d > threshold)
z_coords = [layer_idx] * len(x_coords)
# 3Dスキャッター
fig.add_trace(go.Scatter3d(
x=x_coords,
y=y_coords,
z=z_coords,
mode='markers',
marker=dict(
size=attention_2d[x_coords, y_coords] * 20,
color=attention_2d[x_coords, y_coords],
colorscale='Viridis',
opacity=0.7
),
name=f'Layer {layer_idx}'
))
fig.update_layout(
title='3D Attention Flow Visualization',
scene=dict(
xaxis_title='Query Position',
yaxis_title='Key Position',
zaxis_title='Layer'
)
)
return fig
8.1.2 動的アテンション追跡
リアルタイムでのアテンション変化を追跡する技術:
class DynamicAttentionTracker:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.attention_history = []
def track_incremental_attention(self, text_sequence):
"""インクリメンタルなテキスト入力に対するアテンション追跡"""
cumulative_text = ""
for token in text_sequence:
cumulative_text += token + " "
inputs = self.tokenizer(cumulative_text, return_tensors='pt')
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
# 最終層の平均アテンション
attention = outputs.attentions[-1].mean(dim=1)[0].cpu().numpy()
self.attention_history.append({
'step': len(self.attention_history),
'text': cumulative_text,
'attention': attention,
'tokens': self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
})
return self.attention_history
def visualize_attention_evolution(self):
"""アテンション進化の可視化"""
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
# 各ステップでの最大アテンション値の推移
max_attentions = [np.max(step['attention']) for step in self.attention_history]
axes[0, 0].plot(max_attentions, marker='o')
axes[0, 0].set_title('Maximum Attention per Step')
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Max Attention')
# アテンションエントロピーの推移
entropies = []
for step in self.attention_history:
attention_flat = step['attention'].flatten()
entropy = -np.sum(attention_flat * np.log(attention_flat + 1e-10))
entropies.append(entropy)
axes[0, 1].plot(entropies, marker='s', color='red')
axes[0, 1].set_title('Attention Entropy per Step')
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('Entropy')
# 最終ステップのアテンション
if self.attention_history:
final_attention = self.attention_history[-1]['attention']
final_tokens = self.attention_history[-1]['tokens']
im = axes[1, 0].imshow(final_attention, cmap='Blues')
axes[1, 0].set_title('Final Attention Pattern')
plt.colorbar(im, ax=axes[1, 0])
# アテンション差分(最初と最後の比較)
if len(self.attention_history) > 1:
initial_attention = self.attention_history[0]['attention']
# サイズを合わせるため、小さい方に合わせる
min_size = min(initial_attention.shape[0], final_attention.shape[0])
attention_diff = (final_attention[:min_size, :min_size] -
initial_attention[:min_size, :min_size])
im2 = axes[1, 1].imshow(attention_diff, cmap='RdBu', center=0)
axes[1, 1].set_title('Attention Change (Final - Initial)')
plt.colorbar(im2, ax=axes[1, 1])
plt.tight_layout()
plt.show()
8.2 解釈可能AI(XAI)との統合
8.2.1 LIME・SHAPとの組み合わせ
従来の説明可能AI手法とアテンション可視化の統合:
import shap
from lime.lime_text import LimeTextExplainer
class XAIAttentionIntegration:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def compare_attention_with_shap(self, text):
"""アテンション重みとSHAP値の比較"""
# SHAP値の計算
explainer = shap.Explainer(self.predict_proba)
shap_values = explainer([text])
# アテンション重みの抽出
inputs = self.tokenizer(text, return_tensors='pt')
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
# 最終層の平均アテンション
attention_weights = outputs.attentions[-1].mean(dim=(1, 2))[0].cpu().numpy()
# 比較分析
correlation = np.corrcoef(attention_weights[1:-1], shap_values.values[0])[0, 1]
return {
'tokens': tokens[1:-1], # [CLS], [SEP]を除く
'attention_weights': attention_weights[1:-1],
'shap_values': shap_values.values[0],
'correlation': correlation
}
def predict_proba(self, texts):
"""SHAP用の予測関数"""
predictions = []
for text in texts:
inputs = self.tokenizer(text, return_tensors='pt',
truncation=True, padding=True)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
predictions.append(probs[0].cpu().numpy())
return np.array(predictions)
def visualize_comparison(self, comparison_result):
"""アテンションとSHAP値の比較可視化"""
fig, axes = plt.subplots(3, 1, figsize=(15, 12))
tokens = comparison_result['tokens']
attention = comparison_result['attention_weights']
shap_vals = comparison_result['shap_values']
x_pos = np.arange(len(tokens))
# アテンション重み
axes[0].bar(x_pos, attention, alpha=0.7, color='blue')
axes[0].set_title('Attention Weights')
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels(tokens, rotation=45)
# SHAP値
colors = ['red' if val < 0 else 'green' for val in shap_vals]
axes[1].bar(x_pos, shap_vals, alpha=0.7, color=colors)
axes[1].set_title('SHAP Values')
axes[1].set_xticks(x_pos)
axes[1].set_xticklabels(tokens, rotation=45)
# 散布図による相関分析
axes[2].scatter(attention, shap_vals, alpha=0.7)
axes[2].set_xlabel('Attention Weight')
axes[2].set_ylabel('SHAP Value')
axes[2].set_title(f'Attention vs SHAP Correlation: {comparison_result["correlation"]:.3f}')
# 相関線の追加
z = np.polyfit(attention, shap_vals, 1)
p = np.poly1d(z)
axes[2].plot(attention, p(attention), "r--", alpha=0.8)
plt.tight_layout()
plt.show()
8.2.2 因果推論との組み合わせ
アテンション分析に因果推論手法を適用:
class CausalAttentionAnalyzer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def interventional_attention_analysis(self, text, intervention_tokens):
"""介入的アテンション分析"""
# 原文での予測
original_inputs = self.tokenizer(text, return_tensors='pt')
with torch.no_grad():
original_outputs = self.model(**original_inputs, output_attentions=True)
original_prediction = torch.nn.functional.softmax(original_outputs.logits, dim=-1)
original_attention = original_outputs.attentions[-1].mean(dim=1)[0]
intervention_results = []
# 各介入トークンに対して分析
for token_to_replace in intervention_tokens:
# トークンを[MASK]に置換
modified_text = text.replace(token_to_replace, '[MASK]')
modified_inputs = self.tokenizer(modified_text, return_tensors='pt')
with torch.no_grad():
modified_outputs = self.model(**modified_inputs, output_attentions=True)
modified_prediction = torch.nn.functional.softmax(modified_outputs.logits, dim=-1)
modified_attention = modified_outputs.attentions[-1].mean(dim=1)[0]
# 因果効果の計算
prediction_effect = torch.nn.functional.kl_div(
torch.log(modified_prediction), original_prediction, reduction='batchmean'
).item()
attention_effect = torch.nn.functional.mse_loss(
modified_attention, original_attention
).item()
intervention_results.append({
'token': token_to_replace,
'prediction_effect': prediction_effect,
'attention_effect': attention_effect,
'causal_strength': prediction_effect * attention_effect
})
return sorted(intervention_results, key=lambda x: x['causal_strength'], reverse=True)
def counterfactual_attention_analysis(self, text, counterfactual_modifications):
"""反実仮想アテンション分析"""
results = {}
# 原文
original_inputs = self.tokenizer(text, return_tensors='pt')
with torch.no_grad():
original_outputs = self.model(**original_inputs, output_attentions=True)
results['original'] = {
'text': text,
'prediction': torch.nn.functional.softmax(original_outputs.logits, dim=-1),
'attention': original_outputs.attentions[-1].mean(dim=1)[0]
}
# 反実仮想テキスト
for mod_name, modified_text in counterfactual_modifications.items():
modified_inputs = self.tokenizer(modified_text, return_tensors='pt')
with torch.no_grad():
modified_outputs = self.model(**modified_inputs, output_attentions=True)
results[mod_name] = {
'text': modified_text,
'prediction': torch.nn.functional.softmax(modified_outputs.logits, dim=-1),
'attention': modified_outputs.attentions[-1].mean(dim=1)[0]
}
return results
8.3 大規模言語モデル(LLM)時代のアテンション分析
8.3.1 GPT系モデルでのアテンション分析
因果的アテンション(Causal Attention)を持つGPTモデルでの特殊な可視化手法:
class GPTAttentionAnalyzer:
def __init__(self, model_name='gpt2'):
from transformers import GPT2LMHeadModel, GPT2Tokenizer
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
self.model = GPT2LMHeadModel.from_pretrained(model_name, output_attentions=True)
# パディングトークンの設定
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.eval()
def analyze_causal_attention_pattern(self, text):
"""因果的アテンションパターンの分析"""
inputs = self.tokenizer(text, return_tensors='pt')
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
attentions = outputs.attentions
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# 各層での未来情報への依存度チェック
future_attention_scores = []
for layer_idx, attention in enumerate(attentions):
attention_matrix = attention[0].mean(dim=0).cpu().numpy()
# 上三角行列の要素(未来への注意)
upper_triangle = np.triu(attention_matrix, k=1)
future_attention_ratio = np.sum(upper_triangle) / np.sum(attention_matrix)
future_attention_scores.append(future_attention_ratio)
return {
'tokens': tokens,
'attentions': attentions,
'future_attention_scores': future_attention_scores,
'causal_mask_compliance': np.max(future_attention_scores) < 1e-6
}
def visualize_autoregressive_attention(self, analysis_result):
"""自己回帰的アテンションの可視化"""
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
tokens = analysis_result['tokens']
attentions = analysis_result['attentions']
# 最終層のアテンション(因果マスク適用)
final_attention = attentions[-1][0].mean(dim=0).cpu().numpy()
# 因果マスクの可視化
mask = np.tril(np.ones_like(final_attention))
masked_attention = final_attention * mask
im1 = axes[0, 0].imshow(masked_attention, cmap='Blues')
axes[0, 0].set_title('Causal Attention Pattern (Final Layer)')
axes[0, 0].set_xlabel('Key Position')
axes[0, 0].set_ylabel('Query Position')
plt.colorbar(im1, ax=axes[0, 0])
# 層別の未来注意スコア
future_scores = analysis_result['future_attention_scores']
axes[0, 1].plot(future_scores, marker='o')
axes[0, 1].set_title('Future Attention Score by Layer')
axes[0, 1].set_xlabel('Layer')
axes[0, 1].set_ylabel('Future Attention Ratio')
axes[0, 1].axhline(y=0, color='r', linestyle='--', alpha=0.5)
# 位置別アテンション分散
position_variance = []
for pos in range(final_attention.shape[0]):
attention_to_pos = final_attention[:pos+1, pos] # 因果制約下での注意
if len(attention_to_pos) > 1:
position_variance.append(np.var(attention_to_pos))
else:
position_variance.append(0)
axes[1, 0].plot(position_variance, marker='s')
axes[1, 0].set_title('Attention Variance by Position')
axes[1, 0].set_xlabel('Position')
axes[1, 0].set_ylabel('Attention Variance')
# 距離減衰パターン
distance_attention = []
for distance in range(1, min(20, final_attention.shape[0])):
diag_values = []
for i in range(final_attention.shape[0] - distance):
diag_values.append(final_attention[i + distance, i])
if diag_values:
distance_attention.append(np.mean(diag_values))
axes[1, 1].plot(range(1, len(distance_attention) + 1), distance_attention, marker='^')
axes[1, 1].set_title('Attention Decay by Distance')
axes[1, 1].set_xlabel('Distance')
axes[1, 1].set_ylabel('Average Attention')
plt.tight_layout()
plt.show()
8.3.2 超大規模モデルでの効率的アテンション分析
計算資源の制約下での効率的な分析手法:
class EfficientLargeModelAttentionAnalyzer:
def __init__(self, model, tokenizer, sample_layers=None, sample_heads=None):
self.model = model
self.tokenizer = tokenizer
self.sample_layers = sample_layers or [0, -1] # 最初と最後の層のみ
self.sample_heads = sample_heads or [0, 4, 8] # 代表的なヘッドのみ
def memory_efficient_attention_extraction(self, text, chunk_size=64):
"""メモリ効率的なアテンション抽出"""
inputs = self.tokenizer(text, return_tensors='pt',
max_length=512, truncation=True)
# チャンクに分割して処理
input_ids = inputs['input_ids'][0]
attention_data = {}
for start_idx in range(0, len(input_ids), chunk_size):
end_idx = min(start_idx + chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:end_idx].unsqueeze(0)
chunk_inputs = {'input_ids': chunk_ids}
if 'attention_mask' in inputs:
chunk_inputs['attention_mask'] = inputs['attention_mask'][0][start_idx:end_idx].unsqueeze(0)
with torch.no_grad():
chunk_outputs = self.model(**chunk_inputs, output_attentions=True)
# サンプリングされた層・ヘッドのみ保存
for layer_idx in self.sample_layers:
layer_key = f'layer_{layer_idx}'
if layer_key not in attention_data:
attention_data[layer_key] = {}
attention = chunk_outputs.attentions[layer_idx][0]
for head_idx in self.sample_heads:
if head_idx < attention.shape[0]:
head_key = f'head_{head_idx}'
if head_key not in attention_data[layer_key]:
attention_data[layer_key][head_key] = []
attention_data[layer_key][head_key].append(
attention[head_idx].cpu().numpy()
)
return attention_data
def statistical_attention_summary(self, attention_data):
"""統計的アテンション要約"""
summary = {}
for layer_key, layer_data in attention_data.items():
summary[layer_key] = {}
for head_key, head_data in layer_data.items():
if head_data:
# 全チャンクの統計量を計算
all_values = np.concatenate([chunk.flatten() for chunk in head_data])
summary[layer_key][head_key] = {
'mean': np.mean(all_values),
'std': np.std(all_values),
'min': np.min(all_values),
'max': np.max(all_values),
'entropy': -np.sum(all_values * np.log(all_values + 1e-10)) / len(all_values),
'sparsity': np.sum(all_values < 0.01) / len(all_values)
}
return summary
8.4 実世界への応用と産業実装
8.4.1 コンテンツ生成システムでの応用
class ContentGenerationAttentionMonitor:
def __init__(self, generation_model, quality_threshold=0.5):
self.model = generation_model
self.quality_threshold = quality_threshold
def monitor_generation_quality(self, prompt, max_length=100):
"""生成品質のリアルタイムモニタリング"""
generated_text = prompt
quality_scores = []
attention_patterns = []
for step in range(max_length):
inputs = self.tokenizer(generated_text, return_tensors='pt')
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
# 次のトークンを生成
next_token_logits = outputs.logits[0, -1, :]
next_token_id = torch.argmax(next_token_logits)
next_token = self.tokenizer.decode([next_token_id])
# アテンション品質スコアの計算
attention = outputs.attentions[-1][0].mean(dim=0)
attention_entropy = -torch.sum(attention * torch.log(attention + 1e-10), dim=-1).mean()
quality_score = self.compute_quality_score(attention_entropy, next_token_logits)
quality_scores.append(quality_score.item())
attention_patterns.append(attention.cpu().numpy())
# 品質閾値チェック
if quality_score < self.quality_threshold:
print(f"Warning: Low quality detected at step {step}")
generated_text += next_token
# 終了条件
if next_token_id == self.tokenizer.eos_token_id:
break
return {
'generated_text': generated_text,
'quality_scores': quality_scores,
'attention_patterns': attention_patterns,
'average_quality': np.mean(quality_scores)
}
def compute_quality_score(self, attention_entropy, token_logits):
"""生成品質スコアの計算"""
# アテンションエントロピーと予測信頼度を組み合わせ
prediction_confidence = torch.max(torch.nn.functional.softmax(token_logits, dim=0))
# 正規化されたスコア
normalized_entropy = torch.sigmoid(attention_entropy - 2.0) # 経験的閾値
quality_score = (prediction_confidence + normalized_entropy) / 2
return quality_score
8.4.2 対話システムでの感情・意図理解
class ConversationalAttentionAnalyzer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.conversation_history = []
def analyze_conversational_context(self, current_utterance, conversation_history):
"""対話文脈の分析"""
# 対話履歴を結合
full_context = " [SEP] ".join(conversation_history + [current_utterance])
inputs = self.tokenizer(full_context, return_tensors='pt',
truncation=True, max_length=512)
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# SEPトークンの位置を特定
sep_positions = [i for i, token in enumerate(tokens) if token == '[SEP]']
# 各発話間のアテンション分析
utterance_interactions = self.analyze_utterance_interactions(
outputs.attentions, tokens, sep_positions
)
# 感情・意図の変化追跡
emotional_flow = self.track_emotional_flow(
outputs.attentions, tokens, sep_positions
)
return {
'tokens': tokens,
'utterance_interactions': utterance_interactions,
'emotional_flow': emotional_flow,
'context_coherence': self.compute_context_coherence(outputs.attentions)
}
def analyze_utterance_interactions(self, attentions, tokens, sep_positions):
"""発話間の相互作用分析"""
interactions = {}
# 発話境界の定義
utterance_boundaries = [0] + sep_positions + [len(tokens)]
for layer_idx, attention in enumerate(attentions):
layer_attention = attention[0].mean(dim=0).cpu().numpy()
interaction_matrix = np.zeros((len(utterance_boundaries)-1, len(utterance_boundaries)-1))
for i in range(len(utterance_boundaries)-1):
start_i = utterance_boundaries[i]
end_i = utterance_boundaries[i+1]
for j in range(len(utterance_boundaries)-1):
start_j = utterance_boundaries[j]
end_j = utterance_boundaries[j+1]
# 发話間の平均アテンション
interaction_score = np.mean(layer_attention[start_i:end_i, start_j:end_j])
interaction_matrix[i, j] = interaction_score
interactions[f'layer_{layer_idx}'] = interaction_matrix
return interactions
def track_emotional_flow(self, attentions, tokens, sep_positions):
"""感情的流れの追跡"""
# 感情関連語彙の定義
emotion_words = {
'positive': ['happy', 'good', 'great', 'excellent', 'wonderful'],
'negative': ['sad', 'bad', 'terrible', 'awful', 'disappointed'],
'neutral': ['okay', 'fine', 'normal', 'usual', 'regular']
}
emotional_attention = {}
for emotion, words in emotion_words.items():
emotion_indices = []
for word in words:
word_tokens = self.tokenizer.tokenize(word)
for token in word_tokens:
if token in tokens:
emotion_indices.extend([i for i, t in enumerate(tokens) if t == token])
if emotion_indices:
# 感情語への平均アテンション
final_attention = attentions[-1][0].mean(dim=0).cpu().numpy()
emotion_attention_scores = []
for idx in emotion_indices:
attention_to_emotion = final_attention[:, idx]
emotion_attention_scores.append(np.mean(attention_to_emotion))
emotional_attention[emotion] = np.mean(emotion_attention_scores)
else:
emotional_attention[emotion] = 0.0
return emotional_attention
def compute_context_coherence(self, attentions):
"""文脈一貫性の計算"""
coherence_scores = []
for attention in attentions:
layer_attention = attention[0].mean(dim=0).cpu().numpy()
# 対角要素付近の注意(局所的一貫性)
local_coherence = np.mean(np.diag(layer_attention))
# 全体的な注意分散(グローバル一貫性)
global_coherence = 1 / (1 + np.var(layer_attention))
coherence_scores.append({
'local': local_coherence,
'global': global_coherence,
'combined': (local_coherence + global_coherence) / 2
})
return coherence_scores
第9章:実践的実装ガイドライン
9.1 プロダクション環境での実装考慮事項
9.1.1 パフォーマンス最適化
class ProductionAttentionVisualizer:
def __init__(self, model, tokenizer, cache_size=1000):
self.model = model
self.tokenizer = tokenizer
self.attention_cache = {}
self.cache_size = cache_size
@lru_cache(maxsize=128)
def cached_attention_extraction(self, text_hash):
"""キャッシュされたアテンション抽出"""
inputs = self.tokenizer(text_hash, return_tensors='pt')
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
# メモリ効率のため、必要な情報のみ保存
compressed_attentions = []
for attention in outputs.attentions:
# float16で保存してメモリ使用量を削減
compressed_attention = attention.half().cpu()
compressed_attentions.append(compressed_attention)
return compressed_attentions
def batch_attention_analysis(self, texts, batch_size=16):
"""バッチ処理による効率的な分析"""
results = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
# バッチエンコーディング
batch_inputs = self.tokenizer(
batch_texts,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
batch_outputs = self.model(**batch_inputs, output_attentions=True)
# バッチ結果を個別に処理
for j, attention_set in enumerate(zip(*batch_outputs.attentions)):
results.append({
'text': batch_texts[j],
'attentions': [att[j:j+1] for att in attention_set],
'tokens': self.tokenizer.convert_ids_to_tokens(batch_inputs['input_ids'][j])
})
return results
def streaming_attention_analysis(self, text_stream):
"""ストリーミング処理による実時間分析"""
buffer = ""
results = []
for chunk in text_stream:
buffer += chunk
# 文境界で処理
sentences = buffer.split('.')
# 最後の不完全な文は保持
buffer = sentences[-1]
complete_sentences = sentences[:-1]
for sentence in complete_sentences:
if sentence.strip():
# 高速な軽量分析
attention_summary = self.lightweight_attention_analysis(sentence.strip())
results.append(attention_summary)
return results
def lightweight_attention_analysis(self, text):
"""軽量なアテンション分析"""
inputs = self.tokenizer(text, return_tensors='pt', max_length=128, truncation=True)
with torch.no_grad():
# 最終層のみを分析
outputs = self.model(**inputs, output_attentions=True)
final_attention = outputs.attentions[-1]
# 統計的要約のみを計算
attention_stats = {
'max_attention': final_attention.max().item(),
'mean_attention': final_attention.mean().item(),
'attention_entropy': self.compute_attention_entropy(final_attention),
'sparsity_ratio': (final_attention < 0.01).float().mean().item()
}
return {
'text': text,
'attention_stats': attention_stats,
'tokens': self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
}
def compute_attention_entropy(self, attention_tensor):
"""アテンションエントロピーの効率的計算"""
# 数値安定性のための微小値追加
attention_flat = attention_tensor.view(-1)
log_attention = torch.log(attention_flat + 1e-10)
entropy = -torch.sum(attention_flat * log_attention) / attention_flat.numel()
return entropy.item()
9.1.2 エラーハンドリングと堅牢性
class RobustAttentionAnalyzer:
def __init__(self, model, tokenizer, max_retries=3):
self.model = model
self.tokenizer = tokenizer
self.max_retries = max_retries
def safe_attention_extraction(self, text, fallback_analysis=True):
"""安全なアテンション抽出(エラーハンドリング付き)"""
for attempt in range(self.max_retries):
try:
# 入力バリデーション
if not isinstance(text, str) or len(text.strip()) == 0:
raise ValueError("入力テキストが無効です")
# テキスト長チェック
if len(text) > 10000: # 実用的な上限
text = text[:10000]
logging.warning("テキストが長すぎるため切り詰めました")
# エンコーディング
inputs = self.tokenizer(
text,
return_tensors='pt',
truncation=True,
max_length=512,
padding=True
)
# GPU/CPUメモリチェック
if torch.cuda.is_available():
available_memory = torch.cuda.get_device_properties(0).total_memory
used_memory = torch.cuda.memory_allocated(0)
if (available_memory - used_memory) < 1e9: # 1GB未満
torch.cuda.empty_cache()
# モデル推論
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
return self.process_attention_outputs(outputs, inputs)
except torch.cuda.OutOfMemoryError:
logging.error(f"CUDA メモリ不足 (試行 {attempt + 1}/{self.max_retries})")
torch.cuda.empty_cache()
# より小さなバッチサイズで再試行
if attempt < self.max_retries - 1:
text = text[:len(text)//2] # テキスト長を半分に
except Exception as e:
logging.error(f"アテンション抽出エラー: {str(e)} (試行 {attempt + 1}/{self.max_retries})")
if attempt == self.max_retries - 1 and fallback_analysis:
return self.fallback_analysis(text)
raise RuntimeError("アテンション抽出に失敗しました")
def process_attention_outputs(self, outputs, inputs):
"""アテンション出力の安全な処理"""
try:
attentions = outputs.attentions
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# NaN/Infのチェック
for i, attention in enumerate(attentions):
if torch.isnan(attention).any() or torch.isinf(attention).any():
logging.warning(f"層 {i} でNaN/Inf値を検出。ゼロで置換します。")
attentions[i] = torch.nan_to_num(attention, nan=0.0, posinf=1.0, neginf=0.0)
return {
'attentions': attentions,
'tokens': tokens,
'status': 'success',
'warnings': []
}
except Exception as e:
logging.error(f"アテンション処理エラー: {str(e)}")
raise
def fallback_analysis(self, text):
"""フォールバック分析(簡易版)"""
try:
# 最小限の分析
tokens = self.tokenizer.tokenize(text)
# 統計的特徴量のみ計算
basic_stats = {
'token_count': len(tokens),
'avg_token_length': np.mean([len(token) for token in tokens]),
'unique_tokens': len(set(tokens))
}
return {
'tokens': tokens,
'basic_stats': basic_stats,
'status': 'fallback',
'warnings': ['フォールバック分析を使用']
}
except Exception as e:
logging.error(f"フォールバック分析も失敗: {str(e)}")
return {
'status': 'failed',
'error': str(e)
}
9.2 スケーラブルな可視化アーキテクチャ
9.2.1 分散処理対応
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
class DistributedAttentionAnalyzer:
def __init__(self, model_path, num_workers=None):
self.model_path = model_path
self.num_workers = num_workers or mp.cpu_count()
def init_worker(self):
"""ワーカープロセス初期化"""
global model, tokenizer
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.model_path)
model = AutoModel.from_pretrained(self.model_path, output_attentions=True)
model.eval()
def worker_analyze_attention(self, text_batch):
"""ワーカーでのアテンション分析"""
results = []
for text in text_batch:
try:
inputs = tokenizer(text, return_tensors='pt',
truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
# 統計的要約のみ計算(メモリ効率)
attention_summary = self.compute_attention_summary(outputs.attentions)
results.append({
'text': text,
'attention_summary': attention_summary,
'status': 'success'
})
except Exception as e:
results.append({
'text': text,
'error': str(e),
'status': 'error'
})
return results
def compute_attention_summary(self, attentions):
"""アテンション統計要約の計算"""
summary = {}
for layer_idx, attention in enumerate(attentions):
attention_np = attention[0].cpu().numpy()
layer_summary = {
'mean_attention': float(np.mean(attention_np)),
'max_attention': float(np.max(attention_np)),
'attention_entropy': self.fast_entropy(attention_np),
'sparsity': float(np.sum(attention_np < 0.01) / attention_np.size),
'head_diversity': self.compute_head_diversity(attention_np)
}
summary[f'layer_{layer_idx}'] = layer_summary
return summary
def fast_entropy(self, attention_matrix):
"""高速エントロピー計算"""
flat_attention = attention_matrix.flatten()
# ビニング による近似
hist, _ = np.histogram(flat_attention, bins=50, range=(0, 1))
prob = hist / np.sum(hist)
prob = prob[prob > 0] # ゼロ確率を除去
return float(-np.sum(prob * np.log(prob)))
def compute_head_diversity(self, attention_matrix):
"""ヘッド多様性の計算"""
num_heads = attention_matrix.shape[0]
if num_heads < 2:
return 0.0
# ヘッド間のコサイン類似度
similarities = []
for i in range(num_heads):
for j in range(i+1, num_heads):
head_i = attention_matrix[i].flatten()
head_j = attention_matrix[j].flatten()
# コサイン類似度
cosine_sim = np.dot(head_i, head_j) / (np.linalg.norm(head_i) * np.linalg.norm(head_j))
similarities.append(cosine_sim)
# 多様性 = 1 - 平均類似度
return float(1.0 - np.mean(similarities))
def batch_analyze(self, texts, batch_size=32):
"""分散バッチ分析"""
# テキストをバッチに分割
text_batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
results = []
with ProcessPoolExecutor(max_workers=self.num_workers,
initializer=self.init_worker) as executor:
# 非同期実行
future_to_batch = {
executor.submit(self.worker_analyze_attention, batch): batch
for batch in text_batches
}
# 結果収集
for future in as_completed(future_to_batch):
batch_results = future.result()
results.extend(batch_results)
# 進捗表示
completed_batches = len([f for f in future_to_batch if f.done()])
total_batches = len(text_batches)
print(f"完了: {completed_batches}/{total_batches} バッチ")
return results
9.2.2 Web API での提供
from flask import Flask, request, jsonify
import asyncio
import aiohttp
app = Flask(__name__)
class AttentionVisualizationAPI:
def __init__(self, model_path):
self.analyzer = DistributedAttentionAnalyzer(model_path)
async def analyze_text_async(self, text):
"""非同期テキスト分析"""
loop = asyncio.get_event_loop()
# CPUバウンドなタスクを別スレッドで実行
result = await loop.run_in_executor(
None,
self.analyzer.worker_analyze_attention,
[text]
) return result[0] if result else None @app.route(‘/analyze’, methods=[‘POST’]) def analyze_attention(): “””アテンション分析エンドポイント””” try: data = request.json text = data.get(‘text’, ”) if not text: return jsonify({‘error’: ‘✓テキストが提供されていません’}), 400 if len(text) > 5000: return jsonify({‘error’: ‘テキストが長すぎます(5000文字以内)’}), 400 # 分析実行 analyzer = AttentionVisualizationAPI(‘bert-base-uncased’) result = asyncio.run(analyzer.analyze_text_async(text)) if result[‘status’] == ‘success’: return jsonify({ ‘status’: ‘success’, ‘attention_summary’: result[‘attention_summary’], ‘processing_time’: result.get(‘processing_time’, ‘N/A’) }) else: return jsonify({ ‘status’: ‘error’, ‘error’: result.get(‘error’, ‘不明なエラー’) }), 500 except Exception as e: return jsonify({‘error’: f’サーバーエラー: {str(e)}’}), 500 @app.route(‘/batch_analyze’, methods=[‘POST’]) def batch_analyze_attention(): “””バッチアテンション分析エンドポイント””” try: data = request.json texts = data.get(‘texts’, []) if not texts or not isinstance(texts, list): return jsonify({‘error’: ‘テキストリストが必要です’}), 400 if len(texts) > 100: return jsonify({‘error’: ‘バッチサイズが大きすぎます(100件以内)’}), 400 analyzer = AttentionVisualizationAPI(‘bert-base-uncased’) results = analyzer.analyzer.batch_analyze(texts) # 成功/失敗の統計 success_count = sum(1 for r in results if r[‘status’] == ‘success’) error_count = len(results) – success_count return jsonify({ ‘status’: ‘completed’, ‘results’: results, ‘statistics’: { ‘total’: len(results), ‘success’: success_count, ‘errors’: error_count } }) except Exception as e: return jsonify({‘error’: f’バッチ処理エラー: {str(e)}’}), 500 if __name__ == ‘__main__’: app.run(host=’0.0.0.0′, port=5000, debug=False)
結論
Transformerアテンション機構の可視化は、単なる技術的な実装課題を超えて、AI システムの解釈可能性と信頼性を向上させる重要な研究分野です。本記事では、数学的基盤から実装詳細、実用的応用、そしてプロダクション環境での展開まで、包括的な知識を提供しました。
特に重要なのは、アテンション可視化の限界を理解し、適切な文脈で活用することです。高いアテンション重みが必ずしも高い重要度を意味しないこと、複数層・複数ヘッドの相互作用の複雑性、そして計算コストとスケーラビリティの課題など、実装時に考慮すべき要素は多岐にわたります。
今後のAI技術の発展において、より大規模で複雑なモデルが登場する中で、アテンション可視化技術もさらなる進化が期待されます。3次元可視化、リアルタイム分析、因果推論との統合など、新しい手法の開発により、AI システムの「ブラックボックス」問題の解決に向けた重要な進歩が期待できます。
実践的な観点からは、プロダクション環境での実装において、パフォーマンス最適化、エラーハンドリング、分散処理対応などの工学的な考慮事項が重要であることを強調しました。これらの知識を活用することで、研究段階のプロトタイプから実用的なシステムへの橋渡しが可能になるでしょう。
最終的に、アテンション可視化は AI システムの透明性と説明責任を向上させる重要なツールです。適切な理解と実装により、より信頼性の高い AI アプリケーションの開発に貢献することが期待されます。
参考文献
- Vaswani, A., et al. (2017). “Attention is All You Need.” Advances in Neural Information Processing Systems.
- Clark, K., et al. (2019). “What Does BERT Look At? An Analysis of BERT’s Attention.” BlackboxNLP Workshop.
- Rogers, A., et al. (2020). “A Primer on Neural Network Models for Natural Language Processing.” Journal of Artificial Intelligence Research.
- Jain, S., & Wallace, B. C. (2019). “Attention is not Explanation.” NAACL-HLT.
- Serrano, S., & Smith, N. A. (2019). “Is Attention Interpretable?” ACL.
- Voita, E., et al. (2019). “Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned.” ACL.
- Abnar, S., & Zuidema, W. (2020). “Quantifying Attention Flow in Transformers.” ACL.
著者について
本記事は、元Google Brain研究員で現在AIスタートアップのCTOを務める筆者の実体験に基づいて執筆されました。大規模言語モデルの研究開発から実用化まで、幅広い経験を活かして、理論と実践の両面から Transformer アテンション機構の可視化技術を解説しています。