@@ -14,6 +14,7 @@ import (
1414 "time"
1515
1616 "github.com/feast-dev/feast/go/internal/feast/model"
17+ "golang.org/x/sync/errgroup"
1718
1819 gocqltrace "github.com/DataDog/dd-trace-go/contrib/gocql/gocql/v2"
1920 "github.com/feast-dev/feast/go/internal/feast/registry"
@@ -391,16 +392,36 @@ func (c *CassandraOnlineStore) validateUniqueFeatureNames(featureViewNames []str
391392 return nil
392393}
393394
394- func (c * CassandraOnlineStore ) OnlineRead (ctx context.Context , entityKeys []* types.EntityKey , featureViewNames []string , featureNames []string ) ([][]FeatureData , error ) {
395- if err := c .validateUniqueFeatureNames (featureViewNames ); err != nil {
396- return nil , err
395+ func (c * CassandraOnlineStore ) createBatches (serializedEntityKeys []any ) [][]any {
396+ nKeys := len (serializedEntityKeys )
397+ batchSize := c .KeyBatchSize
398+ nBatches := int (math .Ceil (float64 (nKeys ) / float64 (batchSize )))
399+ batches := make ([][]any , nBatches )
400+
401+ nAssigned := 0
402+ for i := 0 ; i < nBatches ; i ++ {
403+ thisBatchSize := int (math .Min (float64 (batchSize ), float64 (nKeys - nAssigned )))
404+ nAssigned += thisBatchSize
405+ batches [i ] = serializedEntityKeys [i * batchSize : (i )* batchSize + thisBatchSize ]
397406 }
398407
399- serializedEntityKeys , serializedEntityKeyToIndex , err := c .buildCassandraEntityKeys (entityKeys )
408+ return batches
409+ }
400410
411+ type BatchJob struct {
412+ ViewName string
413+ TableName string
414+ FeatureNames []string
415+ EntityKeys []any
416+ CQLStatement string
417+ }
418+
419+ func (c * CassandraOnlineStore ) OnlineRead (ctx context.Context , entityKeys []* types.EntityKey , featureViewNames []string , featureNames []string ) ([][]FeatureData , error ) {
420+ serializedEntityKeys , serializedEntityKeyToIndex , err := c .buildCassandraEntityKeys (entityKeys )
401421 if err != nil {
402422 return nil , fmt .Errorf ("error when serializing entity keys for Cassandra: %v" , err )
403423 }
424+
404425 results := make ([][]FeatureData , len (entityKeys ))
405426 for i := range results {
406427 results [i ] = make ([]FeatureData , len (featureNames ))
@@ -411,136 +432,154 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ
411432 featureNamesToIdx [name ] = idx
412433 }
413434
414- featureViewName := featureViewNames [0 ]
415-
416- // Prepare the query
417- tableName , err := c .getFqTableName (c .clusterConfigs .Keyspace , c .project , featureViewName , c .tableNameFormatVersion )
418- if err != nil {
419- return nil , err
435+ viewGroups := make (map [string ][]string )
436+ for i , viewName := range featureViewNames {
437+ viewGroups [viewName ] = append (viewGroups [viewName ], featureNames [i ])
420438 }
421439
422- // Key batching
423- nKeys := len (serializedEntityKeys )
424- batchSize := c .KeyBatchSize
425- nBatches := int (math .Ceil (float64 (nKeys ) / float64 (batchSize )))
426- batches := make ([][]any , nBatches )
427- nAssigned := 0
428- for i := 0 ; i < nBatches ; i ++ {
429- thisBatchSize := int (math .Min (float64 (batchSize ), float64 (nKeys - nAssigned )))
430- nAssigned += thisBatchSize
431- batches [i ] = make ([]any , thisBatchSize )
432- for j := 0 ; j < thisBatchSize ; j ++ {
433- batches [i ][j ] = serializedEntityKeys [i * batchSize + j ]
434- }
435- }
440+ g , ctx := errgroup .WithContext (ctx )
441+ jobsChan := make (chan BatchJob )
436442
437- var waitGroup sync.WaitGroup
438- waitGroup .Add (nBatches )
443+ batches := c .createBatches (serializedEntityKeys )
439444
440- errorsChannel := make (chan error , nBatches )
441- var currentBatchLength int
442- var prevBatchLength int
443- var cqlStatement string
444- for _ , batch := range batches {
445- currentBatchLength = len (batch )
446- if currentBatchLength != prevBatchLength {
447- cqlStatement = c .getMultiKeyCQLStatement (tableName , featureNames , currentBatchLength )
448- prevBatchLength = currentBatchLength
449- }
450- go func (keyBatch []any , statement string ) {
451- defer waitGroup .Done ()
452- iter := c .session .Query (statement , keyBatch ... ).WithContext (ctx ).Iter ()
453-
454- scanner := iter .Scanner ()
455- var entityKey string
456- var featureName string
457- var eventTs time.Time
458- var valueStr []byte
459- var deserializedValue * types.Value
460- // key 1: entityKey - key 2: featureName
461- batchFeatures := make (map [string ]map [string ]* FeatureData )
462- for scanner .Next () {
463- err := scanner .Scan (& entityKey , & featureName , & eventTs , & valueStr )
464- if err != nil {
465- errorsChannel <- errors .New ("could not read row in query for (entity key, feature name, value, event ts)" )
466- return
445+ g .Go (func () error {
446+ defer close (jobsChan )
447+
448+ for viewName , currentFeatureNames := range viewGroups {
449+ tableName , err := c .getFqTableName (c .clusterConfigs .Keyspace , c .project , viewName , c .tableNameFormatVersion )
450+ if err != nil {
451+ return err
452+ }
453+
454+ var prevBatchLength int
455+ var cqlStatement string
456+
457+ for i , batch := range batches {
458+ var cqlForBatch string
459+ if i == 0 || len (batch ) != prevBatchLength {
460+ cqlForBatch = c .getMultiKeyCQLStatement (tableName , currentFeatureNames , len (batch ))
461+ prevBatchLength = len (batch )
462+ cqlStatement = cqlForBatch
463+ } else {
464+ cqlForBatch = cqlStatement
467465 }
468- deserializedValue , _ , err = UnmarshalStoredProto (valueStr )
469- if err != nil {
470- errorsChannel <- err
471- return
466+
467+ job := BatchJob {
468+ ViewName : viewName ,
469+ TableName : tableName ,
470+ FeatureNames : currentFeatureNames ,
471+ EntityKeys : batch ,
472+ CQLStatement : cqlForBatch ,
472473 }
473474
474- if deserializedValue .Val != nil {
475- if batchFeatures [entityKey ] == nil {
476- batchFeatures [entityKey ] = make (map [string ]* FeatureData )
477- }
478- batchFeatures [entityKey ][featureName ] = & FeatureData {
479- Reference : serving.FeatureReferenceV2 {
480- FeatureViewName : featureViewName ,
481- FeatureName : featureName ,
482- },
483- Timestamp : timestamppb.Timestamp {Seconds : eventTs .Unix (), Nanos : int32 (eventTs .Nanosecond ())},
484- Value : types.Value {
485- Val : deserializedValue .Val ,
486- },
487- }
475+ select {
476+ case jobsChan <- job :
477+ // Job sent successfully
478+ case <- ctx .Done ():
479+ return ctx .Err ()
488480 }
489481 }
482+ }
483+ return nil
484+ })
490485
491- if err := scanner .Err (); err != nil {
492- errorsChannel <- errors .New ("failed to scan features: " + err .Error ())
493- return
486+ for job := range jobsChan {
487+ g .Go (func (j BatchJob ) func () error {
488+ return func () error {
489+ return c .executeBatch (ctx , j , serializedEntityKeyToIndex , results , featureNamesToIdx )
494490 }
491+ }(job ))
492+ }
495493
496- for _ , serializedEntityKey := range keyBatch {
497- for _ , featName := range featureNames {
498- keyString := serializedEntityKey .(string )
499-
500- if featureData , exists := batchFeatures [keyString ][featName ]; exists {
501- results [serializedEntityKeyToIndex [keyString ]][featureNamesToIdx [featName ]] = FeatureData {
502- Reference : serving.FeatureReferenceV2 {
503- FeatureViewName : featureData .Reference .FeatureViewName ,
504- FeatureName : featureData .Reference .FeatureName ,
505- },
506- Timestamp : timestamppb.Timestamp {Seconds : featureData .Timestamp .Seconds , Nanos : featureData .Timestamp .Nanos },
507- Value : types.Value {
508- Val : featureData .Value .Val ,
509- },
510- }
511- } else {
512- // TODO: return not found status to differentiate between nulls and not found features
513- results [serializedEntityKeyToIndex [keyString ]][featureNamesToIdx [featName ]] = FeatureData {
514- Reference : serving.FeatureReferenceV2 {
515- FeatureViewName : featureViewName ,
516- FeatureName : featName ,
517- },
518- Value : types.Value {
519- Val : & types.Value_NullVal {
520- NullVal : types .Null_NULL ,
521- },
522- },
523- }
524- }
525- }
526- }
527- }(batch , cqlStatement )
494+ if err := g .Wait (); err != nil {
495+ return nil , err
528496 }
529- // wait until all concurrent single-key queries are done
530- waitGroup .Wait ()
531- close (errorsChannel )
532497
533- var collectedErrors []error
534- for err := range errorsChannel {
498+ return results , nil
499+ }
500+
501+ func (c * CassandraOnlineStore ) executeBatch (
502+ ctx context.Context ,
503+ job BatchJob ,
504+ serializedEntityKeyToIndex map [string ]int ,
505+ results [][]FeatureData ,
506+ featureNamesToIdx map [string ]int ,
507+ ) error {
508+ iter := c .session .Query (job .CQLStatement , job .EntityKeys ... ).WithContext (ctx ).Iter ()
509+ defer iter .Close ()
510+
511+ scanner := iter .Scanner ()
512+ var entityKey string
513+ var featureName string
514+ var eventTs time.Time
515+ var valueStr []byte
516+ var deserializedValue * types.Value
517+
518+ batchFeatures := make (map [string ]map [string ]* FeatureData )
519+ for scanner .Next () {
520+ err := scanner .Scan (& entityKey , & featureName , & eventTs , & valueStr )
521+ if err != nil {
522+ return fmt .Errorf ("could not read row in query for (entity key, feature name, value, event ts): %w" , err )
523+ }
524+ deserializedValue , _ , err = UnmarshalStoredProto (valueStr )
535525 if err != nil {
536- collectedErrors = append (collectedErrors , err )
526+ return fmt .Errorf ("error unmarshaling stored proto: %w" , err )
527+ }
528+
529+ if deserializedValue .Val != nil {
530+ if batchFeatures [entityKey ] == nil {
531+ batchFeatures [entityKey ] = make (map [string ]* FeatureData )
532+ }
533+ batchFeatures [entityKey ][featureName ] = & FeatureData {
534+ Reference : serving.FeatureReferenceV2 {
535+ FeatureViewName : job .ViewName ,
536+ FeatureName : featureName ,
537+ },
538+ Timestamp : timestamppb.Timestamp {Seconds : eventTs .Unix (), Nanos : int32 (eventTs .Nanosecond ())},
539+ Value : types.Value {
540+ Val : deserializedValue .Val ,
541+ },
542+ }
537543 }
538544 }
539- if len (collectedErrors ) > 0 {
540- return nil , errors .Join (collectedErrors ... )
545+
546+ if err := scanner .Err (); err != nil {
547+ return fmt .Errorf ("failed to scan features: %w" , err )
541548 }
542549
543- return results , nil
550+ for _ , serializedEntityKey := range job .EntityKeys {
551+ for _ , featName := range job .FeatureNames {
552+ keyString := serializedEntityKey .(string )
553+
554+ if featureData , exists := batchFeatures [keyString ][featName ]; exists {
555+ results [serializedEntityKeyToIndex [keyString ]][featureNamesToIdx [featName ]] = FeatureData {
556+ Reference : serving.FeatureReferenceV2 {
557+ FeatureViewName : featureData .Reference .FeatureViewName ,
558+ FeatureName : featureData .Reference .FeatureName ,
559+ },
560+ Timestamp : timestamppb.Timestamp {Seconds : featureData .Timestamp .Seconds , Nanos : featureData .Timestamp .Nanos },
561+ Value : types.Value {
562+ Val : featureData .Value .Val ,
563+ },
564+ }
565+ } else {
566+ // TODO: return not found status to differentiate between nulls and not found features
567+ results [serializedEntityKeyToIndex [keyString ]][featureNamesToIdx [featName ]] = FeatureData {
568+ Reference : serving.FeatureReferenceV2 {
569+ FeatureViewName : job .ViewName ,
570+ FeatureName : featName ,
571+ },
572+ Value : types.Value {
573+ Val : & types.Value_NullVal {
574+ NullVal : types .Null_NULL ,
575+ },
576+ },
577+ }
578+ }
579+ }
580+ }
581+
582+ return nil
544583}
545584
546585func (c * CassandraOnlineStore ) rangeFilterToCQL (filter * model.SortKeyFilter ) (string , []interface {}) {
0 commit comments