Retrain a machine learning model

Edit this page
Today’s leading companies leverage machine learning (ML) models to help better understand and predict customer needs. The process of training an ML model involves leveraging an ML algorithm, and providing it with training data from which to learn. Every so often, you will need to retrain your model if the data distributions have deviated significantly from those of the original training set. ksqlDB makes this easy by triggering the retraining process whenever the prediction error is greater than a defined threshold This recipe walks through how to take an existing ML pipeline, with results stored in MongoDB, and trigger the retraining process once the prediction error exceeds a certain threshold, in this case, 15%.

The generated example is based on a factory for fish processing. In the first step, the fish size (length and height) is measured. The model then predicts the weight of the fish based on its size and species, which will determine a selling price. By retraining the model, we can help maximize revenue by accurately determining the optimal selling price. This recipe is based on the blog post Apache Kafka and R: Real-Time Prediction and Model (Re)training, by Patrick Neff.

Run it

Setup your environment

1

Provision a Kafka cluster in Confluent Cloud.

Once your Confluent Cloud cluster is available, create a ksqlDB application and navigate to the ksqlDB editor to execute this tutorial. ksqlDB supports SQL language for extracting, transforming, and loading events within your Kafka cluster.

Execute ksqlDB code

2

ksqlDB processes data in realtime, and you can also import and export data straight from ksqlDB from popular data sources and end systems in the cloud. This tutorial shows you how to run the recipe in one of two ways: using connector(s) to any supported data source or using ksqlDB’s INSERT INTO functionality to mock the data.

If you cannot connect to a real data source with properly formatted data, or if you just want to execute this tutorial without external dependencies, no worries! Remove the CREATE SOURCE CONNECTOR commands and insert mock data into the streams.

The existing pipeline, which is predicting the weight of fish based on size and species, stores its results in two MongoDB collections, which are used by other processes downstream. One collection contains the data fed to the model, along with the prediction. The other contains the actual weight as determined by a later step in the process. For this tutorial, we’ll use Connect to make this data available to our ksqlDB application.

Start by creating a ksqlDB stream for each of our two input topics coming from Connect. Then create another stream to join those two streams on Fish_Id. Finally, create a ksqlDB table with a windowed aggregation of our joined stream, where the average error rate is over 15%. This table will be used to trigger our model retraining process.

When creating the initial STREAM or TABLE, if the backing Kafka topic already exists, then the PARTITIONS property may be omitted.

-- Substitute your parameter values in the connector configurations below.
-- If you do not want to connect to a real data source, remove the CREATE SOURCE CONNECTOR commands,
-- and add the INSERT INTO commands to insert mock data into the streams

-- Stream of fish weight predictions
CREATE SOURCE CONNECTOR IF NOT EXISTS weight_predictions WITH (
  'connector.class'        = 'MongoDbAtlasSource',
  'name'                   = 'model-retrain-weight-predictions',
  'kafka.api.key'          = '<my-kafka-api-key>',
  'kafka.api.secret'       = '<my-kafka-api-secret>',
  'connection.host'        = '<database-host-address>',
  'connection.user'        = '<database-username>',
  'connection.password'    = '<database-password>',
  'topic.prefix'           = 'kt',
  'database'               = 'mdb',
  'collection'             = 'weight-prediction',
  'poll.await.time.ms'     = '5000',
  'poll.max.batch.size'    = '1000',
  'copy.existing'          = 'true',
  'output.data.format'     = 'JSON',
  'tasks.max'              = '1'
);

-- Stream of actual fish weights
CREATE SOURCE CONNECTOR IF NOT EXISTS actual_weights WITH (
  'connector.class'        = 'MongoDbAtlasSource',
  'name'                   = 'model-retrain-actual-weights',
  'kafka.api.key'          = '<my-kafka-api-key>',
  'kafka.api.secret'       = '<my-kafka-api-secret>',
  'connection.host'        = '<database-host-address>',
  'connection.user'        = '<database-username>',
  'connection.password'    = '<database-password>',
  'topic.prefix'           = 'kt',
  'database'               = 'mdb',
  'collection'             = 'machine-weight',
  'poll.await.time.ms'     = '5000',
  'poll.max.batch.size'    = '1000',
  'copy.existing'          = 'true',
  'output.data.format'     = 'JSON',
  'tasks.max'              = '1'
);


SET 'auto.offset.reset' = 'earliest';

-- Create stream of predictions
CREATE STREAM predicted_weight(
  fish_id VARCHAR KEY,
  species VARCHAR,
  height DOUBLE,
  length DOUBLE,
  prediction DOUBLE
) WITH (
  KAFKA_TOPIC = 'kt.mdb.weight-prediction',
  VALUE_FORMAT = 'JSON',
  PARTITIONS = 6
);

-- Create stream of actual weights
CREATE STREAM actual_weight(
  fish_id VARCHAR KEY,
  species VARCHAR,
  weight DOUBLE
) WITH (
  KAFKA_TOPIC = 'kt.mdb.machine-weight',
  VALUE_FORMAT = 'JSON',
  PARTITIONS = 6
);

-- Create stream joining predictions with actual weights
CREATE STREAM diff_weight WITH (KAFKA_TOPIC = 'diff_weight') AS
  SELECT
   -- This fake key field will give us something to group by in the next step
   'key' AS key,
   predicted_weight.fish_id AS fish_id,
   predicted_weight.species AS species,
   predicted_weight.length AS length,
   predicted_weight.height AS height,
   predicted_weight.prediction AS prediction,
   actual_weight.weight AS actual,
   ROUND(ABS(predicted_weight.prediction - actual_weight.weight) / actual_weight.weight, 3) AS Error
FROM predicted_weight
INNER JOIN actual_weight
WITHIN 1 MINUTE
GRACE PERIOD 1 MINUTE
ON predicted_weight.fish_id = actual_weight.fish_id;

-- Create table of one minute aggregates with over 15% error rate
CREATE TABLE retrain_weight WITH (KAFKA_TOPIC = 'retrain_weight') AS
 SELECT
   key,
   COLLECT_SET(species) AS species,
   EARLIEST_BY_OFFSET(fish_id) AS fish_id_start,
   LATEST_BY_OFFSET(fish_id) AS fish_id_end,
   AVG(Error) AS ErrorAvg
FROM diff_weight
WINDOW TUMBLING (SIZE 1 MINUTE, GRACE PERIOD 1 MINUTE)
GROUP BY key
HAVING ROUND(AVG(diff_weight.Error), 2) > 0.15;

Test with mock data

3

If you are you not running source connectors to produce events, you can use ksqlDB INSERT INTO statements to insert mock data into the source topics:

INSERT INTO predicted_weight VALUES ('101', 'Salmon', 17.33, 74.55, 3.78);
INSERT INTO predicted_weight VALUES ('102', 'Salmon', 19.11, 82.19, 4.17);
INSERT INTO predicted_weight VALUES ('103', 'Salmon', 21.07, 90.62, 4.6);
INSERT INTO predicted_weight VALUES ('104', 'Bass', 15.44, 56.23, 2.54);
INSERT INTO predicted_weight VALUES ('105', 'Bass', 17.02, 62, 2.8);
INSERT INTO predicted_weight VALUES ('106', 'Bass', 18.76, 68.34, 3.09);
INSERT INTO predicted_weight VALUES ('107', 'Trout', 13.34, 64.05, 1.47);
INSERT INTO predicted_weight VALUES ('108', 'Trout', 14.71, 70.61, 1.62);
INSERT INTO predicted_weight VALUES ('109', 'Trout', 16.22, 77.85, 1.79);
INSERT INTO predicted_weight VALUES ('110', 'Trout', 17.03, 81.74, 1.88);

INSERT INTO actual_weight VALUES ('101', 'Salmon', 4.38);
INSERT INTO actual_weight VALUES ('102', 'Salmon', 3.17);
INSERT INTO actual_weight VALUES ('103', 'Salmon', 5.6);
INSERT INTO actual_weight VALUES ('104', 'Bass', 5.54);
INSERT INTO actual_weight VALUES ('105', 'Bass', 1.8);
INSERT INTO actual_weight VALUES ('106', 'Bass', 4.09);
INSERT INTO actual_weight VALUES ('107', 'Trout', 2.47);
INSERT INTO actual_weight VALUES ('108', 'Trout', 2.62);
INSERT INTO actual_weight VALUES ('109', 'Trout', 2.79);
INSERT INTO actual_weight VALUES ('110', 'Trout', 2.88);

To validate that this recipe is working, run the following query:

SELECT * FROM retrain_weight;

Your output should resemble:

+------------------------+------------------------+------------------------+------------------------+------------------------+------------------------+------------------------+
|KEY                     |WINDOWSTART             |WINDOWEND               |SPECIES                 |FISH_ID_START           |FISH_ID_END             |ERRORAVG                |
+------------------------+------------------------+------------------------+------------------------+------------------------+------------------------+------------------------+
|key                     |1646327820000           |1646327880000           |[Salmon, Bass, Trout]   |101                     |110                     |0.3465000000000001      |
Query terminated

Write the data out

4

Now we’ll use a MongoDB sink connector to send the combined predictions and actual weights to a database, and the HTTP sink connector to trigger the retraining process.

CREATE SINK CONNECTOR IF NOT EXISTS training_data WITH (
    'connector.class'          = 'MongoDbAtlasSink',
    'name'                     = 'weight-data',
    'kafka.auth.mode'          = 'KAFKA_API_KEY',
    'kafka.api.key'            = '<my-kafka-api-key',
    'kafka.api.secret'         = '<my-kafka-api-secret>',
    'input.data.format'        = 'JSON',
    'connection.host'          = '<database-host-address>',
    'connection.user'          = '<my-username>',
    'connection.password'      = '<my-password>',
    'topics'                   = 'diff_weight',
    'max.num.retries'          = '3',
    'retries.defer.timeout'    = '5000',
    'max.batch.size'           = '0',
    'database'                 = 'mdb',
    'collection'               = 'training_data',
    'tasks.max'                = '1'
);

CREATE SINK CONNECTOR IF NOT EXISTS retraining_trigger WITH (
    'connector.class'          = 'HttpSink',
    'input.data.format'        = 'JSON',
    'name'                     = 'retrain-trigger',
    'kafka.auth.mode'          = 'KAFKA_API_KEY',
    'kafka.api.key'            = '<my-kafka-api-key>',
    'kafka.api.secret'         = '<my-kafka-api-secret>',
    'topics'                   = 'retrain_weight',
    'tasks.max'                = '1',
    'http.api.url'             = '<training-endpoint-url>',
    'request.method'           = 'POST'
);

Cleanup

5

To clean up the ksqlDB resources created by this tutorial, use the ksqlDB commands shown below (substitute stream or topic name, as appropriate). By including the DELETE TOPIC clause, the topic backing the stream or table is asynchronously deleted as well.

DROP STREAM IF EXISTS <stream_name> DELETE TOPIC;
DROP TABLE IF EXISTS <table_name> DELETE TOPIC;

If you also created connectors, remove those as well (substitute connector name).

DROP CONNECTOR IF EXISTS <connector_name>;