diff --git a/sdk/python/feast/infra/contrib/spark_kafka_processor.py b/sdk/python/feast/infra/contrib/spark_kafka_processor.py index d98366c1a4c..a76cb52e8aa 100644 --- a/sdk/python/feast/infra/contrib/spark_kafka_processor.py +++ b/sdk/python/feast/infra/contrib/spark_kafka_processor.py @@ -1,6 +1,17 @@ import time +from datetime import datetime from types import MethodType -from typing import List, Optional, Set, Union, no_type_check +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Union, + no_type_check, +) import pandas as pd import pyarrow @@ -10,8 +21,15 @@ from pyspark.sql.column import Column, _to_java_column from pyspark.sql.functions import col, from_json from pyspark.sql.streaming import StreamingQuery +from pyspark.sql.types import ( + BinaryType, + StringType, + StructField, + StructType, + TimestampType, +) -from feast import FeatureView +from feast import FeatureView, RepoConfig from feast.data_format import AvroFormat, ConfluentAvroFormat, JsonFormat, StreamFormat from feast.data_source import KafkaSource, PushMode from feast.feature_store import FeatureStore @@ -20,10 +38,16 @@ StreamProcessor, StreamTable, ) +from feast.infra.key_encoding_utils import serialize_entity_key from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( _SparkSerializedArtifacts, ) +from feast.infra.online_stores.contrib.cassandra_online_store.cassandra_online_store import ( + CassandraOnlineStore, +) from feast.infra.provider import get_provider +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.stream_feature_view import StreamFeatureView from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping @@ -272,6 +296,83 @@ def _write_stream_data_expedia(self, df: StreamTable, to: PushMode): # TODO: Support writing to offline store and preprocess_fn. Remove _write_stream_data method # Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema. + def online_write_with_connector( + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[ + EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime] + ] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + """ + Write a batch of features of several entities to the database using Spark Cassandra Connector. + + Args: + config: The RepoConfig for the current FeatureStore. + table: Feast FeatureView. + data: a list of quadruplets containing Feature data. Each + quadruplet contains an Entity Key, a dict containing feature + values, an event timestamp for the row, and + the created timestamp for the row if it exists. + progress: Optional function to be called once every mini-batch of + rows is written to the online store. Can be used to + display progress. + """ + keyspace = config.online_store.keyspace + + fqtable = CassandraOnlineStore._fq_table_name( + keyspace, config.project, table + ) + cassandra_keyspace = keyspace + cassandra_table = fqtable + + def create_spark_dataframe(): + """ + Convert the data into a Spark DataFrame. + """ + rows = [] + for entity_key, values, timestamp, created_ts in data: + entity_key_bin = serialize_entity_key( + entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ).hex() + for feature_name, val in values.items(): + rows.append( + ( + feature_name, + val.SerializeToString(), + entity_key_bin, + timestamp, + created_ts, + ) + ) + + schema = StructType( + [ + StructField("feature_name", StringType(), False), + StructField("feature_value", BinaryType(), False), + StructField("entity_key", StringType(), False), + StructField("event_timestamp", TimestampType(), False), + StructField("created_timestamp", TimestampType(), True), + ] + ) + + return self.spark.createDataFrame(rows, schema) + + # Create a DataFrame from the input data + df = create_spark_dataframe() + + # Write DataFrame to Cassandra + df.write.format("org.apache.spark.sql.cassandra").options( + keyspace=cassandra_keyspace, table=cassandra_table + ).mode("append").save() + + # Call progress function if provided + if progress: + progress(len(data)) + def batch_write_pandas_df(iterator, spark_serialized_artifacts, join_keys): for pdf in iterator: ( @@ -305,7 +406,7 @@ def batch_write_pandas_df(iterator, spark_serialized_artifacts, join_keys): rows_to_write = _convert_arrow_to_proto( table, feature_view, join_key_to_value_type ) - online_store.online_write_batch( + online_write_with_connector( repo_config, feature_view, rows_to_write,