Skip to content

Commit

Permalink
Update notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
cdpierse committed Feb 23, 2021
1 parent c5cd147 commit 66538ff
Showing 1 changed file with 217 additions and 114 deletions.
331 changes: 217 additions & 114 deletions notebooks/multiclass_classification_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -199,63 +199,72 @@
{
"data": {
"text/plain": [
"[('BOS_TOKEN', 0.0),\n",
" ('Stocks', -0.004456741467730228),\n",
" ('ended', 0.13193411668768767),\n",
" ('a', 0.056411549659932295),\n",
" ('choppy', -0.012832653695455687),\n",
" ('session', -0.13861638195151932),\n",
" ('mixed', 0.046772014985904914),\n",
" ('as', 0.004562625031740414),\n",
" ('investors', -0.02196100971228579),\n",
" ('digested', 0.3463117699963133),\n",
" ('a', -0.15007237744066892),\n",
" ('host', -0.2723240045062653),\n",
" ('of', 0.021451277365530166),\n",
" ('corporate', 0.19881984653293971),\n",
" ('earnings', 0.08638392319925758),\n",
" ('results', 0.20302100667530776),\n",
" ('and', 0.21544563171165798),\n",
" ('considered', 0.16101577787308416),\n",
" ('policymakers’', 0.026943727362501158),\n",
" ('next', -0.07904325843448648),\n",
" ('moves', 0.23334843054630117),\n",
" ('to', 0.037448353157484925),\n",
" ('support', -0.04450222286965038),\n",
" ('the', 0.04380056195019436),\n",
" ('still', 0.15665500295281176),\n",
" ('virus-stricken', -0.04790115963194608),\n",
" ('economy', 0.07189596512282359),\n",
" ('.', 0.02121638714028611),\n",
" ('The', -0.2329136487423103),\n",
" ('S&P', 0.04750629281304725),\n",
" ('500', -0.15105901317108278),\n",
" ('shook', -0.18294099820643525),\n",
" ('off', 0.316082550830697),\n",
" ('earlier', -0.076398177169727),\n",
" ('declines', 0.030009315889644415),\n",
" ('to', 0.08982719113716602),\n",
" ('narrowly', -0.1011465651299368),\n",
" ('eke', -0.1702941002616553),\n",
" ('out', -0.01952184223907717),\n",
" ('a', -0.011227586017313017),\n",
" ('record', 0.01133963563437068),\n",
" ('closing', -0.06897702495823117),\n",
" ('high', 0.09192553992692651),\n",
" ('.', -0.07352340021075736),\n",
" ('The', 0.17522286083591584),\n",
" ('Dow', -0.05732253543843262),\n",
" ('ended', -0.16490589953495466),\n",
" ('a', -0.04211285881867463),\n",
" ('tick', 0.04642973321019768),\n",
" ('below', 0.14323973665558232),\n",
" ('its', -0.061569071265021987),\n",
" ('recent', 0.1520246427799714),\n",
" ('record', -0.022675238424145417),\n",
" ('closing', 0.04157891279787718),\n",
" ('level', 0.1173478877671702),\n",
" ('.', 0.06642661005463817),\n",
" ('EOS_TOKEN', 0.02858934938430591)]"
"[('[CLS]', 0.0),\n",
" ('stocks', -0.004456745041476635),\n",
" ('ended', 0.13193417174421274),\n",
" ('a', 0.05641152917733269),\n",
" ('chop', -0.012832735630589245),\n",
" ('##py', -0.13861640177798698),\n",
" ('session', 0.04677203070064154),\n",
" ('mixed', 0.00456270865206808),\n",
" ('as', -0.02196107395623291),\n",
" ('investors', 0.3463117654678203),\n",
" ('digest', -0.1500723809400845),\n",
" ('##ed', -0.27232389332855483),\n",
" ('a', 0.02145133233175586),\n",
" ('host', 0.19881990197329233),\n",
" ('of', 0.08638398072671191),\n",
" ('corporate', 0.20302091157480334),\n",
" ('earnings', 0.2154456398173302),\n",
" ('results', 0.16101571665374576),\n",
" ('and', 0.026943728367538014),\n",
" ('considered', -0.07904327915567533),\n",
" ('policy', 0.23334850964528223),\n",
" ('##makers', 0.03744834172027172),\n",
" ('’', -0.044502146745965795),\n",
" ('next', 0.04380059324826015),\n",
" ('moves', 0.15665502037005144),\n",
" ('to', -0.0479011674293024),\n",
" ('support', 0.07189598606497517),\n",
" ('the', 0.021216427378130336),\n",
" ('still', -0.23291363921919042),\n",
" ('virus', 0.047506291065723814),\n",
" ('-', -0.15105902994209225),\n",
" ('stricken', -0.1829409892474084),\n",
" ('economy', 0.3160825260114336),\n",
" ('.', -0.0763981593059189),\n",
" ('the', 0.030009362468362923),\n",
" ('s', 0.08982718910250631),\n",
" ('&', -0.10114649727856133),\n",
" ('p', -0.17029413762276835),\n",
" ('500', -0.019521830267569346),\n",
" ('shook', -0.011227632667560073),\n",
" ('off', 0.011339665722377214),\n",
" ('earlier', -0.06897708396530404),\n",
" ('declines', 0.09192554579385859),\n",
" ('to', -0.0735234673514767),\n",
" ('narrowly', 0.175222857161181),\n",
" ('ek', -0.057322578552753085),\n",
" ('##e', -0.16490588015574514),\n",
" ('out', -0.04211287407983076),\n",
" ('a', 0.04642970085085266),\n",
" ('record', 0.14323972059147766),\n",
" ('closing', -0.06156912300554976),\n",
" ('high', 0.15202475415471225),\n",
" ('.', -0.022675231849007386),\n",
" ('the', 0.0415788626061664),\n",
" ('dow', 0.11734785896868836),\n",
" ('ended', 0.06642663475176527),\n",
" ('a', 0.028589349365118952),\n",
" ('tick', 0.02279932426731028),\n",
" ('below', -0.017487545974190215),\n",
" ('its', -0.10996836616475462),\n",
" ('recent', 0.1953634357235276),\n",
" ('record', 0.05400907087801813),\n",
" ('closing', -0.020406808200502915),\n",
" ('level', -0.07232510452252366),\n",
" ('.', 0.033205229401471324),\n",
" ('[SEP]', 0.0)]"
]
},
"execution_count": 9,
Expand Down Expand Up @@ -323,63 +332,72 @@
{
"data": {
"text/plain": [
"[('BOS_TOKEN', 0.0),\n",
" ('Stocks', -0.013076093437139466),\n",
" ('ended', 0.21332438787938696),\n",
" ('a', -0.06501736140537122),\n",
" ('choppy', -0.05377650405681654),\n",
" ('session', -0.13618844604494268),\n",
" ('mixed', 0.0524596274309297),\n",
" ('as', 0.11133689370437873),\n",
" ('investors', 0.029009141422987397),\n",
" ('digested', 0.35082490371802794),\n",
" ('a', -0.03137677849547649),\n",
" ('host', -0.0824256120985008),\n",
" ('of', -0.12824966411558325),\n",
" ('corporate', 0.1932059939797767),\n",
" ('earnings', 0.14453150173440635),\n",
" ('results', 0.1749567986388768),\n",
" ('and', 0.18457682852945065),\n",
" ('considered', 0.16502232930342287),\n",
" ('policymakers’', 0.05724637287425622),\n",
" ('next', 0.04470654954517382),\n",
" ('moves', 0.4196253362060768),\n",
" ('to', 0.0772446951479946),\n",
" ('support', 0.06417823661238761),\n",
" ('the', 0.08633115904130688),\n",
" ('still', 0.1391353262265933),\n",
" ('virus-stricken', 0.06091509898963959),\n",
" ('economy', 0.12192535620726684),\n",
" ('.', -0.09255242912925289),\n",
" ('The', -0.07162788259214121),\n",
" ('S&P', 0.10373776462033504),\n",
" ('500', -0.05428224389987557),\n",
" ('shook', -0.067379712805794),\n",
" ('off', 0.28345425712356614),\n",
" ('earlier', -0.04093528552305287),\n",
" ('declines', 0.020150030561753855),\n",
" ('to', 0.24218183633261192),\n",
" ('narrowly', -0.12194819122978925),\n",
" ('eke', -0.12596020671255193),\n",
" ('out', -0.04975339542811973),\n",
" ('a', 0.002074405522171988),\n",
" ('record', -0.001946865510249872),\n",
" ('closing', 0.009730553882977411),\n",
" ('high', 0.09915560429706247),\n",
" ('.', 0.0030853513515697907),\n",
" ('The', 0.07999628385872036),\n",
" ('Dow', -0.16172046681749222),\n",
" ('ended', -0.11540646760210642),\n",
" ('a', -0.04021302467665264),\n",
" ('tick', -0.03195752102771043),\n",
" ('below', 0.10409122357310087),\n",
" ('its', -0.014241891952442663),\n",
" ('recent', 0.13736952656692017),\n",
" ('record', -0.010935398916344388),\n",
" ('closing', -0.00479973377612086),\n",
" ('level', 0.03672195907512466),\n",
" ('.', 0.17244671984950705),\n",
" ('EOS_TOKEN', -0.06339452698593473)]"
"[('[CLS]', 0.0),\n",
" ('stocks', -0.013076095817261638),\n",
" ('ended', 0.21332437398379314),\n",
" ('a', -0.0650173968095115),\n",
" ('chop', -0.053776494369439945),\n",
" ('##py', -0.1361883302504456),\n",
" ('session', 0.05245961920595882),\n",
" ('mixed', 0.11133688837284314),\n",
" ('as', 0.0290092563922162),\n",
" ('investors', 0.3508248779709538),\n",
" ('digest', -0.031376772246840434),\n",
" ('##ed', -0.08242556105918372),\n",
" ('a', -0.1282496697255458),\n",
" ('host', 0.1932059547552784),\n",
" ('of', 0.14453142789631288),\n",
" ('corporate', 0.174956812088488),\n",
" ('earnings', 0.18457683227048555),\n",
" ('results', 0.16502234537811186),\n",
" ('and', 0.057246362109120805),\n",
" ('considered', 0.044706561381177),\n",
" ('policy', 0.419625329483318),\n",
" ('##makers', 0.07724468902782176),\n",
" ('’', 0.0641782420831589),\n",
" ('next', 0.0863311384704364),\n",
" ('moves', 0.13913528586360044),\n",
" ('to', 0.06091511183016304),\n",
" ('support', 0.12192537817088814),\n",
" ('the', -0.09255245322690696),\n",
" ('still', -0.07162789933654426),\n",
" ('virus', 0.10373773961955873),\n",
" ('-', -0.054282238520226525),\n",
" ('stricken', -0.06737975605130951),\n",
" ('economy', 0.2834542905917583),\n",
" ('.', -0.04093521309644109),\n",
" ('the', 0.02014996912513288),\n",
" ('s', 0.2421818723864095),\n",
" ('&', -0.1219481894035061),\n",
" ('p', -0.1259602374090837),\n",
" ('500', -0.049753348446608135),\n",
" ('shook', 0.0020743908453390346),\n",
" ('off', -0.0019468691081433412),\n",
" ('earlier', 0.009730542394735642),\n",
" ('declines', 0.09915561855443226),\n",
" ('to', 0.0030853149473039927),\n",
" ('narrowly', 0.07999627029088487),\n",
" ('ek', -0.16172049279622133),\n",
" ('##e', -0.11540647722477954),\n",
" ('out', -0.040213103001637024),\n",
" ('a', -0.03195762072555721),\n",
" ('record', 0.10409121946019516),\n",
" ('closing', -0.014241849296305585),\n",
" ('high', 0.13736957349377404),\n",
" ('.', -0.010935425504967166),\n",
" ('the', -0.004799759112557838),\n",
" ('dow', 0.0367219813799116),\n",
" ('ended', 0.1724467542111746),\n",
" ('a', -0.06339459661998191),\n",
" ('tick', 0.010420206121999408),\n",
" ('below', -0.033368891375772494),\n",
" ('its', -0.023351263505381116),\n",
" ('recent', 0.18783244048312728),\n",
" ('record', 0.08150135064652181),\n",
" ('closing', 0.04479632753264403),\n",
" ('level', -0.05045663823013374),\n",
" ('.', -0.017507157626781053),\n",
" ('[SEP]', 0.0)]"
]
},
"execution_count": 12,
Expand Down Expand Up @@ -466,6 +484,91 @@
"multiclass_explainer.visualize()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('[CLS]', 0.0),\n",
" ('stocks', -0.09201294942361385),\n",
" ('ended', 0.04407271448219919),\n",
" ('a', 0.16321588179875865),\n",
" ('chop', 0.04632777499253846),\n",
" ('##py', -0.25187041889583356),\n",
" ('session', -0.03191380380572038),\n",
" ('mixed', -0.08486779783022193),\n",
" ('as', 0.0990228369494118),\n",
" ('investors', 0.15616562033570283),\n",
" ('digest', -0.05320900579431406),\n",
" ('##ed', -0.1943873312313048),\n",
" ('a', 0.0701817057427743),\n",
" ('host', 0.02375706910921956),\n",
" ('of', 0.1987030179538432),\n",
" ('corporate', 0.21816854083740192),\n",
" ('earnings', 0.1034336318064541),\n",
" ('results', 0.25163873610032916),\n",
" ('and', -0.009519313354122897),\n",
" ('considered', 0.14355278315043185),\n",
" ('policy', 0.25695840883129323),\n",
" ('##makers', -0.0067529899810760685),\n",
" ('’', 0.042844477015168185),\n",
" ('next', -0.09635841194957916),\n",
" ('moves', 0.04030577027611355),\n",
" ('to', -0.05089957413373526),\n",
" ('support', -0.13898610509971007),\n",
" ('the', 0.16586873963353052),\n",
" ('still', -0.1971275206247004),\n",
" ('virus', 0.048957715766651075),\n",
" ('-', 0.012727397177708568),\n",
" ('stricken', -0.30602791462280804),\n",
" ('economy', 0.3048668000494108),\n",
" ('.', 0.10547830871969005),\n",
" ('the', 0.09001971948094911),\n",
" ('s', -0.07332696135899623),\n",
" ('&', 0.07849644128355428),\n",
" ('p', -0.04898511320126377),\n",
" ('500', -0.12436827518140792),\n",
" ('shook', -0.11176538223886069),\n",
" ('off', -0.14822268555576723),\n",
" ('earlier', 0.0005296906093799398),\n",
" ('declines', 0.09579157569095596),\n",
" ('to', -0.005555441254145121),\n",
" ('narrowly', 0.1891132151430131),\n",
" ('ek', 0.015248211205243483),\n",
" ('##e', 0.0822380826582115),\n",
" ('out', -0.10995897797408631),\n",
" ('a', -0.0038717444955518924),\n",
" ('record', 0.1155507215289045),\n",
" ('closing', 0.01452437989012333),\n",
" ('high', -0.0147534023245287),\n",
" ('.', 0.004128039160521726),\n",
" ('the', 0.0059524965492529355),\n",
" ('dow', 0.06239211946756769),\n",
" ('ended', -0.03860381346838376),\n",
" ('a', -0.16989376804745981),\n",
" ('tick', -0.013011859417387965),\n",
" ('below', 0.024458771568195883),\n",
" ('its', 0.075779587950879),\n",
" ('recent', 0.1642747622333274),\n",
" ('record', 0.08250841362265618),\n",
" ('closing', -0.02657225572482565),\n",
" ('level', 0.010329179592552197),\n",
" ('.', -0.18554493843258346),\n",
" ('[SEP]', 0.0)]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"attributions.word_attributions"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 66538ff

Please sign in to comment.