Work with ML Training Jobs with Viam's ML Training API

The ML training API allows you to get information about and cancel ML training jobs taking place on the Viam app.

The ML training client API supports the following methods:

Method NameDescription
SubmitTrainingJobSubmit a training job.
SubmitCustomTrainingJobSubmit a training job from a custom training script.
GetTrainingJobGet training job metadata.
ListTrainingJobsGet training job metadata for all jobs within an organization.
CancelTrainingJobCancel the specified training job.
DeleteCompletedTrainingJobDelete a completed training job from the database, whether the job succeeded or failed.

Establish a connection

To use the Viam ML training client API, you first need to instantiate a ViamClient and then instantiate an MLTrainingClient. See the following example for reference.

You can create an API key on your settings page.

import asyncio

from viam.rpc.dial import DialOptions, Credentials
from viam.app.viam_client import ViamClient


async def connect() -> ViamClient:
    dial_options = DialOptions(
      credentials=Credentials(
        type="api-key",
        # Replace "<API-KEY>" (including brackets) with your machine's API key
        payload='<API-KEY>',
      ),
      # Replace "<API-KEY-ID>" (including brackets) with your machine's
      # API key ID
      auth_entity='<API-KEY-ID>'
    )
    return await ViamClient.create_from_dial_options(dial_options)


async def main():

    # Make a ViamClient
    viam_client = await connect()
    # Instantiate an MLTrainingClient to run ML training client API methods on
    ml_training_client = viam_client.ml_training_client

    viam_client.close()

if __name__ == '__main__':
    asyncio.run(main())

Once you have instantiated an MLTrainingClient, you can run the following API methods against the MLTrainingClient object (named ml_training_client in the examples).

API

SubmitTrainingJob

Submit a training job.

Parameters:

  • org_id (str) (required): The ID of the organization to submit the training job to. To retrieve this, expand your organization’s dropdown in the top right corner of the Viam app, select Settings, and copy Organization ID.
  • dataset_id (str) (required): The ID of the dataset to train the ML model on. To retrieve this, navigate to your dataset’s page in the Viam app, click in the left-hand menu, and click Copy dataset ID.
  • model_name (str) (required): the model name.
  • model_version (str) (required): The version of the ML model you’re training. This string must be unique from any previous versions you’ve set.
  • model_type (viam.proto.app.mltraining.ModelType.ValueType) (required): The type of the ML model. Options: ModelType.MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION, ModelType.MODEL_TYPE_MULTI_LABEL_CLASSIFICATION, ModelType.MODEL_TYPE_OBJECT_DETECTION.
  • tags (List[str]) (required): the labels to train the model on.

Returns:

  • (str): the ID of the training job.

Example:

from viam.proto.app.mltraining import ModelType

job_id = await ml_training_client.submit_training_job(
    org_id="<organization-id>",
    dataset_id="<dataset-id>",
    model_name="<your-model-name>",
    model_version="1",
    model_type=ModelType.MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION,
    tags=["tag1", "tag2"]
)

For more information, see the Python SDK Docs.

SubmitCustomTrainingJob

Submit a training job from a custom training script. Follow the guide to Train a Model with a Custom Python Training Script.

Parameters:

  • org_id (str) (required): the ID of the org to submit the training job to.
  • dataset_id (str) (required): the ID of the dataset to train the model on.
  • registry_item_id (str) (required): the ID of the training script from the registry.
  • registry_item_version (str) (required): the version of the training script from the registry.
  • model_name (str) (required): the model name.
  • model_version (str) (required): the model version.

Returns:

  • (str): the ID of the training job.

Example:

job_id = await ml_training_client.submit_custom_training_job(
    org_id="<organization-id>",
    dataset_id="<dataset-id>",
    registry_item_id="viam:classification-tflite",
    registry_item_version="2024-08-13T12-11-54",
    model_name="<your-model-name>",
    model_version="1"
)

For more information, see the Python SDK Docs.

GetTrainingJob

Get training job metadata.

Parameters:

  • id (str) (required): the ID of the requested training job.

Returns:

Example:

job_metadata = await ml_training_client.get_training_job(
    id="<job-id>")

For more information, see the Python SDK Docs.

ListTrainingJobs

Get training job metadata for all jobs within an organization.

Parameters:

Returns:

Example:

jobs_metadata = await ml_training_client.list_training_jobs(
    org_id="<org-id>")

first_job_id = jobs_metadata[1].id

For more information, see the Python SDK Docs.

CancelTrainingJob

Cancel the specified training job.

Parameters:

  • id (str) (required): ID of the training job you wish to get metadata from. Retrieve this value with ListTrainingJobs().

Returns:

  • None.

Raises:

  • (GRPCError): if no training job exists with the given ID.

Example:

await ml_training_client.cancel_training_job(
    id="<job-id>")

For more information, see the Python SDK Docs.

DeleteCompletedTrainingJob

Delete a completed training job from the database, whether the job succeeded or failed.

Parameters:

  • id (str) (required): the ID of the training job to delete.

Returns:

  • None.

Example:

await ml_training_client.delete_completed_training_job(
    id="<job-id>")

For more information, see the Python SDK Docs.