[PATCH v5 19/39] ml/cnxk: enable support to stop an ML models
Srikanth Yalavarthi
syalavarthi at marvell.com
Tue Feb 7 17:06:59 CET 2023
Implemented model stop driver function. A model stop job is
enqueued through scratch registers and is checked for
completion through polling in a synchronous mode. OCM pages
are released after model stop completion.
Signed-off-by: Srikanth Yalavarthi <syalavarthi at marvell.com>
---
drivers/ml/cnxk/cn10k_ml_ops.c | 115 ++++++++++++++++++++++++++++++++-
drivers/ml/cnxk/cn10k_ml_ops.h | 1 +
2 files changed, 114 insertions(+), 2 deletions(-)
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index e8ce65b182..77d3728d8d 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -295,10 +295,14 @@ cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *c
/* Re-configure */
void **models;
- /* Unload all models */
+ /* Stop and unload all models */
for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
model = dev->data->models[model_id];
if (model != NULL) {
+ if (model->state == ML_CN10K_MODEL_STATE_STARTED) {
+ if (cn10k_ml_model_stop(dev, model_id) != 0)
+ plt_err("Could not stop model %u", model_id);
+ }
if (model->state == ML_CN10K_MODEL_STATE_LOADED) {
if (cn10k_ml_model_unload(dev, model_id) != 0)
plt_err("Could not unload model %u", model_id);
@@ -362,10 +366,14 @@ cn10k_ml_dev_close(struct rte_ml_dev *dev)
mldev = dev->data->dev_private;
- /* Unload all models */
+ /* Stop and unload all models */
for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
model = dev->data->models[model_id];
if (model != NULL) {
+ if (model->state == ML_CN10K_MODEL_STATE_STARTED) {
+ if (cn10k_ml_model_stop(dev, model_id) != 0)
+ plt_err("Could not stop model %u", model_id);
+ }
if (model->state == ML_CN10K_MODEL_STATE_LOADED) {
if (cn10k_ml_model_unload(dev, model_id) != 0)
plt_err("Could not unload model %u", model_id);
@@ -767,6 +775,108 @@ cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id)
return ret;
}
+int
+cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
+{
+ struct cn10k_ml_model *model;
+ struct cn10k_ml_dev *mldev;
+ struct cn10k_ml_ocm *ocm;
+ struct cn10k_ml_req *req;
+
+ bool job_enqueued;
+ bool job_dequeued;
+ bool locked;
+ int ret = 0;
+
+ mldev = dev->data->dev_private;
+ ocm = &mldev->ocm;
+ model = dev->data->models[model_id];
+
+ if (model == NULL) {
+ plt_err("Invalid model_id = %u", model_id);
+ return -EINVAL;
+ }
+
+ /* Prepare JD */
+ req = model->req;
+ cn10k_ml_prep_sp_job_descriptor(mldev, model, req, ML_CN10K_JOB_TYPE_MODEL_STOP);
+ req->result.error_code = 0x0;
+ req->result.user_ptr = NULL;
+
+ plt_write64(ML_CN10K_POLL_JOB_START, &req->status);
+ plt_wmb();
+
+ locked = false;
+ while (!locked) {
+ if (plt_spinlock_trylock(&model->lock) != 0) {
+ if (model->state == ML_CN10K_MODEL_STATE_LOADED) {
+ plt_ml_dbg("Model not started, model = 0x%016lx",
+ PLT_U64_CAST(model));
+ plt_spinlock_unlock(&model->lock);
+ return 1;
+ }
+
+ if (model->state == ML_CN10K_MODEL_STATE_JOB_ACTIVE) {
+ plt_err("A slow-path job is active for the model = 0x%016lx",
+ PLT_U64_CAST(model));
+ plt_spinlock_unlock(&model->lock);
+ return -EBUSY;
+ }
+
+ model->state = ML_CN10K_MODEL_STATE_JOB_ACTIVE;
+ plt_spinlock_unlock(&model->lock);
+ locked = true;
+ }
+ }
+
+ while (model->model_mem_map.ocm_reserved) {
+ if (plt_spinlock_trylock(&ocm->lock) != 0) {
+ cn10k_ml_ocm_free_pages(dev, model->model_id);
+ model->model_mem_map.ocm_reserved = false;
+ model->model_mem_map.tilemask = 0x0;
+ plt_spinlock_unlock(&ocm->lock);
+ }
+ }
+
+ job_enqueued = false;
+ job_dequeued = false;
+ do {
+ if (!job_enqueued) {
+ req->timeout = plt_tsc_cycles() + ML_CN10K_CMD_TIMEOUT * plt_tsc_hz();
+ job_enqueued = roc_ml_scratch_enqueue(&mldev->roc, &req->jd);
+ }
+
+ if (job_enqueued && !job_dequeued)
+ job_dequeued = roc_ml_scratch_dequeue(&mldev->roc, &req->jd);
+
+ if (job_dequeued)
+ break;
+ } while (plt_tsc_cycles() < req->timeout);
+
+ if (job_dequeued) {
+ if (plt_read64(&req->status) == ML_CN10K_POLL_JOB_FINISH) {
+ if (req->result.error_code == 0x0)
+ ret = 0;
+ else
+ ret = -1;
+ }
+ } else {
+ roc_ml_scratch_queue_reset(&mldev->roc);
+ ret = -ETIME;
+ }
+
+ locked = false;
+ while (!locked) {
+ if (plt_spinlock_trylock(&model->lock) != 0) {
+ model->state = ML_CN10K_MODEL_STATE_LOADED;
+ plt_spinlock_unlock(&model->lock);
+ locked = true;
+ }
+ }
+
+ return ret;
+}
+
struct rte_ml_dev_ops cn10k_ml_ops = {
/* Device control ops */
.dev_info_get = cn10k_ml_dev_info_get,
@@ -783,4 +893,5 @@ struct rte_ml_dev_ops cn10k_ml_ops = {
.model_load = cn10k_ml_model_load,
.model_unload = cn10k_ml_model_unload,
.model_start = cn10k_ml_model_start,
+ .model_stop = cn10k_ml_model_stop,
};
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index 989af978c4..22576b93c0 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -65,5 +65,6 @@ int cn10k_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *para
uint16_t *model_id);
int cn10k_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id);
int cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id);
+int cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id);
#endif /* _CN10K_ML_OPS_H_ */
--
2.17.1
More information about the dev
mailing list