Our Digit recognition model is trained on the MNIST dataset. This dataset is a collection of grayscale images of handwritten digits (0–9), each 28×28 pixels in size. We added some preprocessing and data augmentations to the training data, for better performance on real world sketches of handwritten digits.
The model architecture is FCNN (fully-connected neural network):
Input Layer: flattens the 28x28 image into a 784-dimensional vector.
Hidden Layer: a fully connected layer with 50 neurons and square activation.
Output Layer: a fully connected layer with 10 neurons (one for each digit) and a softmax activation.
The reason we chose to use square activation instead of the popular ReLU activation, is that homomorphic operation are better suited for polynomial operators, and square is the simplest and lowest degree non-linear operator that we can use as layer activation.
model architecture
Here is a sample code for inferring digit from an image using the trained model:
import torch
import matplotlib.pyplot as plt
def inference(l1_weight, l1_bias, l2_weight, l2_bias, x):
x = x.flatten()
x = l1_weight @ x + l1_bias
x = x ** 2
x = l2_weight @ x + l2_bias
return torch.nn.functional.softmax(x, dim=0)
# load model weights
model_dict = torch.load("digits_recognizer.pth", map_location="cpu")
# digit inference
img = plt.imread("digit.png")[..., 0]
pt = torch.tensor(img)
res = inference(
model_dict["fc1.weight"], model_dict["fc1.bias"],
model_dict["fc2.weight"], model_dict["fc2.bias"],
pt,
)
plt.figure()
plt.imshow(img)
plt.title(f"Prediction: {res.argmax().item()}")
plt.axis("off")
plt.show()
import torch
import matplotlib.pyplot as plt
from lattica_query.auth import get_demo_token
from lattica_query.lattica_query_client import QueryClient
model_id = "sketchToNumber"
my_token = get_demo_token(model_id)
client = QueryClient(my_token)
context, secret_key, client_blocks, = client.generate_key()
# SECURE digit inference
img = plt.imread("digit.png")[..., 0]
pt = torch.tensor(img)
# `res` is torch.Tensor, same as in the plain example above
res = client.run_query(context, secret_key, pt, client_blocks)
# Display the image and prediction as above...
import { getDemoToken, LatticaQueryClient } from '@Lattica-ai/lattica-query-client';
const modelId = "sketchToNumber"
const token = await getDemoToken(modelId);
const client = new LatticaQueryClient(myToken);
await client.init();
// pt is an mnist image in the form of number[], as in the python example above
const result = await client.runQuery(pt);