TensorFlow.js: machine learning for the web and beyond

TensorFlow.js: machine learning for the web and beyond Smilkov et al., SysML’19

If machine learning and ML models are to pervade all of our applications and systems, then they’d better go to where the applications are rather than the other way round. Increasingly, that means JavaScript – both in the browser and on the server.

TensorFlow.js brings TensorFlow and Keras to the the JavaScript ecosystem, supporting both Node.js and browser-based applications. As well as programmer accessibility and ease of integration, running on-device means that in many cases user data never has to leave the device.

On-device computation has a number of benefits, including data privacy, accessibility, and low-latency interactive applications.

TensorFlow.js isn’t just for model serving, you can run training with it as well. Since its launch in March 2018, people have done lots of creative things with it. And since it runs in the browser, these are all accessible to you with just one click! Some examples:

As a desktop example, Node Clinic Doctor, an open source Node.js performance profiling tool, integrated a TensorFlow.js model to separate CPU usage spikes caused by the user from those caused by Node.js internals.

TensorFlow.js also ships with an official repository of pretrained models:

One of the major benefits of the JS ecosystem is the ease at which JS code and static resources can be shared. TensorFlow.js takes advantage of this by hosting an official repository of useful pretrained models, serving the weights on a publicly available Google Cloud Storage bucket.

The model prediction methods are designed for ease-of-use so they always take native JS objects such as DOM elements or primitive arrays, and they return JS objects representing ‘human-friendly’ predictions.

Hopefully all that has whetted your appetite to explore what TensorFlow.js has to offer!

TensorFlow.js from the user’s perspective

TensorFlow.js offers two levels of API, both supported across browser and Node.js environments: the Ops API for lower-level linear algebra operations, and the Layers API for high level model building blocks and best practices.

The Layers API mirrors Keras as closely as possible, enabling users to build a model by assembling a set of pre-defined layers.

This enables a two-way door between Keras and TensorFlow.js; users can load a pretrained Keras model in TensorFlow.js, modify it, serialize it, and load it back in Keras Python.

Here’s an example of a TensorFlow.js program building and training a single-layer linear model:


// A linear model with 1 dense layer
const model = tf.sequential();
model.add(tf.layers.dense({
  units: 1, inputShape: [1]
}));

// Specify the loss and the optimizer
model.compile({
  loss: ‘meanSquaredError’,
  optimizer: ‘sgd’
});

// Generate synthetic data to train
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);

// Train the model using the data
model.fit(xs, ys).then(() => {
  // Do inference on an unseen data point and
  // print the result
  const x = tf.tensor2d([5], [1, 1]);
  model.predict(x).print();
});

TensorFlow.js uses an eager differentiation engine, which is more developer-friendly than the alternative graph-based engines.

In eager mode, the computation happens immediately when an operation is called, making it easier to inspect results by printing or using a debugger. Another benefit is that all the functionality of the host language is available while your model is executing; users can use native if and while loops instead of specialized control flow APIs that are hard to use and produce convoluted stack traces.

Debugging tools also come out of the box with TensorFlow.js to help troubleshoot common problems with performance and numerical stability.

The focus on ease of use also shows in the approaches to asynchrony and memory management. Operations are purposefully synchronous where they can be, and asynchronous where they need to be. For WebGL memory, which needs to be explicitly freed, the tf.tidy() wrapper function will take care of this for you.

TensorFlow.js has multiple backend implementations to get the best performance possible out of the execution environment. There is a plain JavaScript implementation that will run everywhere as a fallback; a WebGL-based implementation for browsers; and a server-side implementation for Node.js that binds directly to the TensorFlow C API. The speed-ups over plain JS are definitely worth it: WebGL and Node.js CPU backends are two orders of magnitude faster on a MacBook pro, and three orders of magnitude faster when using a more capable graphics card on a desktop machine.

Based on data available at WebGLStats.com, TensorFlow’s WebGL implementation should be able to run on 99% of desktop devices, 98% of iOS and Windows mobile devices, and 52% of Android devices. (Android shows lower due to a long tail of older devices with no GPU hardware).

Under the covers

There are a number of challenges involved in building a JavaScript based ML environment, including: the number of different environments in which JavaScript can execute; extracting good enough performance; cross-browser compatibility issues; and its single-threaded nature.

Good performance for machine learning these days means GPUs. Within the browser, the way to get at the GPU is via WebGL.

In order to utilize the GPU, TensorFlow.js uses WebGL, a cross-platform web standard providing low-level 3D graphics APIs… (there is no explicit support for GPGPU). Among the three TensorFlow.js backends, the WebGL backend has the highest complexity. This complexity is justified by the fact that it is two orders of magnitude faster than our CPU backend written in plain JS. The realization that WebGL can be re-purposed for numerical computation is what fundamentally enabled running real-world ML models in the browser.

Without any GPGPU support, TensorFlow.js has to map everything into graphics operations. Specifically, it exploits fragment shaders which where originally designed to generate the colors of pixels to be rendered on the screen. Only the red ‘R’ channel is currently used, with a gl.R32F texture type that avoids allocating memory for the green, blue, and alpha channels. To make writing the OpenGL Shading Language (GLSL) code easier, under the hood TensorFlow.js makes use of a shader compiler. The compiler separates logical and physical spaces so that it can make the code into the most efficient form given the device-specific size limits of WebGL textures.

Since disposing and re-allocating WebGL textures is relatively expensive, textures are reused via a ‘texture recycler’. To avoid out-of-memory conditions, WebGL textures will be paged to the CPU whenever the total amount of GPU memory allocated exceeds a threshold.

Where next

Two new web standards, WebAssembly and WebGPU, both have potential to improve TensorFlow.js performance. WebGPU is an emerging standard to express general purpose parallel computation on the GPU, enabling more optimised linear algebra kernels than those the WebGL backend can support today.

Future work will focus on improving performance, continued progress on device compatibility (particularly mobile devices), and increasing parity with the Python TensorFlow implementation. We also see a need to provide support for full machine learning workflows, including data input, output, and transformation.

6 thoughts on “TensorFlow.js: machine learning for the web and beyond

  1. “Android shows lower due to a long tail of older devices with no GPU hardware)”
    Did you maybe mean, with “no GPU hardware acceleration”?

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.