import%20marimo%0A%0A__generated_with%20%3D%20%220.18.4%22%0Aapp%20%3D%20marimo.App(width%3D%22medium%22)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20return%20(mo%2C)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%20Encoding%20and%20Learning%20Relationships%20with%20Graph%20Convolutions%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20What%20is%20a%20graph%20%3F%0A%0A%20%20%20%20If%20you%20wanted%20to%20map%20the%20friendships%20in%20a%20room%2C%20you%20wouldn't%20use%20a%20list%20of%20numbers%20or%20a%20complex%20formula.%20You%20would%20draw%20circles%20for%20people%20and%20lines%20to%20connect%20them.%20This%20is%20a%20graph.%20It%20allows%20us%20to%20move%20beyond%20'how%20much'%20of%20something%20there%20is%2C%20and%20focus%20on%20'who%20is%20connected%20to%20whom'.%0A%0A%20%20%20%20For%20our%20purposes%2C%20it%20suffices%20to%20define%20a%20graph%20as%20%3A%0A%0A%20%20%20%20%3E%20A%20structural%20map%20of%20pairwise%20connections%20between%20discrete%20entities.%0A%0A%20%20%20%20!%5B%5D(%2FUsers%2Fudeshhabaraduwa%2FLibrary%2FCloudStorage%2FGoogleDrive-udesh86%40gmail.com%2FMy%20Drive%2FPhd%2FMisc%20projects%2FGCN%2Fpublic%2Fimage.png)%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.image(src%3D%22friends.png%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Here%2C%20we%20have%20a%20relationship%20*friend%20of*%2C%20these%20are%20the%20**edges**%20of%20our%20graph.%20They%20are%20used%20to%20represent%20the%20relationships%20between%20each%20person%20%2C%20**the%20nodes**.%20Notice%20that%20in%20building%20this%20relationship%20map%2C%20we%20have%20**encoded%20some%20specific%20beliefs**.%0A%0A%20%20%20%201.%20%60Friendship%20is%20not%20a%20one-way%20street%60.%20The%20relationship%20is%20defined%20as%20**undirected**%20(or%20bi-directional)%20implying%20that%20the%20connection%20exists%20in%20both%20directions.%0A%20%20%20%202.%20%60In%20this%20graph%2C%20they%20are%20all%20friends%60.%20The%20graph%20is%20**fully%20connected**.%20In%20a%20bigger%20network%2C%20say%20everyone%20in%20a%20town%2C%20you%20can%20have%20many%20nodes%20that%20are%20not%20connected.%0A%20%20%20%203.%20%60The%20connection%20is%20%22on%20or%20off%22%60.%20There%20is%20no%20%22degree%20of%20friendship%22%20in%20this%20case.%20You%20are%20or%20are%20not%20a%20friend%20-%20it's%20%22not%20complicated%22.%20Imagine%20instead%20that%20we%20define%20degree%20of%20friendship%20as%20the%20number%20of%20years%20they%20have%20known%20each%20other.%20If%20Ash%20has%20known%20Misty%20for%203%20years%20and%20Brock%20for%201%2C%20they%20could%20have%20an%20**edge%20weight**%2C%203%20for%20the%20Edge%20(Ash%2C%20Misty)%20and%201%20for%20Edge(Brock%2C%20Ash).%0A%20%20%20%204.%20What%20is%20an%20%22Ash%22%20%3F.%20We%20have%20this%20abstract%20concept%20of%20a%20person%2C%20but%20how%20is%20it%20defined%20mathematically%3F%20We%20represent%20%22Ash%22%20as%20a%20list%20of%20everything%20we%20know%20about%20him%3A%20he%E2%80%99s%20from%20Pallet%20Town%2C%20he%E2%80%99s%2010%20years%20old%2C%20and%20he%20trains%205%20Pok%C3%A9mon.%20These%20are%20**node%20features**.%20For%20each%20person%20in%20our%20collection%2C%20we%20have%20such%20information.%0A%0A%20%20%20%20Indeed%2C%20some%20relationships%20are%20most%20naturally%20represented%20in%20this%20way.%20Examples%20abound.%20The%20internet%2C%20roads%20between%20cities%2C%20interaction%20between%20muscle%20groups%2C%20etc.%20All%20of%20these%20could%20and%20are%20often%20represented%20in%20this%20way.%0A%0A%20%20%20%20Graphs%20provide%20us%20with%20a%20powerful%20tool%20represent%20abstract%20ideas%20like%20friendship%20in%20a%20computationally%20tractable%20form.%20Then%2C%20we%20can%20do%20interesting%20things%20like%20predicting%20where%20someone%20is%20from%20based%20on%20their%20friendships%20or%20predict%20if%20two%20people%20are%20friends.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Scaling%20up%20to%20a%20real%20dataset%20%3A%20A%20citation%20network.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Our%20Pok%C3%A9mon%20trio%20helped%20us%20define%20the%20'what'%20of%20a%20graph.%20But%20to%20see%20the%20'how'%20of%20a%20Graph%20Convolutional%20Network%2C%20we%20need%20a%20dataset%20with%20more%20complexity.%0A%0A%20%20%20%20In%20the%20code%20below%2C%20we%20use%20the%20Cora%20dataset%2C%20a%20classic%20benchmark%20in%20graph%20machine%20learning.%0A%0A%20%20%20%201.%20Instead%20of%20friends%2C%20we%20have%202%2C708%20scientific%20papers.%0A%20%20%20%202.%20Instead%20of%20social%20ties%2C%20we%20have%205%2C429%20citations%20connecting%20them.%0A%20%20%20%203.%20Instead%20of%20'Age'%20or%20'Home%20Town'%2C%20each%20paper%20has%20a%20list%20of%201%2C433%20features%20representing%20the%20presence%20or%20absence%20of%20specific%20words.%0A%0A%20%20%20%20Just%20as%20you%20might%20guess%20Ash%E2%80%99s%20home%20town%20by%20looking%20at%20his%20friends%2C%20a%20GCN%20will%20try%20to%20guess%20a%20paper's%20research%20category%20by%20looking%20at%20how%20papers%20cite%20each%20other.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20torch%0A%20%20%20%20import%20torch.nn%20as%20nn%0A%20%20%20%20import%20torch.optim%20as%20optim%0A%20%20%20%20import%20torch.nn.functional%20as%20F%0A%20%20%20%20import%20pandas%20as%20pd%0A%20%20%20%20import%20numpy%20as%20np%0A%20%20%20%20import%20matplotlib.pyplot%20as%20plt%0A%20%20%20%20from%20sklearn.preprocessing%20import%20LabelEncoder%0A%20%20%20%20import%20networkx%20as%20nx%0A%20%20%20%20from%20scipy.sparse%20import%20csr_matrix%0A%20%20%20%20import%20scipy.sparse%20as%20sp%0A%20%20%20%20from%20sklearn.model_selection%20import%20train_test_split%0A%20%20%20%20from%20tqdm%20import%20tqdm%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20LabelEncoder%2C%0A%20%20%20%20%20%20%20%20csr_matrix%2C%0A%20%20%20%20%20%20%20%20nn%2C%0A%20%20%20%20%20%20%20%20np%2C%0A%20%20%20%20%20%20%20%20nx%2C%0A%20%20%20%20%20%20%20%20optim%2C%0A%20%20%20%20%20%20%20%20pd%2C%0A%20%20%20%20%20%20%20%20plt%2C%0A%20%20%20%20%20%20%20%20sp%2C%0A%20%20%20%20%20%20%20%20torch%2C%0A%20%20%20%20%20%20%20%20train_test_split%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20For%20a%20given%20paper%2C%20it%20has%20information%20about%20which%20papers%20reference%20it%20and%20which%20papers%20it%20references.%20Each%20paper%20also%20has%20a%20label%0A%0A%20%20%20%20In%20this%20data%2C%20we%20represent%20each%20paper%20**as%20a%20node**.%20Each%20paper%20has%20an%20ID%2C%20a%20label%0A%0A%20%20%20%20Each%20paper%20is%201%20of%207%20classes%20(https%3A%2F%2Fgraphsandnetworks.com%2Fthe-cora-dataset%2F)%0A%0A%20%20%20%20**It's%20not%20the%20case%20that%20every%20paper%20is%20cited%2C%20but%20every%20paper%20cites%20atleast%20one%20other%20paper%20in%20the%20datase.**%20There%20are%20isolated%20nodes%20in%20this%20dataset.%20Each%20has%20atleast%20one%20connection.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.show_code()%20%0A%20%20%20%20%23%20Load%20papers%0A%20%20%20%20%23%20Here%2C%20each%20%20line%20is%20a%20paper%0A%20%20%20%20with%20open('cora%2Fcora.content')%20as%20file%3A%0A%20%20%20%20%20%20%20%20papers%20%3D%20file.readlines()%0A%20%20%20%20%23%20This%20will%20be%20loaded%20tab%20seperated%2C%20so%20let's%20remove%20the%20tabs%0A%20%20%20%20%23%20and%20get%20a%20list%20of%20papers%0A%20%20%20%20papers%20%3D%20%5Bp.split('%5Ct')%20for%20p%20in%20papers%5D%0A%20%20%20%20%23%20The%20first%20element%20is%20the%20paper%20ID%20and%20the%20last%20is%20the%20group%20classification%0A%20%20%20%20%23%20so%20lets%20get%20those%0A%20%20%20%20ids%20%3D%20%5Bp%5B0%5D%20for%20p%20in%20papers%5D%0A%20%20%20%20labels%20%3D%20%5Bp%5B-1%5D%20for%20p%20in%20papers%5D%0A%20%20%20%20ids%5B0%5D%2C%20labels%5B-1%5D%0A%0A%20%20%20%20%23%20let's%20take%20a%20look%20at%20the%20labels%0A%20%20%20%20%23%20use%20set%20to%20get%20unique%20values%0A%20%20%20%20%23%20check%20the%20raw%20strings%0A%20%20%20%20for%20l%20in%20set(labels)%3A%0A%20%20%20%20%20%20%20%20print(repr(l))%0A%20%20%20%20%23%20remove%20the%20new%20line%20characters%0A%20%20%20%20labels%20%3D%20%5Bl.strip()%20for%20l%20in%20labels%5D%0A%20%20%20%20labels%5B0%5D%0A%20%20%20%20return%20labels%2C%20papers%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Let's%20take%20a%20look%20at%20a%20few%20nodes.%20Here%2C%20each%20column%20is%20the%20presence%20or%20absence%20of%20a%20word%20in%20the%20paper%20the%20first%20column%20is%20the%20paper%20ID%20and%20the%20last%20column%20is%20the%20label.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(papers)%3A%0A%0A%20%20%20%20for%20i%20in%20range(5)%3A%0A%20%20%20%20%20%20%20%20print(papers%5Bi%5D)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Let's%20take%20a%20look%20at%20the%20edges.%20Note%20here%20that%20unlike%20the%20friends%20network%2C%20the%20edges%20are%20not%20undirected.%20The%20link%20is%20from%20**citer%20to%20citee**.%20If%20paper%20A%20cites%20paper%20B%2C%20then%20paper%20be%20does%20not%20automatically%20cite%20paper%20A.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20%23%20Load%20the%20citation%20data%0A%20%20%20%20with%20open('cora%2Fcora.cites')%20as%20cites_file%3A%0A%20%20%20%20%20%20%20%20cites%20%3D%20cites_file.readlines()%0A%0A%20%20%20%20%23%20check%20the%20raw%20strings%0A%20%20%20%20%23%20this%20shows%20A%20cites%20B%20as%20A%20%5Ctab%20B%0A%20%20%20%20print(repr(cites%5B0%5D))%0A%20%20%20%20%23%20%20remove%20the%20new%20line%2C%20make%20the%20links%20sublists%0A%20%20%20%20cites%20%3D%20%5Bc.strip().split('%5Ct')%20for%20c%20in%20cites%5D%0A%20%20%20%20return%20(cites%2C)%0A%0A%0A%40app.cell%0Adef%20_(cites)%3A%0A%20%20%20%20%23%20%5Bi%2Cj%5D%20means%20j%20cites%20i.%20%0A%20%20%20%20%23%20This%20only%20has%20the%20incoming%20edges%20to%20all%20the%20papers%0A%20%20%20%20%23%20So%20here%20%20paper%201033%20cites%20paper%2035%2C%20etc.%0A%20%20%20%20cites%5B%3A5%5D%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Building%20the%20graph%20representations%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%20The%20nodes%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(LabelEncoder%2C%20labels%2C%20mo%2C%20papers%2C%20pd)%3A%0A%20%20%20%20mo.show_code()%20%0A%20%20%20%20%23%20let's%20make%20them%20dataframes%20for%20easy%20manipulation%0A%20%20%20%20papers_df%20%3D%20pd.DataFrame(papers)%0A%0A%20%20%20%20papers_df%5Bpapers_df.columns%5B0%3A1433%5D%5D%20%3D%20papers_df%5Bpapers_df.columns%5B0%3A1433%5D%5D.astype(int)%0A%20%20%20%20%23%20Let's%20call%20%20papers%20X%20and%20remove%20the%20ID%20and%20label%0A%20%20%20%20%23%20Each%20row%20represents%20a%20node%0A%20%20%20%20X%20%3D%20papers_df.drop(columns%20%3D%20%5B0%2C1434%5D)%20%20%20%20%20%20%0A%20%20%20%20%23%20do%20integer%20encoding%20of%20the%20text%20labels%20%3A%20map%20the%20labels%20to%20integers%0A%20%20%20%20le%20%3D%20LabelEncoder()%0A%20%20%20%20y%20%20%3D%20le.fit_transform(labels)%0A%20%20%20%20papers_df.head()%0A%20%20%20%20%23%20features%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20print(X.head())%0A%20%20%20%20%23%20labels%0A%20%20%20%20print(y%5B%3A5%5D)%0A%20%20%20%20return%20X%2C%20papers_df%2C%20y%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%20The%20edges%20%3A%20An%20adjacency%20matrix%0A%0A%20%20%20%20The%20adjacency%20matrix%20contains%20all%20the%20information%20about%20how%20the%20nodes%20in%20a%20graph%20are%20connected%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(cites%2C%20papers_df)%3A%0A%20%20%20%20%23%20How%20many%20papers%20are%20there%20%3F%0A%20%20%20%20N_papers%20%3D%20papers_df.shape%5B0%5D%0A%20%20%20%20%23%20a%20dictionary%20for%20reindexing%20%0A%20%20%20%20%23%20Why%20%3F%20Because%20these%20papers%20are%20a%20sample%20of%20a%20bigger%20dataset%2C%20so%20the%20paper%20IDs%20are%20not%20contiguous%20(i.e.%2C%20there%20are%20gaps%20in%20the%20id%20numbers).%0A%20%20%20%20%23%20so%20we%20map%20the%20paper%20IDs%20to%20contiguous%20integers%0A%20%20%20%20to_new_index%20%3D%20%7Bint(paper_id)%20%3A%20index%20for%20index%2C%20paper_id%20in%20enumerate(papers_df%5B0%5D)%7D%0A%20%20%20%20index_to_paper%20%3D%20%7B%20int(index)%20%3A%20int(paper_id)%20for%20paper_id%2C%20index%20in%20to_new_index.items()%7D%0A%0A%20%20%20%20%23%20map%20all%20the%20cites%20to%20the%20new%20index%0A%20%20%20%20cites_new_index%20%3D%20%5B%5Bto_new_index%5Bint(cite%5B0%5D)%5D%2C%20to_new_index%5Bint(cite%5B1%5D)%5D%5D%20for%20cite%20in%20cites%5D%0A%0A%20%20%20%20for%20c%20in%20cites_new_index%5B%3A5%5D%3A%0A%20%20%20%20%20%20%20%20print(f%22paper%20%7Bc%5B1%5D%7D%20--cites--%3E%20paper%20%7Bc%5B0%5D%7D%22)%0A%20%20%20%20return%20N_papers%2C%20cites_new_index%2C%20index_to_paper%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20now%20create%20a%20matrix%20representation%20for%20our%20adjacency%20information.%0A%0A%20%20%20%20What%20we%20want%20is%20something%20like%20this%3A%0A%20%20%20%20where%20each%20column%20represents%20a%20citing%20paper%20and%20each%20row%20represents%20a%20cited%20paper.%0A%0A%20%20%20%20%7C%20%20%7C%20**Paper%200**%20%7C%20**Paper%201**%20%7C%20**Paper%202**%20%7C%20**Paper%203**%20%7C%20**Paper%204**%20%7C%0A%20%20%20%20%7C%3A---%3A%7C%3A---%3A%7C%3A---%3A%7C%3A---%3A%7C%3A---%3A%7C%3A---%3A%7C%0A%20%20%20%20%7C%20**Paper%200**%20%7C%200%20%7C%200%20%7C%200%20%7C%201%20%7C%200%20%7C%0A%20%20%20%20%7C%20**Paper%201**%20%7C%201%20%7C%200%20%7C%200%20%7C%200%20%7C%200%20%7C%0A%20%20%20%20%7C%20**Paper%202**%20%7C%200%20%7C%201%20%7C%200%20%7C%200%20%7C%200%20%7C%0A%20%20%20%20%7C%20**Paper%203**%20%7C%201%20%7C%200%20%7C%201%20%7C%200%20%7C%200%20%7C%0A%20%20%20%20%7C%20**Paper%204**%20%7C%200%20%7C%200%20%7C%200%20%7C%200%20%7C%200%20%7C%0A%0A%20%20%20%20**Reading%20the%20matrix%3A**%0A%20%20%20%20-%20**Columns**%20%E2%86%92%20Citing%20papers%0A%20%20%20%20-%20**Rows**%20%E2%86%92%20Cited%20papers%0A%20%20%20%20-%20A%20%601%60%20at%20entry%20(row%2C%20col)%20means%20paper%20%60col%60%20cites%20paper%20%60row%60%0A%0A%20%20%20%20If%20we%20have%20a%20very%20larg%20network%2C%20having%20a%20massive%20matrix%20with%20a%20lot%20of%200s%20is%20unncessary%0A%20%20%20%20https%3A%2F%2Fyoutu.be%2FQi7FcjN7nsc%3Fsi%3D3_lKyS6MlgD2IWKy%0A%20%20%20%20We%20make%20a%20csr-representation.%20Instead%20of%20holding%20the%20full%20matrix%2C%20we%20maintain%20%20information%20that%0A%20%20%20%20defines%20it.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(N_papers%2C%20cites_new_index%2C%20csr_matrix%2C%20np%2C%20sp)%3A%0A%20%20%20%20%23%20turn%20the%20cites%20into%20a%20numpy%20array%3A%0A%20%20%20%20cites_new_index_arr%20%3D%20np.array(cites_new_index)%0A%0A%20%20%20%20%23%20cites_new_index_arr%20looks%20like%3A%0A%20%20%20%20%23%20%5B%5Bcited%2C%20citer%5D%2C%0A%20%20%20%20%23%20%20%5Bcited%2C%20citer%5D%2C%0A%20%20%20%20%23%20%20......%5D%0A%20%20%20%20%23%20We%20get%20the%20cited%20papers%20%3A%20all%20the%20values%20in%20the%20first%20column%0A%20%20%20%20row%20%3D%20%20cites_new_index_arr%5B%3A%2C0%5D%0A%20%20%20%20%23%20We%20get%20the%20citing%20papers%20%3A%20all%20the%20values%20in%20the%20second%20column%0A%20%20%20%20col%20%3D%20cites_new_index_arr%5B%3A%2C1%5D%0A%20%20%20%20%23%20We%20want%20to%20place%20a%20one%20at%20each%20%5Bcited%2C%20citer%5D%20element%0A%20%20%20%20values%20%3D%20np.ones(len(col))%20%0A%20%20%20%20%23%20This%20puts%20a%201%20in%20every%20%5Bcited%2C%20citer%5D%20element%2C%20holding%20only%20these%20non-zero%20elements%20which%20conserves%20space.%0A%20%20%20%20A%20%20%3D%20sp.csr_matrix((values%2C(row%2Ccol))%2C%20shape%3D(N_papers%2CN_papers))%0A%20%20%20%20A%20%3D%20csr_matrix(A)%20%20%0A%20%20%20%20%23%20It%20still%20has%20the%20same%20shape%20technically%2C%20but%20it%20is%20a%20sparse%20matrix%0A%20%20%20%20print(A.shape)%0A%0A%0A%20%20%20%20%23%20%20Node%20degree%0A%20%20%20%20%23%20The%20node%20degree%20answers%20the%20question%20%3A%20how%20many%20connections%20does%20a%20given%20node%20have%3F%0A%0A%20%20%20%20%23%20Symmetrize%20the%20original%20citation%20adjacency%3A%0A%20%20%20%20%23%20This%20simply%20removes%20the%20directions%20of%20the%20edges%20(e.g.%2C%20paper%203%20--cites--%3E%20paper%204)%20and%20makes%20it%20an%20undirected%20edge%20(paper%203%20and%20paper%204%20are%20linked.).%20%0A%20%20%20%20%23%20We%20treat%20citations%20as%20undirected%20to%20allow%20information%20to%20flow%20'backwards'%20from%20a%20newer%20paper%20to%20the%20older%20seminal%20work%20it%20cites.%0A%20%20%20%20%23%20This%20is%20necessary%20for%20some%20of%20the%20underlying%20linear%20algebra%20assumptions%20to%20apply.%0A%20%20%20%20A%20%3D%20A%20%2B%20A.T%0A%20%20%20%20%23%20We%20add%20self-loops%20(a%20paper%20cites%20iself)%20because%20we%20will%20need%20this%20for%20the%20normalization%20step%20coming%20next%0A%20%20%20%20A%20%20%2B%3D%20np.eye(N_papers)%0A%0A%20%20%20%20%23%20Since%20this%20is%20now%20symmetric%2C%20we%20can%20simply%20sum%20along%20either%20of%20the%20axes%0A%20%20%20%20degrees%20%3D%20A.sum(axis%3D0).A1%20%23%20.A1%20to%20get%20a%20numpy%20array%20%0A%0A%20%20%20%20%23%20We%20create%20a%20matrix%20D%20which%20is%20a%20diagonal%20matrix%20(only%20the%20main%20diagonal%20has%20non-zero%20values).%0A%20%20%20%20%23%20This%20represents%20the%20edges%20that%20a%20given%20paper%20has%0A%20%20%20%20%23%20We%20could%20also%20make%20it%20a%20sparse%20representation%20but%20we%20skip%20this%20for%20now%0A%20%20%20%20D%20%3D%20np.diag(degrees)%0A%20%20%20%20adj_matrix%20%3D%20csr_matrix(A)%0A%20%20%20%20return%20D%2C%20adj_matrix%2C%20col%2C%20row%2C%20values%0A%0A%0A%40app.cell%0Adef%20_(cites_new_index%2C%20nx%2C%20plt)%3A%0A%20%20%20%20%23%20Create%20a%20directed%20graph%0A%20%20%20%20G%20%3D%20nx.DiGraph()%0A%20%20%20%20import%20random%0A%20%20%20%20%23%20Add%20edges%3A%20the%20edge%20is%20directed%20from%20%22citing%22%20to%20%22cited%22%0A%20%20%20%20%23%20lets%20look%20at%2015%20edges%0A%20%20%20%20for%20cited%2C%20citing%20in%20cites_new_index%5B200%3A215%5D%3A%0A%20%20%20%20%20%20%20%20G.add_edge(citing%2C%20cited)%0A%20%20%20%20%23%20Generate%20a%20layout%20for%20our%20nodes%20(here%20we%20use%20a%20circular%20layout)%0A%20%20%20%20pos%20%3D%20nx.circular_layout(G)%0A%0A%20%20%20%20%23%20Draw%20nodes%2C%20edges%2C%20and%20labels%0A%20%20%20%20plt.figure(figsize%3D(8%2C%208))%0A%20%20%20%20nx.draw_networkx_nodes(G%2C%20pos%2C%20node_color%3D'skyblue'%2C%20node_size%3D700)%0A%20%20%20%20nx.draw_networkx_edges(G%2C%20pos%2C%20arrowstyle%3D'-%3E'%2C%20arrowsize%3D20%2C%20edge_color%3D'gray')%0A%20%20%20%20nx.draw_networkx_labels(G%2C%20pos%2C%20font_size%3D12%2C%20font_color%3D'black')%0A%0A%20%20%20%20plt.title(%22Citation%20Network%20Visualization%22)%0A%20%20%20%20plt.axis('off')%0A%20%20%20%20plt.show()%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(D)%3A%0A%20%20%20%20print(D)%0A%20%20%20%20print(D.shape)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(cites%2C%20index_to_paper)%3A%0A%20%20%20%20%23%20lets%20check%0A%20%20%20%20%23%20%20paperID%20of%20index%200%0A%20%20%20%20p%20%3D%20index_to_paper%5B0%5D%0A%20%20%20%20print(p)%0A%20%20%20%20for%20c_1%20in%20cites%3A%0A%20%20%20%20%20%20%20%20if%20c_1%5B0%5D%20%3D%3D%20str(p)%20or%20c_1%5B1%5D%20%3D%3D%20str(p)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20print(c_1)%0A%0A%20%20%20%20%23%20%20This%20is%20correct.%20There%20are%205%20edges%20total%20for%20paper%20in%20a%5B0%2C0%5D%2C%20but%20we%20see%206%20in%20the%20first%20element%20of%20D%20because%20we%20added%20self%20loops.%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(D%2C%20X%2C%20adj_matrix)%3A%0A%20%20%20%20%23%20At%20this%20point%2C%20we%20have%20the%20following%20important%20coponents%0A%20%20%20%20%23%20X%20%3A%20These%20are%20the%20node%20features%20%0A%20%20%20%20%23%20%20%20%3A%20This%20has%20shape%20N_papers%20x%20BOW%20features%0A%20%20%20%20print(%22%20X%20shape%20%3A%20%22%2C%20X.shape)%0A%0A%20%20%20%20%23%20A%20%3A%20This%20is%20the%20adacency%20matrix.%20It%20contains%20the%20information%20about%20which%20papers%20are%20connected%20to%20each%20other%0A%20%20%20%20%23%20%20%20%3A%20This%20has%20shape%20N_papers%20x%20N_papers%0A%20%20%20%20print(%22%20A%20shape%20%3A%20%22%2C%20adj_matrix.shape)%0A%0A%20%20%20%20%23%20D%20%3A%20This%20is%20the%20node%20degree%20matrix.%20It%20contains%20the%20information%20about%20how%20many%20incoming%20and%20outgoing%20edges%20each%20paper%20has%0A%20%20%20%20%23%20%20%20%3A%20This%20has%20shape%20N_papers%20x%20N_papers%0A%20%20%20%20print(%22%20D%20shape%20%3A%20%22%2C%20D.shape)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Graph%20Convolutions%20TL%3BDR%0A%0A%0A%20%20%20%20A%20Graph%20convolution%20layer%20(https%3A%2F%2Farxiv.org%2Fpdf%2F1609.02907)%20applies%20this%20function%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20H%5E%7B(l%2B1)%7D%20%3D%20%5Csigma%20%5Cleft(%20%5Ctilde%7BD%7D%5E%7B-%5Cfrac%7B1%7D%7B2%7D%7D%20%5Ctilde%7BA%7D%20%5Ctilde%7BD%7D%5E%7B-%5Cfrac%7B1%7D%7B2%7D%7D%20H%5E%7B(l)%7D%20W%5E%7B(l)%7D%20%5Cright)%0A%20%20%20%20%24%24%0A%0A%20%20%20%20where%20%3A%0A%0A%20%20%20%20-%20%24H%24%20%3A%20Node%20features%20matrix%20(%24n%24%20papers%20x%20%24m%24%20words)%0A%20%20%20%20-%20%24A%24%2C%24%5Ctilde%7BA%7D%24%20%3A%20Raw%20Adjacency%20matrix%20(%24n%24%20papers%20x%20%24n%24%20papers)%20and%20Augmented%20with%20added%20self%20loops%0A%20%20%20%20-%20%24D%24%2C%20%24%5Ctilde%7BD%7D%24%20%3A%20Raw%20Degree%20matrix%20(%24n%24%20papers%20x%20%24n%24%20papers)%20and%20degree%20matrix%20of%20%24%5Ctilde%7BA%7D%24%0A%20%20%20%20-%20%24W%24%20%3A%20Weights%20matrix%20for%20each%20layer%20(shape%20depends%20on%20which%20layer%20it%20is)%0A%0A%0A%20%20%20%20Each%20node%20(i.e.%2C%20paper)%20is%20a%20feature%20vector.%20This%20is%20the%20art%20of%20taking%20something%20in%20the%20real%20world%20and%20*representing*%20it%20in%20a%20way%20that%20can%20be%20computed%20on.%0A%20%20%20%20In%20this%20example%2C%20we%20have%20chosen%20to%20represent%20a%20paper%20as%20a%20vector%20of%201433%200s%20and%201s.%20There%20are%201433%20unique%20words%20(e.g.%2C%20index%200%20of%20this%20vector%20might%20be%20%60machine%60)%20and%201%20in%20an%20any%20given%20index%20means%20this%20word%20is%20in%20the%20paper.%0A%0A%20%20%20%20We%20have%20a%20label%20for%20each%20of%20these%20vectors%20(e.g.%2C%20neural%20networks%2C%20reinforcement%20learning%20etc.).%20Could%20we%20predict%20a%20paper's%20label%20using%20only%20the%20feature%20vector%20%3F%20Probably.%20Consider%20that%20reinforcement%20learning%20papers%20might%20have%20the%20word%20%22reinforcement%22%20much%20more%20frequently%20than%20other%20papers.%0A%0A%20%20%20%20However%2C%20we%20also%20have%20another%20source%20of%20information%20%3A%20We%20know%20that%20papers%20on%20reinforcement%20learning%20are%20likely%20to%20*cite*%20other%20papers%20on%20reinforcement%20learning.%20One%20paper%20does%20not%20exist%20in%20a%20vacuum%20%3B%20there%20is%20some%20abstract%20relationship%20(i.e.%2C%20citations)%20between%20papers%20that%20influences%20what%20the%20label%20of%20a%20paper%20is.%20How%20can%20we%20use%20this%20information%20as%20well%20%3F%20This%20is%20the%20problem%20a%20graph%20solves.%20It%20is%20a%20way%20to%20encode%20abstract%20dependencies%20in%20the%20real%20world.%20For%20humans%2C%20it%20is%20sufficient%20to%20say%20%22paper%203%20cite%20paper%202%22%20but%20we%20need%20a%20way%20to%20encode%20this%20for%20our%20machines.%20We%20do%20so%20with%20matrices%20%24A%24%20and%20%24D%24.%0A%0A%20%20%20%20**The%20core%20idea%20of%20neural%20networks%20is%20to%20progressively%20learn%20abstract%20transformations%20of%20the%20input%20features%20to%20solve%20a%20task**%20(e.g.%2C%20classification).%20How%20do%20we%20update%20the%20features%20of%20a%20given%20paper%20with%20information%20about%20the%20citation%20graph%20%3F%20This%20is%20the%20problem%20**graph%20convolution**%20solves.%0A%0A%20%20%20%20That%20is%2C%20the%20node%20features%20(H)%20are%20updated%20with%20the%20connectivity%20information%20of%20each%20of%20the%20nodes%20in%20the%20network%20(provided%20by%20A%20and%20D)%20and%20W%20(learnable%20weights).%20Each%20node%20*passes%20a%20message*%20about%20itself%20to%20through%20its%20edges.%20A%20node%20(%24H_i%24)%20receives%20these%20messages%2C%20aggregates%20them%2C%20and%20updates%20itself%20with%20a%20weighted%20sum%20(%24W%24).%20To%20ensure%20that%20a%20given%20node%20also%20uses%20information%20about%20itself%20we%20add%20a%20self-loop%20%24%5Ctilde%7BA%7D%20%3D%20A%20%2B%20I%24%20and%20recalculate%20the%20degree%20to%20get%20%24%5Ctilde%7BD%7D%24.%0A%0A%20%20%20%20Each%20convolution%20is%20one%20round%20of%20message%20passing%20with%20a%20node's%20**direct%20neighbors**%20%2C%20or%201-hop%20in%20the%20graph%20(i.e.%2C%20direct%20ciations%20%3A%20paper%201%20cites%20paper%203).%20Every%20subsequent%20convlution%20is%20one%20layer%20up%20(paper%202%20cites%20paper%201%20cites%20paper%203).%0A%0A%20%20%20%20The%20learnable%20weights%20serve%20the%20same%20role%20as%20in%20a%20standard%20neural%20network%20where%20the%20output%20of%20a%20given%20layer%20is%20%20%24a%20%3D%20%5Csigma(WX%2Bb)%24.%0A%0A%20%20%20%20%23%23%23%20The%20normalization%20step%0A%0A%20%20%20%20A%20popular%20paper%20will%20have%20more%20citations%20(i.e.%2C%20more%20incoming%20edges%20%3D%20higher%20degree).%20Thus%2C%20during%20the%20message%20aggregation%2C%20the%20features%20of%20higher-degree%20nodes%20can%20explode%20in%20size.%20We%20correct%20for%20this%20by%20normalizing%20by%20the%20degree%20of%20a%20given%20node%20(%24%5Ctilde%7BD%7D%5E%7B-%5Cfrac%7B1%7D%7B2%7D%7D%24)%20(%5Bmore%20info%5D(https%3A%2F%2Farc.net%2Fl%2Fquote%2Faislqbvv)).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(D%2C%20X%2C%20adj_matrix%2C%20torch%2C%20y)%3A%0A%20%20%20%20%23%20convert%20everything%20to%20tensors%2C%20and%20rename%20to%20match%20the%20conventions%0A%20%20%20%20X_tensor%20%3D%20torch.Tensor(X.astype(int).to_numpy())%0A%20%20%20%20A_tensor%20%3D%20torch.sparse_csr_tensor(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20torch.from_numpy(adj_matrix.indptr)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20torch.from_numpy(adj_matrix.indices)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20torch.from_numpy(adj_matrix.data))%0A%0A%20%20%20%20%23%20Ensure%20both%20tensors%20are%20of%20the%20same%20type%20(Float%20or%20Double)%0A%20%20%20%20X_tensor%20%3D%20X_tensor.to(torch.float32)%20%20%23%20Convert%20X_tensor%20to%20Float%0A%20%20%20%20A_tensor%20%3D%20A_tensor.to(torch.float32)%20%20%23%20Convert%20A_tensor%20to%20Float%0A%0A%20%20%20%20D_tensor%20%3D%20torch.Tensor(D)%0A%0A%20%20%20%20%23%20The%20labels%20assigned%20to%20a%20given%20paper%20from%20the%20original%20dataset%0A%20%20%20%20y_tensor%3D%20torch.Tensor(y)%0A%0A%20%20%20%20%23%20Let'compute%20the%20inverse%20square%20root%20of%20D%0A%20%20%20%20%23%20Get%20the%20diagonal%20elements%20of%20D%0A%20%20%20%20D_tensor%20%3D%20torch.diag(D_tensor)%0A%20%20%20%20%23%20Compute%20the%20inverse%20square%20root%0A%20%20%20%20D_inv_sqrt%20%3D%201%20%2F%20torch.sqrt(D_tensor)%0A%20%20%20%20%23%20Create%20a%20diagonal%20matrix%20from%20the%20inverse%20square%20root%0A%20%20%20%20D_inv_sqrt%20%3D%20torch.diag(D_inv_sqrt)%0A%20%20%20%20return%20A_tensor%2C%20D_inv_sqrt%2C%20X_tensor%2C%20y_tensor%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Building%20a%20Graph%20Convoluitional%20Layer%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Let's%20build%20the%20operation%20graph%20convolution%20operation%20out%20in%20PyTorch%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(nn%2C%20torch)%3A%0A%0A%20%20%20%20%23%20let's%20define%20a%20GCN%20layer%0A%20%20%20%20class%20GCNlayer(nn.Module)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20X_in_shape%20%3A%20int%2C%20X_out_shape%20%3A%20int%2C%20%20D_inv_sqrt%20%3A%20torch.Tensor%2C%20A%20%3A%20torch.Tensor)%20-%3E%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super(GCNlayer%2C%20self).__init__()%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Since%20we%20want%20to%20do%20X%20%40%20W%2C%20we%20need%20to%20make%20sure%20the%20dimensions%20match%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20X%20has%20shape%20N_papers%20x%20BOW%20features%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20W%20has%20shape%20BOW%20features%20x%20W_cols%20%3A%20We%20can%20adjust%20the%20columns%20to%20change%20the%20output%20size%0A%20%20%20%20%20%20%20%20%20%20%20%20self.W_rows%20%3D%20X_in_shape%0A%20%20%20%20%20%20%20%20%20%20%20%20self.W_cols%20%3D%20X_out_shape%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20we%20have%20learnable%20weights%20for%20the%20W%20matrix%0A%20%20%20%20%20%20%20%20%20%20%20%20self.W%20%3D%20nn.Linear(in_features%3Dself.W_rows%2C%20out_features%3Dself.W_cols)%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20These%20values%20do%20not%20get%20updated%20during%20training%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20constant%20graph%20statistics%20%E2%80%94%20register%20as%20buffers%0A%20%20%20%20%20%20%20%20%20%20%20%20self.register_buffer(%22D_inv_sqrt%22%2C%20D_inv_sqrt)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.register_buffer(%22A%22%2C%20A)%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20X%20%3A%20torch.Tensor)%20-%3E%20torch.Tensor%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20X_updated%20%3D%20self.W(X)%20%23%20X%20%40%20W%0A%20%20%20%20%20%20%20%20%20%20%20%20X_updated%20%3D%20self.D_inv_sqrt%20%40%20X_updated%20%23%20D%5E(-1%2F2)%20%40%20X%20%40%20W%0A%20%20%20%20%20%20%20%20%20%20%20%20X_updated%20%3D%20self.A%20%40%20X_updated%20%23%20A%20%40%20D%5E(-1%2F2)%20%40%20X%20%40%20W%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Returns%20the%20transformed%20X%20of%20shape%20N_papers%20x%20W_cols%0A%20%20%20%20%20%20%20%20%20%20%20%20X_updated%20%3D%20self.D_inv_sqrt%20%40%20X_updated%20%23%20D%5E(-1%2F2)%20%40%20A%20%40%20D%5E(-1%2F2)%20%40%20X%20%40%20W%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20X_updated%20%23%20shape%20N_papers%20x%20W_cols%0A%0A%20%20%20%20%23%20We%20can%20have%20any%20number%20of%20layers%20we%20want.%0A%20%20%20%20class%20GCNManual(nn.Module)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2CX_in_shape%20%20%3A%20int%2C%20layer_out_sizes%20%3A%20list%5Bint%5D%2C%20D_inv_sqrt%20%3A%20torch.Tensor%2C%20A%20%3A%20torch.Tensor)%20-%3E%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super(GCNManual%2C%20self).__init__()%0A%20%20%20%20%20%20%20%20%20%20%20%20%23first%20layer%0A%20%20%20%20%20%20%20%20%20%20%20%20self.layer1%20%3D%20GCNlayer(X_in_shape%3D%20X_in_shape%2C%20%23%20%3D%201433%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20X_out_shape%3D%20layer_out_sizes%5B0%5D%2C%20%23%20%3D%20100%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20D_inv_sqrt%3D%20D_inv_sqrt%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20A%3D%20A)%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20second%20layer%0A%20%20%20%20%20%20%20%20%20%20%20%20self.layer2%20%3D%20GCNlayer(X_in_shape%3D%20layer_out_sizes%5B0%5D%2C%20%23%20The%20input%20shape%20of%20the%20second%20layer%20is%20the%20output%20shape%20of%20the%20first%20layer%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20X_out_shape%3D%20layer_out_sizes%5B1%5D%2C%20%23%20%3D%207%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20D_inv_sqrt%3D%20D_inv_sqrt%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20A%3D%20A)%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20X)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20X%20%3D%20torch.relu(self.layer1(X))%20%23%20sigmoid%20(D%5E(-1%2F2)%20%40%20A%20%40%20D%5E(-1%2F2)%20%40%20X%20%40%20W)%0A%20%20%20%20%20%20%20%20%20%20%20%20X%20%3D%20self.layer2(X)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20X%0A%20%20%20%20return%20GCNManual%2C%20GCNlayer%0A%0A%0A%40app.cell%0Adef%20_(A_tensor%2C%20D_inv_sqrt%2C%20GCNlayer%2C%20X_tensor)%3A%0A%0A%0A%20%20%20%20gcn_layer%20%3D%20GCNlayer(X_in_shape%20%3D%20X_tensor.shape%5B1%5D%2C%20%23%20This%20is%20the%20number%20of%20features%20in%20the%20input%20node%20features%20%3D%20BOW%20features%20%3D%201433%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20This%20is%20the%20number%20of%20features%20in%20the%20output%20node%20features%20%3D%20100.%20We%20can%20choose%20this%20arbitrarily.%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20The%20output%20of%20one%20layer%20will%20be%20of%20shape%20X_previous_layer%20x%20100.%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20X_out_shape%20%3D%20100%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20D_inv_sqrt%20%3D%20D_inv_sqrt%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20A%20%3D%20A_tensor)%0A%20%20%20%20gcn_layer.forward(X_tensor).shape%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(A_tensor%2C%20D_inv_sqrt%2C%20GCNManual%2C%20X_tensor)%3A%0A%20%20%20%20%23%20test%20the%20network%0A%20%20%20%20net%20%3D%20GCNManual(%0A%20%20%20%20%20%20%20%20X_in_shape%3DX_tensor.shape%5B1%5D%2C%20%23%20input%20shape%20%3D%20N_papers%20x%20BOW%20features%0A%20%20%20%20%20%20%20%20layer_out_sizes%3D%5B100%2C%207%5D%2C%20%23%20first%20layer%20has%20100%20features%2C%20second%20layer%20has%207%20features%0A%20%20%20%20%20%20%20%20D_inv_sqrt%3DD_inv_sqrt%2C%0A%20%20%20%20%20%20%20%20A%3DA_tensor%0A%20%20%20%20)%0A%20%20%20%20%23%20We%20should%20expect%20the%20output%20to%20be%20of%20shape%20N_papers%20x%20layer_out_sizes%20after%20each%20layer%0A%20%20%20%20%23%20and%20the%20final%20output%20to%20be%20of%20shape%20N_papers%20x%207%0A%20%20%20%20net(X_tensor).shape%0A%20%20%20%20return%20(net%2C)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Let's%20train%20a%20network%20to%20predict%20the%20label%20for%20a%20given%20node%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(labels%2C%20y)%3A%0A%20%20%20%20%23%20Recall%20that%20we%20have%207%20classes%20%0A%20%20%20%20print(set(labels))%0A%20%20%20%20%23%20encoded%20as%20integers%0A%20%20%20%20print(set(y))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Let's%20separate%20the%20data%20into%20training%20and%20test%20sets.%20Create%20boolean%20masks%20for%20training%2C%20validation%2C%20and%20testing%20with%2060%25%20for%20training%2C%2020%25%20for%20validation%2C%20and%2020%25%20for%20testing%20use%20train%20test%20split%20to%20get%20the%20indices.%0A%0A%20%20%20%20Notice%20that%20this%20is%20different%20from%20standard%20machine%20learning.%20Instead%20of%20splitting%20the%20dataset%20into%20separate%20X_train%20and%20X_test%20matrices%2C%20we%20feed%20the%20entire%20graph%20into%20the%20network.%0A%0A%20%20%20%20Why%3F%20Because%20in%20a%20graph%2C%20removing%20data%20points%20(nodes)%20breaks%20the%20connections.%20We%20need%20the%20edges%20from%20test%20nodes%20to%20help%20learn%20representations%20for%20training%20nodes%2C%20and%20vice%20versa.%20So%2C%20we%20use%20boolean%20masks%20to%20tell%20the%20model%20which%20nodes%20to%20calculate%20loss%20on%20(training)%20and%20which%20to%20hide%20until%20evaluation%20(validation%2Ftest).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(torch%2C%20train_test_split)%3A%0A%20%20%20%20train_indices%2C%20val_test_indices%20%3D%20train_test_split(range(2708)%2C%20test_size%3D0.4%2C%20random_state%3D42)%0A%20%20%20%20val_indices%2C%20test_indices%20%3D%20train_test_split(val_test_indices%2C%20test_size%3D0.5%2C%20random_state%3D42)%0A%20%20%20%20%23%20use%20the%20indices%20to%20get%20the%20masks%0A%20%20%20%20train_mask%20%3D%20torch.zeros(2708%2C%20dtype%3Dtorch.bool)%0A%20%20%20%20train_mask%5Btrain_indices%5D%20%3D%20True%0A%0A%20%20%20%20val_mask%20%3D%20torch.zeros(2708%2C%20dtype%3Dtorch.bool)%0A%20%20%20%20val_mask%5Bval_indices%5D%20%3D%20True%0A%0A%20%20%20%20test_mask%20%3D%20torch.zeros(2708%2C%20dtype%3Dtorch.bool)%0A%20%20%20%20test_mask%5Btest_indices%5D%20%3D%20True%0A%20%20%20%20return%20test_mask%2C%20train_mask%2C%20val_mask%0A%0A%0A%40app.cell%0Adef%20_(net)%3A%0A%20%20%20%20net%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(net)%3A%0A%20%20%20%20%23%20Make%20sure%20that%20only%20the%20W%20matrices%20are%20trainable%0A%20%20%20%20%23%20Pytorch%20will%20automatically%20set%20the%20requires_grad%20attribute%20to%20True%20for%20all%20parameters%20by%20default%0A%20%20%20%20%23%20Named%20parameters%20returns%20a%20list%20of%20tuples%2C%20where%20each%20tuple%20contains%20the%20name%20of%20the%20parameter%20and%20the%20parameter%20itself%0A%20%20%20%20%23%20We%20can%20use%20this%20to%20check%20which%20parameters%20are%20trainable%0A%20%20%20%20%23%20We%20should%20expect%20only%20the%20W%20matrices%20to%20have%20requires_grad%20%3D%20%20True%0A%20%20%20%20for%20name%2C%20param%20in%20net.named_parameters()%3A%0A%20%20%20%20%20%20%20%20if%20param.requires_grad%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20print(name%2C%20param.shape)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(X_tensor%2C%20net%2C%20nn%2C%20optim%2C%20torch%2C%20train_mask%2C%20val_mask%2C%20y_tensor)%3A%0A%20%20%20%20%23%20Now%20we%20can%20do%20a%20standard%20training%20loop%0A%20%20%20%20criterion%20%3D%20nn.CrossEntropyLoss()%0A%20%20%20%20optimizer%20%3D%20optim.Adam(net.parameters()%2C%20lr%3D0.01)%0A%20%20%20%20for%20epoch%20in%20range(20)%3A%0A%20%20%20%20%20%20%20%20net.train()%0A%20%20%20%20%20%20%20%20optimizer.zero_grad()%0A%20%20%20%20%20%20%20%20%23%20Each%20forward%20pass%20computes%20the%20output%20for%20all%20nodes%20in%20the%20network%0A%20%20%20%20%20%20%20%20output%20%3D%20net(X_tensor)%0A%20%20%20%20%20%20%20%20%23%20We%20only%20want%20to%20compute%20the%20loss%20for%20the%20training%20nodes%0A%20%20%20%20%20%20%20%20loss%20%3D%20criterion(output%5Btrain_mask%5D%2C%20y_tensor%5Btrain_mask%5D.long())%0A%20%20%20%20%20%20%20%20loss.backward()%0A%20%20%20%20%20%20%20%20optimizer.step()%0A%0A%20%20%20%20%20%20%20%20net.eval()%0A%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20output%20%3D%20net(X_tensor)%0A%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20criterion(output%5Bval_mask%5D%2C%20y_tensor%5Bval_mask%5D.long())%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20epoch%20%25%2010%20%3D%3D%200%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20val_outputs%20%3D%20net(X_tensor)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20check%20if%20the%20prediction%20at%20each%20node%20matches%20what%20is%20expected.%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20val_outputs%20is%20(n_papers%20x%20n_labels).%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20We%20get%20the%20validation%20nodes%20from%20this%20and%20get%20the%20column%20index%20of%20the%20highest%20value%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20(i.e.%2C%20which%20label%20is%20predicted%20highest%20for%20a%20given%20node%20%3F)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20We%20compare%20these%20labels%20with%20the%20actual.%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20val_accuracy%20%3D%20(val_outputs%5Bval_mask%5D.argmax(1)%20%3D%3D%20y_tensor%5Bval_mask%5D).float().mean()%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20print(f%22Validation%20Accuracy%3A%20%7Bval_accuracy.item()%7D%22)%0A%20%20%20%20return%20(criterion%2C)%0A%0A%0A%40app.cell%0Adef%20_(X_tensor%2C%20criterion%2C%20net%2C%20test_mask%2C%20torch%2C%20y_tensor)%3A%0A%20%20%20%20%23%20Test%20the%20network%0A%20%20%20%20net.eval()%0A%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20test_output%20%3D%20net(X_tensor)%0A%20%20%20%20%20%20%20%20test_loss%20%3D%20criterion(test_output%5Btest_mask%5D%2C%20y_tensor%5Btest_mask%5D.long())%0A%20%20%20%20%20%20%20%20test_accuracy%20%3D%20(test_output%5Btest_mask%5D.argmax(1)%20%3D%3D%20y_tensor%5Btest_mask%5D).float().mean()%0A%20%20%20%20%20%20%20%20print(f%22Test%20Accuracy%3A%20%7Btest_accuracy.item()%7D%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Build%20a%20GCN%20with%20torch%20geometric.%0A%0A%20%20%20%20Let's%20build%20the%20same%20thing%20but%20with%20torch.geometric's%20inbuilt%20features.%0A%0A%0A%20%20%20%20Everything%20stays%20pretty%20much%20the%20same%20except%20that%20the%20forward%20methods%20of%20the%20gcn%20layers%20require%20the%20following%20inputs%3A%0A%0A%20%20%20%201.%20x%20%3A%20This%20is%20the%20n_nodes%20(i.e.%2C%20papers)%20x%20n_features%20(i.e.%2C%20BOW%20representation)%20representation%20of%20each%20node.%0A%20%20%20%202.%20edge_index%20%3A%20This%20the%20(row%2Ccolumn)%20representation%20of%20where%20each%201%20in%20our%20matrix%20resides.%0A%20%20%20%203.%20edge_weight%20%3A%20This%20the%20value%20of%20each%20of%20the%20edges%20(1%20in%20our%20case.)%0A%20%20%20%204.%20(optionally)%20we%20can%20also%20include%20the%20ground%20truth%20labels%20for%20each%20node%0A%0A%20%20%20%20Torch%20geometric%20provides%20a%20Data%20class%20that%20we%20can%20use%20to%20wrap%20this%20information%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(N_papers%2C%20X_tensor%2C%20col%2C%20np%2C%20row%2C%20torch%2C%20values%2C%20y_tensor)%3A%0A%20%20%20%20from%20torch_geometric.nn%20import%20GCNConv%20%23%20use%20one%20of%20the%20GNN%20layers%20from%20torch%20geometric%0A%20%20%20%20from%20torch_geometric.data%20import%20Data%2C%20DataLoader%0A%20%20%20%20from%20torch_geometric.utils%20import%20to_undirected%2C%20add_self_loops%0A%0A%20%20%20%20%23%20this%20is%20just%20the%20row%20and%20column%20indices%20we%20had%20earlier%0A%20%20%20%20edge_index%20%3D%20torch.from_numpy(np.vstack((row%2C%20col))).long()%0A%20%20%20%20edge_weight%20%3D%20torch.Tensor(values).reshape(-1%2C1)%0A%0A%20%20%20%20%23Symmetrize%20the%20graph%0A%20%20%20%20edge_index%2C%20edge_weight%20%3D%20to_undirected(edge_index%2C%20edge_weight)%0A%0A%20%20%20%20%23%20add%20self%20loops%0A%20%20%20%20edge_index%2C%20edge_weight%20%3D%20add_self_loops(edge_index%2C%20edge_weight%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20fill_value%3D1.0%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20num_nodes%3DN_papers)%0A%0A%20%20%20%20data%20%3D%20Data(x%20%3D%20X_tensor%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20edge_index%3D%20edge_index%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20edge_attr%3Dedge_weight%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20y%20%3D%20y_tensor)%0A%20%20%20%20return%20GCNConv%2C%20data%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20GCN%20layer%20we%20are%20using%20GCNConv%20layers.%20%5BThere%20are%20others%5D(https%3A%2F%2Fpytorch-geometric.readthedocs.io%2Fen%2F2.5.1%2Fmodules%2Fnn.html%23id42).%0A%0A%20%20%20%20By%20default%2C%20it%20applies%20the%20symmetric%20normalization%20described%20above.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GCNConv%2C%20nn%2C%20torch)%3A%0A%20%20%20%20%23%20Instead%20of%20defining%20our%20own%20layers%20we%20use%20the%20ChebConv%20layers%0A%0A%20%20%20%20class%20GCN(nn.Module)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2CX_in_shape%20%20%3A%20int%2C%20layer_out_sizes%20%3A%20list%5Bint%5D)%20-%3E%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super(GCN%2C%20self).__init__()%0A%20%20%20%20%20%20%20%20%20%20%20%20%23first%20layer%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Note%20that%20the%20first%20dimension%20of%20the%20input%20layer%20must%20match%20the%20second%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20dimension%20of%20the%20input%20data%20%0A%20%20%20%20%20%20%20%20%20%20%20%20self.layer1%20%3D%20GCNConv(in_channels%3DX_in_shape%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20out_channels%3Dlayer_out_sizes%5B0%5D)%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20second%20layer%2C%20the%20input%20size%20should%20be%20the%20output%20size%20of%20the%20previous%20layer.%0A%20%20%20%20%20%20%20%20%20%20%20%20self.layer2%20%3D%20GCNConv(in_channels%3Dlayer_out_sizes%5B0%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20out_channels%3Dlayer_out_sizes%5B1%5D)%0A%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20data)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20x%2C%20edge_index%2C%20edge_weight%20%3D%20data.x%2C%20data.edge_index%2C%20data.edge_attr%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20torch.relu(self.layer1(x%2C%20edge_index%2C%20edge_weight))%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.layer2(x%2C%20edge_index%2C%20edge_weight)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20x%0A%20%20%20%20return%20(GCN%2C)%0A%0A%0A%40app.cell%0Adef%20_(GCN%2C%20X_tensor)%3A%0A%20%20%20%20net_pyg%20%3D%20GCN(X_in_shape%3DX_tensor.shape%5B1%5D%2C%20%23%20input%20shape%20%3D%201433%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20layer_out_sizes%3D%5B100%2C%207%5D%2C%20%23%20first%20layer%20has%20100%20features%2C%20second%20layer%20has%207%20features%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20return%20(net_pyg%2C)%0A%0A%0A%40app.cell%0Adef%20_(net_pyg)%3A%0A%20%20%20%20net_pyg%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(data%2C%20net_pyg%2C%20nn%2C%20optim%2C%20torch%2C%20train_mask%2C%20val_mask%2C%20y_tensor)%3A%0A%20%20%20%20%23%20Now%20we%20can%20do%20a%20standard%20training%20loop%0A%20%20%20%20%23%20Create%20criterion%20and%20optimizer%20specifically%20for%20the%20PyG%20model%0A%20%20%20%20criterion_pyg%20%3D%20nn.CrossEntropyLoss()%0A%20%20%20%20optimizer_pyg%20%3D%20optim.Adam(net_pyg.parameters()%2C%20lr%3D0.01)%0A%0A%20%20%20%20for%20epoch_pyg%20in%20range(20)%3A%0A%20%20%20%20%20%20%20%20net_pyg.train()%0A%20%20%20%20%20%20%20%20optimizer_pyg.zero_grad()%0A%20%20%20%20%20%20%20%20%23%20Each%20forward%20pass%20computes%20the%20output%20for%20all%20nodes%20in%20the%20network%0A%20%20%20%20%20%20%20%20%23%20output%20%3D%20(n_nodes%20x%20n_labels)%0A%20%20%20%20%20%20%20%20%23%20i.e%20the%20updated%20features%20after%20the%20graph%20convolution.%0A%20%20%20%20%20%20%20%20output_pyg%20%3D%20net_pyg(data)%0A%20%20%20%20%20%20%20%20%23%20We%20only%20want%20to%20compute%20the%20loss%20for%20the%20training%20nodes%0A%20%20%20%20%20%20%20%20loss_pyg%20%3D%20criterion_pyg(output_pyg%5Btrain_mask%5D%2C%20y_tensor%5Btrain_mask%5D.long())%0A%20%20%20%20%20%20%20%20loss_pyg.backward()%0A%20%20%20%20%20%20%20%20optimizer_pyg.step()%0A%0A%20%20%20%20%20%20%20%20net_pyg.eval()%0A%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20output_pyg%20%3D%20net_pyg(data)%0A%20%20%20%20%20%20%20%20%20%20%20%20loss_pyg%20%3D%20criterion_pyg(output_pyg%5Bval_mask%5D%2C%20y_tensor%5Bval_mask%5D.long())%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20epoch_pyg%20%25%2010%20%3D%3D%200%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20check%20if%20the%20prediction%20at%20each%20node%20matches%20what%20is%20expected.%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20val_outputs_pyg%20is%20(n_papers%20x%20n_labels).%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20We%20get%20the%20validation%20nodes%20from%20this%20and%20get%20the%20column%20index%20of%20the%20highest%20value%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20(i.e.%2C%20which%20label%20is%20predicted%20highest%20for%20a%20given%20node%20%3F)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20We%20compare%20these%20labels%20with%20the%20actual.%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20val_outputs_pyg%20%3D%20net_pyg(data)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20val_accuracy_pyg%20%3D%20(val_outputs_pyg%5Bval_mask%5D.argmax(1)%20%3D%3D%20y_tensor%5Bval_mask%5D).float().mean()%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20print(f%22Validation%20Accuracy%3A%20%7Bval_accuracy_pyg.item()%7D%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Working%20with%20a%20dataloader%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20In%20this%20example%20so%20far%2C%20we%20have%20been%20working%20with%20one%20graph%20and%20doing%20**node%20classification**%20on%20that%20graph%20(i.e.%2C%20predicting%20which%20subject%20a%20given%20paper%20belongs%20to).%0A%0A%20%20%20%20When%20we%20do%20%60output%20%3D%20net(data)%60%20in%20our%20training%20loop%2C%20we%20are%20sending%20this%20entire%20graph%20back%20through%20the%20network%20once%20before%20we%20do%20an%20update%20to%20the%20weights%20of%20our%20network.%20This%20is%20fine%20for%20this%20small%20case%20but%20what%20if%20we%20had%20millions%20of%20papers%3F%20We%20would%20have%20to%20wait%20to%20process%20this%20entire%20batch.%0A%0A%20%20%20%20A%20Dataloader%20is%20a%20helpful%20object%20provided%20by%20pyTorch%20to%20send%20**mini-batches**%20of%20your%20data%20into%20the%20network%20(%5Blearn%20more%5D(https%3A%2F%2Fyoutu.be%2F4qJaSmvhxi8%3Fsi%3DoFFVA2U9eqONa6w7)).%0A%0A%20%20%20%20Let%20your%20entire%20training%20dataset%20be%20your%20**batch**.%20In%20our%20example%20so%20far%2C%20that%20would%20be%202708%20papers.%20We%20can%20break%20this%20up%20by%20**training%20with%20sub-graphs%20for%20each%20gradient%20update**.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20from%20torch_geometric.loader%20import%20NeighborLoader%0A%20%20%20%20return%20(NeighborLoader%2C)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20can%20generate%20sub-graphs%20as%20our%20mini-batches%20with%20%60NeighborLoader%60%20(or%20any%20of%20the%20other%20loaders%20here%20%3A%20https%3A%2F%2Fpytorch-geometric.readthedocs.io%2Fen%2F2.5.2%2Fmodules%2Floader.html%23torch_geometric.loader.DataLoader%20).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(NeighborLoader%2C%20data%2C%20test_mask%2C%20train_mask%2C%20val_mask)%3A%0A%20%20%20%20kwargs%20%20%3D%20%7B'batch_size'%20%3A%20100%7D%0A%0A%20%20%20%20%23%20A%20loader%20to%20handle%20only%20our%20training%20data%0A%20%20%20%20train_loader%20%3D%20NeighborLoader(data%20%3D%20data%2C%20%23%20The%20underlying%20graph%20connectivity.%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20num_neighbors%3D%5B10%2C10%5D%2C%20%23%20How%20many%20neighbors%20to%20sample%20for%20each%20node%20at%20each%20hop.%20I%20want%20to%20know%20about%20you%2C%2010%20of%20your%20friends%20and%2010%20of%20each%20of%20their%20friends.%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20input_nodes%3Dtrain_mask%2C%20%23%20use%20only%20the%20training%20set%20to%20sample%20the%20batches%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20shuffle%20%3D%20False%2C%20%23%20we%20already%20did%20this%20when%20generating%20the%20labels%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20**kwargs%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20n_val%20%3D%20int(val_mask.sum().item())%0A%20%20%20%20%23%20So%20that%20the%20entire%20validation%20set%20is%20the%20batch%20for%20illustration%0A%20%20%20%20%23%20This%20would%20need%20to%20change%20depending%20on%20how%20big%20this%20set%20is%20and%20if%20it%20can%20fit%20in%20memory%0A%20%20%20%20val_kwargs%20%3D%20%7B'batch_size'%20%3A%20n_val%7D%20%0A%20%20%20%20val_loader%20%3D%20NeighborLoader(data%20%3D%20data%2C%20%23%20The%20underlying%20graph%20connectivity.%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20num_neighbors%3D%5B10%2C10%5D%2C%20%23%20How%20many%20neighbors%20to%20sample%20for%20each%20node.%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20input_nodes%3Dval_mask%2C%20%23%20use%20only%20the%20training%20set%20to%20sample%20the%20batches%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20shuffle%20%3D%20False%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20**val_kwargs%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%23%20same%20as%20above%0A%20%20%20%20n_test%20%3D%20int(test_mask.sum().item())%0A%20%20%20%20test_kwargs%20%3D%20%7B'batch_size'%20%3A%20n_test%7D%0A%20%20%20%20test_loader%20%3D%20NeighborLoader(data%20%3D%20data%2C%20%23%20The%20underlying%20graph%20connectivity.%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20num_neighbors%3D%5B10%2C10%5D%2C%20%23%20How%20many%20neighbors%20to%20sample%20for%20each%20node.%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20input_nodes%3Dtest_mask%2C%20%23%20use%20only%20the%20training%20set%20to%20sample%20the%20batches%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20shuffle%20%3D%20False%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20**test_kwargs%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20return%20test_loader%2C%20train_loader%2C%20val_loader%0A%0A%0A%40app.cell%0Adef%20_(train_loader)%3A%0A%20%20%20%20%23%20take%20a%20look%20at%20one%20batch%0A%20%20%20%20for%20batch%20in%20train_loader%3A%0A%20%20%20%20%20%20%20%20print(type(batch))%20%23%20Each%20batch%20is%20a%20graph%20data%20object%2C%20with%20the%20same%20attributes%20as%20we%20saw%20above.%0A%20%20%20%20%20%20%20%20%23%20Rest%20of%20the%20attributes%20%3A%20https%3A%2F%2Farc.net%2Fl%2Fquote%2Ftectabvv%20%0A%20%20%20%20%20%20%20%20%23%20Note%20that%20x%20is%20(163%2C%201433)%20even%20though%20we%20have%20a%20batch%20size%20of%20100%20because%20of%20the%20num_neighbors%0A%20%20%20%20%20%20%20%20%23%20around%20it%20that%20are%20sampled%20to%20give%20the%20graph%20some%20context.%0A%20%20%20%20%20%20%20%20print(batch)%20%0A%20%20%20%20%20%20%20%20break%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GCN%2C%20X_tensor)%3A%0A%0A%20%20%20%20net_pyg_mb%20%3D%20GCN(X_in_shape%3DX_tensor.shape%5B1%5D%2C%20%23%20input%20shape%20%3D%201433%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20layer_out_sizes%3D%5B100%2C%207%5D%2C%20%23%20first%20layer%20has%20100%20features%2C%20second%20layer%20has%207%20features%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20return%20(net_pyg_mb%2C)%0A%0A%0A%40app.cell%0Adef%20_(net_pyg_mb%2C%20nn%2C%20optim%2C%20torch%2C%20train_loader%2C%20val_loader)%3A%0A%20%20%20%20%23%20Now%20we%20repeat%20the%20standard%20training%20loop%20with%20mini-batches%0A%20%20%20%20criterion_dl%20%3D%20nn.CrossEntropyLoss()%0A%20%20%20%20optimizer_dl%20%3D%20optim.Adam(net_pyg_mb.parameters()%2C%20lr%3D0.01)%0A%0A%20%20%20%20for%20epoch_dl%20in%20range(20)%3A%0A%20%20%20%20%20%20%20%20for%20mini_batch%20in%20train_loader%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20net_pyg_mb.train()%0A%20%20%20%20%20%20%20%20%20%20%20%20optimizer_dl.zero_grad()%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Each%20forward%20pass%20computes%20the%20output%20for%20all%20nodes%20in%20the%20batch%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20i.e%20the%20updated%20features%20after%20the%20graph%20convolution.%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20this%20returns%20a%20(n_nodes%20in%20the%20batch%20x%20n_labels)%20tensor%20%3A%20for%20each%20node%2C%20the%20logits%20for%20each%20category%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Note%20that%20this%20also%20includes%20the%20neighbors%20of%20the%20root%20sampled%20nodes.%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20However%2C%20the%20loader%20does%20not%20return%20the%20true%20labels%20for%20these%20neighbors%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20So%20we%20index%20the%20batch%20size%20we%20wanted%20(i.e.%2C%20100)%20nodes%0A%20%20%20%20%20%20%20%20%20%20%20%20output_dl%20%3D%20net_pyg_mb(mini_batch)%5B%3Amini_batch.batch_size%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20y_dl%20%3D%20mini_batch.y%5B%3Amini_batch.batch_size%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20loss_dl%20%3D%20criterion_dl(output_dl%2C%20y_dl.long())%0A%20%20%20%20%20%20%20%20%20%20%20%20loss_dl.backward()%0A%20%20%20%20%20%20%20%20%20%20%20%20optimizer_dl.step()%0A%20%20%20%20%20%20%20%20net_pyg_mb.eval()%0A%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20output_dl%20%3D%20net_pyg_mb(mini_batch)%5B%3Amini_batch.batch_size%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20y_dl%20%3D%20mini_batch.y%5B%3Amini_batch.batch_size%5D.long()%0A%20%20%20%20%20%20%20%20%20%20%20%20loss_dl%20%3D%20criterion_dl(output_dl%2C%20y_dl)%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20epoch_dl%20%25%2020%20%3D%3D%200%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20for%20val_batch%20in%20val_loader%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20check%20if%20the%20prediction%20at%20each%20node%20matches%20what%20is%20expected.%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20val_outputs_dl%20is%20(n_papers%20x%20n_labels).%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20We%20get%20the%20validation%20nodes%20from%20this%20and%20get%20the%20column%20index%20of%20the%20highest%20value%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20(i.e.%2C%20which%20label%20is%20predicted%20highest%20for%20a%20given%20node%20%3F)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20We%20compare%20these%20labels%20with%20the%20actual.%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20val_outputs_dl%20%3D%20net_pyg_mb(val_batch)%5B%3Aval_batch.batch_size%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20y_val_dl%20%3D%20val_batch.y%5B%3Aval_batch.batch_size%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20val_accuracy_dl%20%3D%20(val_outputs_dl.argmax(1)%20%3D%3D%20y_val_dl).float().mean()%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20print(f%22Validation%20Accuracy%3A%20%7Bval_accuracy_dl.item()%7D%22)%0A%20%20%20%20return%20(criterion_dl%2C)%0A%0A%0A%40app.cell%0Adef%20_(criterion_dl%2C%20net_pyg_mb%2C%20test_loader%2C%20torch)%3A%0A%20%20%20%20%23%20Test%20the%20network%0A%20%20%20%20net_pyg_mb.eval()%0A%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20for%20test_batch%20in%20test_loader%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20output_dl_test%20%3D%20net_pyg_mb(test_batch)%5B%3Atest_batch.batch_size%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20y_dl_test%20%3D%20test_batch.y%5B%3Atest_batch.batch_size%5D.long()%0A%20%20%20%20%20%20%20%20%20%20%20%20test_loss_dl%20%3D%20criterion_dl(output_dl_test%2C%20y_dl_test)%0A%20%20%20%20%20%20%20%20%20%20%20%20test_accuracy_dl%20%3D%20(output_dl_test.argmax(1)%20%3D%3D%20y_dl_test).float().mean()%0A%20%20%20%20%20%20%20%20%20%20%20%20print(f%22Test%20Accuracy%3A%20%7Btest_accuracy_dl.item()%7D%22)%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
b5847c43f8fbc4d645fa82405049114f