• Home
  • graph-convolutional-networks-implementation-in-pytorch

Graph Convolutional Networks: Implementation in PyTorch

image
image

For more than a decade, researchers have been working on neural networks that operate on graph data (known as graph neural networks, or GNNs). Read this blog for a practical guide on how to utilise GCNs.

Real-world objects can be defined in terms of their connections to other things, so graphs are effectively all around us. A graph is a natural representation of the relationship between a collection of things. For more than a decade, researchers have been working on neural networks that operate on graph data (known as graph neural networks, or GNNs). 

Recent technological advancements have increased their expressive power and capabilities. Gene discovery, physical simulations, identifying  fake news, traffic prediction and recommendation systems are all practical uses. Some forms of graph data, such as social networks, are familiar to you. However, because graphs are such a versatile tool for representing data, we will show you how to model two types of data in GNN’s that you might not have considered: images and text.

Graph convolutional networks (GCNs) have introduced to us a new way to model the problems faced by the machine learning domain. Many of the latest problems such as image captioning require extraordinarily complex Neural Networks for performing a prediction, and still, they perform badly on most of the data. Graph Convolutional Networks provide a way to create meaningful connections in your data using nodes and segments to understand the underlying correlation in the data. In this article, we will enquire about the nature of GCNs, what data they operate on, how they differ from classical Convolutional Neural Networks (CNNs) and their implementation in PyTorch. Hence, the article is divided into the following sections:

  1. What are Graphs?
  2. What is graphical data?
  3. What are Graph Convolutional Networks (GCNs).
  4. Challenges of using graphical data.
  5. Applications of Graph Convolutional Networks.
  6. What is PyTorch
  7. Implementation of GCN in PyTorch.
  8. Conclusion.

What are Graphs?

A graph is actually a series of connections, or relationships, between entities. The entities are represented as nodes, while the connection between them is referred to as edges.

A vertex (V) represents the identity of a node, that is, it defines the number of neighbours of a node. The Edge (E) represents the link between nodes. They can be directed or undirected. 

In the figure above, the yellow part is the vertex (or node) and the blue part is the edge.

What is graphical data?

A vast variety of data can be represented by graphs. We, however, will look at three types of data as graphs. These include:

  1. Social Network Graph
  2. Images as Graph
  3. Text as Graph.

 

Social Network Graph

The Social Network graph allows you to see the connections between the selected entity and all other entities to which it is associated. Another approach to show "who knows who" is to use this one-of-a-kind graph.

The social network graph shows:

  1. Entity-to-entity links: All entities associated to the primary (hub) entity are displayed. The attributes that link the entities, on the other hand, are not visible on the graph but may be accessed using the Attribute Explorer in conjunction with the graph.
  2. Relationship clusters: The Social Network graph is special in that it groups or clusters related things. This graph may be used to examine all of the relationship clusters to which a certain item belongs, as well as to search for patterns within the clusters and relationships.

Images as Graphs

Images are commonly thought of as rectangular grids containing picture channels, and they are represented as arrays (e.g., 244x244x3 floats). Images can also be thought of as regular-structured networks, with each pixel representing a node and linked to neighboring pixels through an edge. Each non-border pixel has precisely eight neighbors, with each node storing a 3-dimensional vector encoding the pixel's RGB value. 

The adjacency matrix is a tool for visualizing a graph's connection. If two nodes share an edge, we fill a matrix of n_nodes * n_nodes with an entry. It is worth noting that the three representations below are all distinct perspectives on the same data.

Text as Graphs

Text can be digitized by assigning indices to each letter, word, or token and representing text as a series of these indices. This results in a basic directed graph, in which each letter or index is a node connected to the node that follows it by an edge.

Of course, this is not how text and images are normally recorded in practice: these graph representations are redundant because all images and text will have extremely regular patterns. Because all nodes (pixels) are linked in a grid, pictures have a banded pattern in their adjacency matrix. Because each word only relates to the preceding and subsequent words, the adjacency matrix for text is merely a diagonal line.

What are Graph Convolutional Networks?

The Graph Neural Network (GNN) is a type of Neural Network that works with graph structures and makes difficult graph data understandable. The simplest application is node classification, in which each node has a label, and we can predict the label for other nodes without any ground-truth. There are several differences between Convolution Neural Network and Graph Neural Network, including pipeline architecture, loss functions, techniques, computations, and so on. GCNs are a sort of convolutional neural network that can operate directly on graphs and exploit structure information. It handles the challenge of categorizing nodes (like documents) in a graph (like a citation network) where only a small proportion of nodes have labels (semi-supervised learning).

As the name “Convolutional” suggests, the idea was from Images and then brought to Graphs. However, when Images have a fixed structure, Graphs are much more complex.

The basic principle behind GCN is that for each node, we obtain feature information from all its neighbors, as well as the node's own feature. Let us pretend we are using the average() method. We will proceed in the same manner with all the nodes. Finally, we use a neural network to process these average results.

In practice, rather simply using the average function, we might utilize more advanced aggregate functions. To create a deeper GCN, we can stack more layers on top of each other. A layer's output will be used as the input for the following layer.

A graph's adjacency matrix is a square matrix that describes the connection between nodes. It specifies whether or not two nodes are connected/adjacent, where 1 is for connected and 0 for not connected.

Challenges of using Graphical Data

Inputs to machine learning models are often rectangular or grid-like arrays. As a result, it is not immediately obvious how to describe them in a deep learning-compatible style. Nodes, edges, global context, and connectedness are four sorts of information in graphs that we may wish to utilize to create predictions. The first three are simple: for example, we may create a node feature matrix N by giving an index I to each node and storing the feature for node i in N. While the number of instances in these matrices varies, they may be handled without the need of any special approaches.

The representation of a graph's connection, on the other hand, is more difficult. The most apparent option would be to utilize an adjacency matrix, which is simple to tensorize. This portrayal, however, has a few flaws. The number of nodes in a network can be on the scale of millions, and the number of edges per node can be extremely varied, as seen in the example dataset table. This frequently results in very sparse adjacency matrices, which are wasteful in terms of space.

Another issue is that there are many distinct adjacency matrices that can express the same connectivity, and there is no assurance that these matrices would generate the same outcome in a deep neural network.

Applications of Graph Convolution Networks

  • Graph Convolutional Networks (GCNs) provide predictions about physical systems like graphs, using an interactive approach. GCN also gives reliable data on the qualities of actual items and systems in the real world (dynamics of the collision, objects trajectories).
  • Image differentiation difficulties are solved with GCNs. 'Zero-Shot Learning' is the method it employs. The main goal of this model is to find an unknown tagged image and group it with others that are known. They also collect and categorize semantic information from these labels.
  • GCN may be used to solve a wide range of challenges in research operations and combinatorial optimization applications. Solving the classic traveling salesperson problem, quadratic assignment issues, and other challenges requires the use of Graph Convolutional Networks. It can outperform typical sophisticated algorithms with the aid of the input graph.

PyTorch

PyTorch is an open-source machine learning library based on the Torch library, largely created by Facebook's AI Research (FAIR) division for applications such as computer vision and natural language processing. PyTorch also provides a C++ interface, though the Python interface is more refined and the  focus of development.

Implementation in PyTorch

Now, we will be using pytorch to implement a simple graph convolutional network. We will be following a GitHub Repository for this project. You can find the repository here. The dataset used in this tutorial is the Cora dataset, which can be found here.

To run the code on Google Colab, follow the tutorial below: 

!git clone https://github.com/andrejmiscic/gcn-pytorch.git

First, we clone the Repo and move to the working directory.

Then, we configure the environment with proper version

Next, we install the proper module versions.

Then, we import the modules to our python environment.

Now, we define a utility function

We set the required variables.

Now, we declare the model

Next, we evaluate the model.

We evaluate a standard GCN model and a variant of GCN model that has residual connections between hidden layers. The performance of the model without residual connections deteriorates when it has many hidden layers as the training becomes more difficult.

Now, we compute the accuracy of both models and plot them.

Graph Convolutional Networks provide an efficient and elegant way to understand the relationships hidden within datasets and their outputs. We have demonstrated an extremely simple and limited  way of explaining GCNs so as to just touch the surface and set out  the basics for the reader. For a detailed understanding of the concepts and the mathematics lying behind them,  there are many resources that detail the mathematical explanations and intuition behind GCNs.

 

Comments

0 replies

 Help