Hands-on guide to Python Optimal Transport toolbox: Part 2

Color transfer, Image editing and Automatic translation

Ievgen Redko
Towards Data Science

--

As a follow-up of my previous introductory article on optimal transport and a first part of this guide provided by Aurelie Boisbunon here, I will present below how you can solve different tasks with Optimal Transport (OT) in practice using the Python Optimal Transport (POT) toolbox.

To start with, let us install POT using pip from the terminal by simply running

pip3 install ot

And voilà! If everything went well, you now have POT installed and ready to use on your computer. Let me now explain how you can reproduce the results from my previous article.

Color transfer

In this application our goal is to transfer the color style of one image onto another image in the smoothest way possible. To do this, we will follow the example from the official webpage of the POT library and start by defining several supplementary functions needed when working with images:

import numpy as np
import matplotlib.pylab as pl
import ot


r = np.random.RandomState(42)


def im2mat(img):
"""Converts an image to a matrix (one pixel per line)"""
return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))


def mat2im(X, shape):
"""Converts a matrix back to an image"""
return X.reshape(shape)

So, first three lines here are just imports for numpy, matplotlib.pylab and ot packages. Then, we have two functions that allow us to convert an image represented by a 3d matrix (some people call them tensors) where the first dimension is the height of the image, the second one is its width, while the third is given by RGB coordinates of the pixels. Let’s now load some images to see what it means.

I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64)/256
I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64)/256

Here they are, the daytime ocean image and the sunset one provided directly in the POT toolbox. Note that originally all the pixel RGB coordinates are integers so that astype(np.float64) converts them to floats. Then, each value is divided by 256 (the maximum value of each pixel’s coordinate) to normalize the data to lie in [0,1] interval. If we check their dimensions, we get the following

print(I1[0,0,:])
[0.0234375 0.2421875 0.53125]

This means that the first pixel in the bottom left corner has RGB coordinates given by the vector [R = 0.0234375, G= 0.2421875, B = 0.53125] (blue color dominates as expected for the daytime image). We now convert our tensors to a 2d matrix where each line is a pixel described by its RGB coordinates as follows:

day = im2mat(I1)
sunset = im2mat(I2)

Note that these matrices are rather large as can be seen by running the following code:

print(day.shape)
(669000, 3)

Let’s sample 1000 pixels randomly from each image to reduce the size of the matrices that we will apply OT to. We can do it as follows:

nb = 1000
idx1 = r.randint(day.shape[0], size=(nb,))
idx2 = r.randint(sunset.shape[0], size=(nb,))

Xs = day[idx1, :]
Xt = sunset[idx2, :]

We now have two matrices with only 1000 rows and 3 columns in each of them. Let’s plot them in the RB (red-blue) plane to see the pixels of what color we actually sampled:

plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
#plt.axis([0, 1, 0, 1])
plt.xlabel('Red')
plt.ylabel('Blue')
plt.xticks([])
plt.yticks([])
plt.title('Day')

plt.subplot(1, 2, 2)

plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
#plt.axis([0, 1, 0, 1])
plt.xlabel('Red')
plt.ylabel('Blue')
plt.title('Sunset')
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.show()

The result will look like this:

Everything seems to be set up to finally run our OT algorithm on them. To this end, we create an instance of the Monge-Kantorovich problem class and fit it on our images:

ot_emd = ot.da.EMDTransport()
ot_emd.fit(Xs=Xs, Xt=Xt)

Note that we create an instance of ot.da.EMDTransport() class that provides features for doing domain adaptation with OT and defines automatically the uniform empirical distributions (each pixel is a point having a probability 1/1000) and the cost matrix (squared Euclidean distance between vectors of pixel coordinates) when we call its fit() method. We can now “transport” one image onto another one using the coupling matrix as follows:

transp_Xt_emd = ot_emd.inverse_transform(Xt=sunset)

The function inverse_transform() we’ve just called transports the sunset image to the daytime one using a barycentric mapping: each transported pixel in the final result is an average of pixels from the sunset image weighted by the corresponding values of the coupling matrix. You can do the same thing the other way around by calling transform(Xs=day) too. We now plot the final result as follows:

I2t = mat2im(transp_Xt_emd, I2.shape)

plt.figure()
plt.imshow(I2t)
plt.axis('off')
plt.title('Color transfer')
plt.tight_layout()
plt.show()

And it gives the desired result:

Image editing

We now want to do a seamless copy that consists in editing an image by replacing a part of it using a patch of another image. For instance, this can be you face transported onto the fact of the Mona Lisa painting. To proceed, we will first need to download the poissonblending.py file from this github repository. Then, we will load three images from the data folder (you need to put them there beforehand) as follows:

import matplotlib.pyplot as plt
from poissonblending import blend


img_mask = plt.imread('./data/me_mask_copy.png')
img_mask = img_mask[:,:,:3] # remove alpha

img_source = plt.imread('./data/me_blend_copy.jpg')
img_source = img_source[:,:,:3] # remove alpha

img_target = plt.imread('./data/target.png')
img_target = img_target[:,:,:3] # remove alpha

First image is my portrait, second image provides the area of my portrait that will be copied into the Mona Lisa’s face. The pre-processing also removes the transparency and keeps only RGB values for each pixel. Overall, they will look as follows:

You can adjust the mask using any image editor with simple geometrical objects.

The final result can then be obtained using the blend() function called as follows:

nbsample = 500
off = (35,-15)
seamless_copy = blend(img_target, img_source, img_mask, reg=5, eta=1, nbsubsample=nbsample, offset=off, adapt='kernel')

Once again, we apply OT only to a subset of 500 pixels as doing it for the whole image will take some time. The code behind this function involves many image pre-processing routines but what is of a special interest for us is the OT part. This is represented by the adapt_Gradients_kernel() from poissonblending.py that contains the following code:

Xs, Xt = subsample(G_src,G_tgt,nb)

ot_mapping=ot.da.MappingTransport(mu=mu,eta=eta,bias=bias, max_inner_iter = 10,verbose=True, inner_tol=1e-06)
ot_mapping.fit(Xs=Xs,Xt=Xt)
return ot_mapping.transform(Xs=G_src)

The first line here extracts two samples of 500 pixels from the gradients G_src,G_tgt. Then, ot.da.MappingTransport() function learns a non-linear (kernelized) transformation that approximates the barycentric mapping that we have used in the previous example. You may wonder why is that needed? Well, the barycentric mapping relies on the coupling matrix that aligns only the samples it was fitted on (it’s shape is number of samples from the first distribution * number of samples from the second one) and thus it cannot be used to out-of-sample points. Finally, the return uses this approximation just as before to transport the gradient of my face onto the gradient of the Mona Lisa’s portrait. The final result is then given by:

Rightmost image by Author.

Automatic translation

For this last application, our goal is to take find an optimal alignment between the words in two sentences given in the different languages. As an example, we will work with the English proposition “the cat sits on the mat’’ and its French translation “le chat est assis sur le tapis’’ with the goal of recovering a matching that provides the correspondences “cat”- “chat”, “sits”- “assis” and “mat”- “tapis”. For this, we will need the nltk library that can be installed via pip as follows:

pip3 install nltk

We will also need to clone the following github repository and to follow its readme in order to download the embeddings that will be used to describe our propositions (I also provide a shortcut to this here where you can find directly the embeddings for the considered pair).

Let us now do some usual imports and add two functions that will be used afterwards.

import numpy as np, sys, codecs
import ot
import nltk
nltk.download('stopwords') # download stopwords
nltk.download('punkt') # download punctuation

from nltk import word_tokenize
import matplotlib.pyplot as plt

def load_embeddings(path, dimension):
"""
Loads the embeddings from a file with word2vec format.
The word2vec format is one line per words and its associated embedding.
"""
f = codecs.open(path, encoding="utf8").read().splitlines()
vectors = {}
for i in f:
elems = i.split()
vectors[" ".join(elems[:-dimension])] = " ".join(elems[-dimension:])
return vectors

def clean(embeddings_dico, corpus, vectors, language, stops, instances = 10000):

clean_corpus, clean_vectors, keys = [], {}, []
words_we_want = set(embeddings_dico).difference(stops)
for key, doc in enumerate(corpus):
clean_doc = []
words = word_tokenize(doc)
for word in words:
word = word.lower()
if word in words_we_want:
clean_doc.append(word+"__%s"%language)
clean_vectors[word+"__%s"%language] = np.array(vectors[word].split()).astype(np.float)

if
len(clean_doc) > 5 :
keys.append(key)
clean_corpus.append(" ".join(clean_doc))
return clean_vectors

First function is used to load the embeddings, while the second pre-process the text to remove all the stopwords and punctuation.

To proceed, we now load the embeddings for English and French languages as follows:

vectors_en = load_embeddings("concept_net_1706.300.en", 300) 
vectors_fr = load_embeddings("concept_net_1706.300.fr", 300)

And define the two propositions to be translated:

en = ["the cat sits on the mat"]
fr = ["le chat est assis sur le tapis"]

Let us now clean our sentences as follows:

clean_en = clean(set(vectors_en.keys()), en, vectors_en, "en", set(nltk.corpus.stopwords.words("english")))clean_fr = clean(set(vectors_fr.keys()), fr, vectors_fr, "fr", set(nltk.corpus.stopwords.words("french")))

This returns only the embeddings of meaningful words “cat”, “sits”, “mat” and “chat”, “assis” and “tapis”. Everything is now set for optimal transport to be used. As shown in the image above, we define two empirical uniform distributions over the terms and run OT between them with a cost matrix given by pairwise square Euclidean distances.

emp_en = np.ones((len(en_emd),))/len(en_emd)
emp_fr = np.ones((len(fr_emd),))/len(fr_emd)
M = ot.dist(en_emd,fr_emd)

coupling = ot.emd(emp_en, emp_fr, M)

Once the coupling is obtained, we can now find a projection of our embeddings to 2d space with t-SNE and then plot the corresponding words and their matched pairs as follows:

np.random.seed(2) # fix the seed for visualization purpose

en_embedded = TSNE(n_components=3).fit_transform(en_emd)
fr_embedded = TSNE(n_components=3).fit_transform(fr_emd)

f, ax = plt.subplots()
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
plt.axis('off')
ax.scatter(en_embedded[:,0], en_embedded[:,1], c = 'blue', s = 50, marker = 's')

for i in range(len(en)):
ax.annotate(en[i], (en_embedded[i,0], en_embedded[i,1]), ha = 'right', va = 'bottom', fontsize = 30)

ax.scatter(fr_embedded[:,0], fr_embedded[:,1], c = 'red', s = 50, marker = 's')

for i in range(len(fr)):
ax.annotate(fr[i], (fr_embedded[i,0], fr_embedded[i,1]), va = 'top', fontsize = 30)

coupling /= np.max(coupling)

for i, j in enumerate(np.array(np.argmax(coupling, axis= 1)).flatten()):
ax.plot([en_embedded[i, 0], fr_embedded[j, 0]], [en_embedded[i, 1], fr_embedded[j, 1]], c='k')

plt.show()

This gives the final result:

Image by Author.

Note that the latter examples can be extended to do automatic translation based on the wikipidea data available on the github repository from which we took the initial code. For more details, you can also check the corresponding paper and the results therein to further boost the performance.

--

--