From ab399b6fb34d7271be4b5f00f6bfd5a22d691464 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 6 Jan 2025 17:39:33 -0100 Subject: [PATCH] add self-attention.typ converted from self-attention.tex for visualizing self-attention mechanism in neural networks --- assets/self-attention/self-attention.typ | 105 +++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 assets/self-attention/self-attention.typ diff --git a/assets/self-attention/self-attention.typ b/assets/self-attention/self-attention.typ new file mode 100644 index 0000000..8b504ea --- /dev/null +++ b/assets/self-attention/self-attention.typ @@ -0,0 +1,105 @@ +#import "@preview/cetz:0.3.1": canvas, draw +#import draw: line, content, circle, rect, bezier + +#set page(width: auto, height: auto, margin: 3pt) + +#canvas({ + // Define spacing constants + let node-spacing = 2 + let layer-spacing = 2 + let vertical-spacing = 1.3 + + // Input nodes + let y1 = 6 + let y_dots_1 = y1 - vertical-spacing + let yj = y_dots_1 - vertical-spacing + let y_dots_2 = yj - vertical-spacing + let yn = y_dots_2 - vertical-spacing + let arrow_style = (end: "stealth", fill: black, scale: 0.7) + + // First column (input vectors) + content((0, y1), $arrow(e)_1$, name: "arrow1", padding: 2pt) + content((0, y_dots_1), $dots$) + content((0, yj), $arrow(e)_j$, name: "arrowj", padding: 2pt) + content((0, y_dots_2), $dots$) + content((0, yn), $arrow(e)_n$, name: "arrown", padding: 2pt) + + // Second column (attention nodes) + let x2 = layer-spacing + content((x2, y1), $a_phi$, frame: "rect", padding: (3pt, 4pt), name: "attn1") + content((x2, yj), $a_phi$, frame: "rect", padding: (3pt, 4pt), name: "attnj") + content((x2, yn), $a_phi$, frame: "rect", padding: (3pt, 4pt), name: "attnn") + + // Third column (alpha values) + let x3 = x2 + layer-spacing + content((x3, y1), text(fill: rgb(0, 0, 0, 20%))[$alpha_(1j)$], name: "alpha1j", padding: 3pt) + content((x3, yj), $alpha_(j j)$, name: "alphajj", padding: 3pt) + content((x3, yn), text(fill: rgb(0, 0, 0, 60%))[$alpha_(n j)$], name: "alphanj", padding: 3pt) + + // Fourth column (multiplication nodes) + let x4 = x3 + layer-spacing + content((x4, y1), name: "times1", $times$, frame: "circle", padding: 3pt, stroke: .7pt) + content((x4, yj), name: "timesj", $times$, frame: "circle", padding: 3pt, stroke: .7pt) + content((x4, yn), name: "timesn", $times$, frame: "circle", padding: 3pt, stroke: .7pt) + + // Fifth column (sum node) + let x5 = x4 + layer-spacing + content((x5, yj), $Sigma$, frame: "rect", padding: 4pt, name: "sum") + + // Output node + let x6 = x5 + 1 + content((x6, yj), $arrow(e)'_j$, name: "output", padding: 2pt) + + // Draw connections + line("arrow1.east", "attn1.west", mark: arrow_style) + line("arrowj.east", "attnj.west", mark: arrow_style) + line("arrown.east", "attnn.west", mark: arrow_style) + line("arrowj.east", "attn1.west", mark: arrow_style) + line("arrowj.east", "attnn.west", mark: arrow_style) + + line("attn1.east", "alpha1j.west", mark: arrow_style) + line("attnj.east", "alphajj.west", mark: arrow_style) + line("attnn.east", "alphanj.west", mark: arrow_style) + + line("alpha1j.east", "times1.west", mark: arrow_style) + line("alphajj.east", "timesj.west", mark: arrow_style) + line("alphanj.east", "timesn.west", mark: arrow_style) + + line("times1.east", "sum.north-west", mark: arrow_style) + line("timesj.east", "sum.west", mark: arrow_style) + line("timesn.east", "sum.south-west", mark: arrow_style) + + line("sum.east", "output.west", mark: arrow_style) + + // Draw f_psi connections with labels + for (idx, (start, end)) in ( + ("arrow1.east", "times1.south-west"), + ("arrowj.east", "timesj.south-west"), + ("arrown.east", "timesn.south-west"), + ).enumerate(start: 1) { + bezier( + start, + end, + ( + (v1, v2) => { + let (x1, y1, ..) = v1 + let (x2, y2, ..) = v2 + return ((x1 + x2) / 2, (y1 + y2) / 2 - 2) + }, + start, + end, + ), + mark: arrow_style, + stroke: 1pt, + name: "fpsi" + str(idx), + ) + content( + "fpsi" + str(idx) + ".50%", + text(fill: rgb(0, 0, 0, 60%))[$f_psi$], + frame: "rect", + padding: (3pt, 4pt), + name: "fpsi", + fill: white, + ) + } +})