Skip to content

Commit 1d8e13b

Browse files
authored
feat: Add grouping over featureViews in scylla/parallelize online-read calls. (#284)
1 parent 66f7f75 commit 1d8e13b

3 files changed

Lines changed: 232 additions & 141 deletions

File tree

go/internal/feast/featurestore.go

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ package feast
33
import (
44
"context"
55
"fmt"
6-
"github.com/feast-dev/feast/go/internal/feast/errors"
7-
"github.com/feast-dev/feast/go/types"
86
"os"
97
"strings"
8+
9+
"github.com/feast-dev/feast/go/internal/feast/errors"
10+
"github.com/feast-dev/feast/go/types"
1011

12+
"golang.org/x/sync/errgroup"
1113
"github.com/DataDog/dd-trace-go/v2/ddtrace/tracer"
1214
"github.com/apache/arrow/go/v17/arrow/memory"
1315

@@ -237,7 +239,6 @@ func (fs *FeatureStore) GetOnlineFeatures(
237239
return nil, err
238240
}
239241

240-
result := make([]*onlineserving.FeatureVector, 0)
241242
arrowMemory := memory.NewGoAllocator()
242243
featureViews := make([]*model.FeatureView, len(requestedFeatureViews))
243244
index := 0
@@ -254,22 +255,41 @@ func (fs *FeatureStore) GetOnlineFeatures(
254255
return nil, err
255256
}
256257

258+
resultChan := make(chan []*onlineserving.FeatureVector, len(groupedRefs))
259+
g, ctx := errgroup.WithContext(ctx) // Can consider adding 'setLimit' and a variable to limit the max number of concurrent reads to prevent thundering herd.
257260
for _, groupRef := range groupedRefs {
258-
featureData, err := fs.readFromOnlineStore(ctx, groupRef.EntityKeys, groupRef.FeatureViewNames, groupRef.FeatureNames)
259-
if err != nil {
260-
return nil, errors.GrpcFromError(err)
261-
}
261+
g.Go(func(grpRef *onlineserving.GroupedFeaturesPerEntitySet) func() error {
262+
return func() error {
263+
featureData, err := fs.readFromOnlineStore(ctx, grpRef.EntityKeys, grpRef.FeatureViewNames, grpRef.FeatureNames)
264+
if err != nil {
265+
return err
266+
}
262267

263-
vectors, err := onlineserving.TransposeFeatureRowsIntoColumns(
264-
featureData,
265-
groupRef,
266-
requestedFeatureViews,
267-
arrowMemory,
268-
numRows,
269-
)
270-
if err != nil {
271-
return nil, err
272-
}
268+
vectors, err := onlineserving.TransposeFeatureRowsIntoColumns(
269+
featureData,
270+
grpRef,
271+
requestedFeatureViews,
272+
arrowMemory,
273+
numRows,
274+
)
275+
if err != nil {
276+
return err
277+
}
278+
279+
resultChan <- vectors
280+
return nil
281+
}
282+
}(groupRef))
283+
}
284+
285+
if err := g.Wait(); err != nil {
286+
return nil, err
287+
}
288+
close(resultChan)
289+
290+
// Flatten channel into 1D
291+
var result []*onlineserving.FeatureVector
292+
for vectors := range resultChan {
273293
result = append(result, vectors...)
274294
}
275295

go/internal/feast/onlinestore/cassandraonlinestore.go

Lines changed: 154 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"time"
1515

1616
"github.com/feast-dev/feast/go/internal/feast/model"
17+
"golang.org/x/sync/errgroup"
1718

1819
gocqltrace "github.com/DataDog/dd-trace-go/contrib/gocql/gocql/v2"
1920
"github.com/feast-dev/feast/go/internal/feast/registry"
@@ -391,16 +392,36 @@ func (c *CassandraOnlineStore) validateUniqueFeatureNames(featureViewNames []str
391392
return nil
392393
}
393394

394-
func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {
395-
if err := c.validateUniqueFeatureNames(featureViewNames); err != nil {
396-
return nil, err
395+
func (c *CassandraOnlineStore) createBatches(serializedEntityKeys []any) [][]any {
396+
nKeys := len(serializedEntityKeys)
397+
batchSize := c.KeyBatchSize
398+
nBatches := int(math.Ceil(float64(nKeys) / float64(batchSize)))
399+
batches := make([][]any, nBatches)
400+
401+
nAssigned := 0
402+
for i := 0; i < nBatches; i++ {
403+
thisBatchSize := int(math.Min(float64(batchSize), float64(nKeys-nAssigned)))
404+
nAssigned += thisBatchSize
405+
batches[i] = serializedEntityKeys[i*batchSize : (i)*batchSize+thisBatchSize]
397406
}
398407

399-
serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys)
408+
return batches
409+
}
400410

411+
type BatchJob struct {
412+
ViewName string
413+
TableName string
414+
FeatureNames []string
415+
EntityKeys []any
416+
CQLStatement string
417+
}
418+
419+
func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {
420+
serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys)
401421
if err != nil {
402422
return nil, fmt.Errorf("error when serializing entity keys for Cassandra: %v", err)
403423
}
424+
404425
results := make([][]FeatureData, len(entityKeys))
405426
for i := range results {
406427
results[i] = make([]FeatureData, len(featureNames))
@@ -411,136 +432,154 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ
411432
featureNamesToIdx[name] = idx
412433
}
413434

414-
featureViewName := featureViewNames[0]
415-
416-
// Prepare the query
417-
tableName, err := c.getFqTableName(c.clusterConfigs.Keyspace, c.project, featureViewName, c.tableNameFormatVersion)
418-
if err != nil {
419-
return nil, err
435+
viewGroups := make(map[string][]string)
436+
for i, viewName := range featureViewNames {
437+
viewGroups[viewName] = append(viewGroups[viewName], featureNames[i])
420438
}
421439

422-
// Key batching
423-
nKeys := len(serializedEntityKeys)
424-
batchSize := c.KeyBatchSize
425-
nBatches := int(math.Ceil(float64(nKeys) / float64(batchSize)))
426-
batches := make([][]any, nBatches)
427-
nAssigned := 0
428-
for i := 0; i < nBatches; i++ {
429-
thisBatchSize := int(math.Min(float64(batchSize), float64(nKeys-nAssigned)))
430-
nAssigned += thisBatchSize
431-
batches[i] = make([]any, thisBatchSize)
432-
for j := 0; j < thisBatchSize; j++ {
433-
batches[i][j] = serializedEntityKeys[i*batchSize+j]
434-
}
435-
}
440+
g, ctx := errgroup.WithContext(ctx)
441+
jobsChan := make(chan BatchJob)
436442

437-
var waitGroup sync.WaitGroup
438-
waitGroup.Add(nBatches)
443+
batches := c.createBatches(serializedEntityKeys)
439444

440-
errorsChannel := make(chan error, nBatches)
441-
var currentBatchLength int
442-
var prevBatchLength int
443-
var cqlStatement string
444-
for _, batch := range batches {
445-
currentBatchLength = len(batch)
446-
if currentBatchLength != prevBatchLength {
447-
cqlStatement = c.getMultiKeyCQLStatement(tableName, featureNames, currentBatchLength)
448-
prevBatchLength = currentBatchLength
449-
}
450-
go func(keyBatch []any, statement string) {
451-
defer waitGroup.Done()
452-
iter := c.session.Query(statement, keyBatch...).WithContext(ctx).Iter()
453-
454-
scanner := iter.Scanner()
455-
var entityKey string
456-
var featureName string
457-
var eventTs time.Time
458-
var valueStr []byte
459-
var deserializedValue *types.Value
460-
// key 1: entityKey - key 2: featureName
461-
batchFeatures := make(map[string]map[string]*FeatureData)
462-
for scanner.Next() {
463-
err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr)
464-
if err != nil {
465-
errorsChannel <- errors.New("could not read row in query for (entity key, feature name, value, event ts)")
466-
return
445+
g.Go(func() error {
446+
defer close(jobsChan)
447+
448+
for viewName, currentFeatureNames := range viewGroups {
449+
tableName, err := c.getFqTableName(c.clusterConfigs.Keyspace, c.project, viewName, c.tableNameFormatVersion)
450+
if err != nil {
451+
return err
452+
}
453+
454+
var prevBatchLength int
455+
var cqlStatement string
456+
457+
for i, batch := range batches {
458+
var cqlForBatch string
459+
if i == 0 || len(batch) != prevBatchLength {
460+
cqlForBatch = c.getMultiKeyCQLStatement(tableName, currentFeatureNames, len(batch))
461+
prevBatchLength = len(batch)
462+
cqlStatement = cqlForBatch
463+
} else {
464+
cqlForBatch = cqlStatement
467465
}
468-
deserializedValue, _, err = UnmarshalStoredProto(valueStr)
469-
if err != nil {
470-
errorsChannel <- err
471-
return
466+
467+
job := BatchJob{
468+
ViewName: viewName,
469+
TableName: tableName,
470+
FeatureNames: currentFeatureNames,
471+
EntityKeys: batch,
472+
CQLStatement: cqlForBatch,
472473
}
473474

474-
if deserializedValue.Val != nil {
475-
if batchFeatures[entityKey] == nil {
476-
batchFeatures[entityKey] = make(map[string]*FeatureData)
477-
}
478-
batchFeatures[entityKey][featureName] = &FeatureData{
479-
Reference: serving.FeatureReferenceV2{
480-
FeatureViewName: featureViewName,
481-
FeatureName: featureName,
482-
},
483-
Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())},
484-
Value: types.Value{
485-
Val: deserializedValue.Val,
486-
},
487-
}
475+
select {
476+
case jobsChan <- job:
477+
// Job sent successfully
478+
case <-ctx.Done():
479+
return ctx.Err()
488480
}
489481
}
482+
}
483+
return nil
484+
})
490485

491-
if err := scanner.Err(); err != nil {
492-
errorsChannel <- errors.New("failed to scan features: " + err.Error())
493-
return
486+
for job := range jobsChan {
487+
g.Go(func(j BatchJob) func() error {
488+
return func() error {
489+
return c.executeBatch(ctx, j, serializedEntityKeyToIndex, results, featureNamesToIdx)
494490
}
491+
}(job))
492+
}
495493

496-
for _, serializedEntityKey := range keyBatch {
497-
for _, featName := range featureNames {
498-
keyString := serializedEntityKey.(string)
499-
500-
if featureData, exists := batchFeatures[keyString][featName]; exists {
501-
results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = FeatureData{
502-
Reference: serving.FeatureReferenceV2{
503-
FeatureViewName: featureData.Reference.FeatureViewName,
504-
FeatureName: featureData.Reference.FeatureName,
505-
},
506-
Timestamp: timestamppb.Timestamp{Seconds: featureData.Timestamp.Seconds, Nanos: featureData.Timestamp.Nanos},
507-
Value: types.Value{
508-
Val: featureData.Value.Val,
509-
},
510-
}
511-
} else {
512-
// TODO: return not found status to differentiate between nulls and not found features
513-
results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = FeatureData{
514-
Reference: serving.FeatureReferenceV2{
515-
FeatureViewName: featureViewName,
516-
FeatureName: featName,
517-
},
518-
Value: types.Value{
519-
Val: &types.Value_NullVal{
520-
NullVal: types.Null_NULL,
521-
},
522-
},
523-
}
524-
}
525-
}
526-
}
527-
}(batch, cqlStatement)
494+
if err := g.Wait(); err != nil {
495+
return nil, err
528496
}
529-
// wait until all concurrent single-key queries are done
530-
waitGroup.Wait()
531-
close(errorsChannel)
532497

533-
var collectedErrors []error
534-
for err := range errorsChannel {
498+
return results, nil
499+
}
500+
501+
func (c *CassandraOnlineStore) executeBatch(
502+
ctx context.Context,
503+
job BatchJob,
504+
serializedEntityKeyToIndex map[string]int,
505+
results [][]FeatureData,
506+
featureNamesToIdx map[string]int,
507+
) error {
508+
iter := c.session.Query(job.CQLStatement, job.EntityKeys...).WithContext(ctx).Iter()
509+
defer iter.Close()
510+
511+
scanner := iter.Scanner()
512+
var entityKey string
513+
var featureName string
514+
var eventTs time.Time
515+
var valueStr []byte
516+
var deserializedValue *types.Value
517+
518+
batchFeatures := make(map[string]map[string]*FeatureData)
519+
for scanner.Next() {
520+
err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr)
521+
if err != nil {
522+
return fmt.Errorf("could not read row in query for (entity key, feature name, value, event ts): %w", err)
523+
}
524+
deserializedValue, _, err = UnmarshalStoredProto(valueStr)
535525
if err != nil {
536-
collectedErrors = append(collectedErrors, err)
526+
return fmt.Errorf("error unmarshaling stored proto: %w", err)
527+
}
528+
529+
if deserializedValue.Val != nil {
530+
if batchFeatures[entityKey] == nil {
531+
batchFeatures[entityKey] = make(map[string]*FeatureData)
532+
}
533+
batchFeatures[entityKey][featureName] = &FeatureData{
534+
Reference: serving.FeatureReferenceV2{
535+
FeatureViewName: job.ViewName,
536+
FeatureName: featureName,
537+
},
538+
Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())},
539+
Value: types.Value{
540+
Val: deserializedValue.Val,
541+
},
542+
}
537543
}
538544
}
539-
if len(collectedErrors) > 0 {
540-
return nil, errors.Join(collectedErrors...)
545+
546+
if err := scanner.Err(); err != nil {
547+
return fmt.Errorf("failed to scan features: %w", err)
541548
}
542549

543-
return results, nil
550+
for _, serializedEntityKey := range job.EntityKeys {
551+
for _, featName := range job.FeatureNames {
552+
keyString := serializedEntityKey.(string)
553+
554+
if featureData, exists := batchFeatures[keyString][featName]; exists {
555+
results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = FeatureData{
556+
Reference: serving.FeatureReferenceV2{
557+
FeatureViewName: featureData.Reference.FeatureViewName,
558+
FeatureName: featureData.Reference.FeatureName,
559+
},
560+
Timestamp: timestamppb.Timestamp{Seconds: featureData.Timestamp.Seconds, Nanos: featureData.Timestamp.Nanos},
561+
Value: types.Value{
562+
Val: featureData.Value.Val,
563+
},
564+
}
565+
} else {
566+
// TODO: return not found status to differentiate between nulls and not found features
567+
results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = FeatureData{
568+
Reference: serving.FeatureReferenceV2{
569+
FeatureViewName: job.ViewName,
570+
FeatureName: featName,
571+
},
572+
Value: types.Value{
573+
Val: &types.Value_NullVal{
574+
NullVal: types.Null_NULL,
575+
},
576+
},
577+
}
578+
}
579+
}
580+
}
581+
582+
return nil
544583
}
545584

546585
func (c *CassandraOnlineStore) rangeFilterToCQL(filter *model.SortKeyFilter) (string, []interface{}) {

0 commit comments

Comments
 (0)