Digit Recognition with LatticaAI Demo Tutorial

Digit Recognition flow

Overview of the Model

Our Digit recognition model is trained on the MNIST dataset. This dataset is a collection of grayscale images of handwritten digits (0–9), each 28×28 pixels in size. We added some preprocessing and data augmentations to the training data, for better performance on real world sketches of handwritten digits.

The model architecture is FCNN (fully-connected neural network):

  1. Input Layer: flattens the 28x28 image into a 784-dimensional vector.

  2. Hidden Layer: a fully connected layer with 50 neurons and square activation.

  3. Output Layer: a fully connected layer with 10 neurons (one for each digit) and a softmax activation.

The reason we chose to use square activation instead of the popular ReLU activation, is that homomorphic operation are better suited for polynomial operators, and square is the simplest and lowest degree non-linear operator that we can use as layer activation.

model architecture

Here is a sample code for inferring digit from an image using the trained model:

import torch
import matplotlib.pyplot as plt


def inference(l1_weight, l1_bias, l2_weight, l2_bias, x):
    x = x.flatten()
    x = l1_weight @ x + l1_bias
    x = x ** 2
    x = l2_weight @ x + l2_bias
    return torch.nn.functional.softmax(x, dim=0)


# load model weights
model_dict = torch.load("digits_recognizer.pth", map_location="cpu")

# digit inference
img = plt.imread("digit.png")[..., 0]
pt = torch.tensor(img)

res = inference(
    model_dict["fc1.weight"], model_dict["fc1.bias"],
    model_dict["fc2.weight"], model_dict["fc2.bias"],
    pt,
)

plt.figure()
plt.imshow(img)
plt.title(f"Prediction: {res.argmax().item()}")
plt.axis("off")

plt.show()

Achieving Full Privacy with LatticaAI

First install our client package

import torch
import matplotlib.pyplot as plt

from lattica_query.auth import get_demo_token
from lattica_query.lattica_query_client import QueryClient


model_id = "sketchToNumber"
my_token = get_demo_token(model_id)

client = QueryClient(my_token)

context, secret_key, client_blocks, = client.generate_key()

# SECURE digit inference
img = plt.imread("digit.png")[..., 0]
pt = torch.tensor(img)

# `res` is torch.Tensor, same as in the plain example above
res = client.run_query(context, secret_key, pt, client_blocks)

# Display the image and prediction as above...

See our step-by-step guide for a detailed explanation of each step in this flow. To use the image sharpening model use the sketchToNumber model ID

Last updated

Was this helpful?