Building an End-to-End Image Classification Web App with TensorFlow, Keras, and TensorFlow.js
Imagine a web application that can instantly identify objects from a user’s camera or uploaded photo — without sending a single byte of data to a server. No latency, no API costs, and full privacy. This is the promise of client-side machine learning, and it’s now within reach for any developer thanks to TensorFlow.js.
In this comprehensive guide, you’ll learn how to train a state‑of‑the‑art image classifier using TensorFlow and Keras in Python, convert it to a browser‑friendly format with the TensorFlow.js converter, and deploy it in a web app that runs inferences entirely on the client side. We’ll cover everything from data preparation and transfer learning to JavaScript inference and performance optimization. By the end, you’ll have a fully functional web app that classifies images in real time.
Table of Contents#
- What You’ll Learn
- Prerequisites
- Stage 1: Data Preparation (Python)
- Stage 2: Model Training with Keras (Transfer Learning)
- Stage 3: Converting the Model to TensorFlow.js
- Stage 4: Building the Web Frontend with TensorFlow.js
- Stage 5: Deploying and Optimizing
- Best Practices and Common Mistakes
- Real‑World Use Cases
- Conclusion
- References
What You’ll Learn#
- How to prepare an image dataset for classification using
tf.keras.utils.image_dataset_from_directory. - How to perform transfer learning with MobileNetV2 in Keras, including fine‑tuning.
- How to convert a Keras model to TensorFlow.js using both the CLI and the Python API.
- How to load the converted model in a browser and run predictions with tensorflow.js.
- How to build a simple web app (vanilla JS or with React) that lets users upload images and see predictions.
- Best practices for memory management, model size reduction (quantization), and avoiding CORS issues.
Prerequisites#
- Python 3.8+ and basic familiarity with Jupyter notebooks or scripts.
- Node.js and npm (for the web app).
- A development environment with TensorFlow 2.x installed (
pip install tensorflow). - Basic knowledge of HTML, CSS, and JavaScript.
Stage 1: Data Preparation (Python)#
The foundation of any good classifier is a well‑organized dataset. For this tutorial we’ll assume you have images sorted into class folders inside a data/ directory:
data/
├── train/
│ ├── cats/
│ ├── dogs/
│ └── flowers/
└── validation/
├── cats/
├── dogs/
└── flowers/
Loading the Data#
Use the image_dataset_from_directory helper to load and augment images on the fly:
import tensorflow as tf
from tensorflow.keras import layers
BATCH_SIZE = 32
IMG_SIZE = (224, 224) # MobileNetV2 expects 224x224
train_ds = tf.keras.utils.image_dataset_from_directory(
'data/train',
validation_split=0.2,
subset='training',
seed=123,
image_size=IMG_SIZE,
batch_size=BATCH_SIZE
)
val_ds = tf.keras.utils.image_dataset_from_directory(
'data/train',
validation_split=0.2,
subset='validation',
seed=123,
image_size=IMG_SIZE,
batch_size=BATCH_SIZE
)Data Augmentation (to Fight Overfitting)#
When you have a small dataset, augmentation is critical. Keras provides a pipeline of preprocessing layers:
data_augmentation = tf.keras.Sequential([
layers.RandomFlip('horizontal'),
layers.RandomRotation(0.2),
layers.RandomZoom(0.1),
layers.RandomBrightness(0.1),
])
# Apply augmentation to the training dataset
train_ds = train_ds.map(
lambda x, y: (data_augmentation(x, training=True), y)
)Tip: Normalize pixel values to [‑1, 1] (as MobileNetV2 expects) either in the preprocessing pipeline or as a layer in the model.
Stage 2: Model Training with Keras (Transfer Learning)#
For browser deployment, MobileNetV2 is the ideal backbone: it’s only 17 MB, runs fast on WebGL, and achieves excellent accuracy on many tasks. We’ll use it with ImageNet weights and then fine‑tune on our data.
Step 1: Load Pretrained Base and Freeze It#
base_model = tf.keras.applications.MobileNetV2(
input_shape=(224, 224, 3),
include_top=False, # remove the classification head
weights='imagenet'
)
base_model.trainable = False # freeze the base layersStep 2: Add a Custom Classification Head#
model = tf.keras.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(128, activation='relu'),
layers.Dropout(0.2),
layers.Dense(3, activation='softmax') # change to number of classes
])Step 3: Compile and Train the Head#
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=10
)Step 4: Fine‑Tune (Optional but Recommended)#
Unfreeze the top layers of the base model and retrain with a lower learning rate:
# Unfreeze the base model
base_model.trainable = True
# Freeze all layers except the last 30 (fine‑tune only the top)
for layer in base_model.layers[:100]:
layer.trainable = False
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
history_fine = model.fit(
train_ds,
validation_data=val_ds,
epochs=5
)Step 5: Save the Model#
model.save('image_classifier.h5')Note: Save as
.h5for the simplesttensorflowjs_converterworkflow. The SavedModel format is also supported via--input_format tf_saved_model.
Stage 3: Converting the Model to TensorFlow.js#
Install the tensorflowjs package and convert the model:
pip install tensorflowjs
tensorflowjs_converter --input_format keras image_classifier.h5 ./tfjs_modelAlternatively, do it from Python:
import tensorflowjs as tfjs
tfjs.converters.save_keras_model(model, './tfjs_model')The output folder (./tfjs_model) contains:
model.json– the model architecture and weight manifest.group1-shard1of1.bin– binary weight shards (may be multiple files).
Optimize Model Size with Quantization#
To reduce the download size (critical for mobile networks), add quantization during conversion:
tensorflowjs_converter --input_format keras \
--quantization_bytes 2 \
image_classifier.h5 ./tfjs_model_quantizedThis converts 32‑bit floats to 16‑bit floats, cutting the model size in half with minimal accuracy loss.
Important: Host the
model.jsonand.binfiles on the same domain as your web app to avoid CORS errors. If you must use a CDN, ensure CORS headers are set correctly.
Stage 4: Building the Web Frontend with TensorFlow.js#
We’ll build a simple HTML/JavaScript app. (The same principles apply to React, Vue, etc.)
Step 1: Set Up HTML#
<!DOCTYPE html>
<html>
<head>
<title>Image Classifier</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"></script>
</head>
<body>
<h1>Upload an Image</h1>
<input type="file" id="imageUpload" accept="image/*">
<img id="preview" width="224" height="224" style="display:none;">
<div id="results"></div>
<script src="app.js"></script>
</body>
</html>Step 2: Load the Model and Run Inference (app.js)#
let model;
async function loadModel() {
model = await tf.loadLayersModel('./model/model.json');
console.log('Model loaded');
}
function preprocessImage(imageElement) {
// Convert HTMLImageElement to tensor, resize, normalize to [-1, 1]
return tf.browser.fromPixels(imageElement)
.resizeNearestNeighbor([224, 224])
.toFloat()
.sub(tf.scalar(127.5))
.div(tf.scalar(127.5))
.expandDims(0);
}
async function classifyImage() {
if (!model) await loadModel();
const image = document.getElementById('preview');
const tensor = preprocessImage(image);
const predictions = await model.predict(tensor).data();
// Get class names (must match your training order)
const classNames = ['cat', 'dog', 'flower']; // update accordingly
const topK = 3;
// Sort predictions and display top‑K
const indices = Array.from(predictions)
.map((p, i) => ({ prob: p, idx: i }))
.sort((a, b) => b.prob - a.prob)
.slice(0, topK);
const resultsDiv = document.getElementById('results');
resultsDiv.innerHTML = indices.map(item =>
`<p>${classNames[item.idx]}: ${(item.prob * 100).toFixed(2)}%</p>`
).join('');
}
// Handle file upload
document.getElementById('imageUpload').addEventListener('change', (event) => {
const file = event.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = (e) => {
const img = document.getElementById('preview');
img.src = e.target.result;
img.style.display = 'block';
// Wait for image to load before classifying
img.onload = classifyImage;
};
reader.readAsDataURL(file);
});Memory Management#
Always dispose of intermediate tensors to prevent memory leaks that can crash the browser. For predictions, dispose the tensor manually after use:
const tensor = preprocessImage(image);
const predictions = await model.predict(tensor).data();
tensor.dispose(); // free GPU memoryFor intermediate operations, wrap them in tf.tidy() which auto-disposes tensors that are not returned.
Using WebGL Backend#
TensorFlow.js automatically uses WebGL if available. You can force it:
await tf.setBackend('webgl');Stage 5: Deploying and Optimizing#
Hosting#
- Place the
model/folder (withmodel.jsonand.binfiles) in your web server’s public directory. - Use a static hosting service like Netlify, Vercel, Firebase Hosting, or a simple Nginx server.
Performance Tips#
- Model size: Keep it under 50 MB for acceptable load times. MobileNet‑based models easily fit this.
- Quantization: Use
--quantization_bytes 1(8‑bit) for aggressive size reduction if accuracy allows. - Load model only once – store it in a global variable or a React ref.
- Process one image at a time – don’t block the UI with batch predictions.
- Consider the WASM backend if WebGL is not supported (smaller but slower than GPU).
Best Practices and Common Mistakes#
| Practice | Why |
|---|---|
Use tf.browser.fromPixels instead of manual canvas conversion | It’s optimized and handles memory better. |
| Normalize the same way as during training | MobileNetV2 expects pixel values in [-1, 1] (sub 127.5, divide by 127.5). Wrong normalization will break predictions. |
| Freeze base model first, then fine‑tune | If you start training the whole model, random initialisation of the new head can destroy pretrained weights. |
| Use same input size | 224×224 for MobileNetV2; resize with resizeBilinear or resizeNearestNeighbor. |
| Host model files on the same domain | CORS errors are the most common deployment issue. |
Don’t forget await when loading model or running predict. | Otherwise you’ll get a Promise object instead of the result. |
| Monitor browser console for memory warnings | Use tf.memory() for debugging; dispose tensors manually if not using tf.tidy(). |
Real‑World Use Cases#
- Medical image screening – Detect fractures or anomalies from X‑rays directly in the browser, keeping patient data private.
- E‑commerce product tagging – Let users snap a photo of an item and instantly see similar products.
- Content moderation – Flag inappropriate images client‑side before they are uploaded to a server.
- Educational tools – Interactive ML demos where students can upload their own images.
- Plant disease identification – Farmers in rural areas can use a lightweight web app offline (with a Service Worker caching the model).
Conclusion#
You now have a complete blueprint for building an end‑to‑end image classification web app that runs entirely in the browser. By combining TensorFlow/Keras for training, transfer learning with MobileNetV2, the TensorFlow.js converter, and the TensorFlow.js library, you can deliver fast, private, and cost‑effective AI experiences.
Key takeaways:
- Transfer learning lets you train accurate models with very little data.
- MobileNetV2 is the sweet spot between accuracy and browser performance.
- Quantization and proper memory management keep your app responsive.
- With client‑side ML, you eliminate server costs and latency, and user data never leaves the device.
Now go ahead and build something amazing – perhaps a “Not Hotdog” classifier of your own or a plant disease detector that works without an internet connection. The tools are in your hands.
References#
- TensorFlow.js official documentation
- TensorFlow Image Classification tutorial
- Keras Transfer Learning guide
- TensorFlow.js Models & pre‑trained options
- tfjs-converter README
- Image Classification with TensorFlow.js and React
- Classifying images using Keras MobileNet and TensorFlow.js
- Not Hotdog with Keras and TensorFlow.js