Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions index/scorch/train_vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ import (
)

type trainRequest struct {
finalSample bool
sampleSize int
ackCh chan error
sample segment.Segment
finalSample bool
sampleSize int
ackCh chan error
sample segment.Segment
trainingParams *index.TrainingParams
}

type vectorTrainer struct {
Expand Down Expand Up @@ -154,6 +155,10 @@ func (t *vectorTrainer) trainLoop() {
return
case trainReq := <-t.trainCh:
sampleSeg := trainReq.sample
if trainReq.trainingParams != nil {
t.config[index.TrainingKey] = trainReq.trainingParams
}

// no sample segment: just persist state if this is the final sample and move on.
if sampleSeg == nil {
if trainReq.finalSample {
Expand All @@ -180,10 +185,8 @@ func (t *vectorTrainer) trainLoop() {
// merge the new segment with the existing one into a .tmp file, then
// atomically rename it into place (Os.Open on the live path is unsafe
// during the merge).
t.config[index.TrainingKey] = true
_, _, err := t.parent.segPlugin.MergeUsing([]segment.Segment{t.trainedIndex.segment, sampleSeg},
[]*roaring.Bitmap{nil, nil}, path+".tmp", t.parent.closeCh, nil, t.config)
t.config[index.TrainingKey] = false
if err != nil {
trainReq.ackCh <- fmt.Errorf("error merging trained index: %v", err)
close(trainReq.ackCh)
Expand All @@ -207,7 +210,7 @@ func (t *vectorTrainer) trainLoop() {
return
}

trainedIndex, err := t.parent.segPlugin.OpenUsing(path, t.parent.segmentConfig)
trainedIndex, err := t.parent.segPlugin.OpenUsing(path, t.config)
if err != nil {
trainReq.ackCh <- fmt.Errorf("error opening trained index: %v", err)
close(trainReq.ackCh)
Expand Down Expand Up @@ -303,6 +306,23 @@ func (t *vectorTrainer) train(batch *index.Batch) error {
sampleSize: len(trainData),
ackCh: make(chan error),
}
// setting the training params using the internal value before the actual
// training has started
config := t.config
if atomic.LoadUint64(&t.trainedSamples) == 0 {
trainingParamsBytes := batch.InternalOps[index.TrainingKey]
var trainingParams index.TrainingParams
if trainingParamsBytes != nil {
err = util.UnmarshalJSON(trainingParamsBytes, &trainingParams)
if err != nil {
return fmt.Errorf("error parsing training params: %v", err)
}
trainReq.trainingParams = &trainingParams
config = maps.Clone(t.config)
config[index.TrainingKey] = &trainingParams
}
}

// just builds a new vector index out of the train data provided
// this is not necessarily the final train data since this is submitted
// as a request to the trainer component to be merged. once the training
Expand All @@ -312,7 +332,7 @@ func (t *vectorTrainer) train(batch *index.Batch) error {
// note: this might index text data too, how to handle this? s.segmentConfig?
// todo: updates/deletes -> data drift detection
if len(trainData) > 0 {
trainReq.sample, _, err = t.parent.segPlugin.NewUsing(trainData, t.parent.segmentConfig)
trainReq.sample, _, err = t.parent.segPlugin.NewUsing(trainData, config)
if err != nil {
return err
}
Expand Down
Loading