How Did I Make My Doodle Predicting Server

This is a documentation about the server part of my Thesis project, a game played by doodling. In the game, you can create everything in your mind, via doodling. It’s like Scribblenauts, but instead of typing the name of an item, you draw it.

In the beginning, I thought Google could provide an API for people to classify doodles, but after Googling around, I found out they don’t. However, the good news is that they open sources the whole database and have a tutorial about how to build your classifier.

So, I made my classifier, and here is how.

My source code: Github

An online client end to test the server : Here
I’m not a specialist in Machine Learning, so I borrowed some of codes and model from other people. I’m listing all my references later.

Step 0, planing.

  • I need a model running somewhere, which can take an image or strokes of a doodle as input, the output the prediction.
  • I guess this is hard to be done inside Unity for me since the mainstream ML solutions are based on Python.
  • So this will probably live in a Linux server running in the cloud. Then I need to wrap it into an API and request it in Unity.

Step 1, Tensorflow and model.

Luckily, Kaggle started a competition of training Quick Draw models, and a lot of best programmers participated and shared their solutions, along with their code and models. Then I read the notebook of the No.1 solution and stole his model.

Reference: His Kaggle notebook

His solution:

  1. Load Keras and the pre-trained Mobilenet model.
  2. Read the dataset, and draw those doodles by drawing strokes on canvas by OpenCV.
  3. Re-train Keras’s Mobilenet by those doodles to get a new weight.
  4. Some functions to test and visualize the performance of his new model.

To quick test whether his solution works for me or not, I first downloaded his weight file model.h5, and copy/pasted his functions for preparing pictures.

Note: his model was trained on pictures with a black background and white strokes.

My environment: Python 3.5.5, Tensorflow 1.13.1.

1. Now, create a

As a Python tradition, import a hell load of things into this script, though you don’t need some of them.

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import os
import json
import datetime as dt
import cv2
import base64
import io
from PIL import Image
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation
from tensorflow.keras.metrics import categorical_accuracy, top_k_categorical_accuracy, categorical_crossentropy
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.mobilenet import preprocess_input
from tensorflow.keras import backend as K

Core libraries:
1. numpy and panda, for preparing raw data.
2. cv2 AKA OpenCV, for preparing pictures.
3. tensorflow and keras.

2. Now, some variables.

No need to understand them now, let’s talk about them when we use them.

NCSVS = 100
NCATS = 340

def top_3_accuracy(y_true, y_pred):
    return top_k_categorical_accuracy(y_true, y_pred, k=3)

STEPS = 800
size = 64
batchsize = 680

3. Load and compile the model and weight, and create our session and graph

Let’s put the model.h5 inside ./model/. This model.h5 is not the model, is only a weight of Mobilenet, we need to load it after we load Mobilenet.

def init(): sess = tf.InteractiveSession() loaded_model = MobileNet(input_shape=(size, size, 1), alpha=1., weights=None, classes=NCATS) loaded_model.load_weights("./model/model.h5") loaded_model.compile(optimizer=Adam(lr=0.002), loss='categorical_crossentropy', metrics=[categorical_crossentropy, categorical_accuracy, top_3_accuracy]) print(loaded_model.summary()) graph = tf.get_default_graph() return loaded_model, sess, graph global model, sess, graph model, sess, graph = init()

4. A function to prepare images into the desired format.

The input of this model must be an image that is 64*64, and with a black background and white strokes.

def prepareImage(im):
    im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    thresh = 127
    im = cv2.threshold(im, thresh, 255, cv2.THRESH_BINARY)[1]
    # see if need to invert
    n_white_pix = np.sum(im == 255)
    n_black_pix = np.sum(im == 0)
    if n_white_pix > n_black_pix:
        im = cv2.bitwise_not(im)
    #trim1, move content to the left-up corner;
    size = len(im[0])
    sum0 = im.sum(axis = 0)
    sum1 = im.sum(axis = 1)
    for i in range(len(sum0)):
        if sum0[i] == 0:
            im = np.delete(im, 0, 1)
            zero = np.zeros((size,1))
            im = np.append(im,zero,1)
        else :
    for i in range(len(sum1)):
        if sum1[i] == 0:
            im = np.delete(im, 0, 0)
            zero = np.zeros((1,size))
            im = np.append(im,zero,0)
        else :
    # trim2 crop content
    sum3 = im.sum(axis = 0)
    sum4 = im.sum(axis = 1)
    x2 = 1
    y2 = 1
    while x2 < len(sum3) and sum3[-x2] ==0:
        x2 += 1
    while y2 < len(sum4) and sum4[-y2] ==0:
        y2 += 1
    w = size - x2
    h = size - y2
    contentSize = w if w > h  else  h
    # only crop if there is realy content
    if contentSize > 16:
        im = im[0:contentSize, 0:contentSize]
    return im

What this function do is:

  • Turn whatever image into grayscale.
  • Then thresholed it into black and white.
  • Check the amount of white pixels and black pixels, the one have higher amount is the background color.
  • If the background is white, invert it. Else do nothing.
  • Move content of image to the upper left corner.
  • Cut spare background.

I’m going to rescale the image into 64 by 64 later because I wanna store the high-resolution raw drawing first and show them on my website.

5. Create a function to make the prediction.

def prepareImageAndPredict(model, cv2ImageData,size=64):
        # downsize to 64
        image64 = cv2.resize(cv2ImageData, (64, 64))
        x = np.zeros((1,size, size, 1))
        x[0, :, :, 0] = image64
        x = preprocess_input(x).astype(np.float32)
        prediction = model.predict(x, batch_size=128, verbose=1)
        top5 = np.argsort(-prediction, axis=1)[:, :5]
        return top5[0]

    except Exception as e:

Inside this function:

  • Resize the image to 64*64
  • Create a 4d array to store our picture. The reason why it’s not a 3d array is that this model was designed to take multiple images. Though we only have 1 image, we need to wrap it in another array to make the model work.
  • Call model.predict to get the prediction.
  • Get the indexes of the top 5 confidence.
  • Again, since we only have one image, return top5[0]

6. Let’s do a quick test.

imagePath = "./whateverDoodle.jpg"
image = cv2.imread(imagePath)
image = prepareImage(image)

with sess.as_default():
        with graph.as_default():
            prediction= prepareImageAndPredict(model, image).tolist()


After running it, the script will print 5 numbers, which are the indexes of this list:

categories = ['airplane', 'alarm clock', 'ambulance', 'angel', 'animal migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball bat', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday cake', 'blackberry', 'blueberry', 'book', 'boomerang', 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling fan', 'cell phone', 'cello', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise ship', 'cup', 'diamond', 'dishwasher', 'diving board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip flops', 'floor lamp', 'flower', 'flying saucer', 'foot', 'fork', 'frog', 'frying pan', 'garden', 'garden hose', 'giraffe', 'goatee', 'golf club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey puck', 'hockey stick', 'horse', 'hospital', 'hot air balloon', 'hot dog', 'hot tub', 'hourglass', 'house', 'house plant', 'hurricane', 'ice cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light bulb', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paint can', 'paintbrush', 'palm tree', 'panda', 'pants', 'paper clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup truck', 'picture frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote control', 'rhinoceros', 'river', 'roller coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 'saxophone', 'school bus', 'scissors', 'scorpion', 'screwdriver', 'sea turtle', 'see saw', 'shark', 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping bag', 'smiley face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop sign', 'stove', 'strawberry', 'streetlight', 'string bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing set', 'sword', 't-shirt', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis racquet', 'tent', 'The Eiffel Tower', 'The Great Wall of China', 'The Mona Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing machine', 'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine bottle', 'wine glass', 'wristwatch', 'yoga', 'zebra', 'zigzag']

for i in range(5):

Now we should see a bunch of names!

Step 2, now the model is working, let’s create an API server.

Flask is an ideal framework for this job.

1. Create a and import a shit load of libraries.

from flask import Flask , jsonify, request, render_template, send_from_directory
from flask_cors import CORS
from predictor import *
import random
import json
import pandas as pd
import numpy as np
from tensorflow.keras import models
import time
import datetime
import cv2
import sys, getopt
import os
import base64
import io
from PIL import Image

2. Create a server.

At this moment let’s not consider https.

app = Flask(__name__)

CORS allows ajax request from other domains.

3. Create the first route.

def stringToRGB(base64_string):
    imgdata = base64.b64decode(str(base64_string))
    image =
    return cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)

app.route("/api/doodlePredict", methods=["POST"])
def predictAPI():
    global model, graph
    print("this is the request: ", request.form.to_dict())
    image_raw = request.form.to_dict()["data"]
    image_raw = stringToRGB(image_raw)
    image = prepareImage(image_raw)
    response = {'prediction':{
    with sess.as_default():
        with graph.as_default():
            response['prediction']['numbers'] = prepareImageAndPredict(model, image).tolist()
    for i in range(len(response['prediction']['numbers'])):
    print("this is the response: ", response['prediction']['names'])
    cv2.imwrite("./doodleHistory/"+ ', '.join(response['prediction']['names']) +", ""%I:%M%p on %B %d, %Y") +".jpg", image_raw)
    return jsonify(response)

This route assumes that the image was sent to this server in the format of base64, with a key data inside a form and via POST request.

So it:
– Read the form,
– Convert it into a dict,
– Get data of the dict,
– Convert the base64 string into an RGB image.

– Call prepareImage() to make sure the image fits the requirement of the model.
– Use the session and graph we created, call model.predict() to get the prediction.
– Return the prediction in the format of JSON.

4. Lastly, to start the server.

if __name__ == "__main__": = "", port = 5800, debug = True)

The host must be in order to request this API out of your local environment.

Step 3, create a demo to test this API

Let’s do it in the online editor of

– Create a p5 canvas
– Create an empty 2d array to store all the stokes, which are 1d arrays.
– From stroke 0, when the mouse is pressed, push the mouse position into the current stroke array.
– When the mouse is released, create a new empty stroke.
– Also, send the current canvas to our API.
– After getting a response, print the prediction on the page.

Source code

Step 4, more

After his, I created a bunch of new APIs to serve my game, including an API showing all the doodles people have drawn, an API returning a Unity sprite to help spawn items in my game, and more.

However, I fond that this documentation can also be a beginner tutorial of How to remake a Doodle Quick Draw, so I decide not to make this article more complex.

Enjoy, I’m gonna document everything I have done for this game in the future in other articles.

Leave a Reply

Your email address will not be published. Required fields are marked *