Skip to content

Conquering ControlNet

Harness the Power of Diffusion Models with Higher-Quality Data

ControlNet has been one of the biggest success stories in ML in 2023. The project, which has racked up 21,000+ stars on GitHub, was all the rage at CVPR – and for good reason: it’s an easy, interpretable way to exert influence over the outputs of diffusion models. 

Rather than running the same diffusion model on the same prompt over and over again, hoping for a reasonable result, you can guide the model via an input map. Hence ControlNet’s cheeky tagline: “Let us control diffusion models!” There are distinct ControlNet models to ‘control’ the output via Canny edge maps, segmentation masks, pose keypoints, and even scribbles.

Controlling stable diffusion via scribble maps with prompt “turtle”. Image from the ControlNet 1.0 GitHub repository.

One of the features that makes ControlNet so popular is its accessibility. In an era of hundred-billion parameter foundation models, ControlNet models are just 1.45GB (the same size as the underlying diffusion model). At a time when models like GPT-3.5 are being trained on tens of thousands of GPUs at a cost of hundreds of thousands, or even millions of USD, a ControlNet model can be trained at home on a single GPU in just 600 GPU hours! In other words, you can train your own ControlNet model.

Despite ControlNet 1.0’s remarkable success, the model suffered from a few rather unfortunate bugs. Here’s an example:

Illustration of a failure mode of ControlNet 1.0. Left: input image. Right: outputs with high ControlNet “weight”, leading to oversaturated colors.

While for most inputs, the model produced stunning, realistic images, in some cases, such as the scenario above, the model’s output was significantly oversaturated.

When ControlNet’s creator Lvmin Zhang published ControlNet 1.1, which resolved these issues, the changes were so substantial that he created an entirely new GitHub repository

Issue resolution in ControlNet 1.1. Left: same base image from the previous figure. Right: outputs when inputting the same prompt and metadata as in the oversaturated ControlNet 1.0 case above.

The craziest part: there were NO CHANGES to the model architecture. 

What changed? Data quality!

It turns out that the data used to train ControlNet 1.0 had a few insidious flaws, including a group of grayscale people that was somehow duplicated thousands of times. The ControlNet 1.1 repo explicitly mentions this and other problems.

The lesson: 

Data reigns supreme. State-of-the-art performance requires high quality data.

In this blog post, I’ll show you how to clean and curate high quality data so you can train your own state-of-the-art ControlNet model.

All of the code required to follow along and curate your own image-caption dataset can be found here

If you’re eager, you can jump straight to the highlights:

Setup

The only libraries we will need to clean and curate this data are pandas (for tabular data) and FiftyOne (for unstructured image data):

pip install pandas fiftyone

Additionally, you will need hashlib for helper functions, and you will probably want tqdm to track progress while downloading images.

You can import all of the required modules as follows:

import hashlib
import pandas as pd
from tqdm.notebook import tqdm

import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.brain as fob
from fiftyone import ViewField as F

Select the Dataset

According to the paper that introduced ControlNet, Adding Conditional Control to Text-to-Image Diffusion Models (CVPR 2023), the original ControlNet models were trained on “3M image-caption pairs from the internet”. 

Unfortunately, Lvmin et al. stop short of revealing precisely what data they use:

“Given the current complicated situation outside research community, we refrain from disclosing more details about data. Nevertheless, researchers may take a look at that dataset project everyone know.”

Lvmin Zhang

That being said, the information they do reveal lines up closely with Google’s Conceptual Captions Dataset: a dataset “consisting of ~3.3M images annotated with captions”. Regardless of whether this is the dataset the ControlNet team used to train their models, Conceptual Captions will provide us with an illustrative example, and the dataset — when properly cleaned — should allow for training ControlNet models from scratch.

Download the Dataset

Google’s proposed dataset download process is too cumbersome for my taste: first, you need to download a tab-separated variables (`.tsv`) file containing the captions and the urls where the corresponding images can be found, and then you need to download the images from their urls. Lucky for you, I’ve written this code so you don’t have to.

Download the tsv file by clicking the “Download” button at the bottom of Google’s Conceptual Captions webpage, or by clicking on this link.

We can load the tsv file as a pandas DataFrame in similar fashion to a csv, by passing in sep=\t to specify that the separator is a tab.

df = pd.read_csv("Train_GCC-training.tsv", sep='\t')

Give the columns of the DataFrame descriptive names:

df.columns =['caption', 'url']

And then hash the url for each entry to generate a unique ID:

def hash_url(url):

    return hashlib.md5(url.encode()).hexdigest()[:12]

df['url_hash'] = df['url'].apply(hash_url)

The DataFrame looks like this:

	caption	                                                url	                                                url_hash
0	sierra looked stunning in this top and this sk...	http://78.media.tumblr.com/3b133294bdc7c7784b7...	e7023a8dfcd2
1	young confused girl standing in front of a war...	https://media.gettyimages.com/photos/young-con...	92679c323fc6
2	interior design of modern living room with fir...	https://thumb1.shutterstock.com/display_pic_wi...	74c4fa5539f4
3	cybernetic scene isolated on white background .	        https://thumb1.shutterstock.com/display_pic_wi...	f1ea388e05e1
4	gangsta rap artist attends sports team vs play...	https://media.gettyimages.com/photos/jayz-atte...	9a6f8026f593
...	...	                                                ...	                                                ...
3318327	the teams line up for a photo after kick - off	        https://i0.wp.com/i.dailymail.co.uk/i/pix/2015...	6aec77a477f9
3318328	stickers given to delegates at the convention .	        http://cdn.radioiowa.com/wp-content/uploads/20...	7d42aea90652
3318329	this is my very favourite design that i recent...	https://i.pinimg.com/736x/96/f0/77/96f07728efe...	f6dd151121c0
3318330	man driving a car through the mountains	                https://www.quickenloans.com/blog/wp-content/u...	ee4244df5c55
3318331	a longtail boat with a flag goes by spectacula...	http://l7.alamy.com/zooms/338c4740f7b2480dbb72...	7625946297b7
3318332 rows × 3 columns

We will use these IDs to specify the download locations (filepaths) of images, so that we can associate captions to the corresponding images.

If we want to download the images in batches, we can do so as follows:

def download_batch(df, batch_size=10000, start_index=0):
    batch = df.iloc[start_index:start_index+batch_size]
    for j in tqdm(range(batch_size)):
        url, uh = batch.iloc[j][['url', 'url_hash']]
        !curl -s --connect-timeout 3 --max-time 3 "{url}" -o images/{uh}.jpg

Here we download batch_size images starting from start_index into the folder images, with filename specified by the url hash we generated above. We use curl to execute the download operation, and set limits for the time spent attempting to download each image, because some of the links are no longer valid. 

To download a total of num_images images, run the following:

def download_images(df, batch_size=10000, num_images = 100000):
    for i in range(num_images//batch_size):
        download_batch(df, batch_size=batch_size, start_index=i*batch_size)

Load and Visualize the Data

Once we have the images downloaded into a images folder, we can load the images and their captions as a Dataset in FiftyOne:

dataset = fo.Dataset(name="gcc", persistent=True)
dataset.add_sample_field("caption", fo.StringField)

samples = []

for i in tqdm(range(num_images)):
    caption, uh = df.iloc[i]['caption'], df.iloc[i]['url_hash']
    filepath = f"images/{uh}.jpg"
    sample = fo.Sample(
        filepath=filepath,
        caption=caption
        )
    samples.append(sample)
dataset.add_samples(samples)

This code creates a Dataset named “gcc”, which is persisted to the underlying database, and then iterates through the first num_images rows of the pandas DataFrame, creating a Sample with the appropriate filepath and caption.

For this walkthrough, I downloaded the first roughly 310,000 images.

The first step we should take when inspecting a new computer vision dataset is to visualize it! We can do this by launching the FiftyOne App:

session = fo.launch_app(dataset)
All 310,000+ images scraped from Google’s Conceptual Captions Dataset, visualized in the FiftyOne App.

Remove Corrupted Samples

When we look at the data, we can immediately see that some of the images are not valid. This may be due to links which are no longer working, interruptions during downloading, or some other issue entirely.

Fortunately, we can filter out these invalid images easily. In FiftyOne, the compute_metadata() method computes media-type-specific metadata for each sample. For image-based samples, this includes image width, height, and size in bytes. 

When the media file is nonexistent or corrupted, the metadata will be left as null. We can thus filter out the corrupted images by running compute_metadata() and matching for samples where the metadata exists:

dataset.compute_metadata()

## view containing only valid images
view = dataset.exists("metadata")

session = fo.launch_app(view)
DatasetView containing just the uncorrupted images and their metadata.

Filter by Aspect Ratio

A next step we may want to take is filtering out samples with unusual aspect ratios. If our goal is to control the outputs of a diffusion model, we will likely only be working with images within a certain range of reasonable aspect ratios.

We can do this using FiftyOne’s ViewField, which allows us to apply arbitrary expressions to attributes of our samples, and then filter based on these. For instance, if we want to discard all images that are more than twice as large in either dimension as they are in the other dimension, we can do so with the following code:

from fiftyone import ViewField as F

long_filter = F("metadata.width") > 2*F("metadata.height")
tall_filter = F("metadata.height") > 2*F("metadata.width")
aspect_ratio_filter = (~long_filter) & (~tall_filter)

view = valid_image_view.match(aspect_ratio_filter)

For the sake of clarity, this is what the discarded samples look like:

bad_aspect_view = valid_image_view.match(~aspect_ratio_filter)

session = fo.launch_app(bad_aspect_view)
View containing images with atypical aspect ratios, which we remove from the training data.

If you so choose, you can use a more or less stringent aspect ratio filter!

Filter by Resolution

In a similar vein, we might want to remove the low resolution images. We want to generate stunning, photorealistic images, so there is no sense including low resolution images in the training data.

This filter is similar to the aspect ratio filter. If we select 300 pixels as our lowest allowed width and height, the filter takes the form:

hires_filter = (F("metadata.width") > 300) & (F("metadata.height") > 300)
view = good_aspect_view.match(hires_filter)

Once again, you can choose whatever thresholds you like. For clarity, here is a representative view of the discarded images:

lowres_view = good_aspect_view.match(~hires_filter)
session = fo.launch_app(lowres_view)
View containing small images and images with low resolution, which are removed from the training data.

Ensure Color Pallette

Looking at the low resolution images, we also might be reminded that some of the images in our dataset are greyscale. We likely want to generate images that are as vibrant as possible, so we should discard the black-and-white images.

In FiftyOne, one of the attributes logged in image metadata is the number of channels: color images have three channels (RGB), whereas grayscale images only have one channel. Removing grayscale images is as simple as matching for images with three channels!

## color images to keep
view = view.match(F("metadata.num_channels") == 3)
## gray images to discard
gray_view = view.match(F("metadata.num_channels") == 1)
session = fo.launch_app(gray_view)
DatasetView consisting of grayscale images, which are subsequently removed from the training data.

Deduplicate the Dataset

Our next task in our data curation quest is to remove duplicate images. When an image is exactly or approximately duplicated in a training dataset, the resulting model may be biased by this small set of overrepresented samples – not to mention the added training costs.

We can find approximate duplicates in our dataset by using a model to generate embeddings for our images (we will use a CLIP model for illustration):

## load CLIP model from the FiftyOne Model Zoo
model = foz.load_zoo_model("clip-vit-base32-torch")
## Compute embeddings and store them in embeddings_field
view.compute_embeddings(
    model, 
    embeddings_field = "image_clip_embedding"
    )

Then we create a similarity index based on these embeddings:

results = fob.compute_similarity(view, embeddings="image_clip_embedding")

Finally, we can set a numerical threshold at which point we will consider images approximate duplicates (here we choose 0.3), and only retain one representative from each group of approximate duplicates:

results.find_duplicates(thresh=0.3)

# view the duplicates, paired up
dup_view = results.duplicates_view()
session = fo.launch_app(dup_view, auto = False)
View containing exact and approximate duplicates in our dataset. To deduplicate the data, we take one representative image from each group of near duplicates, as well as all of the highly unique images.

Validate Image-Caption Alignment

Okay, now you’re in luck, because we saved the coolest step for last!

Google’s Conceptual Captions Dataset consists of image-caption pairs from the internet. More precisely, “the raw descriptions are harvested from the Alt-text HTML attribute associated with web images”. This is great as an initial pass, but there are bound to be some low-quality captions in there.

We may not be able to ensure that all of our captions perfectly describe their images, but we can certainly filter out some poorly aligned image-captions pairs!

We will do so using CLIPScore, which is a “reference-free evaluation metric for image captioning”. In other words, you just need the image and the caption. CLIPScore is easy to implement. First, we use Scipy’s cosine distance method to define a cosine similarity function:

from scipy.spatial.distance import cosine as cosine_distance

def cosine(vector1, vector2):
    return 1. - cosine_distance(vector1, vector2)

Then we define a function which takes in a Sample, and computes the CLIPScore between image embedding and caption embedding, stored on the samples:

def compute_clip_score(sample):
    image_embedding = sample["image_clip_embedding"]
    caption_embedding = sample["caption_clip_embedding"]
    return max(100.*cosine(image_embedding, caption_embedding), 0.)

Essentially, this expression just lower bounds the score at zero. The scaling factor 100 is the same as used by PyTorch.

We can then compute the CLIPScore – our measure of alignment between images and captions – by adding the fields to our dataset and iterating over our samples:

dataset.add_sample_field("caption_clip_embedding", fo.VectorField)
dataset.add_sample_field("clip_score", fo.FloatField)

for sample in view.iter_samples(autosave=True, progress=True):
    sample["caption_clip_embedding"] = model.embed_prompt(sample["caption"])
    sample["clip_score"] = compute_clip_score(sample)
view.save()

If we want to see the “least aligned” samples, we can sort by “clip_score”. 

## 100 least aligned samples
least_aligned_view = view.sort_by("clip_score")[:100]
DatasetView displaying samples with the lowest image-caption alignment. Captions are displayed on the images.

To see the most aligned samples, we can do the same, but passing in reverse=True:

## 100 most aligned samples
most_aligned_view = view.sort_by("clip_score", reverse=True)[:100]
DatasetView displaying samples with the highest image-caption alignment. Captions are displayed on the images.

We can then set a CLIPScore threshold depending on how aligned we demand the image-caption pairs are. To my taste, a threshold of 21.8 seemed good enough:

view = view.match(F("clip_score") > 21.8)
gcc_clean = view.clone(name = "gcc_clean", persistent=True)

The second line clones the view into a new persistent Dataset named “gcc_clean”.

Conclusion

After our data cleaning and curation, we have turned a relatively mediocre initial dataset of more than 310,000 samples into a high quality dataset with 83,181 samples. The fruits of our labor look like this:

Final view displaying samples in cleaned and curated selection from the Google Conceptual Captions Dataset.

We surely haven’t created a perfect dataset — a perfect dataset does not exist. What we have done is addressed all of the data quality issues that plagued ControlNet 1.0, plus a few more, just for good measure.

Now you are ready to train your own state-of-the-art ControlNet model! 

Note: this post is adapted from a flash session that I presented at CVPR last week!

What’s Next?

If you enjoyed this blog post, you may also find the following blog posts interesting: