Predicting pneumonia in chest X-Rays
In this post, I’ll be training a model to predict pneumonia in chest X-Rays. This will be the first time that I’ll be using SageMaker. Let’s see how it goes!
A) Load the Data
We’ll install the necessary libraries and import those that are already locally available.
!pip3 install tqdm
!pip3 install s3fs
!pip3 install seaborn==0.9.0
!pip3 install imbalanced-learn%matplotlib inline
import numpy as np
import pandas as pd
import seaborn as sns
import os
import io
import shutil
import json
import random
import boto3
import s3fs
import sagemaker
from PIL import Image
from itertools import islice
import requests
from io import BytesIO
from imblearn.pipeline import Pipelinefrom tqdm import tqdm
from sagemaker.amazon.amazon_estimator import get_image_uri
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot as plt
from sagemaker.predictor import json_serializer, json_deserializer, Predictor""" Global variables """
role = sagemaker.get_execution_role()
s3_client = boto3.client('s3')
s3 = boto3.resource('s3')
fs = s3fs.S3FileSystem()OUTCOMES = ["NORMAL", "PNEUMONIA"]SESS = sagemaker.Session()
BUCKET = SESS.default_bucket()PREFIX = "flatiron-module-4"IMAGE_PREFIX = PREFIX + "/images"
LOAD_PATH = f"s3://{BUCKET}/{IMAGE_PREFIX}/"
Create annotations
To use SageMaker, we need to store the labels in a specific format. Ultimately, we will store this information in the RecordIO format, which is a ‘.lst’ file. As an intermediary step, however, we will create csv files with each image filepath and its corresponding outcome.
I’ve already manually loaded the image folders to another s3 bucket.
# https://www.coursera.org/learn/image-classification-sagemaker/ungradedLti/fsCCK/image-classification-with-sagemakerdef convert_dic_to_df(classify_dic):
df = pd.DataFrame({'image': [], 'value': []})
for outcome in classify_dic.keys():
for i, png in enumerate(classify_dic[outcome]):
image = f"{outcome}/{png.split('/')[-1]}"
df = df.append({
'image': image,
'value': 1 if outcome == "PNEUMONIA" else 0,
}, ignore_index=True)
return df
def get_s3_dir_contents(path):
all_files = s3.Bucket(BUCKET).objects.all()
return [i.key for i in all_files if i.key.startswith(path)]
def get_classifications_as_df(dataset_path):
classify_dic = {}
for outcome in OUTCOMES:
outcome_path = f"{dataset_path}/{outcome}"
classify_dic[outcome] = get_s3_dir_contents(outcome_path)
return convert_dic_to_df(classify_dic)
def create_annotations_by_type(dataset_type):
dataset_path = f"{IMAGE_PREFIX}/{dataset_type}"
df = get_classifications_as_df(dataset_path)
df.reset_index(inplace=True)
df.to_csv(f"annotations_{dataset_type}.csv", index=False)
return dfdef create_annotations():
for i in ["val", "train", "test"]:
create_annotations_by_type(i)create_annotations()
Visualize outcome distribution
We can see in the below that pneumonia outcomes are overrepresented in the train data. To adjust for this imbalance, we will try specifying different weights when we train the model.
pneu = cnts[1.0]
normal = cnts[0.0]plt.bar(OUTCOMES, [normal, pneu])
title = f"NORMAL: {normal}. PNEUMONIA: {pneu}."
plt.title(title)
plt.show()
Create .lst files
Now that our annotation files are created, we’ll use this data to write the .lst files to our s3 bucket. SageMaker will access the labels on the s3 bucket when fitting the model.
# https://mxnet.apache.org/versions/1.8.0/api/faq/recordio.html
# https://docs.aws.amazon.com/sagemaker/latest/dg/image-classification.htmldef create_lst_file_by_type(dataset_type):
df = pd.read_csv(f"annotations_{dataset_type}.csv")
text = ""
text_col = df.apply(lambda row: f"{int(row['index'])}\t{int(row['value'])}\t{row['image']}\n", axis=1)
text = "".join(text_col.values.tolist())
s3_client.put_object(Body=text,
Bucket=BUCKET,
Key=f"{IMAGE_PREFIX}/{dataset_type}_lst/{dataset_type}.lst")
with open(f"{dataset_type}.lst", "w") as f:
f.write(text)
f.close()
def create_lst_files():
for i in ["val", "train", "test"]:
create_lst_file_by_type(i)create_lst_files()
Visualize a sample of the training data
We can view a sample of our images here in the notebook to make sure that everything is correctly referenced. Using the local annotation csvs, we’ll extract the image file paths.
def get_annotations_as_dic(dataset_type):
dic = {}
df = pd.read_csv(f"annotations_{dataset_type}.csv")
for i, row in df.iterrows():
if not row['image'].endswith(".DS_Store"):
dic[row['image']] = int(row['value'])
return dicdef get_annotations_by_type(data_type, dev=False, n=10):
annotations = get_annotations_as_dic(data_type)
annotation_keys = list(annotations.keys())
random.shuffle(annotation_keys)
if dev:
downsample = {}
downsample_keys = annotation_keys[:n]
for k in downsample_keys:
downsample[k] = annotations[k]
return downsample
else:
return annotations
train_annotations = get_annotations_by_type("train")
val_annotations = get_annotations_by_type("val")
test_annotations = get_annotations_by_type("test")
Now let’s use those file paths to display a sample of training images.
def load_s3_png(path):
response = s3client.get_object(Bucket=BUCKET, Key=path)
img_data = response['Body'].read()
return Image.open(io.BytesIO(img_data))def viz_data_sample(data_type, annotations):
plt.figure(figsize=(8,8))
plt_i = 0
image_paths = list(annotations.keys())
random.shuffle(image_paths)
subset = image_paths[:8]for image in subset:
outcome = "PNEUMONIA" if annotations[image]==1 else "NORMAL"
path = f"{IMAGE_PREFIX}/{data_type}/{image}"
img = load_s3_png(path)
plt.subplot(2, 4, plt_i+1)
plt.imshow(img)
plt.xlabel(outcome)
plt_i += 1
plt.show()viz_data_sample("train", train_annotations)
B) Build and Deploy SageMaker Model
Now we can get started with the SageMaker configuration!
Define the data locations
We’ll define the bucket directories where the model should output the report results.
s3_train_path = LOAD_PATH + "train"
s3_train_lst_path = LOAD_PATH + "train_lst"s3_validation_path = LOAD_PATH + "val"
s3_validation_lst_path = LOAD_PATH + "val_lst"s3_test_path = LOAD_PATH + "test"
s3_test_lst_path = LOAD_PATH + "test_lst"output_location = f"s3://{BUCKET}/{PREFIX}/output"
print(f"Training artifacts will be uploaded to: {output_location}")
Build the model
Next, using the SageMaker image-classification container, we’ll define our model.
As hinted at above, I decided to try training the model with weighting on (‘use_weighted_loss=1’) to address the imbalance in outcome class frequency. Unfortunately, this only succeeded in ensuring that my model predicted every case as normal instead of every case as pneumonia.
training_image = get_image_uri('us-east-1', 'image-classification', repo_version='latest')model = sagemaker.estimator.Estimator(
training_image,
role=role,
train_instance_count=1,
train_volume_size=100,
train_max_run=36000,
input_mode="File",
train_instance_type = 'ml.p2.xlarge',
output_path = output_location,
use_weighted_loss=1,
sagemaker_session = SESS
)
Now we set the hyperparameters for our model.
# https://sunjackson.github.io/2018/09/13/84e31f8847e234adb7a5c3eadb253a14/model.set_hyperparameters(
num_layers=34,
image_shape='3,224,224',
num_classes=2,
mini_batch_size=16,
resize=224,
epochs=100,
learning_rate=0.01,
num_training_samples=len(list(train_annotations.keys())),
augmentation_type='crop_color_transform',
optimizer="adam",
use_pretrained_model=1,
)
Specify the data locations
First, we need to define the paths that our model will need to use to locate the correct s3 directories.
def get_from_s3(path):
return sagemaker.session.s3_input(
path,
distribution="FullyReplicated",
content_type="application/x-image",
s3_data_type="S3Prefix"
)train_data = get_from_s3(s3_train_path)
train_lst_data = get_from_s3(s3_train_lst_path)validation_data = get_from_s3(s3_validation_path)
validation_lst_data = get_from_s3(s3_validation_lst_path)data_channels = {
'train': train_data,
'train_lst': train_lst_data,
'validation': validation_data,
'validation_lst': validation_lst_data,
}
Fit the model
We’ll fit the model.
model.fit(inputs=data_channels, logs=True)
Deploy the model
And now we get to deploy our model — We’ll save the endpoint to reference it and delete the related resources later when we clean up. This will help us avoid AWS charges for lingering, unused resources.
predictor = model.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")
endpoint_name = predictor.endpoint_name
print("\nEndpoint name: ", endpoint_name)
C) Predict Pneumonia Cases
Finally! It’s time to find out how accurately our model classifies chest X-rays.
Get test images
We’ll load the test image s3 paths from the annotations file we created earlier.
# endpoint_name = "image-classification-2021-06-22-01-42-51-569"predictor = Predictor(endpoint_name=endpoint_name, sagemaker_session=SESS, accept='application/json')image_names = pd.read_csv("annotations_test.csv")["image"].values.tolist()
image_paths = [f"{IMAGE_PREFIX}/test/{image}" for image in image_names]
Calculate model accuracy
Now we’ll use our model to classify each image and compare the predicted outcome with the actual outcome.
# https://github.com/aws/amazon-sagemaker-examples/issues/1885
PREDS = []
ACTUALS = []def get_prediction_accuracy(deployed, image_path):
byte_str = s3client.get_object(Bucket=BUCKET, Key=image_path)['Body'].read()
b = bytearray(byte_str)result = deployed.predict(b, initial_args={'ContentType': 'image/jpeg'})
result = json.loads(result)
prediction = OUTCOMES[np.argmax(result)]
PREDS.append(prediction)
actual_outcome = image_path.split("/")[-2]
ACTUALS.append(actual_outcome)
print(actual_outcome, prediction)
return actual_outcome == predictionresults = [get_prediction_accuracy(predictor, path) for path in image_paths]
accuracy = sum(results) / len(results)
acc_msg = f"Accuracy rate: {round(accuracy * 100, 2)}%"
print(acc_msg)
Our confusion matrix sheds more light on our model’s predictive power.
# https://medium.com/@dtuk81/confusion-matrix-visualization-fc31e3f30feaconf = confusion_matrix(ACTUALS, PREDS, labels=OUTCOMES)
sns.heatmap(conf/np.sum(conf), annot=True,fmt='.0%', cmap='Blues')
Evaluate model accuracy
In a first read-out, without weighting, we reached an overall accuracy rate of 36.38%.
In a second read-out, with weighting, we reached an overall accuracy rate of 62.5% .
Notice anything odd about these numbers? The confusion matrix really says it all. Added together, they equal roughly 100%.
The unweighted model automatically classified everything as “normal,” while the weighted model automatically classified everything as “pneumonia”!
I’ve continued to tweak the hyperparameters according to the various SageMaker tutorials above. At this point, I don’t seem to be achieving much discerning power no matter how much computing power I throw at it — please chime in in the comments with your SageMaker expertise and suggestions!
D) Disconnect
Before we wrap up, we need to make sure that we delete all of the resources created above. This will ensure that we don’t get billed for unused resources.
sm = boto3.client('sagemaker')
sm.delete_endpoint_config(EndpointConfigName=endpoint_name)
sm.delete_endpoint(EndpointName=endpoint_name)
sm.delete_model(ModelName=endpoint_name)