Health Analysis with LatticaAI Demo Tutorial

Overview of the Model

Our Health Analysis model is trained on the Disease Prediction Kaggle dataset. This dataset is designed to facilitate the application of machine learning to the medical field, aiding physicians by automating disease diagnosis based on symptoms.

The dataset consists of 131 binary columns representing different symptoms that a person may experience, and maps symptoms to 41 different diseases, allowing classification based on input symptoms.

We trained multi-class logistic regression model, implemented the equivalent fully homomorphic model and deployed it to our cloud service for secure inference.

  • Input Format: binary vector of length 131.

  • Output: Probability vector of length 41 that represents the possible diseases.

The equivalent pytorch code for the inference is:

List of 131 possible symptoms
List of 41 disease
import torch
import json
import numpy as np


# load the symptoms and diseases mappings
with open('symptoms.json', 'r') as f:
    symptoms = json.load(f)

with open('diseases.json', 'r') as f:
    diseases = json.load(f)

# load the coef matrix and intercept vector of the trained logistic regression model
W = np.load('coef.npy')  # shape (41,131)
b = np.load('intercept.npy')  # shape (41,)

# create binary vector of length 131, corresponding to the mapping symptoms.json
pt = torch.zeros(131, dtype=torch.float64)
pt[symptoms['Continuous Sneezing']] = 1
pt[symptoms['Shivering']] = 1
pt[symptoms['Chills']] = 1
pt[symptoms['Watering From Eyes']] = 1

# calculate the logistic regression prediction
res = torch.nn.functional.softmax(pt @ W.T + b, dim=0)

# the predicted disease according to the mapping diseases.json
disease = diseases[res.argmax().item()]
print(f"Predicted disease: {disease}")

Achieving Full Privacy with LatticaAI

First install our client package

from lattica_query.auth import get_demo_token
from lattica_query.lattica_query_client import QueryClient

model_id = "healthPrediction"
my_token = get_demo_token(model_id)

client = QueryClient(my_token)

context, secret_key, client_blocks, = client.generate_key()

# `pt` and `res` are torch.Tensor, same as in the plain example above
pt = torch.randint(0, 2, (131,), dtype=torch.float64)

res = client.run_query(context, secret_key, pt, client_blocks)

See our step-by-step guide for a detailed explanation of each step in this flow. To use the image sharpening model use the healthPrediction model ID

Last updated

Was this helpful?