Skip to content

The why and how of nonnegative matrix factorization

February 18, 2019

The why and how of nonnegative matrix factorization Gillis, arXiv 2014 from: ‘Regularization, Optimization, Kernels, and Support Vector Machines.’

Last week we looked at the paper ‘Beyond news content,’ which made heavy use of nonnegative matrix factorisation. Today we’ll be looking at that technique in a little more detail. As the name suggests, ‘The Why and How of Nonnegative matrix factorisation’ describes both why NMF is interesting (the intuition for how it works), and how to compute an NMF. I’m mostly interested in the intuition (and also out of my depth for some of the how!), but I’ll give you a sketch of the implementation approaches.

Nonnegative matrix factorization (NMF) has become a widely used tool for the analysis of high dimensional data as it automatically extracts sparse and meaningful features from a set of nonnegative data vectors.

NMF was first introduced by Paatero andTapper in 1994, and popularised in a article by Lee and Seung in 1999. Since then, the number of publications referencing the technique has grown rapidly:

What is NMF?

NMF approximates a matrix \mathbf{X} with a low-rank matrix approximation such that \mathbf{X} \approx \mathbf{WH}.

For the discussion in this paper, we’ll assume that \mathbf{X} is set up so that there are n data points each with p dimensions, and every column of \mathbf{X} is a data point, i.e. \mathbf{X} \in \mathbb{R}^{p \times n}.

We want to reduce the p original dimensions to r (aka, create a rank r approximation). So we’ll have \mathbf{W} \in \mathbb{R}^{p \times r} and \mathbf{H} \in \mathbb{R}^{r \times n}.

The interpretation of \mathbf{W} is that each column is a basis element. By basis element we mean some component that crops up again and again in all of the n original data points. These are the fundamental building blocks from which we can reconstruct approximations to all of the original data points.

The interpretation of \mathbf{H} is that each column gives the ‘coordinates of a data point’ in the basis \mathbf{W}. In other words, it tells you how to reconstruct an approximation to the original data point from a linear combination of the building blocks in \mathbf{W}

A popular way of measuring how good the approximation \mathbf{WH} actually is, is the Frobenius norm (denoted by the F subscript you may have noticed). The Frobenius norm is:

\displaystyle ||\mathbf{X} - \mathbf{WH}||^{2}_{F} = \sum_{i,j}(\mathbf{X} - \mathbf{WH})^{2}_{ij}.

An optimal approximation to the Frobenius norm can be computed through truncated Singular Value Decomposition (SVD).

Why does it work? The intuition.

The reason why NMF has become so popular is because of its ability to automatically extract sparse and easily interpretable factors.

The authors give three examples of NMF at work: in image processing, text mining, and hyperspectral imaging.

Image processing

Say we take a gray-level image of a face containing p pixels, and squash the data into a single vector such that the ith entry represents the value of the ith pixel. Let the rows of \mathbf{X} \in \mathbb{R}^{p \times n} represent the p pixels, and the n columns each represent one image.

NMF will produce two matrices W and H. The columns of W can be interpreted as images (the basis images), and H tells us how to sum up the basis images in order to reconstruct an approximation to a given face.

In the case of facial images, the basis images are features such as eyes, noses, moustaches, and lips, while the columns of H indicate which feature is present in which image.

Text mining

In text mining consider the bag-of-words matrix representation where each row corresponds to a word, and each column to a document (for the attentive reader, that’s the transpose of the bag-of-words matrix we looked at in ‘Beyond news content’9).

NMF will produce two matrices W and H. The columns of W can be interpreted as basis documents (bags of words). What interpretation can we give to such a basis document in this case? They represent topics! Sets of words found simultaneously in different documents. H tells us how to sum contributions from different topics to reconstruct the word mix of a given original document.

Therefore, given a set of documents, NMF identifies topics and simultaneously classifies the documents among these different topics.

Hyperspectral unmixing

A hyperspectral image typically has 100 to 200 wavelength-indexed bands showing the fraction of incident light being reflected by the pixel at each of those wavelengths. Given such an image we want to identify the different materials present in it (e.g. grass, roads, metallic surfaces) – these are called the endmembers. Then we want to know which endmembers are present in each pixel, and in what proportion. For example, a pixel might be reflecting 0.3 x the spectral signal of grass, and 0.7 x the spectral signal of a road surface.

NMF will produce two matrices W and H. The columns of W can be interpreted as basis endmembers. H tells us how to sum contributions from different endmembers to reconstruct the spectral signal observed at a pixel.

…given a hyperspectral image, NMF is able to compute the spectral signatures of the endmembers, and simultaneously the abundance of each endmember in each pixel.

Implementing NMF

For a rank r factorisation, we have the following optimisation problem:

Though note that the Frobenius norm show here assumes Gaussian noise, and other norms may be used in practice depending on the distribution (e.g., Kullback-Leibler divergence for text-mining, the Itakura-Saito distance for music analysis, or the l_1 norm to improve robustness against outliers).

So far everything to do with NMF sounds pretty good, until you reach the key moment in section 3:

There are many issues when using NMF in practice. In particular, NMF is NP-hard. Unfortunately, as opposed to the unconstrained problem which can be solved efficiently using the SVD, NMF is NP-hard in general.

Fortunately there are heuristic approximations which have been proven to work well in many applications.

Another issue with NMF is that there is not guaranteed to be a single unique decomposition (in general, there might be many schemes for defining sets of basis elements). For example, in text mining you would end up with different topics and classifications. “In practice, this issue is tackled using other priors on the factors W and H and adding proper regularization terms in the objective function.”

Finally, it’s hard to know how to choose the factorisation rank, r. Some approaches include trial and error, estimation using SVD based of the decay of the singular values, and insights from experts (e.g., there are roughly so many endmembers you might expect to find in a hyperspectral image).

Almost all NMF algorithms use a two-block coordinate descent scheme (exact or inexact), that is, they optimize alternatively over one of the two factors, W or H, while keeping the other fixed. The reason is that the subproblem in one factor is convex. More precisely, it is a nonnegative least squares problem (NNLS). Many algorithms exist to solve the NNLS problem; and NMF algorithms based on two-block coordinate descent differ by which NNLS algorithm is used.

Some NNLS algorithms that can be plugged in include multiplicative updates, alternating least squares, alternating nonnegative least squares, and hierarchical alternating least squares.

The following charts show the performance of these algorithms on a dense data set (left), and a sparse data set (right).

You can initialise W and H randomly, but there are also alternate strategies designed to give better initial estimates in the hope of converging more rapidly to a good solution:

  • Use some clustering method, and make the cluster means of the top r clusters as the columns of W, and H as a scaling of the cluster indicator matrix (which elements belong to which cluster).
  • Finding the best rank-r approximation of X using SVD and using this to initialise W and H (see section 3.1.8)
  • Picking r columns of X and just using those as the initial values for W.

Section 3.2 in the paper discusses an emerging class of polynomial time algorithms for NMF in the special case where the matrix X is r-separable. That is, there exist a subset of r columns such that all other columns of X can be reconstructed from them. In the text mining example for instance this would mean that each topic has at least one document focused solely on that topic.

… we believe NMF has a bright future…

A survey on dynamic and stochastic vehicle routing problems

February 15, 2019

A survey on dynamic and stochastic vehicle routing problems Ritzinger et al., International Journal of Production Research

It’s been a while since we last looked at an overview of dynamic vehicle routing problems: that was back in 2014 (See ‘Dynamic vehicle routing, pickup, and delivery problems’). That paper has fond memories for me, I looked at it while doing diligence for our investment in Deliveroo, and my how they’ve grown since then! With vehicle routing problems popping up in a number of interesting businesses, it’s time to take another look! Today’s paper choice is a more recent survey, focusing in on DSVRP problems.

So what exactly is a DSVRP problem? The VRP part stands for vehicle routing problems, typically you have a fleet of vehicles, and you need to use them to make a set of deliveries from point A to point B. How you assign pick-ups and deliveries to vehicles, and the routes those vehicles take, is the the VRP problem. Historically the VRP problem would be solved statically (we know up front the set of vehicles, pick-up and drop-off locations, etc.). Much more interesting (and much more realistic for many companies) is when we allow things to change over time. For example, customer requests come in during the day, traffic conditions change meaning that journey times are impacted, and so on. These are dynamic vehicle routing problems (DVRP). Then some bright spark had the insight that we can also learn from past data! That is, even though we don’t know in advance the exact conditions we’re likely to encounter, we can build predictive models that can inform our planning. This is the stochastic element. Put it all together and you have DSVRP: dynamic, stochastic vehicle routing.

In the last years, research interest in vehicle routing has increasingly focused on dynamic and stochastic approaches… The recent development in telematics, such as the widespread use of positioning services and mobile communication, allows gathering real-time information and exact monitoring of vehicles. This builds the basis for extensive data collection and real-time decision support in vehicle routing. Additionally, today’s computing capacity allows the application of routing algorithms which provide real-time solutions by incorporating online information and considering possible future events.

One of the interesting findings from this survey is the impact that different dynamic and stochastic elements can have on the performance of a routing algorithm. Compared to a static model:

  • Adding in consideration of travel time uncertainty gives improvements of up to 10%.
  • When routes are known in advance, but customer demand at each stop is not known (e.g., supplying of gas stations, garbage collection), adding in online decisions (dynamic) can improve performance by around 30-40%
  • When we consider dynamic customer requests, performance can be improved by up to 60%

In other words, invest your energy in building predictive models of customer demand first, and worry about incorporating dynamic information on traffic conditions afterwards.

Now, I waved my hands a little there in talking about the ‘performance’ of a routing algorithm. Setting up an appropriate objective function is in fact an interesting question in its own right. Are you optimising for minimum cost (e.g. by reducing travel time or distances), minimal response times, maximal revenue, number of serviced requests, percentage of requests met within some SLO, etc.?

We can divide the class of DSVRP problems into two broad groups. The first group use preprocessed decisions. They use up front computation to work out ahead of time what the best course of action is to take in any given situation, and then choose from amongst those pre-computed alternatives at runtime as new information comes in. More interesting is the online decision group, which evaluate the best course of action dynamically. There are two basic ways of triggering re-computation here: in a time-driven approach we re-calculate at fixed time intervals; in an event-driven approach we re-calculate whenever a new event arises (aka rolling horizon or look ahead strategy). When an event arrives we take into account the new event information and the latest available stochastic information, and either make a single decision for the current situation, or a full re-optimised plan. A variation is to provide a quick greedy single decision, and schedule more intensive re-optimisations in the background. The best way to go is influenced by the degree of dynamism in the system, and the desired reaction time to new events.

Modelling travel times

There are three possible approaches to modelling changing travel times:

  1. The simplest model is time dependent (i.e., you have a simple lookup table for expected travel time at a given time of day)
  2. Then there are stochastic approaches modelling a distribution of travel times independent of time of day
  3. Finally, stochastic and time-dependent solutions model link travel times as random variables with time-dependent distributions.

These days I’d be tempted just to throw a neural network at it and learn a predictive model. Time of day would certainly be one of the input features.

Results show that a good strategy is to accept lateness in case of small deviations in travel times, but react on events of a large magnitude… the reliability (i.e., level of customer service) increases considerably when uncertainty in travel times is considered in the solution approach.

Modelling customer demand

Models for dynamic customer demand combine stochastic information about both customer locations and the time of request.

The aim is to appropriately react to new events by introducing rules that can be easily computed in real-time. These rules consider future events in the decision making process by generating scenarios of potential outcomes. A common concept, called sample scenario approach (SSA), is to generate multiple scenarios of future customer requests and include them in the planning process. After selection the most appropriate solution for the future, the sampled customers are removed but with the effect that the new solution is well prepared for possible future requests.

Incorporating very near future information in the planning process (short-term sampling) proves to be the most beneficial. Besides sampling, other strategies such as waiting strategies are also investigated (see e.g. ‘Putting data in the driver’s seat…’). For example, using a rolling horizon approach where geographical and temporal information about customer requests is exploited to reposition vehicles during idle times. Incorporating stochastic information about customer requests yields up to 60% improvements.

What about when both traffic conditions and customer demand vary over time?

Aka, the real world!

Almost all research on DSVRPs considers not more than one stochastic aspect. Only little work is performed on investigating the impact of incorporating multiple stochastic aspects, e.g. consideration of stochastic travel times and customers.

The following chart gives an indication of the volume of published papers investigating different aspects of the DSVRP problem over the years:

Papers considering multiple stochastic aspects are on the rise! Unfortunately, in addition to being an area less covered by research, it’s also less covered in this survey. It seems reasonable to assume that such approaches should yield improvements of up to at least 60%, since approaches considering only variations in customer demand do that well.

I can’t help but wonder if maybe there’s a radically different approach out there using end-to-end reinforcement learning. Given historical data about traffic and demand, we could replay a subsequence of the data and then ask the system to take a decision. We could then give that decision a score based on what we know about what happened next in the real-world… I’ve seen this ‘learning from sub-sequences’ tactic used in one of the previous papers that we’ve looked at on The Morning Paper, but right now I can’t for the life of me remember which one! Let me know if you find it :).

Beyond news contents: the role of social context for fake news detection

February 13, 2019

Beyond news contents: the role of social context for fake news detection Shu et al., WSDM’19

Today we’re looking at a more general fake news problem: detecting fake news that is being spread on a social network. Forgetting the computer science angle for a minute, it seems intuitive to me that some important factors here might be:

  • what is being said (the content of the news), and perhaps how it is being said (although fake news can be deliberately written to mislead users by mimicking true news)
  • where it was published (the credibility / authority of the source publication). For example, something in the Financial Times is more likely to be true than something in The Onion!
  • who is spreading the news (the credibility of the user accounts retweeting it for example – are they bots??)

Therefore I’m a little surprised to read in the introduction that:

The majority of existing detection algorithms focus on finding clues from the news content, which are generally not effective because fake news is often intentionally written to mislead users by mimicking true news.

(The related work section does however discuss several works that include social context.).

So instead of just looking at the content, we should also look at the social context: the publishers and the users spreading the information! The fake news detection system developed in this paper, TriFN considers tri-relationships between news pieces, publishers, and social network users.

… we are to our best knowledge the first to classify fake news by learning the effective news features through the tri-relationship embedding among publishers, news contents, and social engagements.

And guess what, considering publishers and users does indeed turn out to improve fake news detection!


We have l publishers, m social network users, and n news articles. Using a vocabulary of t words, we can compute an \mathbf{X} \in \mathbb{R}^{n \times t} bag-of-word feature matrix.

For the m users, we can have an m x m adjacency matrix \mathbf{A} \in \{0,1\}^{m \times m}, where \mathbf{A}_{ij} is 1 if i and j are friends, and 0 otherwise.

We also know which users have shared which news pieces, this is encoded in a matrix \mathbf{W} \in \{0,1\}^{m \times n}.

The matrix \mathbf{B} \in \{0,1\}^{l \times n} similarly encodes which publishers have published which news pieces.

For some publishers, we can know their partisan bias. In this work, bias ratings from are used, taking just the ‘Left-Bias’, ‘Least-Bias’ (neutral) and ‘Right-Bias’ values (ignoring the intermediate left-center and right-center values) and encoding these as -1, 0, and 1 respectively in a publisher partisan label vector, \mathbf{o}. Not every publisher will have a bias rating available. We’d like to put ‘-’ in the entry for that publisher in \mathbf{o} but since we can’t do that, the separate vector \mathbf{e} \in \{0,1\}^l encodes whether or not we have a bias rating available for publisher p.

There’s one last thing at our disposal: a labelled dataset for news articles telling us whether they are fake or not. (Here we have just the news article content, not the social context).

The Tri-relationship embedding framework

TriFN takes all of those inputs and combines them with a fake news binary classifier. Given lots of users and lots of news articles, we can expect some of the raw inputs to be pretty big, so the authors make heavy use of dimensionality reduction using non-negative matrix factorisation to learn latent space embeddings (more on that in a minute!) TriFN combines:

  • A news content embedding
  • A user embedding
  • A user-news interaction embedding
  • A publisher-news interaction embedding, and
  • The prediction made by a linear classifier trained on the labelled fake news dataset

Pictorially it looks like this (with apologies for the poor resolution, which is an artefact of the original):

News content embedding

Let’s take a closer look at non-negative matrix factorisation (NMF) to see how this works to reduce dimensionality. Remember the bag-of-words sketch for news articles? That’s an n x t matrix where n is the number of news articles and t is the number of words in the vocabulary. NMF tries to learn a latent embedding that captures the information in the matrix in a much smaller space. In the general form NMF seeks to factor a (non-negative) matrix M into the product of two (non-negative) matrices W and H (or D and V as used in this paper). How does that help us? We can pick some dimension d (controlling the size of the latent space) and break down the \mathbf{X} \in \mathbb{R}^{n \times t} matrix into a d-dimension representation of news articles \mathbf{D} \in \mathbb{R}^{n \times d}, and a d-dimension representation of words in the vocabulary, \mathbf{V} \in \mathbb{R}^{t \times d}. That means that \mathbf{V}^T has shape d \times t and so DV^T ends up with the desired shape n \times t. Once we’ve learned a good representation of news articles, \mathbf{D} we can use those as the news content embeddings within TriFN.

We’d like to get \mathbf{DV}^T as close to \mathbf{X} as we can, and at the same time keep \mathbf{D} and \mathbf{T} ‘sensible’ to avoid over-fitting. We can do that with a regularisation term. So the overall optimisation problem looks like this:

User embedding

For the user embedding there’s a similar application of NMF, but in this case we’re splitting the adjacency matrix \mathbf{A} into a user latent matrix \mathbf{U} \in \mathbb{R}^{m \times d}, and a user correlation matrix \mathbf{T} \in \mathbb{R}^{d \times d}. So in this case we’re using NMF to learn \mathbf{UTU^T} which has shape mxd . dxd . dxm, resulting in the desired mxm shape. There’s also a user-user relation matrix \mathbf{Y} which controls the contribution of \mathbf{A}. The basic idea is that any given user will only share a small fraction of news articles, so a positive case (having shared an article) should have more weight than a negative case (not having shared).

User-news interaction embedding

For the user-news interaction embedding we want to capture the relationship between user features and the labels of news items. The intuition is that users with low credibility are more likely to spread fake news. So how do we get user credibility? Following ‘Measuring user credibility in social media’ the authors base this on similarity to other users. First users are clustered into groups such that members of the same cluster all tend to share the same news stories. Then each cluster is given a credibility score based on its relative size. Users take on the credibility score of the cluster they belong to. It all seems rather vulnerable to the creation of large numbers of fake bot accounts that collaborate to spread fake news if you ask me. Nevertheless, assuming we have reliable credibility scores then we want to set things up such that the latent features of high-credibility users are close to true news, and the latent features of low-credibility users are close to fake news.

Publisher-news embeddings

Recall we have the matrix \mathbf{B} encoding which publishers have published which news pieces. Let \bar{\mathbf{B}} be the normalised version of the same. We want to find \mathbf{q}, a weighting matrix mapping news publisher’s latent features to the corresponding partisan label vector \mathbf{o}. It looks like this:

Semi-supervised linear classifier

Using the labelled data available, we also learn a weighting matrix \mathbf{p} mapping news latent features to fake news labels.

Putting it all together

The overall objective becomes to find matrices \mathbf{D,U,V,T,p,q} using a weighted combination of each of the above embedding formulae, and a regularisation term combining all of the learned matrices.

It looks like this:

and it’s trained like this:


TriFN is evaluated against several state of the art fake news detection methods using the FakeNewsNet BuzzFeed and PolitiFact datasets.

It gives the best performance on both of them:


ExFaKT: a framework for explaining facts over knowledge graphs and text

February 11, 2019

ExFaKT: a framework for explaining facts over knowledge graphs and text Gad-Elrab et al., WSDM’19

Last week we took a look at Graph Neural Networks for learning with structured representations. Another kind of graph of interest for learning and inference is the knowledge graph.

Knowledge Graphs (KGs) are large collections of factual triples of the form \langle subject\ predicate\ object \rangle (SPO) about people, companies, places etc.

Today’s paper choice focuses on the topical area of fact-checking : how do we know whether a candidate fact, which might for example be harvested from a news article or social media post, is likely to be true? For the first generation of knowledge graphs, fact checking was performed manually by human reviewers, but this clearly doesn’t scale to the volume of information published daily. Automated fact checking methods typically produce a numerical score (probability the fact is true), but these scores are hard to understand and justify without a corresponding explanation.

To better support KG curators in deciding the correctness of candidate facts, we propose a novel framework for finding semantically related evidence in Web sources and the underlying KG, and for computing human—comprehensible explanations for facts. We refer to our framework as ExFaKT (Explaining Facts over KGs and Text resources).

There could be multiple ways of producing evidence for a given fact. Intuitively we prefer a short concise explanation to a long convoluted one. This translates into explanations that use as few atoms and rules as possible. Furthermore, we prefer evidence from trusted sources to evidence over less trusted sources. In this setting, the sources available to us are the existing knowledge graph and external text resources. The assumption is that we have some kind of quality control over the addition of facts to the knowledge graph, and so we prefer explanations making heavy use of knowledge graph facts to those than rely mostly on external texts.

In addition to the facts in the knowledge base, we have at our disposal a collection of rules. These rules may have been automatically harvested from unstructured documents (see e.g. DeepDive), or they might be provided by human authors. In the evaluation 10 students are given a short 20 minute tutorial on how to write rules, and then asked to pick 5 predicates from a list of KG predicates and write at least one supporting rule and one refuting rule for each. It took 30 minutes for the students to produce 96 rules between them. More than half of the rules were strong (represent causality and generalise), and more than 80% were either strong or valid (valid rules capture a correlation but may be tied to specific cases). The remaining incorrect rules were filtered out by having the same participants judge each others rules using a voting scheme. The authors conclude that manual creation of rules is not a bottleneck, and could be informed by crowdsourcing at a fairly low cost.

A rule is specified using the implication form of Horn clauses: H \leftarrow B_1, B_2, ..., B_n. H is the head, which we can infer if all of the body clauses are true. For example,

  • citizenOf(X,Y) \leftarrow mayorOf(X,Z), locatedIn(Z,Y)

E.g., if London is located in England, and Sadiq Khan is the mayor of London, then we can infer that Sadiq Khan is a citizen of England.

One technique we could try is to iterate from known facts and rules to a fixpoint. The challenge here is that we are including an external text corpus as part of the resources available to us. That text corpus could be, e.g. ‘the Web’. Finding every possible deducible fact across the entire web, and then checking to see if our candidate fact is among them isn’t going to be very efficient!

So what ExFaKT does instead is to work backwards from the candidate fact, recursively using query rewriting to break the query down into a set of simpler queries, until we arrive at body atoms. The recursion stops when either no rules can be found or all atoms are instantiated either by facts in the KG or by text.

For example, suppose we have a knowledge graph with three facts:

directed(lucas, star_wars)
directed(lucas, amer_graffiti)

And two rules:

influencedBy(X,Y) <- isDirector(X), directed(Y,Z),
inspiredBy(X,Y) <- liked(X,Y), isArtist(X)

For out text corpus we’ll assume all Wikipedia articles. The query we want to fact check is influencedBy(nolan, lucas).

Expanding the initial query we have three things to check: isDirector(nolan), directed(lucas,Z) and inspiredBy(nolan, Z). We know that isDirector(nolan) directly from the fact base. Now we can ground the second directed(lucas, Z) from the fact base, leading us to two candidate explanations:

  • isDirector(nolan), directed(lucas, star_wars), inspiredBy(nolan, star_wars)
  • isDirector(nolan), directed(lucas, amer_graffiti), inspiredBy(nolan, amer_graffiti)

Starting with the first candidate, inspired(nolan, star_wars) is found in wikipedia, so we add this explanation to the output set. Since text is a noisy resource, we can also break down inspiredBy(nolan, star_wars) into the two subgoals liked(nolan, star_wars), isArtist(nolan). If these two atoms are also spotted in the text corpus, we can add this as an additional explanation to the output set.

Using bind to refer to the operation of retrieving answers to a given query from underlying data sources, the overall ExFaKT algorithm looks like this:

We keep track of the depth of recursion to limit the total number of rewritings and ensure termination. The algorithm can be stopped whenever it has uncovered satisfactory explanations, but we want to ensure that we find good explanations meeting the criteria we laid out earlier. Thus when picking a promising explanation to explore (line 4), we favour shorter explanations and explanations with fewer rewritings. When picking an atom to search for (line 9), atoms without variables are preferred, then those with some constants, and atoms with only variables are sent to the back of the queue. Atoms with KG substitutions are preferred to those that can only be backed up by text.

Key results from the evaluation shows that:

  • Combining both KG and textual resources results in superior predicate recall than relying solely on just the KG or just the text corpus.
  • When the evidence produced by ExFaKT is presented to a human, they are able to judge the truthfulness of a fact candidate in 27 seconds on average. Using just standard web searches it took them on average 51 seconds.

This illustrates the benefits of using our method for increasing the productivity and accuracy of human fact-checkers.

One of the things I wondered when working through the paper is that the system seems very vulnerable to confirmation bias. I.e., it deliberately goes looking for confirming facts, and declares the candidate true if it finds them. But maybe there is an overwhelming body of evidence to the contrary, which the system is going to ignore? The answer to this puzzle is found in section 4.5, where the authors evaluate the use of ExFaKT in automated fact checking. For each candidate fact ExFaKT is used to generate two sets of supporting explanations: one set confirming the fact, and one set refuting it. By scoring the evidence presented (roughly, the trust level of the sources used, over the depth of the explanation) it’s possible to come to a judgement as to which scenario is the more likely.

The conducted experiments demonstrate the usefulness of our method for supporting human curators in making accurate decisions about the truthfulness of facts as well as the potential of our explanations for improving automated fact-checking systems.

Graph neural networks: a review of methods and applications

February 8, 2019

Graph neural networks: a review of methods and applications Zhou et al., arXiv 2019

It’s another graph neural networks survey paper today! Cue the obligatory bus joke. Clearly, this covers much of the same territory as we looked at earlier in the week, but when we’re lucky enough to get two surveys published in short succession it can add a lot to compare the two different perspectives and sense of what’s important. In particular here, Zhou et al., have a different formulation for describing the core GNN problem, and a nice approach to splitting out the various components. Rather than make this a standalone write-up, I’m going to lean heavily on the Graph neural network survey we looked at on Wednesday and try to enrich my understanding starting from there.

An abstract GNN model

For this survey, the GNN problem is framed based on the formulation in the original GNN paper, ‘The graph neural network model,’ Scarselli 2009.

Associated with each node is an s-dimensional state vector.

The target of GNN is to learn a state embedding \mathbf{h}_v \in \mathbb{R}^s which contains the information of the neighbourhood for each node.

Given the state embedding we can produce a node-level output \mathbf{o}_v, for example, a label.

The local transition function, f, updates a node’s state based on its input neighbourhood, and the local output function, g, describes how the output is produced. We can express \mathbf{h} and \mathbf{o} in terms of f and g:

  • \mathbf{h}_v = f(\mathbf{x}_v, \mathbf{x}_{co[v]}, \mathbf{h}_{ne[v]}, \mathbf{x}_{ne[v]}).
  • \mathbf{o}_v = g(\mathbf{h}_v, \mathbf{x}_v)

    Where \mathbf{x}_v are the features of node v, and \mathbf{x}_{co[v]} are the features of its edges. \mathbf{h}_{ne[v]} and \mathbf{x}_{ne[v]}) are the state and node features for the neighbourhood of v.

If we stack all the embedded states into one vector, \mathbf{H}, the outputs into one vector \mathbf{O}, the node and edge features into one vector \mathbf{X} and the just the node features into \mathbf{X}_N then we have:

  • \mathbf{H} = F(\mathbf{H}, \mathbf{X}), and
  • \mathbf{O} = G(\mathbf{H}, \mathbf{X}_N)

Where F is the global transition function and _G is the global output function.

The original GNN iterates to a fix point using \mathbf{H}^{t+1} = F(\mathbf{H}^t, \mathbf{X}).

Graph types, propagation steps, and training methods

Zhou et al. break their analysis down into three main component parts: the types of graph that a method can work with, the kind of propagation step that is used, and the training method.

Graph types

In the basic GNN, edges are bidirectional, and nodes are homogenous. Zhou et al. consider three extensions: directed graphs, heterogenous graphs, and attributed edges.

When edges have attributes there are two basic strategies: you can associate a weight matrix with the edges, or you can split the original edge into two introducing an ‘edge node’ in the middle (which then holds the edge attributes). If there are multiple different types of edges you can use one weight and one adjacency matrix per edge type.

For graphs with several different kinds of node the simplest strategy is to concatenate a one-hot encoded node type vector to the node features. GraphInception creates one neighbourhood per node type and deals with each as if it were a sub-graph in a homogenous graph.

Propagation steps

There are lots of different approaches to propagation. Many of which we looked at earlier this week .

Gated approaches use Gate Recurrent Units (GRU) or LSTM models to try and improve long-term propagation of information across the graph structure. Gated Graph Neural Network (GGNN) uses GRUs with update functions that incorporate information from other nodes and from the previous timestep. Graph LSTM approaches adapt LSTM propagation based on a tree or graph. In Tree-LSTM each unit contains one forget gate and parameter matrix per child, allowing the unit to selectively incorporate per-child information. The tree approach can be generalised to an arbitrary graph, one application of which is in sentence encoding: Sentence-LSTM (S-LSTM) converts the text into a graph and then uses a Graph LSTM to learn a representation.

Another classic approach to try and improve learning across many layers is the introduction of skip connections. Highway GCN uses ‘highway gates’ which sum the output of a layer with its input and gating weights. The Jump Knowledge Network selects from amongst all the intermediate representations for each node at the last layer (the selected representation ‘jumps’ to the output layer).

The following monster table shows how some of the different GNN methods map into the abstract model we started with.

Training methods

GNNs struggle with large graphs, so a number of training enhancements have been proposed.

GraphSage, as we saw earlier this week, uses neighbourhood sampling. The control-variate based approach uses historical activations of nodes to limit the receptive field in the 1-hop neighbourhood (I’d need to read the referenced paper to really understand what that means here!). Co-training GCN and Self-training GCN both try to deal with performance drop-off in different ways: co-training enlarges the training data set by finding nearest neighbours, and self-training uses boosting.

The Graph Network

Section 2.3.3 in the paper discusses Graph Networks, which generalise and extend Message-Passing Neural Networks (MPNNs) and Non-Local Neural Networks (NLNNs). Graph networks support flexible representations for global, node, and edge attributes; different functions can be used within blocks; and blocks can be composed to construct complex architectures.

More applications of GNNs

Section 3 discusses many applications of graph neural networks, including the following handy table with references to example papers and algorithms.

Research challenges

There is good agreement on the open challenges with graph neural networks (depth, scale, skew, dynamic graphs with structure that changes over time). There’s also a discussion of the challenges involved in taking unstructured inputs and turning them into structured forms – this is a different kind of graph generation problem.

…finding the best graph generation approach will offer a wider range of fields where GNN could make a contribution.

A comprehensive survey on graph neural networks

February 6, 2019

A comprehensive survey on graph neural networks Wu et al., arXiv’19

Last year we looked at ‘Relational inductive biases, deep learning, and graph networks,’ where the authors made the case for deep learning with structured representations, which are naturally represented as graphs. Today’s paper choice provides us with a broad sweep of the graph neural network landscape. It’s a survey paper, so you’ll find details on the key approaches and representative papers, as well as information on commonly used datasets and benchmark performance on them.

We’ll be talking about graphs as defined by a tuple (V, E, A) where V is the set of nodes (vertices), E is the set of edges, and A is the adjacency matrix. An edge is a pair (v_i, v_j), and the adjacency matrix is an N \times N (for N nodes) matrix where A_{ij} = 0 if nodes v_i and v_j are not directly connected by a edge, and some weight value > 0 if they are.

In an attributed graph we also have a set of attributes for each node. For node attributes with D dimensions we have a node feature matrix X \in R^{N \times D}.

A spatial-temporal graph is one where the feature matrix X evolves over time. It is defined as G = (V, E, A, X) with X \in R^{T \times N \times D} for T time steps.

Applications of graph networks

I thought I’d start by looking at some of the applications of graph neural networks as motivation for studying them.

One of the biggest application areas for graph neural networks is computer vision. Researchers have explored leveraging graph structures in scene graph generation, point clouds classification and segmentation, action recognition, and many other directions.

Take a scene graph for example. Entities in the scene are nodes in the graph, and edges can capture the relationships between them. We might want a scene graph as an output of image interpretation. Alternatively, we can start with a scene graph as input, and generate realistic looking images from it.

Point clouds from LiDAR scans can be converted into k-nearest-neighbour graphs and graph convolutional networks used to explore the topological structure.

In action recognition the human skeleton can be represented as a graph, and spatial-temporal networks can use this representation to learn action patterns. Another application of spatial-temporal network is modelling and predicting traffic congestion, as well as anticipating taxi demand.

Graphs also naturally model relationships between items and users, making it possible to build graph-based recommender systems. In chemistry, researchers use molecular graphs to study the graph structures of molecules. Here node classification, graph classification, and graph generation are the three main tasks.

The big picture

The fundamental building block of many graph-based neural networks is the graph convolution network or GCN. In a standard image convolution we move a filter over the image and compute some function of the inputs captured by it, with adjacency determining the area captured by the filter. In graph convolution we move a filter over the nodes of the graph, with the adjacency matrix determining the area captured by the filter.

Strictly this is a spatial GCN, there are also spectral-based GCNs which interpret graphs as a signal (normalising the structure and laying the nodes out in order to form the signal). I find the spectral approach much less intuitive, but thankfully spatial models seem to be where all the action this these days.

If we add the concept of attention, we arrive at Graph Attention Networks.

If we add the idea that the graph changes over time then we get to Graph Spatial-Temporal Networks. Graph Auto-encoders combine the familiar encoder-decoder pairs, but using graph representations on both sides. Finally, the class of Graph Generative Networks aim to generate plausible structures from data.

The following table highlights some of the key approaches in these extended application areas.


The following table highlights a selection of GCN approaches.

The earliest networks used a spectral approach, with graph signal filters performing convolution operations. However spectral-based models have a number of limitations: they can’t easily scale to larger graphs as the computational cost increases dramatically with graph size; they assume a fixed graph, so they generalise poorly to different graphs; and they only work on undirected graphs (§4.4). Spatial models have therefore been attracting increasing attention.

In the domain of spatial GCNs there are two further main divisions. Recurrent-based methods apply the same graph convolution layer repeatedly over time until a stable fixed point is reached. Composition-based GCNs stack a number of convolutional layers, much like a classic CNN. A representative example here is Message-Passing Neural Networks.

To help scale to larger graphs, GraphSage introduces a sampling batch training algorithm. It first samples the k-hop neighbourhood of a node, then derives that node’s final state by aggregating the feature information of the sampled neighbours. This derived state is then used to make predictions and backpropagate errors.

If we want to throw a few pooling layers into mix to keep things manageable, then we’ll need a way of doing graph pooling.

Images have a natural structure for pooling, for graphs we can use sub-graphs as the pools. DIFFPOOL provides a general solution for hierarchical pooling of nodes across a broad set of input graphs, and can be used in end-to-end training. It learns a cluster assignment matrix based on cluster node features and a coarsened adjacency matrix.

Graph Auto-encoders

Graph auto-encoders learn low dimensional node vectors via an encoder, and then reconstruct the graph via a decoder. They can be used to learn graph embeddings.

Some approaches learn node embeddings based only on topological structure, while others also make use of node content features.

Graph Spatial-Temporal networks

Graph spatial-temporal networks have a global graph structure with inputs to the nodes changing across time. For example, a traffic network (the structure) with traffic arriving over time (the content). This is modelled as Input tensors X \in R^{T \times N \times D}. In CNN-GCN, shown below, a 1D-CNN layer slides over X in the time axis while the GCN operates on the spatial information at each time step.

Graph generative networks

An example in this category is MolGAN (Molecular Generative Adversarial Networks), which integrates relational GCN, improved GAN, and reinforcement learning.

In MolGAN, the generator tries to propose a fake graph along with its feature matrix while the discriminator aims to distinguish the faked sample from the empirical data. Additionally, a reward network is introduced in parallel with the discriminator to encourage the generated graphs to possess certain properties according to an external evaluator.

Datasets & Benchmarks

The most commonly used datasets to benchmark graph neural networks are shown in the table below.

And for the four most frequently used datasets (Cora, Citeseer, Pubmed, and PPI), here are the leading models:

Research challenges and directions

I’ve only been able to scratch the surface of this 22 page survey in this short write-up. It’s clear that interest in graph neural networks is growing at a pace, but there are still a number of research challenges to be addressed.

  • We haven’t figured out how to go deep with graph layers without model performance dropping off dramatically. It’s possible that depth is not a good strategy for learning graph-structured data.
  • In many graphs of interest node degrees follow a power-law distribution, and hence some nodes have a huge number of neighbours and many nodes have very few. Researchers are trying to figure out how to deal with this problem through effective sampling strategies. I politely suggest they might want to look into edge partitioning too.
  • Scaling to larger graphs, especially with multiple layers, remains difficult.
  • Most current approaches assume static homogenous graphs. Neither of those conditions hold in many real world graphs that have different attributes for different node types, and where nodes and edges may come and go over time.

TensorFlow.js: machine learning for the web and beyond

February 4, 2019

TensorFlow.js: machine learning for the web and beyond Smilkov et al., SysML’19

If machine learning and ML models are to pervade all of our applications and systems, then they’d better go to where the applications are rather than the other way round. Increasingly, that means JavaScript – both in the browser and on the server.

TensorFlow.js brings TensorFlow and Keras to the the JavaScript ecosystem, supporting both Node.js and browser-based applications. As well as programmer accessibility and ease of integration, running on-device means that in many cases user data never has to leave the device.

On-device computation has a number of benefits, including data privacy, accessibility, and low-latency interactive applications.

TensorFlow.js isn’t just for model serving, you can run training with it as well. Since its launch in March 2018, people have done lots of creative things with it. And since it runs in the browser, these are all accessible to you with just one click! Some examples:

As a desktop example, Node Clinic Doctor, an open source Node.js performance profiling tool, integrated a TensorFlow.js model to separate CPU usage spikes caused by the user from those caused by Node.js internals.

TensorFlow.js also ships with an official repository of pretrained models:

One of the major benefits of the JS ecosystem is the ease at which JS code and static resources can be shared. TensorFlow.js takes advantage of this by hosting an official repository of useful pretrained models, serving the weights on a publicly available Google Cloud Storage bucket.

The model prediction methods are designed for ease-of-use so they always take native JS objects such as DOM elements or primitive arrays, and they return JS objects representing ‘human-friendly’ predictions.

Hopefully all that has whetted your appetite to explore what TensorFlow.js has to offer!

TensorFlow.js from the user’s perspective

TensorFlow.js offers two levels of API, both supported across browser and Node.js environments: the Ops API for lower-level linear algebra operations, and the Layers API for high level model building blocks and best practices.

The Layers API mirrors Keras as closely as possible, enabling users to build a model by assembling a set of pre-defined layers.

This enables a two-way door between Keras and TensorFlow.js; users can load a pretrained Keras model in TensorFlow.js, modify it, serialize it, and load it back in Keras Python.

Here’s an example of a TensorFlow.js program building and training a single-layer linear model:

// A linear model with 1 dense layer
const model = tf.sequential();
  units: 1, inputShape: [1]

// Specify the loss and the optimizer
  loss: ‘meanSquaredError’,
  optimizer: ‘sgd’

// Generate synthetic data to train
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);

// Train the model using the data, ys).then(() => {
  // Do inference on an unseen data point and
  // print the result
  const x = tf.tensor2d([5], [1, 1]);

TensorFlow.js uses an eager differentiation engine, which is more developer-friendly than the alternative graph-based engines.

In eager mode, the computation happens immediately when an operation is called, making it easier to inspect results by printing or using a debugger. Another benefit is that all the functionality of the host language is available while your model is executing; users can use native if and while loops instead of specialized control flow APIs that are hard to use and produce convoluted stack traces.

Debugging tools also come out of the box with TensorFlow.js to help troubleshoot common problems with performance and numerical stability.

The focus on ease of use also shows in the approaches to asynchrony and memory management. Operations are purposefully synchronous where they can be, and asynchronous where they need to be. For WebGL memory, which needs to be explicitly freed, the tf.tidy() wrapper function will take care of this for you.

TensorFlow.js has multiple backend implementations to get the best performance possible out of the execution environment. There is a plain JavaScript implementation that will run everywhere as a fallback; a WebGL-based implementation for browsers; and a server-side implementation for Node.js that binds directly to the TensorFlow C API. The speed-ups over plain JS are definitely worth it: WebGL and Node.js CPU backends are two orders of magnitude faster on a MacBook pro, and three orders of magnitude faster when using a more capable graphics card on a desktop machine.

Based on data available at, TensorFlow’s WebGL implementation should be able to run on 99% of desktop devices, 98% of iOS and Windows mobile devices, and 52% of Android devices. (Android shows lower due to a long tail of older devices with no GPU hardware).

Under the covers

There are a number of challenges involved in building a JavaScript based ML environment, including: the number of different environments in which JavaScript can execute; extracting good enough performance; cross-browser compatibility issues; and its single-threaded nature.

Good performance for machine learning these days means GPUs. Within the browser, the way to get at the GPU is via WebGL.

In order to utilize the GPU, TensorFlow.js uses WebGL, a cross-platform web standard providing low-level 3D graphics APIs… (there is no explicit support for GPGPU). Among the three TensorFlow.js backends, the WebGL backend has the highest complexity. This complexity is justified by the fact that it is two orders of magnitude faster than our CPU backend written in plain JS. The realization that WebGL can be re-purposed for numerical computation is what fundamentally enabled running real-world ML models in the browser.

Without any GPGPU support, TensorFlow.js has to map everything into graphics operations. Specifically, it exploits fragment shaders which where originally designed to generate the colors of pixels to be rendered on the screen. Only the red ‘R’ channel is currently used, with a gl.R32F texture type that avoids allocating memory for the green, blue, and alpha channels. To make writing the OpenGL Shading Language (GLSL) code easier, under the hood TensorFlow.js makes use of a shader compiler. The compiler separates logical and physical spaces so that it can make the code into the most efficient form given the device-specific size limits of WebGL textures.

Since disposing and re-allocating WebGL textures is relatively expensive, textures are reused via a ‘texture recycler’. To avoid out-of-memory conditions, WebGL textures will be paged to the CPU whenever the total amount of GPU memory allocated exceeds a threshold.

Where next

Two new web standards, WebAssembly and WebGPU, both have potential to improve TensorFlow.js performance. WebGPU is an emerging standard to express general purpose parallel computation on the GPU, enabling more optimised linear algebra kernels than those the WebGL backend can support today.

Future work will focus on improving performance, continued progress on device compatibility (particularly mobile devices), and increasing parity with the Python TensorFlow implementation. We also see a need to provide support for full machine learning workflows, including data input, output, and transformation.