Sitecore Cortex and ML: Part 6 - How to Implement Training and Evaluation Workers

In this article, we will implement the remaining logic to get a functional solution.

Let's finish with the training worker we've described in the previous part.This is a step-by-step diagram of the processes that are related to training.

Training Process

  1. When TrainingWorker is run it will create Cortex Model (PurchaseInteractionModel) and execute method TrainAsync of this model with provided schema and table definitions as parameters:

    ModelStatistics modelStatistics = await _model.TrainAsync(_options.SchemaName, token, tableDefinitionList.ToArray()).ConfigureAwait(false);
    
  2. Cortex Model will retrieve data from blob storage (projected data with all purchases) and pass it to Train method of MLNetService:

    public async Task TrainAsync(string schemaName, CancellationToken cancellationToken, params TableDefinition[] tables)
    {
        var tableStore = _tableStoreFactory.Create(schemaName);
        var data = await GetDataRowsAsync(tableStore, tables.First().Name, cancellationToken);
        
        return _mlNetService.Train(data);
    }
    
    
  3. MLNetService will map data to a normal C# model and execute logic that will calculate RFM values for this model. Next, we will pass RFM values to ML Engine and return RFM values to Cortex Model:

    public ModelStatistics Train(IReadOnlyList<IDataRow> data)
    {
        var customersData = CustomerMapper.MapToCustomers(data);
    
        var rfmCalculateService = new RfmCalculateService();
        var calculatedScores = rfmCalculateService.CalculateRfmScores(customersData);
    
        var businessData = calculatedScores.Select(x => new Rfm
        {
            R = x.R,
            F = x.F,
            M = x.M
        }).ToList();
    
        var client = new RestClient(_mlServerUrl);
        var request = new RestRequest(_trainUrl, Method.POST);
        request.AddJsonBody(businessData);
        var response = client.Execute<bool>(request);
        var ok = response.Data;
        if (!ok)
        {
            throw new Exception("something is wrong with ML engine, check it");
        }
    
        return new RfmStatistics{ Customers = calculatedScores };
    }
    
    
  4. Cortex Model will return RFM values to Training worker and Training worker will populate contacts RFM faces with these values:

    public async Task UpdateRfmFacets(RfmStatistics statistics, CancellationToken token)
    {
        using (IServiceScope scope = _serviceProvider.CreateScope())
        {
            using (var xdbContext = scope.ServiceProvider.GetService<IXdbContext>())
            {
                foreach (var identifier in statistics.Customers)
                {
                    var reference = new IdentifiedContactReference(XConnectService.IdentificationSource, identifier.CustomerId.ToString());
                    var contact = await xdbContext.GetContactAsync(reference, new ContactExpandOptions(
                        PersonalInformation.DefaultFacetKey,
                        EmailAddressList.DefaultFacetKey,
                        ContactBehaviorProfile.DefaultFacetKey,
                        RfmContactFacet.DefaultFacetKey
                    ));
                    if (contact != null)
                    {
                        var rfmFacet = contact.GetFacet<RfmContactFacet>(RfmContactFacet.DefaultFacetKey) ?? new RfmContactFacet();
                        rfmFacet.R = identifier.R;
                        rfmFacet.F = identifier.F;
                        rfmFacet.M = identifier.M;
                        rfmFacet.Recency = identifier.Recency;
                        rfmFacet.Frequency = identifier.Frequency;
                        rfmFacet.Monetary = (double)identifier.Monetary;
                        xdbContext.SetFacet(contact, RfmContactFacet.DefaultFacetKey, rfmFacet);
    
                        _logger.LogInformation(string.Format("Update RFM info: customerId={0}, R={1}, F={2}, M={3}, Recency={4}, Frequency={5}, Monetary={6}",
                            identifier.CustomerId, rfmFacet.R, rfmFacet.F, rfmFacet.M, rfmFacet.Recency, rfmFacet.Frequency, rfmFacet.Monetary));
    
                        await xdbContext.SubmitAsync(token);
                    }
                }
    
              
            }
        }
    }
    
    

Training worker is finished, let’s continue with evaluation worker. The sequence of processes is about the same, but implementation is much easier because we use OOTB Sitecore evaluation worker.

Evaluation Process

  1. When EvaluationWorker is run it will create a Cortex Model (ContactModel) and execute method EvaluateAsync of this model with supplied schema and table definitions as parameters. It works OOTB, no additional code.

  2. Cortex Model will retrieve data from blob storage (batch of projected contacts with RFM values) and pass it to Evaluate method of MLNetService:

    public async Task<IReadOnlyList<object>> EvaluateAsync(string schemaName, CancellationToken cancellationToken, params TableDefinition[] tables)
    {
        var tableStore = _tableStoreFactory.Create(schemaName);
        var data = await GetDataRowsAsync(tableStore, tables.First().Name, cancellationToken);
        return _mlNetService.Evaluate(data);
    }
    
    
  3. MLNetService will map data to a normal C# model and retrieve RFM values. Next, we will pass contacts RFM values to ML Engine and get a predicted number of Cluster in response:

    public IReadOnlyList<PredictionResult> Evaluate(IReadOnlyList<IDataRow> data)
    {
        var validContacts = data.Where(x => x.Enabled() && !string.IsNullOrEmpty(x.GetContactEmail())).ToList();
        var rfmList = validContacts.Select(x => x.MapToRfmFacet()).Select(rfm => new ClusteringData
        {
            R = rfm.R,
            F = rfm.F,
            M = rfm.M
        }).ToList();
    
        var client = new RestClient(_mlServerUrl);
        var request = new RestRequest(_predictUrl, Method.POST);
        request.AddJsonBody(rfmList);
        var response = client.Executet<Listt<intt>t>(request);
    
        var predictions = response.Data;
        return validContacts.Select((t, i) => new PredictionResult {Email = t.GetContactEmail(), Cluster = predictions[i]}).ToList();
    }
    
    
  4. ConsumeEvaluationResultsAsync of Evaluation worker will be executed. Here we will populate the Cluster value of RFM faces for the corresponding batch of contacts:

    
    protected override async Task ConsumeEvaluationResultsAsync(IReadOnlyListt<Contactt> entities, IReadOnlyList<оbject> evaluationResults, CancellationToken token)
    {
        var contactIdentifiers = entities
            .SelectMany(x =>
                x.Identifiers.Where(s => s.Source == XConnectService.IdentificationSourceEmail));
    
        var predictionResults = evaluationResults.ToPredictionResults();
    
        using (IServiceScope scope = _serviceProvider.CreateScope())
        {
            using (var xdbContext = scope.ServiceProvider.GetRequiredServicet<IXdbContextt>())
            {
                foreach (var identifier in contactIdentifiers)
                {
                    var reference = new IdentifiedContactReference(identifier.Source, identifier.Identifier);
                    var contact = await xdbContext.GetContactAsync(reference, new ContactExpandOptions(
                        PersonalInformation.DefaultFacetKey,
                        EmailAddressList.DefaultFacetKey,
                        ContactBehaviorProfile.DefaultFacetKey,
                        RfmContactFacet.DefaultFacetKey
                    ));
                    if (contact != null)
                    {
                        var rfmFacet = contact.GetFacett<RfmContactFacett>(RfmContactFacet.DefaultFacetKey) ?? new RfmContactFacet();
                        rfmFacet.Cluster = predictionResults.First(x => x.Email.Equals(identifier.Identifier)).Cluster;
                        xdbContext.SetFacet(contact, RfmContactFacet.DefaultFacetKey, rfmFacet);
    
                        _logger.LogInformation(string.Format("RFM info: email={0}, R={1}, F={2}, M={3}, Recency={4}, Frequency={5}, Monetary={6}, CLUSTER={7}",
                            identifier.Identifier, rfmFacet.R, rfmFacet.F, rfmFacet.M, rfmFacet.Recency, rfmFacet.Frequency, rfmFacet.Monetary, rfmFacet.Cluster));
                    }
                }
                await xdbContext.SubmitAsync(token);
            }
        }
    }
    
    

That’s all. We have a functional solution.

Full source code of workers and models:

PurchaseInteractionModel.cs


public class PurchaseInteractionModel : BaseWorker, IModel<Interaction>
{
    private readonly IMLNetService _mlNetService;
    private readonly ITableStoreFactory _tableStoreFactory;
    public PurchaseInteractionModel(IReadOnlyDictionary<string, string> options, IMLNetService mlNetService,  ITableStoreFactory tableStoreFactory) : base (tableStoreFactory)
    {

        _tableStoreFactory = tableStoreFactory;

        this.Projection = Sitecore.Processing.Engine.Projection.Projection.Of<Interaction>()
            .CreateTabular("PurchaseOutcome",
                interaction => interaction.Events.OfType<PurchaseOutcome>(),
                cfg => cfg.Key("ID", x => x.Id)
                    .Attribute("InvoiceId", x => x.InvoiceId)
                    .Attribute("Quantity", x => x.Quantity)
                    .Attribute("Timestamp", x => x.Timestamp)
                    .Attribute("UnitPrice", x => x.UnitPrice)
                    .Attribute("CustomerId", x => x.CustomerId)
                    .Attribute("ProductId", x => x.ProductId)
            );

        this._mlNetService = mlNetService;

    }

    public async Task<ModelStatistics> TrainAsync(string schemaName, CancellationToken cancellationToken, params TableDefinition[] tables)
    {
        var tableStore = _tableStoreFactory.Create(schemaName);
        var data = await GetDataRowsAsync(tableStore, tables.First().Name, cancellationToken);
        
        return _mlNetService.Train(data);
    }

    public Task<IReadOnlyList<оbject>> EvaluateAsync(string schemaName, CancellationToken cancellationToken, params TableDefinition[] tables)
    {
        throw new NotImplementedException();
    }

    public IProjection<Interaction> Projection { get; set; }

}

RfmTrainingWorker.cs


public class RfmTrainingWorker : IDeferredWorker
{
    private readonly IModel<Interaction> _model;
    private readonly RfmTrainingWorkerOptionsDictionary _options;
    private readonly ITableStore _tableStore;
    private readonly ILogger<RfmTrainingWorker> _logger;
    private readonly IServiceProvider _serviceProvider;

    public RfmTrainingWorker(
        ITableStoreFactory tableStoreFactory,
        IServiceProvider provider,
        ILogger<RfmTrainingWorker> logger,
        AllowedModelsDictionary modelsDictionary,
        RfmTrainingWorkerOptionsDictionary options,
        IServiceProvider serviceProvider)
    {

        this._tableStore = tableStoreFactory.Create(options.SchemaName);
        this._options = options;
        this._logger = logger;
        this._model = modelsDictionary.CreateModel<Interaction>(provider, options.ModelType, options.ModelOptions);
        this._serviceProvider = serviceProvider;
    }

    public RfmTrainingWorker(
        ITableStoreFactory tableStoreFactory,
        IServiceProvider provider,
        ILogger<RfmTrainingWorker> logger,
        AllowedModelsDictionary modelsDictionary,
        IReadOnlyDictionary<string, string> options,
        IServiceProvider serviceProvider)
        : this(tableStoreFactory, provider, logger, modelsDictionary, RfmTrainingWorkerOptionsDictionary.Parse(options), serviceProvider)
    {
    }



    public async Task RunAsync(CancellationToken token)
    {
        _logger.LogInformation("RfmTrainingWorker.RunAsync");
        
        IReadOnlyList<string> tableNames = _options.TableNames;
        List<Task<TableStatistics>> tableStatisticsTasks = new List<Task<TableStatistics>>(tableNames.Count);
        foreach (string tableName in tableNames)
            tableStatisticsTasks.Add(this._tableStore.GetTableStatisticsAsync(tableName, token));
        TableStatistics[] tableStatisticsArray = await Task.WhenAll(tableStatisticsTasks).ConfigureAwait(false);
        List<TableDefinition> tableDefinitionList = new List<TableDefinition>(tableStatisticsTasks.Count);
        for (int index = 0; index < tableStatisticsTasks.Count; ++index)
        {
            TableStatistics result = tableStatisticsTasks[index].Result;
            if (result == null)
                this._logger.LogWarning(string.Format("Statistics data for {0} table could not be retrieved. It will not participate in model training.", (оbject)tableNames[index]));
            else
                tableDefinitionList.Add(result.Definition);
        }
        ModelStatistics modelStatistics = await _model.TrainAsync(_options.SchemaName, token, tableDefinitionList.ToArray()).ConfigureAwait(false);

        await UpdateRfmFacets(modelStatistics as RfmStatistics, token);
    }

    public async Task UpdateRfmFacets(RfmStatistics statistics, CancellationToken token)
    {
        using (IServiceScope scope = _serviceProvider.CreateScope())
        {
            using (var xdbContext = scope.ServiceProvider.GetService<IXdbContext>())
            {
                foreach (var identifier in statistics.Customers)
                {
                    var reference = new IdentifiedContactReference(XConnectService.IdentificationSource, identifier.CustomerId.ToString());
                    var contact = await xdbContext.GetContactAsync(reference, new ContactExpandOptions(
                        PersonalInformation.DefaultFacetKey,
                        EmailAddressList.DefaultFacetKey,
                        ContactBehaviorProfile.DefaultFacetKey,
                        RfmContactFacet.DefaultFacetKey
                    ));
                    if (contact != null)
                    {
                        var rfmFacet = contact.GetFacet<RfmContactFacet>(RfmContactFacet.DefaultFacetKey) ?? new RfmContactFacet();
                        rfmFacet.R = identifier.R;
                        rfmFacet.F = identifier.F;
                        rfmFacet.M = identifier.M;
                        rfmFacet.Recency = identifier.Recency;
                        rfmFacet.Frequency = identifier.Frequency;
                        rfmFacet.Monetary = (double)identifier.Monetary;
                        xdbContext.SetFacet(contact, RfmContactFacet.DefaultFacetKey, rfmFacet);

                        _logger.LogInformation(string.Format("Update RFM info: customerId={0}, R={1}, F={2}, M={3}, Recency={4}, Frequency={5}, Monetary={6}",
                            identifier.CustomerId, rfmFacet.R, rfmFacet.F, rfmFacet.M, rfmFacet.Recency, rfmFacet.Frequency, rfmFacet.Monetary));

                        await xdbContext.SubmitAsync(token);
                    }
                }

              
            }
        }
    }

    public void Dispose()
    {
        Dispose(true);
        GC.SuppressFinalize(this);
    }

    protected virtual void Dispose(bool dispose)
    {
    }
}

ContactModel.cs


public class ContactModel : BaseWorker, IModel<Contact>
{
    private readonly IMLNetService _mlNetService;
    private readonly ITableStoreFactory _tableStoreFactory;

    public ContactModel(IReadOnlyDictionary<string, string> options, IMLNetService mlNetService, ITableStoreFactory tableStoreFactory) : base(tableStoreFactory)
    {
        _mlNetService = mlNetService;
        _tableStoreFactory = tableStoreFactory;

        Projection = Sitecore.Processing.Engine.Projection.Projection.Of<Contact>().CreateTabular(
            "ContactModel",
            cfg => cfg
                .Key("ContactId", c => c.Id)
                .Attribute("Enabled", c => c.GetFacet<RfmContactFacet>()==null ? 0 : 1)
                .Attribute("R", c => c.GetFacet<RfmContactFacet>()==null ? 0 : c.GetFacet<RfmContactFacet>().R)
                .Attribute("F", c => c.GetFacet<RfmContactFacet>() == null ? 0 : c.GetFacet<RfmContactFacet>().F)
                .Attribute("M", c => c.GetFacet<RfmContactFacet>() == null ? 0 : c.GetFacet<RfmContactFacet>().M)
                .Attribute("Recency", c => c.GetFacet<RfmContactFacet>() == null ? 0 : c.GetFacet<RfmContactFacet>().Recency)
                .Attribute("Frequency", c => c.GetFacet<RfmContactFacet>() == null ? 0 : c.GetFacet<RfmContactFacet>().Frequency)
                .Attribute("Monetary", c => c.GetFacet<RfmContactFacet>() == null ? 0 : c.GetFacet<RfmContactFacet>().Monetary)
                .Attribute("Email", c => c.Emails()?.PreferredEmail?.SmtpAddress, nullable: true));
    }

    public Task<ModelStatistics> TrainAsync(string schemaName, CancellationToken cancellationToken, params TableDefinition[] tables)
    {
        throw new NotImplementedException();
    }

    public async Task<IReadOnlyList<оbject>> EvaluateAsync(string schemaName, CancellationToken cancellationToken, params TableDefinition[] tables)
    {
        var tableStore = _tableStoreFactory.Create(schemaName);
        var data = await GetDataRowsAsync(tableStore, tables.First().Name, cancellationToken);
        return _mlNetService.Evaluate(data);
    }

    public IProjection<Contact> Projection { get; }



}

RfmEvaluationWorker.cs


public class RfmEvaluationWorker : EvaluationWorker<Contact>
{
    private readonly ILogger<RfmEvaluationWorker> _logger;
    private readonly IServiceProvider _serviceProvider;
    public RfmEvaluationWorker(IModelEvaluator evaluator, IReadOnlyDictionary<string, string> options, ILogger<RfmEvaluationWorker> logger, IServiceProvider serviceProvider) : base(evaluator, options)
    {
        _logger = logger;
        _serviceProvider = serviceProvider;
    }

    public RfmEvaluationWorker(IModelEvaluator evaluator, EvaluationWorkerOptionsDictionary options, ILogger<RfmEvaluationWorker> logger, IServiceProvider serviceProvider) : base(evaluator, options)
    {
        _logger = logger;
        _serviceProvider = serviceProvider;
    }


    // Update Cluster for Contact
    protected override async Task ConsumeEvaluationResultsAsync(IReadOnlyList<Contact> entities, IReadOnlyList<оbject> evaluationResults, CancellationToken token)
    {
        var contactIdentifiers = entities
            .SelectMany(x =>
                x.Identifiers.Where(s => s.Source == XConnectService.IdentificationSourceEmail));

        var predictionResults = evaluationResults.ToPredictionResults();

        using (IServiceScope scope = _serviceProvider.CreateScope())
        {
            using (var xdbContext = scope.ServiceProvider.GetRequiredService<IXdbContext>())
            {
                foreach (var identifier in contactIdentifiers)
                {
                    var reference = new IdentifiedContactReference(identifier.Source, identifier.Identifier);
                    var contact = await xdbContext.GetContactAsync(reference, new ContactExpandOptions(
                        PersonalInformation.DefaultFacetKey,
                        EmailAddressList.DefaultFacetKey,
                        ContactBehaviorProfile.DefaultFacetKey,
                        RfmContactFacet.DefaultFacetKey
                    ));
                    if (contact != null)
                    {
                        var rfmFacet = contact.GetFacet<RfmContactFacet>(RfmContactFacet.DefaultFacetKey) ?? new RfmContactFacet();
                        rfmFacet.Cluster = predictionResults.First(x => x.Email.Equals(identifier.Identifier)).Cluster;
                        xdbContext.SetFacet(contact, RfmContactFacet.DefaultFacetKey, rfmFacet);

                        _logger.LogInformation(string.Format("RFM info: email={0}, R={1}, F={2}, M={3}, Recency={4}, Frequency={5}, Monetary={6}, CLUSTER={7}",
                            identifier.Identifier, rfmFacet.R, rfmFacet.F, rfmFacet.M, rfmFacet.Recency, rfmFacet.Frequency, rfmFacet.Monetary, rfmFacet.Cluster));

                    }
                }

                await xdbContext.SubmitAsync(token);
            }
        }
    }
}

MLNetService.cs


public class MLNetService : IMLNetService
    {

        private readonly string _trainUrl;
        private readonly string _predictUrl;
        private readonly string _mlServerUrl;

        public MLNetService(IConfiguration configuration)
        {
            _mlServerUrl = configuration.GetValue<string>("MLServerUrl");
            _trainUrl =  configuration.GetValue<string>("TrainUrl");
            _predictUrl = configuration.GetValue<string>("PredictUrl");
        }

        public ModelStatistics Train(IReadOnlyList<IDataRow> data)
        {
            var customersData = CustomerMapper.MapToCustomers(data);

            var rfmCalculateService = new RfmCalculateService();
            var calculatedScores = rfmCalculateService.CalculateRfmScores(customersData);

            var businessData = calculatedScores.Select(x => new Rfm
            {
                R = x.R,
                F = x.F,
                M = x.M
            }).ToList();

            var client = new RestClient(_mlServerUrl);
            var request = new RestRequest(_trainUrl, Method.POST);
            request.AddJsonBody(businessData);
            var response = client.Execute<bool>(request);
            var ok = response.Data;
            if (!ok)
            {
                throw new Exception("something is wrong with ML engine, check it");
            }

            return new RfmStatistics{ Customers = calculatedScores };
        }

        public IReadOnlyList<PredictionResult> Evaluate(IReadOnlyList<IDataRow> data)
        {
            var validContacts = data.Where(x => x.Enabled() && !string.IsNullOrEmpty(x.GetContactEmail())).ToList();
            var rfmList = validContacts.Select(x => x.MapToRfmFacet()).Select(rfm => new ClusteringData
            {
                R = rfm.R,
                F = rfm.F,
                M = rfm.M
            }).ToList();

            var client = new RestClient(_mlServerUrl);
            var request = new RestRequest(_predictUrl, Method.POST);
            request.AddJsonBody(rfmList);
            var response = client.Execute<List<int>>(request);

            var predictions = response.Data;
            return validContacts.Select((t, i) => new PredictionResult {Email = t.GetContactEmail(), Cluster = predictions[i]}).ToList();
        }
    }

    public interface IMLNetService
    {
        ModelStatistics Train(IReadOnlyList<IDataRow> data);
        IReadOnlyList<PredictionResult> Evaluate(IReadOnlyList<IDataRow> data);
    }

Table of contents Dive into Sitecore Cortex and Machine Learning - Introduction

Final Part 7 - Configure customers segmentation, live demo


Do you need help with your Sitecore project?
VIEW SITECORE SERVICES