Coverage Report

Created: 2024-12-20 00:05

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/awaited_action_db/mod.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::cmp;
16
use std::ops::Bound;
17
use std::sync::Arc;
18
19
pub use awaited_action::{AwaitedAction, AwaitedActionSortKey};
20
use futures::{Future, Stream};
21
use nativelink_error::{make_input_err, Error, ResultExt};
22
use nativelink_metric::MetricsComponent;
23
use nativelink_util::action_messages::{ActionInfo, ActionStage, OperationId};
24
use serde::{Deserialize, Serialize};
25
26
mod awaited_action;
27
28
/// A simple enum to represent the state of an `AwaitedAction`.
29
#[derive(Debug, Clone, Copy)]
30
pub enum SortedAwaitedActionState {
31
    CacheCheck,
32
    Queued,
33
    Executing,
34
    Completed,
35
}
36
37
impl TryFrom<&ActionStage> for SortedAwaitedActionState {
38
    type Error = Error;
39
2
    fn try_from(value: &ActionStage) -> Result<Self, Error> {
40
2
        match value {
41
0
            ActionStage::CacheCheck => Ok(Self::CacheCheck),
42
1
            ActionStage::Executing => Ok(Self::Executing),
43
0
            ActionStage::Completed(_) => Ok(Self::Completed),
44
1
            ActionStage::Queued => Ok(Self::Queued),
45
0
            _ => Err(make_input_err!("Invalid State")),
46
        }
47
2
    }
48
}
49
50
impl TryFrom<ActionStage> for SortedAwaitedActionState {
51
    type Error = Error;
52
0
    fn try_from(value: ActionStage) -> Result<Self, Error> {
53
0
        Self::try_from(&value)
54
0
    }
55
}
56
57
/// A struct pointing to an `AwaitedAction` that can be sorted.
58
#[derive(Debug, Clone, Serialize, Deserialize, MetricsComponent)]
59
pub struct SortedAwaitedAction {
60
    #[metric(help = "The sort key of the AwaitedAction")]
61
    pub sort_key: AwaitedActionSortKey,
62
    #[metric(help = "The operation id")]
63
    pub operation_id: OperationId,
64
}
65
66
impl PartialEq for SortedAwaitedAction {
67
0
    fn eq(&self, other: &Self) -> bool {
68
0
        self.sort_key == other.sort_key && self.operation_id == other.operation_id
  Branch (68:9): [True: 0, False: 0]
  Branch (68:9): [Folded - Ignored]
69
0
    }
70
}
71
72
impl Eq for SortedAwaitedAction {}
73
74
impl PartialOrd for SortedAwaitedAction {
75
0
    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
76
0
        Some(self.cmp(other))
77
0
    }
78
}
79
80
impl Ord for SortedAwaitedAction {
81
62
    fn cmp(&self, other: &Self) -> cmp::Ordering {
82
62
        self.sort_key
83
62
            .cmp(&other.sort_key)
84
62
            .then_with(|| 
self.operation_id.cmp(&other.operation_id)54
)
85
62
    }
86
}
87
88
impl std::fmt::Display for SortedAwaitedAction {
89
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90
0
        std::fmt::write(
91
0
            f,
92
0
            format_args!("{}-{}", self.sort_key.as_u64(), self.operation_id),
93
0
        )
94
0
    }
95
}
96
97
impl From<&AwaitedAction> for SortedAwaitedAction {
98
2
    fn from(value: &AwaitedAction) -> Self {
99
2
        Self {
100
2
            operation_id: value.operation_id().clone(),
101
2
            sort_key: value.sort_key(),
102
2
        }
103
2
    }
104
}
105
106
impl From<AwaitedAction> for SortedAwaitedAction {
107
0
    fn from(value: AwaitedAction) -> Self {
108
0
        Self::from(&value)
109
0
    }
110
}
111
112
impl TryInto<Vec<u8>> for SortedAwaitedAction {
113
    type Error = Error;
114
0
    fn try_into(self) -> Result<Vec<u8>, Self::Error> {
115
0
        serde_json::to_vec(&self)
116
0
            .map_err(|e| make_input_err!("{}", e.to_string()))
117
0
            .err_tip(|| "In SortedAwaitedAction::TryInto::<Vec<u8>>")
118
0
    }
119
}
120
121
impl TryFrom<&[u8]> for SortedAwaitedAction {
122
    type Error = Error;
123
0
    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
124
0
        serde_json::from_slice(value)
125
0
            .map_err(|e| make_input_err!("{}", e.to_string()))
126
0
            .err_tip(|| "In AwaitedAction::TryFrom::&[u8]")
127
0
    }
128
}
129
130
/// Subscriber that can be used to monitor when `AwaitedActions` change.
131
pub trait AwaitedActionSubscriber: Send + Sync + Sized + 'static {
132
    /// Wait for `AwaitedAction` to change.
133
    fn changed(&mut self) -> impl Future<Output = Result<AwaitedAction, Error>> + Send;
134
135
    /// Get the current awaited action.
136
    fn borrow(&self) -> impl Future<Output = Result<AwaitedAction, Error>> + Send;
137
}
138
139
/// A trait that defines the interface for an `AwaitedActionDb`.
140
pub trait AwaitedActionDb: Send + Sync + MetricsComponent + Unpin + 'static {
141
    type Subscriber: AwaitedActionSubscriber;
142
143
    /// Get the `AwaitedAction` by the client operation id.
144
    fn get_awaited_action_by_id(
145
        &self,
146
        client_operation_id: &OperationId,
147
    ) -> impl Future<Output = Result<Option<Self::Subscriber>, Error>> + Send;
148
149
    /// Get all `AwaitedActions`. This call should be avoided as much as possible.
150
    fn get_all_awaited_actions(
151
        &self,
152
    ) -> impl Future<
153
        Output = Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error>,
154
    > + Send;
155
156
    /// Get the `AwaitedAction` by the operation id.
157
    fn get_by_operation_id(
158
        &self,
159
        operation_id: &OperationId,
160
    ) -> impl Future<Output = Result<Option<Self::Subscriber>, Error>> + Send;
161
162
    /// Get a range of `AwaitedActions` of a specific state in sorted order.
163
    fn get_range_of_actions(
164
        &self,
165
        state: SortedAwaitedActionState,
166
        start: Bound<SortedAwaitedAction>,
167
        end: Bound<SortedAwaitedAction>,
168
        desc: bool,
169
    ) -> impl Future<
170
        Output = Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error>,
171
    > + Send;
172
173
    /// Process a change changed `AwaitedAction` and notify any listeners.
174
    fn update_awaited_action(
175
        &self,
176
        new_awaited_action: AwaitedAction,
177
    ) -> impl Future<Output = Result<(), Error>> + Send;
178
179
    /// Add (or join) an action to the `AwaitedActionDb` and subscribe
180
    /// to changes.
181
    fn add_action(
182
        &self,
183
        client_operation_id: OperationId,
184
        action_info: Arc<ActionInfo>,
185
    ) -> impl Future<Output = Result<Self::Subscriber, Error>> + Send;
186
}