{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "e950fa8e", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Train a SKLearn Model using Script Mode\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "0abdc17b", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "---" ] }, { "attachments": {}, "cell_type": "markdown", "id": "90e7cac6", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "\n", "The aim of this notebook is to demonstrate how to train and deploy a scikit-learn model in Amazon SageMaker. The method used is called Script Mode, in which we write a script to train our model and submit it to the SageMaker Python SDK. For more information, feel free to read [Using Scikit-learn with the SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/frameworks/sklearn/using_sklearn.html).\n", "\n", "## Runtime\n", "This notebook takes approximately 15 minutes to run.\n", "\n", "## Contents\n", "1. [Download data](#Download-data)\n", "1. [Prepare data](#Prepare-data)\n", "1. [Train model](#Train-model)\n", "1. [Deploy and test endpoint](#Deploy-and-test-endpoint)\n", "1. [Cleanup](#Cleanup)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a16db1a6", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Download data \n", "Download the [Iris Data Set](https://archive.ics.uci.edu/ml/datasets/iris), which is the data used to trained the model in this demo." ] }, { "cell_type": "code", "execution_count": null, "id": "e2d5c27c", "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "!pip install -U sagemaker" ] }, { "cell_type": "code", "execution_count": null, "id": "a670c242", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import boto3\n", "import pandas as pd\n", "import numpy as np\n", "\n", "s3 = boto3.client(\"s3\")\n", "s3.download_file(\n", " f\"sagemaker-example-files-prod-{boto3.session.Session().region_name}\",\n", " \"datasets/tabular/iris/iris.data\",\n", " \"iris.data\",\n", ")\n", "\n", "df = pd.read_csv(\n", " \"iris.data\", header=None, names=[\"sepal_len\", \"sepal_wid\", \"petal_len\", \"petal_wid\", \"class\"]\n", ")\n", "df.head()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7c03b3d2", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Prepare data\n", "Next, we prepare the data for training by first converting the labels from string to integers. Then we split the data into a train dataset (80% of the data) and test dataset (the remaining 20% of the data) before saving them into CSV files. Then, these files are uploaded to S3 where the SageMaker SDK can access and use them to train the model." ] }, { "cell_type": "code", "execution_count": null, "id": "72748b04", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Convert the three classes from strings to integers in {0,1,2}\n", "df[\"class_cat\"] = df[\"class\"].astype(\"category\").cat.codes\n", "categories_map = dict(enumerate(df[\"class\"].astype(\"category\").cat.categories))\n", "print(categories_map)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "fb5ea6cf", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Split the data into 80-20 train-test split\n", "num_samples = df.shape[0]\n", "split = round(num_samples * 0.8)\n", "train = df.iloc[:split, :]\n", "test = df.iloc[split:, :]\n", "print(\"{} train, {} test\".format(split, num_samples - split))" ] }, { "cell_type": "code", "execution_count": null, "id": "48770a6b", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Write train and test CSV files\n", "train.to_csv(\"train.csv\", index=False)\n", "test.to_csv(\"test.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "ba40dab3", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Create a sagemaker session to upload data to S3\n", "import sagemaker\n", "\n", "sagemaker_session = sagemaker.Session()\n", "\n", "# Upload data to default S3 bucket\n", "prefix = \"DEMO-sklearn-iris\"\n", "training_input_path = sagemaker_session.upload_data(\"train.csv\", key_prefix=prefix + \"/training\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9d52c534", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Train model\n", "The model is trained using the SageMaker SDK's Estimator class. Firstly, get the execution role for training. This role allows us to access the S3 bucket in the last step, where the train and test data set is located." ] }, { "cell_type": "code", "execution_count": null, "id": "f7cbdad2", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Use the current execution role for training. It needs access to S3\n", "role = sagemaker.get_execution_role()\n", "print(role)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "10cdcfb6", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Then, it is time to define the SageMaker SDK Estimator class. We use an Estimator class specifically desgined to train scikit-learn models called `SKLearn`. In this estimator, we define the following parameters:\n", "1. The script that we want to use to train the model (i.e. `entry_point`). This is the heart of the Script Mode method. Additionally, set the `script_mode` parameter to `True`.\n", "1. The role which allows us access to the S3 bucket containing the train and test data set (i.e. `role`)\n", "1. How many instances we want to use in training (i.e. `instance_count`) and what type of instance we want to use in training (i.e. `instance_type`)\n", "1. Which version of scikit-learn to use (i.e. `framework_version`)\n", "1. Training hyperparameters (i.e. `hyperparameters`)\n", "\n", "After setting these parameters, the `fit` function is invoked to train the model." ] }, { "cell_type": "code", "execution_count": null, "id": "ac14dcb7", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Docs: https://sagemaker.readthedocs.io/en/stable/frameworks/sklearn/sagemaker.sklearn.html\n", "\n", "from sagemaker.sklearn import SKLearn\n", "\n", "sk_estimator = SKLearn(\n", " entry_point=\"train.py\",\n", " role=role,\n", " instance_count=1,\n", " instance_type=\"ml.c5.xlarge\",\n", " py_version=\"py3\",\n", " framework_version=\"1.2-1\",\n", " script_mode=True,\n", " hyperparameters={\"estimators\": 20},\n", ")\n", "\n", "# Train the estimator\n", "sk_estimator.fit({\"train\": training_input_path})" ] }, { "attachments": {}, "cell_type": "markdown", "id": "3813b62c", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Deploy and test endpoint\n", "After training the model, it is time to deploy it as an endpoint. To do so, we invoke the `deploy` function within the scikit-learn estimator. As shown in the code below, one can define the number of instances (i.e. `initial_instance_count`) and instance type (i.e. `instance_type`) used to deploy the model." ] }, { "cell_type": "code", "execution_count": null, "id": "06aace5c", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import time\n", "\n", "sk_endpoint_name = \"sklearn-rf-model\" + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n", "sk_predictor = sk_estimator.deploy(\n", " initial_instance_count=1, instance_type=\"ml.m5.large\", endpoint_name=sk_endpoint_name\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "bbc747e1", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "After the endpoint has been completely deployed, it can be invoked using the [SageMaker Runtime Client](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker-runtime.html) (which is the method used in the code cell below) or [Scikit Learn Predictor](https://sagemaker.readthedocs.io/en/stable/frameworks/sklearn/sagemaker.sklearn.html#scikit-learn-predictor). If you plan to use the latter method, make sure to use a [Serializer](https://sagemaker.readthedocs.io/en/stable/api/inference/serializers.html) to serialize your data properly." ] }, { "cell_type": "code", "execution_count": null, "id": "85491166", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import json\n", "\n", "client = sagemaker_session.sagemaker_runtime_client\n", "\n", "request_body = {\"Input\": [[9.0, 3571, 1976, 0.525]]}\n", "data = json.loads(json.dumps(request_body))\n", "payload = json.dumps(data)\n", "\n", "response = client.invoke_endpoint(\n", " EndpointName=sk_endpoint_name, ContentType=\"application/json\", Body=payload\n", ")\n", "\n", "result = json.loads(response[\"Body\"].read().decode())[\"Output\"]\n", "print(\"Predicted class category {} ({})\".format(result, categories_map[result]))" ] }, { "attachments": {}, "cell_type": "markdown", "id": "90f26921", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Cleanup\n", "If the model and endpoint are no longer in use, they should be deleted to save costs and free up resources." ] }, { "cell_type": "code", "execution_count": null, "id": "ec5a3a83", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "sk_predictor.delete_model()\n", "sk_predictor.delete_endpoint()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "454a7ca7", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/sagemaker-script-mode|sklearn|sklearn_byom.ipynb)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (Data Science 3.0)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-310-v1" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }