-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathremote_job.py
More file actions
237 lines (201 loc) · 7.09 KB
/
remote_job.py
File metadata and controls
237 lines (201 loc) · 7.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import time
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
from feast.core.JobService_pb2 import CancelJobRequest, GetJobRequest
from feast.core.JobService_pb2 import Job as JobProto
from feast.core.JobService_pb2 import JobStatus, JobType
from feast.core.JobService_pb2_grpc import JobServiceStub
from feast_spark.pyspark.abc import (
BatchIngestionJob,
RetrievalJob,
SparkJob,
SparkJobFailure,
SparkJobStatus,
StreamIngestionJob,
)
GrpcExtraParamProvider = Callable[[], Dict[str, Any]]
class RemoteJobMixin:
def __init__(
self,
service: JobServiceStub,
grpc_extra_param_provider: GrpcExtraParamProvider,
job_id: str,
start_time: datetime,
log_uri: Optional[str],
):
"""
Args:
service: Job service GRPC stub
job_id: job reference
"""
self._job_id = job_id
self._service = service
self._grpc_extra_param_provider = grpc_extra_param_provider
self._start_time = start_time
self._log_uri = log_uri
def get_id(self) -> str:
return self._job_id
def get_status(self) -> SparkJobStatus:
response = self._service.GetJob(
GetJobRequest(job_id=self._job_id), **self._grpc_extra_param_provider()
)
if response.job.status == JobStatus.JOB_STATUS_RUNNING:
return SparkJobStatus.IN_PROGRESS
elif response.job.status == JobStatus.JOB_STATUS_PENDING:
return SparkJobStatus.STARTING
elif response.job.status == JobStatus.JOB_STATUS_DONE:
return SparkJobStatus.COMPLETED
elif response.job.status == JobStatus.JOB_STATUS_ERROR:
return SparkJobStatus.FAILED
else:
# we should never get here
raise Exception(f"Invalid remote job state {response.job.status}")
def get_start_time(self) -> datetime:
return self._start_time
def cancel(self):
self._service.CancelJob(
CancelJobRequest(job_id=self._job_id), **self._grpc_extra_param_provider()
)
def _wait_for_job_status(
self, goal_status: List[SparkJobStatus], timeout_seconds=90
) -> SparkJobStatus:
start_time = time.time()
while time.time() < (start_time + timeout_seconds):
status = self.get_status()
if status in goal_status:
return status
else:
time.sleep(1.0)
else:
raise TimeoutError("Timed out waiting for job status")
def get_log_uri(self) -> Optional[str]:
return self._log_uri
def get_error_message(self) -> str:
job = self._service.GetJob(
GetJobRequest(job_id=self._job_id), **self._grpc_extra_param_provider()
).job
return job.error_message
def wait_termination(self, timeout_sec=None):
status = self._wait_for_job_status(
goal_status=[SparkJobStatus.COMPLETED, SparkJobStatus.FAILED],
timeout_seconds=timeout_sec or 600,
)
if status != SparkJobStatus.COMPLETED:
raise SparkJobFailure(
f"Spark job failed; Reason: {self.get_error_message()}"
)
class RemoteRetrievalJob(RemoteJobMixin, RetrievalJob):
"""
Historical feature retrieval job result, job being run remotely bt the job service
"""
def __init__(
self,
service: JobServiceStub,
grpc_extra_param_provider: GrpcExtraParamProvider,
job_id: str,
output_file_uri: str,
start_time: datetime,
log_uri: Optional[str],
):
"""
This is the job object representing the historical retrieval job.
Args:
output_file_uri (str): Uri to the historical feature retrieval job output file.
"""
super().__init__(
service, grpc_extra_param_provider, job_id, start_time, log_uri
)
self._output_file_uri = output_file_uri
def get_output_file_uri(self, timeout_sec=None, block=True):
if not block and self.get_status() != SparkJobStatus.COMPLETED:
return
self.wait_termination(timeout_sec)
return self._output_file_uri
class RemoteBatchIngestionJob(RemoteJobMixin, BatchIngestionJob):
"""
Batch ingestion job result.
"""
def __init__(
self,
service: JobServiceStub,
grpc_extra_param_provider: GrpcExtraParamProvider,
job_id: str,
feature_table: str,
start_time: datetime,
log_uri: Optional[str],
):
super().__init__(
service, grpc_extra_param_provider, job_id, start_time, log_uri
)
self._feature_table = feature_table
def get_feature_table(self) -> str:
return self._feature_table
class RemoteStreamIngestionJob(RemoteJobMixin, StreamIngestionJob):
"""
Stream ingestion job result.
"""
def __init__(
self,
service: JobServiceStub,
grpc_extra_param_provider: GrpcExtraParamProvider,
job_id: str,
feature_table: str,
start_time: datetime,
log_uri: Optional[str],
):
super().__init__(
service, grpc_extra_param_provider, job_id, start_time, log_uri
)
self._feature_table = feature_table
def get_hash(self) -> str:
response = self._service.GetJob(
GetJobRequest(job_id=self._job_id), **self._grpc_extra_param_provider()
)
return response.job.hash
def get_feature_table(self) -> str:
return self._feature_table
def get_remote_job_from_proto(
service: JobServiceStub,
grpc_extra_param_provider: GrpcExtraParamProvider,
job: JobProto,
) -> SparkJob:
"""Get the remote job python object from Job proto.
Args:
service (JobServiceStub): Reference to Job Service
grpc_extra_param_provider (GrpcExtraParamProvider): Callable for providing extra parameters to grpc requests
job (JobProto): Proto object describing the Job
Returns:
(SparkJob): A remote job object for the given job
"""
if job.type == JobType.RETRIEVAL_JOB:
return RemoteRetrievalJob(
service,
grpc_extra_param_provider,
job.id,
job.retrieval.output_location,
job.start_time.ToDatetime(),
job.log_uri,
)
elif job.type == JobType.BATCH_INGESTION_JOB:
return RemoteBatchIngestionJob(
service,
grpc_extra_param_provider,
job.id,
job.batch_ingestion.table_name,
job.start_time.ToDatetime(),
job.log_uri,
)
elif job.type == JobType.STREAM_INGESTION_JOB:
return RemoteStreamIngestionJob(
service,
grpc_extra_param_provider,
job.id,
job.stream_ingestion.table_name,
job.start_time.ToDatetime(),
job.log_uri,
)
else:
raise ValueError(
f"Invalid Job Type {job.type}, has to be one of "
f"{(JobType.RETRIEVAL_JOB, JobType.BATCH_INGESTION_JOB, JobType.STREAM_INGESTION_JOB)}"
)