Image Sharpening with LatticaAI Demo Tutorial

Overview of the Model

Our Image Sharpening model enhances the clarity and detail of an input image by applying a specialized 2D filter through convolution.

  • Input Format: RGB image tensor of shape (3, 200, 200), with pixel values in the [0,1] range.

  • Output: Sharpened image preserving original dimensions.

The equivalent pytorch code for the operator is:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


# Load image as numpy ndarray (ignore alpha channel if it exists)
np_img = plt.imread('house.png')[..., :3]  # shape format (H, W, C)
assert np_img.shape[-1] == 3, "Image must have 3 channels"

# Keep values in range [0, 255]
if np_img.max() <= 1:
  np_img *= 255

# Convert to PyTorch tensor and arrange dimensions as (C, H, W)
pt = torch.tensor(np_img, dtype=torch.float).permute(2, 0, 1)
# Resize image to expected input size
pt = F.interpolate(pt.unsqueeze(0), size=(200, 200), mode='bilinear').squeeze(0)  # shape (3, 200, 200)

# Define sharpening kernel
sharpen_kernel = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float)

# Apply convolution for each channel separately (3 for RGB)
res = F.conv2d(pt, sharpen_kernel.expand(3, 1, 3, 3), groups=3, padding=1)
# Clamp to range [0, 255]
res = torch.clamp(res, 0, 255)

# Display the original and sharpened images
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(pt.permute(1, 2, 0) / 255)
plt.axis("off")

plt.subplot(1, 2, 2)
plt.title("Sharpened Image")
plt.imshow(res.permute(1, 2, 0) / 255)
plt.axis("off")

plt.show()

Achieving Full Privacy with LatticaAI

This demo is the only one that uses RBGV encryption scheme (the others are using CKKS scheme). The reason for this is that the plaintext tensor (image pixels) can be accurately represented in integer values, and the convolution kernel has integer values (as opposed to most machine learning models where either the input and/or the weights are floating point numbers).

In order to convert this simple code to use homomorphic operations, all you need are the following few extra steps:

  1. Install Lattica python package and obtain a JWT token

  2. Generate encryption keys

  3. Replace the actual convolution with our function that will:

    1. preprocess the image and encrypt it

    2. send the encrypted data to the cloud for computation

    3. receive and decrypt the encrypted result using your private key

Everything else remains the same.

First install our client package

from lattica_query.auth import get_demo_token
from lattica_query.lattica_query_client import QueryClient

model_id = "imageEnhancement"
my_token = get_demo_token(model_id)

client = QueryClient(my_token)

context, secret_key, client_blocks, = client.generate_key()

# `pt` and `res` are torch.Tensor, same as in the plain example above
res = client.run_query(context, secret_key, pt, client_blocks)

# Display the original and sharpened images...

See our step-by-step guide for a detailed explanation of each step in this flow. To use the image sharpening model use the imageEnhancement model ID


Last updated

Was this helpful?