[PATCH v5 19/39] ml/cnxk: enable support to stop an ML models

Srikanth Yalavarthi syalavarthi at marvell.com
Tue Feb 7 17:06:59 CET 2023


Implemented model stop driver function. A model stop job is
enqueued through scratch registers and is checked for
completion through polling in a synchronous mode. OCM pages
are released after model stop completion.

Signed-off-by: Srikanth Yalavarthi <syalavarthi at marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 115 ++++++++++++++++++++++++++++++++-
 drivers/ml/cnxk/cn10k_ml_ops.h |   1 +
 2 files changed, 114 insertions(+), 2 deletions(-)

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index e8ce65b182..77d3728d8d 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -295,10 +295,14 @@ cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *c
 		/* Re-configure */
 		void **models;
 
-		/* Unload all models */
+		/* Stop and unload all models */
 		for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
 			model = dev->data->models[model_id];
 			if (model != NULL) {
+				if (model->state == ML_CN10K_MODEL_STATE_STARTED) {
+					if (cn10k_ml_model_stop(dev, model_id) != 0)
+						plt_err("Could not stop model %u", model_id);
+				}
 				if (model->state == ML_CN10K_MODEL_STATE_LOADED) {
 					if (cn10k_ml_model_unload(dev, model_id) != 0)
 						plt_err("Could not unload model %u", model_id);
@@ -362,10 +366,14 @@ cn10k_ml_dev_close(struct rte_ml_dev *dev)
 
 	mldev = dev->data->dev_private;
 
-	/* Unload all models */
+	/* Stop and unload all models */
 	for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
 		model = dev->data->models[model_id];
 		if (model != NULL) {
+			if (model->state == ML_CN10K_MODEL_STATE_STARTED) {
+				if (cn10k_ml_model_stop(dev, model_id) != 0)
+					plt_err("Could not stop model %u", model_id);
+			}
 			if (model->state == ML_CN10K_MODEL_STATE_LOADED) {
 				if (cn10k_ml_model_unload(dev, model_id) != 0)
 					plt_err("Could not unload model %u", model_id);
@@ -767,6 +775,108 @@ cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id)
 	return ret;
 }
 
+int
+cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
+{
+	struct cn10k_ml_model *model;
+	struct cn10k_ml_dev *mldev;
+	struct cn10k_ml_ocm *ocm;
+	struct cn10k_ml_req *req;
+
+	bool job_enqueued;
+	bool job_dequeued;
+	bool locked;
+	int ret = 0;
+
+	mldev = dev->data->dev_private;
+	ocm = &mldev->ocm;
+	model = dev->data->models[model_id];
+
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	/* Prepare JD */
+	req = model->req;
+	cn10k_ml_prep_sp_job_descriptor(mldev, model, req, ML_CN10K_JOB_TYPE_MODEL_STOP);
+	req->result.error_code = 0x0;
+	req->result.user_ptr = NULL;
+
+	plt_write64(ML_CN10K_POLL_JOB_START, &req->status);
+	plt_wmb();
+
+	locked = false;
+	while (!locked) {
+		if (plt_spinlock_trylock(&model->lock) != 0) {
+			if (model->state == ML_CN10K_MODEL_STATE_LOADED) {
+				plt_ml_dbg("Model not started, model = 0x%016lx",
+					   PLT_U64_CAST(model));
+				plt_spinlock_unlock(&model->lock);
+				return 1;
+			}
+
+			if (model->state == ML_CN10K_MODEL_STATE_JOB_ACTIVE) {
+				plt_err("A slow-path job is active for the model = 0x%016lx",
+					PLT_U64_CAST(model));
+				plt_spinlock_unlock(&model->lock);
+				return -EBUSY;
+			}
+
+			model->state = ML_CN10K_MODEL_STATE_JOB_ACTIVE;
+			plt_spinlock_unlock(&model->lock);
+			locked = true;
+		}
+	}
+
+	while (model->model_mem_map.ocm_reserved) {
+		if (plt_spinlock_trylock(&ocm->lock) != 0) {
+			cn10k_ml_ocm_free_pages(dev, model->model_id);
+			model->model_mem_map.ocm_reserved = false;
+			model->model_mem_map.tilemask = 0x0;
+			plt_spinlock_unlock(&ocm->lock);
+		}
+	}
+
+	job_enqueued = false;
+	job_dequeued = false;
+	do {
+		if (!job_enqueued) {
+			req->timeout = plt_tsc_cycles() + ML_CN10K_CMD_TIMEOUT * plt_tsc_hz();
+			job_enqueued = roc_ml_scratch_enqueue(&mldev->roc, &req->jd);
+		}
+
+		if (job_enqueued && !job_dequeued)
+			job_dequeued = roc_ml_scratch_dequeue(&mldev->roc, &req->jd);
+
+		if (job_dequeued)
+			break;
+	} while (plt_tsc_cycles() < req->timeout);
+
+	if (job_dequeued) {
+		if (plt_read64(&req->status) == ML_CN10K_POLL_JOB_FINISH) {
+			if (req->result.error_code == 0x0)
+				ret = 0;
+			else
+				ret = -1;
+		}
+	} else {
+		roc_ml_scratch_queue_reset(&mldev->roc);
+		ret = -ETIME;
+	}
+
+	locked = false;
+	while (!locked) {
+		if (plt_spinlock_trylock(&model->lock) != 0) {
+			model->state = ML_CN10K_MODEL_STATE_LOADED;
+			plt_spinlock_unlock(&model->lock);
+			locked = true;
+		}
+	}
+
+	return ret;
+}
+
 struct rte_ml_dev_ops cn10k_ml_ops = {
 	/* Device control ops */
 	.dev_info_get = cn10k_ml_dev_info_get,
@@ -783,4 +893,5 @@ struct rte_ml_dev_ops cn10k_ml_ops = {
 	.model_load = cn10k_ml_model_load,
 	.model_unload = cn10k_ml_model_unload,
 	.model_start = cn10k_ml_model_start,
+	.model_stop = cn10k_ml_model_stop,
 };
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index 989af978c4..22576b93c0 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -65,5 +65,6 @@ int cn10k_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *para
 			uint16_t *model_id);
 int cn10k_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id);
 int cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id);
+int cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id);
 
 #endif /* _CN10K_ML_OPS_H_ */
-- 
2.17.1



More information about the dev mailing list