Coverage Report

Created: 2024-12-20 00:05

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-worker/src/local_worker.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::pin::Pin;
16
use std::process::Stdio;
17
use std::str;
18
use std::sync::atomic::{AtomicU64, Ordering};
19
use std::sync::{Arc, Weak};
20
use std::time::Duration;
21
22
use futures::future::BoxFuture;
23
use futures::stream::FuturesUnordered;
24
use futures::{select, Future, FutureExt, StreamExt, TryFutureExt};
25
use nativelink_config::cas_server::LocalWorkerConfig;
26
use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt};
27
use nativelink_metric::{MetricsComponent, RootMetricsComponent};
28
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update;
29
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_client::WorkerApiClient;
30
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
31
    execute_result, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForWorker,
32
};
33
use nativelink_store::fast_slow_store::FastSlowStore;
34
use nativelink_util::action_messages::{ActionResult, ActionStage, OperationId};
35
use nativelink_util::common::fs;
36
use nativelink_util::digest_hasher::{DigestHasherFunc, ACTIVE_HASHER_FUNC};
37
use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime};
38
use nativelink_util::origin_context::ActiveOriginContext;
39
use nativelink_util::shutdown_guard::ShutdownGuard;
40
use nativelink_util::store_trait::Store;
41
use nativelink_util::{spawn, tls_utils};
42
use tokio::process;
43
use tokio::sync::{broadcast, mpsc};
44
use tokio::time::sleep;
45
use tokio_stream::wrappers::UnboundedReceiverStream;
46
use tonic::Streaming;
47
use tracing::{event, info_span, instrument, Level};
48
49
use crate::running_actions_manager::{
50
    ExecutionConfiguration, Metrics as RunningActionManagerMetrics, RunningAction,
51
    RunningActionsManager, RunningActionsManagerArgs, RunningActionsManagerImpl,
52
};
53
use crate::worker_api_client_wrapper::{WorkerApiClientTrait, WorkerApiClientWrapper};
54
use crate::worker_utils::make_supported_properties;
55
56
/// Amount of time to wait if we have actions in transit before we try to
57
/// consider an error to have occurred.
58
const ACTIONS_IN_TRANSIT_TIMEOUT_S: f32 = 10.;
59
60
/// If we lose connection to the worker api server we will wait this many seconds
61
/// before trying to connect.
62
const CONNECTION_RETRY_DELAY_S: f32 = 0.5;
63
64
/// Default endpoint timeout. If this value gets modified the documentation in
65
/// `cas_server.rs` must also be updated.
66
const DEFAULT_ENDPOINT_TIMEOUT_S: f32 = 5.;
67
68
/// Default maximum amount of time a task is allowed to run for.
69
/// If this value gets modified the documentation in `cas_server.rs` must also be updated.
70
const DEFAULT_MAX_ACTION_TIMEOUT: Duration = Duration::from_secs(1200); // 20 mins.
71
72
struct LocalWorkerImpl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> {
73
    config: &'a LocalWorkerConfig,
74
    // According to the tonic documentation it is a cheap operation to clone this.
75
    grpc_client: T,
76
    worker_id: String,
77
    running_actions_manager: Arc<U>,
78
    // Number of actions that have been received in `Update::StartAction`, but
79
    // not yet processed by running_actions_manager's spawn. This number should
80
    // always be zero if there are no actions running and no actions being waited
81
    // on by the scheduler.
82
    actions_in_transit: Arc<AtomicU64>,
83
    metrics: Arc<Metrics>,
84
}
85
86
4
async fn preconditions_met(precondition_script: Option<String>) -> Result<(), Error> {
87
4
    let Some(
precondition_script1
) = &precondition_script else {
  Branch (87:9): [True: 1, False: 3]
  Branch (87:9): [Folded - Ignored]
  Branch (87:9): [Folded - Ignored]
88
        // No script means we are always ok to proceed.
89
3
        return Ok(());
90
    };
91
    // TODO: Might want to pass some information about the command to the
92
    //       script, but at this point it's not even been downloaded yet,
93
    //       so that's not currently possible.  Perhaps we'll move this in
94
    //       future to pass useful information through?  Or perhaps we'll
95
    //       have a pre-condition and a pre-execute script instead, although
96
    //       arguably entrypoint already gives us that.
97
1
    let precondition_process = process::Command::new(precondition_script)
98
1
        .kill_on_drop(true)
99
1
        .stdin(Stdio::null())
100
1
        .stdout(Stdio::piped())
101
1
        .stderr(Stdio::null())
102
1
        .env_clear()
103
1
        .spawn()
104
1
        .err_tip(|| 
format!("Could not execute precondition command {precondition_script:?}")0
)
?0
;
105
1
    let output = precondition_process.wait_with_output().await
?0
;
106
1
    if output.status.code() == Some(0) {
  Branch (106:8): [True: 0, False: 1]
  Branch (106:8): [Folded - Ignored]
  Branch (106:8): [Folded - Ignored]
107
0
        Ok(())
108
    } else {
109
1
        Err(make_err!(
110
1
            Code::ResourceExhausted,
111
1
            "Preconditions script returned status {} - {}",
112
1
            output.status,
113
1
            str::from_utf8(&output.stdout).unwrap_or("")
114
1
        ))
115
    }
116
4
}
117
118
impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, T, U> {
119
5
    fn new(
120
5
        config: &'a LocalWorkerConfig,
121
5
        grpc_client: T,
122
5
        worker_id: String,
123
5
        running_actions_manager: Arc<U>,
124
5
        metrics: Arc<Metrics>,
125
5
    ) -> Self {
126
5
        Self {
127
5
            config,
128
5
            grpc_client,
129
5
            worker_id,
130
5
            running_actions_manager,
131
5
            // Number of actions that have been received in `Update::StartAction`, but
132
5
            // not yet processed by running_actions_manager's spawn. This number should
133
5
            // always be zero if there are no actions running and no actions being waited
134
5
            // on by the scheduler.
135
5
            actions_in_transit: Arc::new(AtomicU64::new(0)),
136
5
            metrics,
137
5
        }
138
5
    }
139
140
    /// Starts a background spawn/thread that will send a message to the server every `timeout / 2`.
141
5
    async fn start_keep_alive(&self) -> Result<(), Error> {
142
4
        // According to tonic's documentation this call should be cheap and is the same stream.
143
4
        let mut grpc_client = self.grpc_client.clone();
144
145
        loop {
146
4
            let timeout = self
147
4
                .config
148
4
                .worker_api_endpoint
149
4
                .timeout
150
4
                .unwrap_or(DEFAULT_ENDPOINT_TIMEOUT_S);
151
4
            // We always send 2 keep alive requests per timeout. Http2 should manage most of our
152
4
            // timeout issues, this is a secondary check to ensure we can still send data.
153
4
            sleep(Duration::from_secs_f32(timeout / 2.)).await;
154
0
            if let Err(e) = grpc_client
  Branch (154:20): [True: 0, False: 0]
  Branch (154:20): [Folded - Ignored]
  Branch (154:20): [Folded - Ignored]
155
0
                .keep_alive(KeepAliveRequest {
156
0
                    worker_id: self.worker_id.clone(),
157
0
                })
158
0
                .await
159
            {
160
0
                return Err(make_err!(
161
0
                    Code::Internal,
162
0
                    "Failed to send KeepAlive in LocalWorker : {:?}",
163
0
                    e
164
0
                ));
165
0
            }
166
        }
167
0
    }
168
169
5
    async fn run(
170
5
        &mut self,
171
5
        update_for_worker_stream: Streaming<UpdateForWorker>,
172
5
        shutdown_rx: &mut broadcast::Receiver<ShutdownGuard>,
173
5
    ) -> Result<(), Error> {
174
5
        // This big block of logic is designed to help simplify upstream components. Upstream
175
5
        // components can write standard futures that return a `Result<(), Error>` and this block
176
5
        // will forward the error up to the client and disconnect from the scheduler.
177
5
        // It is a common use case that an item sent through update_for_worker_stream will always
178
5
        // have a response but the response will be triggered through a callback to the scheduler.
179
5
        // This can be quite tricky to manage, so what we have done here is given access to a
180
5
        // `futures` variable which because this is in a single thread as well as a channel that you
181
5
        // send a future into that makes it into the `futures` variable.
182
5
        // This means that if you want to perform an action based on the result of the future
183
5
        // you use the `.map()` method and the new action will always come to live in this spawn,
184
5
        // giving mutable access to stuff in this struct.
185
5
        // NOTE: If you ever return from this function it will disconnect from the scheduler.
186
5
        let mut futures = FuturesUnordered::new();
187
5
        futures.push(self.start_keep_alive().boxed());
188
5
189
5
        let (add_future_channel, add_future_rx) = mpsc::unbounded_channel();
190
5
        let mut add_future_rx = UnboundedReceiverStream::new(add_future_rx).fuse();
191
5
192
5
        let mut update_for_worker_stream = update_for_worker_stream.fuse();
193
194
        loop {
195
16
            select! {
196
16
                
maybe_update6
= update_for_worker_stream.next() => {
197
6
                    match maybe_update
198
6
                        .err_tip(|| 
"UpdateForWorker stream closed early"1
)
?1
199
5
                        .err_tip(|| 
"Got error in UpdateForWorker stream"0
)
?0
200
                        .update
201
5
                        .err_tip(|| 
"Expected update to exist in UpdateForWorker"0
)
?0
202
                    {
203
                        Update::ConnectionResult(_) => {
204
0
                            return Err(make_input_err!(
205
0
                                "Got ConnectionResult in LocalWorker::run which should never happen"
206
0
                            ));
207
                        }
208
                        // TODO(allada) We should possibly do something with this notification.
209
0
                        Update::Disconnect(()) => {
210
0
                            self.metrics.disconnects_received.inc();
211
0
                        }
212
0
                        Update::KeepAlive(()) => {
213
0
                            self.metrics.keep_alives_received.inc();
214
0
                        }
215
1
                        Update::KillOperationRequest(kill_operation_request) => {
216
1
                            let operation_id = OperationId::from(kill_operation_request.operation_id);
217
1
                            if let Err(
err0
) = self.running_actions_manager.kill_operation(&operation_id).await {
  Branch (217:36): [True: 0, False: 1]
  Branch (217:36): [Folded - Ignored]
  Branch (217:36): [Folded - Ignored]
218
0
                                event!(
219
0
                                    Level::ERROR,
220
                                    ?operation_id,
221
                                    ?err,
222
0
                                    "Failed to send kill request for operation"
223
                                );
224
1
                            };
225
                        }
226
4
                        Update::StartAction(start_execute) => {
227
4
                            self.metrics.start_actions_received.inc();
228
4
229
4
                            let execute_request = start_execute.execute_request.as_ref();
230
4
                            let operation_id = start_execute.operation_id.clone();
231
4
                            let maybe_instance_name = execute_request.map(|v| v.instance_name.clone());
232
4
                            let action_digest = execute_request.and_then(|v| v.action_digest.clone());
233
4
                            let digest_hasher = execute_request
234
4
                                .ok_or(make_input_err!("Expected execute_request to be set"))
235
4
                                .and_then(|v| DigestHasherFunc::try_from(v.digest_function))
236
4
                                .err_tip(|| 
"In LocalWorkerImpl::new()"0
)
?0
;
237
238
4
                            let start_action_fut = {
239
4
                                let precondition_script_cfg = self.config.experimental_precondition_script.clone();
240
4
                                let actions_in_transit = self.actions_in_transit.clone();
241
4
                                let worker_id = self.worker_id.clone();
242
4
                                let running_actions_manager = self.running_actions_manager.clone();
243
4
                                self.metrics.clone().wrap(move |metrics| async move {
244
4
                                    metrics.preconditions.wrap(preconditions_met(precondition_script_cfg))
245
4
                                    .and_then(|()| 
running_actions_manager.create_and_add_action(worker_id, start_execute)3
)
246
4
                                    .map(move |r| {
247
4
                                        // Now that we either failed or registered our action, we can
248
4
                                        // consider the action to no longer be in transit.
249
4
                                        actions_in_transit.fetch_sub(1, Ordering::Release);
250
4
                                        r
251
4
                                    })
252
4
                                    .and_then(|action| {
253
3
                                        event!(
254
3
                                            Level::INFO,
255
0
                                            operation_id = ?action.get_operation_id(),
256
0
                                            "Received request to run action"
257
                                        );
258
3
                                        action
259
3
                                            .clone()
260
3
                                            .prepare_action()
261
3
                                            .and_then(RunningAction::execute)
262
3
                                            .and_then(RunningAction::upload_results)
263
3
                                            .and_then(RunningAction::get_finished_result)
264
3
                                            // Note: We need ensure we run cleanup even if one of the other steps fail.
265
3
                                            .then(|result| async move 
{2
266
2
                                                if let Err(
e0
) = action.cleanup().await {
  Branch (266:56): [True: 0, False: 2]
  Branch (266:56): [Folded - Ignored]
  Branch (266:56): [Folded - Ignored]
267
0
                                                    return Result::<ActionResult, Error>::Err(e).merge(result);
268
2
                                                }
269
2
                                                result
270
4
                                            })
271
4
                                    }).await
272
7
                                })
273
4
                            };
274
4
275
4
                            let make_publish_future = {
276
4
                                let mut grpc_client = self.grpc_client.clone();
277
4
278
4
                                let worker_id = self.worker_id.clone();
279
4
                                let running_actions_manager = self.running_actions_manager.clone();
280
3
                                move |res: Result<ActionResult, Error>| async move {
281
3
                                    let instance_name = maybe_instance_name
282
3
                                        .err_tip(|| 
"`instance_name` could not be resolved; this is likely an internal error in local_worker."0
)
?0
;
283
3
                                    match res {
284
2
                                        Ok(mut action_result) => {
285
                                            // Save in the action cache before notifying the scheduler that we've completed.
286
2
                                            if let Some(digest_info) = action_digest.clone().and_then(|action_digest| action_digest.try_into().ok()) {
  Branch (286:52): [True: 2, False: 0]
  Branch (286:52): [Folded - Ignored]
  Branch (286:52): [Folded - Ignored]
287
2
                                                if let Err(
err0
) = running_actions_manager.cache_action_result(digest_info, &mut action_result, digest_hasher).await {
  Branch (287:56): [True: 0, False: 2]
  Branch (287:56): [Folded - Ignored]
  Branch (287:56): [Folded - Ignored]
288
0
                                                    event!(
289
0
                                                        Level::ERROR,
290
                                                        ?err,
291
                                                        ?action_digest,
292
0
                                                        "Error saving action in store",
293
                                                    );
294
2
                                                }
295
0
                                            }
296
2
                                            let action_stage = ActionStage::Completed(action_result);
297
2
                                            grpc_client.execution_response(
298
2
                                                ExecuteResult{
299
2
                                                    worker_id,
300
2
                                                    instance_name,
301
2
                                                    operation_id,
302
2
                                                    result: Some(execute_result::Result::ExecuteResponse(action_stage.into())),
303
2
                                                }
304
2
                                            )
305
2
                                            .await
306
0
                                            .err_tip(|| "Error while calling execution_response")?;
307
                                        },
308
1
                                        Err(e) => {
309
1
                                            grpc_client.execution_response(ExecuteResult{
310
1
                                                worker_id,
311
1
                                                instance_name,
312
1
                                                operation_id,
313
1
                                                result: Some(execute_result::Result::InternalError(e.into())),
314
1
                                            }).await.
err_tip(0
||
"Error calling execution_response with error"0
)0
?0
;
315
                                        },
316
                                    }
317
0
                                    Ok(())
318
3
                                }
319
                            };
320
321
4
                            self.actions_in_transit.fetch_add(1, Ordering::Release);
322
4
                            let futures_ref = &futures;
323
4
324
4
                            let add_future_channel = add_future_channel.clone();
325
4
                            let mut ctx = ActiveOriginContext::fork().err_tip(|| 
"Expected ActiveOriginContext to be set in local_worker::run"0
)
?0
;
326
4
                            ctx.set_value(&ACTIVE_HASHER_FUNC, Arc::new(digest_hasher));
327
4
                            ctx.run(info_span!("worker_start_action_ctx"), move || {
328
4
                                futures_ref.push(
329
4
                                    spawn!("worker_start_action", start_action_fut).map(move |res| 
{3
330
3
                                        let res = res.err_tip(|| 
"Failed to launch spawn"0
)
?0
;
331
3
                                        if let Err(
err1
) = &res {
  Branch (331:48): [True: 1, False: 2]
  Branch (331:48): [Folded - Ignored]
  Branch (331:48): [Folded - Ignored]
332
1
                                            event!(
333
1
                                                Level::ERROR,
334
                                                ?err,
335
1
                                                "Error executing action",
336
                                            );
337
2
                                        }
338
3
                                        add_future_channel
339
3
                                            .send(make_publish_future(res).boxed())
340
3
                                            .map_err(|_| 
make_err!(Code::Internal, "LocalWorker could not send future")0
)
?0
;
341
3
                                        Ok(())
342
4
                                    
}3
)
343
4
                                    .boxed()
344
4
                                );
345
4
                            });
346
4
                        }
347
                    };
348
                },
349
16
                
res3
= add_future_rx.next() => {
350
3
                    let fut = res.err_tip(|| 
"New future stream receives should never be closed"0
)
?0
;
351
3
                    futures.push(fut);
352
                },
353
16
                
res3
= futures.next() =>
res.err_tip(3
||
"Keep-alive should always pending. Likely unable to send data to scheduler"0
)3
?0
?0
,
354
16
                
complete_msg0
= shutdown_rx.recv().fuse() => {
355
0
                    event!(Level::WARN, "Worker loop reveiced shutdown signal. Shutting down worker...",);
356
0
                    let mut grpc_client = self.grpc_client.clone();
357
0
                    let worker_id = self.worker_id.clone();
358
0
                    let running_actions_manager = self.running_actions_manager.clone();
359
0
                    let complete_msg_clone = complete_msg.map_err(|e| make_err!(Code::Internal, "Failed to receive shutdown message: {e:?}"))?.clone();
360
0
                    let shutdown_future = async move {
361
0
                        if let Err(e) = grpc_client.going_away(GoingAwayRequest { worker_id }).await {
  Branch (361:32): [True: 0, False: 0]
  Branch (361:32): [Folded - Ignored]
  Branch (361:32): [Folded - Ignored]
362
0
                            event!(Level::ERROR, "Failed to send GoingAwayRequest: {e}",);
363
0
                            return Err(e.into());
364
0
                        }
365
0
                        running_actions_manager.complete_actions(complete_msg_clone).await;
366
0
                        Ok::<(), Error>(())
367
0
                    };
368
0
                    futures.push(shutdown_future.boxed());
369
0
                },
370
            };
371
        }
372
        // Unreachable.
373
1
    }
374
}
375
376
type ConnectionFactory<T> = Box<dyn Fn() -> BoxFuture<'static, Result<T, Error>> + Send + Sync>;
377
378
pub struct LocalWorker<T: WorkerApiClientTrait, U: RunningActionsManager> {
379
    config: Arc<LocalWorkerConfig>,
380
    running_actions_manager: Arc<U>,
381
    connection_factory: ConnectionFactory<T>,
382
    sleep_fn: Option<Box<dyn Fn(Duration) -> BoxFuture<'static, ()> + Send + Sync>>,
383
    metrics: Arc<Metrics>,
384
}
385
386
/// Creates a new `LocalWorker`. The `cas_store` must be an instance of
387
/// `FastSlowStore` and will be checked at runtime.
388
2
pub async fn new_local_worker(
389
2
    config: Arc<LocalWorkerConfig>,
390
2
    cas_store: Store,
391
2
    ac_store: Option<Store>,
392
2
    historical_store: Store,
393
2
) -> Result<
394
2
    (
395
2
        LocalWorker<WorkerApiClientWrapper, RunningActionsManagerImpl>,
396
2
        Arc<Metrics>,
397
2
    ),
398
2
    Error,
399
2
> {
400
2
    let fast_slow_store = cas_store
401
2
        .downcast_ref::<FastSlowStore>(None)
402
2
        .err_tip(|| 
"Expected store for LocalWorker's store to be a FastSlowStore"0
)
?0
403
2
        .get_arc()
404
2
        .err_tip(|| 
"FastSlowStore's Arc doesn't exist"0
)
?0
;
405
406
2
    if let Ok(
path1
) = fs::canonicalize(&config.work_directory).await {
  Branch (406:12): [True: 1, False: 1]
  Branch (406:12): [Folded - Ignored]
  Branch (406:12): [Folded - Ignored]
407
1
        fs::remove_dir_all(path)
408
1
            .await
409
1
            .err_tip(|| 
"Could not remove work_directory in LocalWorker"0
)
?0
;
410
1
    }
411
412
2
    fs::create_dir_all(&config.work_directory)
413
2
        .await
414
2
        .err_tip(|| 
format!("Could not make work_directory : {}", config.work_directory)0
)
?0
;
415
2
    let entrypoint = if config.entrypoint.is_empty() {
  Branch (415:25): [True: 2, False: 0]
  Branch (415:25): [Folded - Ignored]
  Branch (415:25): [Folded - Ignored]
416
2
        None
417
    } else {
418
0
        Some(config.entrypoint.clone())
419
    };
420
2
    let max_action_timeout = if config.max_action_timeout == 0 {
  Branch (420:33): [True: 2, False: 0]
  Branch (420:33): [Folded - Ignored]
  Branch (420:33): [Folded - Ignored]
421
2
        DEFAULT_MAX_ACTION_TIMEOUT
422
    } else {
423
0
        Duration::from_secs(config.max_action_timeout as u64)
424
    };
425
2
    let running_actions_manager =
426
2
        Arc::new(RunningActionsManagerImpl::new(RunningActionsManagerArgs {
427
2
            root_action_directory: config.work_directory.clone(),
428
2
            execution_configuration: ExecutionConfiguration {
429
2
                entrypoint,
430
2
                additional_environment: config.additional_environment.clone(),
431
2
            },
432
2
            cas_store: fast_slow_store,
433
2
            ac_store,
434
2
            historical_store,
435
2
            upload_action_result_config: &config.upload_action_result,
436
2
            max_action_timeout,
437
2
            timeout_handled_externally: config.timeout_handled_externally,
438
2
        })
?0
);
439
2
    let local_worker = LocalWorker::new_with_connection_factory_and_actions_manager(
440
2
        config.clone(),
441
2
        running_actions_manager,
442
2
        Box::new(move || {
443
0
            let config = config.clone();
444
0
            Box::pin(async move {
445
0
                let timeout = config
446
0
                    .worker_api_endpoint
447
0
                    .timeout
448
0
                    .unwrap_or(DEFAULT_ENDPOINT_TIMEOUT_S);
449
0
                let timeout_duration = Duration::from_secs_f32(timeout);
450
0
                let tls_config =
451
0
                    tls_utils::load_client_config(&config.worker_api_endpoint.tls_config)
452
0
                        .err_tip(|| "Parsing local worker TLS configuration")?;
453
0
                let endpoint =
454
0
                    tls_utils::endpoint_from(&config.worker_api_endpoint.uri, tls_config)
455
0
                        .map_err(|e| make_input_err!("Invalid URI for worker endpoint : {e:?}"))?
456
0
                        .connect_timeout(timeout_duration)
457
0
                        .timeout(timeout_duration);
458
459
0
                let transport = endpoint.connect().await.map_err(|e| {
460
0
                    make_err!(
461
0
                        Code::Internal,
462
0
                        "Could not connect to endpoint {}: {e:?}",
463
0
                        config.worker_api_endpoint.uri
464
0
                    )
465
0
                })?;
466
0
                Ok(WorkerApiClient::new(transport).into())
467
0
            })
468
2
        }),
469
2
        Box::new(move |d| 
Box::pin(sleep(d))0
),
470
2
    );
471
2
    let metrics = local_worker.metrics.clone();
472
2
    Ok((local_worker, metrics))
473
2
}
474
475
impl<T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorker<T, U> {
476
8
    pub fn new_with_connection_factory_and_actions_manager(
477
8
        config: Arc<LocalWorkerConfig>,
478
8
        running_actions_manager: Arc<U>,
479
8
        connection_factory: ConnectionFactory<T>,
480
8
        sleep_fn: Box<dyn Fn(Duration) -> BoxFuture<'static, ()> + Send + Sync>,
481
8
    ) -> Self {
482
8
        let metrics = Arc::new(Metrics::new(Arc::downgrade(
483
8
            running_actions_manager.metrics(),
484
8
        )));
485
8
        Self {
486
8
            config,
487
8
            running_actions_manager,
488
8
            connection_factory,
489
8
            sleep_fn: Some(sleep_fn),
490
8
            metrics,
491
8
        }
492
8
    }
493
494
0
    pub fn name(&self) -> &String {
495
0
        &self.config.name
496
0
    }
497
498
8
    async fn register_worker(
499
8
        &self,
500
8
        client: &mut T,
501
8
    ) -> Result<(String, Streaming<UpdateForWorker>), Error> {
502
8
        let supported_properties =
503
8
            make_supported_properties(&self.config.platform_properties).await
?0
;
504
8
        let 
mut update_for_worker_stream6
= client
505
8
            .connect_worker(supported_properties)
506
8
            .await
507
6
            .err_tip(|| 
"Could not call connect_worker() in worker"0
)
?0
508
6
            .into_inner();
509
510
6
        let 
first_msg_update5
= update_for_worker_stream
511
6
            .next()
512
6
            .await
513
6
            .err_tip(|| 
"Got EOF expected UpdateForWorker"1
)
?1
514
5
            .err_tip(|| 
"Got error when receiving UpdateForWorker"0
)
?0
515
            .update;
516
517
5
        let worker_id = match first_msg_update {
518
5
            Some(Update::ConnectionResult(connection_result)) => connection_result.worker_id,
519
0
            other => {
520
0
                return Err(make_input_err!(
521
0
                    "Expected first response from scheduler to be a ConnectResult got : {:?}",
522
0
                    other
523
0
                ))
524
            }
525
        };
526
5
        Ok((worker_id, update_for_worker_stream))
527
6
    }
528
529
6
    #[instrument(skip(self), level = Level::INFO)]
530
    pub async fn run(
531
        mut self,
532
        mut shutdown_rx: broadcast::Receiver<ShutdownGuard>,
533
    ) -> Result<(), Error> {
534
        let sleep_fn = self
535
            .sleep_fn
536
            .take()
537
0
            .err_tip(|| "Could not unwrap sleep_fn in LocalWorker::run")?;
538
        let sleep_fn_pin = Pin::new(&sleep_fn);
539
2
        let error_handler = Box::pin(move |err| async move {
540
2
            event!(Level::ERROR, ?err, "Error");
541
2
            (sleep_fn_pin)(Duration::from_secs_f32(CONNECTION_RETRY_DELAY_S)).await;
542
4
        });
543
544
        loop {
545
            // First connect to our endpoint.
546
            let mut client = match (self.connection_factory)().await {
547
                Ok(client) => client,
548
                Err(e) => {
549
                    (error_handler)(e).await;
550
                    continue; // Try to connect again.
551
                }
552
            };
553
554
            // Next register our worker with the scheduler.
555
            let (mut inner, update_for_worker_stream) =
556
                match self.register_worker(&mut client).await {
557
                    Err(e) => {
558
                        (error_handler)(e).await;
559
                        continue; // Try to connect again.
560
                    }
561
                    Ok((worker_id, update_for_worker_stream)) => (
562
                        LocalWorkerImpl::new(
563
                            &self.config,
564
                            client,
565
                            worker_id,
566
                            self.running_actions_manager.clone(),
567
                            self.metrics.clone(),
568
                        ),
569
                        update_for_worker_stream,
570
                    ),
571
                };
572
            event!(
573
                Level::WARN,
574
                worker_id = %inner.worker_id,
575
                "Worker registered with scheduler"
576
            );
577
578
            // Now listen for connections and run all other services.
579
            if let Err(err) = inner.run(update_for_worker_stream, &mut shutdown_rx).await {
580
                'no_more_actions: {
581
                    // Ensure there are no actions in transit before we try to kill
582
                    // all our actions.
583
                    const ITERATIONS: usize = 1_000;
584
585
                    const ERROR_MSG: &str = "Actions in transit did not reach zero before we disconnected from the scheduler";
586
587
                    let sleep_duration = ACTIONS_IN_TRANSIT_TIMEOUT_S / ITERATIONS as f32;
588
                    for _ in 0..ITERATIONS {
589
                        if inner.actions_in_transit.load(Ordering::Acquire) == 0 {
590
                            break 'no_more_actions;
591
                        }
592
                        (sleep_fn_pin)(Duration::from_secs_f32(sleep_duration)).await;
593
                    }
594
                    event!(Level::ERROR, ERROR_MSG);
595
                    return Err(err.append(ERROR_MSG));
596
                }
597
                event!(Level::ERROR, ?err, "Worker disconnected from scheduler");
598
                // Kill off any existing actions because if we re-connect, we'll
599
                // get some more and it might resource lock us.
600
                self.running_actions_manager.kill_all().await;
601
602
                (error_handler)(err).await;
603
                continue; // Try to connect again.
604
            }
605
        }
606
        // Unreachable.
607
    }
608
}
609
610
#[derive(MetricsComponent)]
611
pub struct Metrics {
612
    #[metric(
613
        help = "Total number of actions sent to this worker to process. This does not mean it started them, it just means it received a request to execute it."
614
    )]
615
    start_actions_received: CounterWithTime,
616
    #[metric(help = "Total number of disconnects received from the scheduler.")]
617
    disconnects_received: CounterWithTime,
618
    #[metric(help = "Total number of keep-alives received from the scheduler.")]
619
    keep_alives_received: CounterWithTime,
620
    #[metric(
621
        help = "Stats about the calls to check if an action satisfies the config supplied script."
622
    )]
623
    preconditions: AsyncCounterWrapper,
624
    #[metric]
625
    running_actions_manager_metrics: Weak<RunningActionManagerMetrics>,
626
}
627
628
impl RootMetricsComponent for Metrics {}
629
630
impl Metrics {
631
8
    fn new(running_actions_manager_metrics: Weak<RunningActionManagerMetrics>) -> Self {
632
8
        Self {
633
8
            start_actions_received: CounterWithTime::default(),
634
8
            disconnects_received: CounterWithTime::default(),
635
8
            keep_alives_received: CounterWithTime::default(),
636
8
            preconditions: AsyncCounterWrapper::default(),
637
8
            running_actions_manager_metrics,
638
8
        }
639
8
    }
640
}
641
642
impl Metrics {
643
4
    async fn wrap<U, T: Future<Output = U>, F: FnOnce(Arc<Self>) -> T>(
644
4
        self: Arc<Self>,
645
4
        fut: F,
646
4
    ) -> U {
647
4
        fut(self).await
648
3
    }
649
}