diff --git a/index/scorch/train_vector.go b/index/scorch/train_vector.go index 6a58fd06c..60db17883 100644 --- a/index/scorch/train_vector.go +++ b/index/scorch/train_vector.go @@ -37,10 +37,11 @@ import ( ) type trainRequest struct { - finalSample bool - sampleSize int - ackCh chan error - sample segment.Segment + finalSample bool + sampleSize int + ackCh chan error + sample segment.Segment + trainingParams *index.TrainingParams } type vectorTrainer struct { @@ -154,6 +155,10 @@ func (t *vectorTrainer) trainLoop() { return case trainReq := <-t.trainCh: sampleSeg := trainReq.sample + if trainReq.trainingParams != nil { + t.config[index.TrainingKey] = trainReq.trainingParams + } + // no sample segment: just persist state if this is the final sample and move on. if sampleSeg == nil { if trainReq.finalSample { @@ -180,10 +185,8 @@ func (t *vectorTrainer) trainLoop() { // merge the new segment with the existing one into a .tmp file, then // atomically rename it into place (Os.Open on the live path is unsafe // during the merge). - t.config[index.TrainingKey] = true _, _, err := t.parent.segPlugin.MergeUsing([]segment.Segment{t.trainedIndex.segment, sampleSeg}, []*roaring.Bitmap{nil, nil}, path+".tmp", t.parent.closeCh, nil, t.config) - t.config[index.TrainingKey] = false if err != nil { trainReq.ackCh <- fmt.Errorf("error merging trained index: %v", err) close(trainReq.ackCh) @@ -207,7 +210,7 @@ func (t *vectorTrainer) trainLoop() { return } - trainedIndex, err := t.parent.segPlugin.OpenUsing(path, t.parent.segmentConfig) + trainedIndex, err := t.parent.segPlugin.OpenUsing(path, t.config) if err != nil { trainReq.ackCh <- fmt.Errorf("error opening trained index: %v", err) close(trainReq.ackCh) @@ -303,6 +306,23 @@ func (t *vectorTrainer) train(batch *index.Batch) error { sampleSize: len(trainData), ackCh: make(chan error), } + // setting the training params using the internal value before the actual + // training has started + config := t.config + if atomic.LoadUint64(&t.trainedSamples) == 0 { + trainingParamsBytes := batch.InternalOps[index.TrainingKey] + var trainingParams index.TrainingParams + if trainingParamsBytes != nil { + err = util.UnmarshalJSON(trainingParamsBytes, &trainingParams) + if err != nil { + return fmt.Errorf("error parsing training params: %v", err) + } + trainReq.trainingParams = &trainingParams + config = maps.Clone(t.config) + config[index.TrainingKey] = &trainingParams + } + } + // just builds a new vector index out of the train data provided // this is not necessarily the final train data since this is submitted // as a request to the trainer component to be merged. once the training @@ -312,7 +332,7 @@ func (t *vectorTrainer) train(batch *index.Batch) error { // note: this might index text data too, how to handle this? s.segmentConfig? // todo: updates/deletes -> data drift detection if len(trainData) > 0 { - trainReq.sample, _, err = t.parent.segPlugin.NewUsing(trainData, t.parent.segmentConfig) + trainReq.sample, _, err = t.parent.segPlugin.NewUsing(trainData, config) if err != nil { return err }