Noah Trenaman

Cooperative And Adversarial Language Games With Neural Networks

July 03, 2020
Code on GitHub

In 1948, Claude Shannon published A Mathematical Theory of Communication, which introduces a way to think about systems where some kind of communication is taking place. It discusses the idea of "information" and how it can be represented or transformed. To communicate a message correctly, we don't want to lose any of the important information that it contains.

At a high level, the patterns we use to communicate information tend to be called "protocols" or "languages." For example, HTTP is a protocol for passing information through the Internet, and English is a language for passing ideas from person to person. A general model of communication can help describe how protocols and languages work1. By applying these ideas to deep learning systems, we can see how neural networks could generate rudimentary protocols or "languages" based on given constraints. By shaping these constraints, we can create interesting and potentially useful representations of information.

The Shannon-Weaver Model of Communication

The basic commonality between communication systems is that they have five abstract parts, which each play an important role. The parts are:

  1. An information source
    the message we want to communicate
  2. A transmitter
    a process for representing that message
  3. A channel
    how the message is represented as it travels from the source to the destination
  4. A receiver
    a process for reconstructing the message
  5. A destination
    the message that was received
Original Representation Transformation Modified Representation

Using Neural Networks To Generate Simple Writing Systems

The transmitter and receiver have the special role of transforming information from one representation to another. Neural networks are very good at transforming information from one form to another, so we could imagine systems where they take on these roles as communicators. An autoencoder is probably the most relevant example. The input and the output of an autoencoder are the same, which is similar to a lot of communication systems. (The message received is the same as the one sent.) The interesting part of an autoencoder is what is called the bottleneck, or latent space, which lines up with the idea of a communication channel. The latent space is the information in the middle of an autoencoder, which is a modified representation of the original message.

There's been some more technical and rigorous work exploring the idea of using neural networks to communicate information over a noisy or limited channel2. I've been interested in exploring how these systems could generate patterns that are human-interpretable. Particularly visual patterns generated by neural networks that operate similarly to writing systems. This could have a whole lot of interesting use cases such as generating highly readable fonts, turning source code into pictures, or transforming entire sentences or paragraphs into images using a language model such as BERT. Joel Simon did some really interesting experiments in generating writing systems using neural networks in 2018. His work inspired my initial curiosity and it is worth taking a look at.

This line of thinking could also be an interesting approach to machine learning interpretability. High-dimensional vectors are hard to visualize, so if they could be turned into perceptible images, maybe we could have a better understanding of the patterns that neural networks are picking up on.

To start with a toy problem, we could generate an alphabet of a particular size. For an alphabet with 10 characters, one can train an autoencoder to produce 10 unique visual patterns. The inputs and outputs are simply the numbers 1-10 (as one-hot vectors), and the latent space is an image. The loss function should encourage the autoencoder to reconstruct the original message by accurately predicting what the input was, based on the image. We can do this with a simple cross-entropy loss which encourages the predicted message to match the input message. If we add some constraints to the latent space such as rotation or blur, we'll get patterns that are more coherent and distinct. This is similar to how, if you tilt your head or squint your eyes, you can still read the words in a sentence.

The result of training an autoencoder this way may look like this:

This animation interpolates between the different symbols. We get a very distinct visual pattern for each character in the artificial alphabet.

These generative systems grow in complexity in proportion to the complexity of the environment that they are generated in. We can set additional constraints on the communication channel to encourage more interesting patterns to emerge. We could also define different kinds of messages to represent visually. Essentially, we can define (with code) whatever environment we want to experiment with and then see what interesting patterns are generated!

Attempting To Learn A Math Notation System

One kind of message I enjoyed experimenting with was arithmetic expressions. I was curious if an autoencoder would, by turning syntax trees into images, learn a kind of mathematical notation system. Just like with a simple 10-character alphabet, we can train an autoencoder to learn representations of the message by defining a reconstruction loss function. The loss function for reconstructing syntax trees is a little trickier. A syntax tree predicted by a graph autoencoder (GraphVAE) is based on an adjacency matrix and a set of node feature matrices (in this case there was one for node type and one for node order). So, the message is a collection of one-hot matrices that are predicted all at once by the graph decoder.

The adjacency matrix is predicted with binary cross-entropy loss, and the node features are predicted with categorical cross-entropy loss. The two are summed together to give the full reconstruction loss (with a caveat3).

These patterns are visually distinct, but much less interpretable compared to the simple alphabet. At a first glance, it doesn't seem like there are any clear patterns between expressions.

Inspiration from GANs: Introducing A Spy

Maybe one of the most interesting constraints on an environment is to have another intelligent agent that is actively trying to learn the same language, and trying to deceive the decoder network. We'll find that this is very similar to a GAN, there are two generators, and the images are not constrained to any human-curated dataset. In this system, there is an honest generator that is designing our language, and a spy generator that is trying to discover the language. The true generator and decoder will (ideally) collaborate to outsmart the spy. (Note: The rest of this post will include a bit of math notation.)

Our system has two generators and one decoder. One is the "true" generator GtG_t which wants to generate symbols that the decoder DD can interpret. Their relationship is cooperative. At the same time, there is a "spy" generator GsG_s which attempts to imitate GtG_t and thus trick DD. Their relationship is adversarial.

The spy has no access to the true generator, but it can learn to imitate GtG_t via backpropagation of its loss function — GsG_s is trained to trick DD by making symbols that are indistinguishable from GtG_t. Thus, both generators attempt to transform a message MM into a representation that DD can understand:

zt=Gt(M)z_t = G_t(M)

zs=Gs(M)z_s = G_s(M)

And if both generators succeed, DD will be able to reconstruct the original message from the generated representations:

D(zt)=D(zs)=MD(z_t) = D(z_s) = M

Catching The Spy

So far, there's nothing in this system that would be able to catch the insidious spy. We need to train GtG_t and DD to collaborate in outsmarting GsG_s. We can train DD itself to tell the difference between ztz_t and zsz_s, but couldn't we help the true generator more directly?

We can borrow the idea of adversarial training from GANs but update GtG_t and DD based on how much DD is being deceived. I think the closest example to the loss function is a Wasserstein GAN. A WGAN has a decoder network (called a "critic") that predicts the realness of an image. In this case, rather than predicting a scalar of "realness," we can encourage the decoder to be more certain with real images (ztz_t), and more uncertain when exposed to fake images (zsz_s).

Cooperative And Adversarial Loss

Like in a GAN, we have two loss functions described by competing networks.

Our shorthand for the loss of a prediction MM' of a message MM will be L(M,M)\mathcal{L}(M', M).

Cooperative loss (with GsG_s frozen):

Lc=L(D(Gt(M)),M)L(D(Gs(M)),M)\mathcal{L}_c = \mathcal{L}(D(G_t(M)), M) - \mathcal{L}(D(G_s(M)), M)

Read as: try to reconstruct messages from the true generator as accurately as possible, and try to reconstruct messages from the spy generator as inaccurately as possible (with maximum uncertainty).

Adversarial loss (with DD frozen):

La=L(D(Gs(M)),M)\mathcal{L}_a = \mathcal{L}(D(G_s(M)), M)

Read as: try to trick the decoder into thinking that the spy generator's representations are actually from the true generator.

Cross-entropy loss can be unstable and produce very large gradients when the cooperative networks are doing well in outsmarting the spy generator, to the point where the gradients for outsmarting the spy are greater than the cooperative objective itself. This can be fixed by making the adversarial loss linear rather than logistic. You can do this by switching out the cross-entropy-based loss function L\mathcal{L} with L1L_1 (mean absolute error) loss:

Lc=L(D(Gt(M)),M)L1(D(Gs(M)),M)\mathcal{L}_c = \mathcal{L}(D(G_t(M)), M) - L_1(D(G_s(M)), M)

Here are some examples of spies attempting to copy the true generator. The top row is the original generator, and the bottom row is the spy (the middle row is a visualization of some image augmentations also used during training4):

A GAN uses a training dataset as its source of ground truth, but we're using the symbols generated from GtG_t. This makes our "dataset" differentiable.

What's The Equilibrium?

In competitive games, there's a Nash equilibrium that the system ideally converges towards. For a GAN, it should converge to the point where the decoder network's guesses are as good as random (it has been thoroughly deceived by the generator.) For this system, the equilibrium is different: the cooperative pair of networks should be able to outsmart the adversarial generator to the point where the spy generator's symbols are as good as random. In this winning case, the decoder has maximum certainty for true symbols and maximum uncertainty for counterfeit symbols.

Thus, later in training, the spy should be doing poorly:

Example Results

Cooperative (left) and adversarial (right) generators trained to represent syntax trees.

Generators trained to represent unique symbols (the alphabet task.)

(In both cases, the adversarial generator collapsed.)

I haven't done an in-depth study on the hyperparameters of both generators, but have found good results by giving the spy half as many parameters as the true generator (you can do this by cutting the number of feature channels in half.) This might encourage the true generator to make use of its extra parameters to outsmart the adversarial generator and could be a reasonable way to encourage more complex and diverse visual patterns.

  1. Disclaimer that natural language is very complex and can only be understood by information theory at a very abstract level.
  2. Neural Communication Systems with Bandwidth-limited Channel , Karen Ullrich et al., 2020
  3. Calculating the loss for GraphVAE is a little more complicated than this because a prediction MM' could be isomorphic to MM and still be completely valid. We need to do graph matching before calculating the loss. If you're interested, see section 3.3 of GraphVAE for more details.
  4. The code for doing these image augmentations (which are differentiable and TPU-compatible) can be found here.