Introduction to JumpStart - Text Classification


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


  1. Set Up

  2. Select a pre-trained model

  3. Finetune the pre-trained model on a custom dataset

1. Set Up


Before executing the notebook, there are some initial steps required for setup. This notebook requires latest version of sagemaker and ipywidgets. ***

[ ]:
!pip install sagemaker ipywidgets --upgrade --quiet
[ ]:
import sagemaker, boto3, json
from sagemaker import get_execution_role

aws_role = get_execution_role()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

2. Select a pre-trained model


You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of JumpStart models can also be accessed at JumpStart Models. ***

[ ]:
model_id = "tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2"
[ ]:
import IPython
from ipywidgets import Dropdown

# download JumpStart model_manifest file.
boto3.client("s3").download_file(
    f"jumpstart-cache-prod-{aws_region}", "models_manifest.json", "models_manifest.json"
)
with open("models_manifest.json", "rb") as json_file:
    model_list = json.load(json_file)

# filter-out all the Text Classification models from the manifest list.
tc_models_all_versions, tc_models = [
    model["model_id"] for model in model_list if "-tc-" in model["model_id"]
], []
[tc_models.append(model) for model in tc_models_all_versions if model not in tc_models]

# display the model-ids in a dropdown, for user to select a model.
dropdown = Dropdown(
    value=model_id,
    options=tc_models,
    description="JumpStart Text Classification Models:",
    style={"description_width": "initial"},
    layout={"width": "max-content"},
)
display(IPython.display.Markdown("## Select a JumpStart pre-trained model from the dropdown below"))
display(dropdown)
[ ]:
model_id = dropdown.value

3. Finetune the pre-trained model on a custom dataset


We discuss how a model can be finetuned to a custom dataset with any number of classes.

The Text Embedding model can be fine-tuned on any text classification dataset in the same way the model available for inference has been fine-tuned on the SST2 movie review dataset.

The model available for fine-tuning attaches a classification layer to the Text Embedding model and initializes the layer parameters to random values. The output dimension of the classification layer is determined based on the number of classes detected in the input data. The fine-tuning step fine-tunes all the model parameters to minimize prediction error on the input data and returns the fine-tuned model. The model returned by fine-tuning can be further deployed for inference. Below are the instructions for how the training data should be formatted for input to the model.

  • Input: A directory containing a ‘data.csv’ file.

    • Each row of the first column of ‘data.csv’ should have integer class labels between 0 to the number of classes.

    • Each row of the second column should have the corresponding text.

  • Output: A trained model that can be deployed for inference.

Below is an example of ‘data.csv’ file showing values in its first two columns. Note that the file should not have any header.

0

hide new secretions from the parental units

0

contains no wit , only labored gags

1

that loves its characters and communicates something rather beautiful about human nature

source: TensorFlow Hub. License:Apache 2.0 License.

SST2 dataset is downloaded from TensorFlow. Apache 2.0 License. Dataset Homepage. ***

Set Training parameters

Now that we are done with all the setup that is needed, we are ready to fine-tune our model. To begin, let us create a `sageMaker.estimator.Estimator <https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html>`__ object. This estimator will launch the training job.

There are two kinds of parameters that need to be set for training.

The first one are the parameters for the training job. These include: Training data path. This is S3 folder in which the input data is stored The second set of parameters are algorithm specific training hyper-parameters.

[ ]:
import json
from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket


# Sample training data is available in this bucket
training_data_bucket = get_jumpstart_content_bucket()
training_data_prefix = "training-datasets/SST/"

training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}"

Start Training


We start by creating the estimator object with all the required assets and then launch the training job. Since default hyperparameter values are model-specific, inspect estimator.hyperparameters() to view default values for your selected model. ***

[ ]:
estimator = JumpStartEstimator(
    model_id=model_id,
    hyperparameters={"epochs": "1", "batch_size": "64"},
)
[ ]:
# You can now fit the estimator by providing training data to the train channel

estimator.fit({"training": training_dataset_s3_path}, logs=True)

Deploy & run Inference on the fine-tuned model


A trained model does nothing on its own. We now want to use the model to perform inference. For this example, that means predicting the class label of an input sentence. ***

[ ]:
# You can deploy the fine-tuned model to an endpoint directly from the estimator.
predictor = estimator.deploy()
[ ]:
text1 = "astonishing ... ( frames ) profound ethical and philosophical questions in the form of dazzling pop entertainment"
text2 = "simply stupid , irrelevant and deeply , truly , bottomlessly cynical "
[ ]:
for text in [text1, text2]:
    query_response = predictor.predict(text)
    print(query_response)
[ ]:
# Delete the SageMaker endpoint and the attached resources
predictor.delete_predictor()

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable