/build/source/nativelink-scheduler/src/awaited_action_db/mod.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use std::cmp; |
16 | | use std::ops::Bound; |
17 | | use std::sync::Arc; |
18 | | |
19 | | pub use awaited_action::{AwaitedAction, AwaitedActionSortKey}; |
20 | | use futures::{Future, Stream}; |
21 | | use nativelink_error::{make_input_err, Error, ResultExt}; |
22 | | use nativelink_metric::MetricsComponent; |
23 | | use nativelink_util::action_messages::{ActionInfo, ActionStage, OperationId}; |
24 | | use serde::{Deserialize, Serialize}; |
25 | | |
26 | | mod awaited_action; |
27 | | |
28 | | /// A simple enum to represent the state of an `AwaitedAction`. |
29 | | #[derive(Debug, Clone, Copy)] |
30 | | pub enum SortedAwaitedActionState { |
31 | | CacheCheck, |
32 | | Queued, |
33 | | Executing, |
34 | | Completed, |
35 | | } |
36 | | |
37 | | impl TryFrom<&ActionStage> for SortedAwaitedActionState { |
38 | | type Error = Error; |
39 | 2 | fn try_from(value: &ActionStage) -> Result<Self, Error> { |
40 | 2 | match value { |
41 | 0 | ActionStage::CacheCheck => Ok(Self::CacheCheck), |
42 | 1 | ActionStage::Executing => Ok(Self::Executing), |
43 | 0 | ActionStage::Completed(_) => Ok(Self::Completed), |
44 | 1 | ActionStage::Queued => Ok(Self::Queued), |
45 | 0 | _ => Err(make_input_err!("Invalid State")), |
46 | | } |
47 | 2 | } |
48 | | } |
49 | | |
50 | | impl TryFrom<ActionStage> for SortedAwaitedActionState { |
51 | | type Error = Error; |
52 | 0 | fn try_from(value: ActionStage) -> Result<Self, Error> { |
53 | 0 | Self::try_from(&value) |
54 | 0 | } |
55 | | } |
56 | | |
57 | | /// A struct pointing to an `AwaitedAction` that can be sorted. |
58 | | #[derive(Debug, Clone, Serialize, Deserialize, MetricsComponent)] |
59 | | pub struct SortedAwaitedAction { |
60 | | #[metric(help = "The sort key of the AwaitedAction")] |
61 | | pub sort_key: AwaitedActionSortKey, |
62 | | #[metric(help = "The operation id")] |
63 | | pub operation_id: OperationId, |
64 | | } |
65 | | |
66 | | impl PartialEq for SortedAwaitedAction { |
67 | 0 | fn eq(&self, other: &Self) -> bool { |
68 | 0 | self.sort_key == other.sort_key && self.operation_id == other.operation_id Branch (68:9): [True: 0, False: 0]
Branch (68:9): [Folded - Ignored]
|
69 | 0 | } |
70 | | } |
71 | | |
72 | | impl Eq for SortedAwaitedAction {} |
73 | | |
74 | | impl PartialOrd for SortedAwaitedAction { |
75 | 0 | fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> { |
76 | 0 | Some(self.cmp(other)) |
77 | 0 | } |
78 | | } |
79 | | |
80 | | impl Ord for SortedAwaitedAction { |
81 | 62 | fn cmp(&self, other: &Self) -> cmp::Ordering { |
82 | 62 | self.sort_key |
83 | 62 | .cmp(&other.sort_key) |
84 | 62 | .then_with(|| self.operation_id.cmp(&other.operation_id)54 ) |
85 | 62 | } |
86 | | } |
87 | | |
88 | | impl std::fmt::Display for SortedAwaitedAction { |
89 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
90 | 0 | std::fmt::write( |
91 | 0 | f, |
92 | 0 | format_args!("{}-{}", self.sort_key.as_u64(), self.operation_id), |
93 | 0 | ) |
94 | 0 | } |
95 | | } |
96 | | |
97 | | impl From<&AwaitedAction> for SortedAwaitedAction { |
98 | 2 | fn from(value: &AwaitedAction) -> Self { |
99 | 2 | Self { |
100 | 2 | operation_id: value.operation_id().clone(), |
101 | 2 | sort_key: value.sort_key(), |
102 | 2 | } |
103 | 2 | } |
104 | | } |
105 | | |
106 | | impl From<AwaitedAction> for SortedAwaitedAction { |
107 | 0 | fn from(value: AwaitedAction) -> Self { |
108 | 0 | Self::from(&value) |
109 | 0 | } |
110 | | } |
111 | | |
112 | | impl TryInto<Vec<u8>> for SortedAwaitedAction { |
113 | | type Error = Error; |
114 | 0 | fn try_into(self) -> Result<Vec<u8>, Self::Error> { |
115 | 0 | serde_json::to_vec(&self) |
116 | 0 | .map_err(|e| make_input_err!("{}", e.to_string())) |
117 | 0 | .err_tip(|| "In SortedAwaitedAction::TryInto::<Vec<u8>>") |
118 | 0 | } |
119 | | } |
120 | | |
121 | | impl TryFrom<&[u8]> for SortedAwaitedAction { |
122 | | type Error = Error; |
123 | 0 | fn try_from(value: &[u8]) -> Result<Self, Self::Error> { |
124 | 0 | serde_json::from_slice(value) |
125 | 0 | .map_err(|e| make_input_err!("{}", e.to_string())) |
126 | 0 | .err_tip(|| "In AwaitedAction::TryFrom::&[u8]") |
127 | 0 | } |
128 | | } |
129 | | |
130 | | /// Subscriber that can be used to monitor when `AwaitedActions` change. |
131 | | pub trait AwaitedActionSubscriber: Send + Sync + Sized + 'static { |
132 | | /// Wait for `AwaitedAction` to change. |
133 | | fn changed(&mut self) -> impl Future<Output = Result<AwaitedAction, Error>> + Send; |
134 | | |
135 | | /// Get the current awaited action. |
136 | | fn borrow(&self) -> impl Future<Output = Result<AwaitedAction, Error>> + Send; |
137 | | } |
138 | | |
139 | | /// A trait that defines the interface for an `AwaitedActionDb`. |
140 | | pub trait AwaitedActionDb: Send + Sync + MetricsComponent + Unpin + 'static { |
141 | | type Subscriber: AwaitedActionSubscriber; |
142 | | |
143 | | /// Get the `AwaitedAction` by the client operation id. |
144 | | fn get_awaited_action_by_id( |
145 | | &self, |
146 | | client_operation_id: &OperationId, |
147 | | ) -> impl Future<Output = Result<Option<Self::Subscriber>, Error>> + Send; |
148 | | |
149 | | /// Get all `AwaitedActions`. This call should be avoided as much as possible. |
150 | | fn get_all_awaited_actions( |
151 | | &self, |
152 | | ) -> impl Future< |
153 | | Output = Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error>, |
154 | | > + Send; |
155 | | |
156 | | /// Get the `AwaitedAction` by the operation id. |
157 | | fn get_by_operation_id( |
158 | | &self, |
159 | | operation_id: &OperationId, |
160 | | ) -> impl Future<Output = Result<Option<Self::Subscriber>, Error>> + Send; |
161 | | |
162 | | /// Get a range of `AwaitedActions` of a specific state in sorted order. |
163 | | fn get_range_of_actions( |
164 | | &self, |
165 | | state: SortedAwaitedActionState, |
166 | | start: Bound<SortedAwaitedAction>, |
167 | | end: Bound<SortedAwaitedAction>, |
168 | | desc: bool, |
169 | | ) -> impl Future< |
170 | | Output = Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error>, |
171 | | > + Send; |
172 | | |
173 | | /// Process a change changed `AwaitedAction` and notify any listeners. |
174 | | fn update_awaited_action( |
175 | | &self, |
176 | | new_awaited_action: AwaitedAction, |
177 | | ) -> impl Future<Output = Result<(), Error>> + Send; |
178 | | |
179 | | /// Add (or join) an action to the `AwaitedActionDb` and subscribe |
180 | | /// to changes. |
181 | | fn add_action( |
182 | | &self, |
183 | | client_operation_id: OperationId, |
184 | | action_info: Arc<ActionInfo>, |
185 | | ) -> impl Future<Output = Result<Self::Subscriber, Error>> + Send; |
186 | | } |