/build/source/nativelink-service/src/bytestream_server.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use std::collections::hash_map::Entry; |
16 | | use std::collections::HashMap; |
17 | | use std::convert::Into; |
18 | | use std::fmt::{Debug, Formatter}; |
19 | | use std::pin::Pin; |
20 | | use std::sync::atomic::{AtomicU64, Ordering}; |
21 | | use std::sync::Arc; |
22 | | use std::time::Duration; |
23 | | |
24 | | use futures::future::{pending, BoxFuture}; |
25 | | use futures::stream::unfold; |
26 | | use futures::{try_join, Future, Stream, TryFutureExt}; |
27 | | use nativelink_config::cas_server::ByteStreamConfig; |
28 | | use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; |
29 | | use nativelink_proto::google::bytestream::byte_stream_server::{ |
30 | | ByteStream, ByteStreamServer as Server, |
31 | | }; |
32 | | use nativelink_proto::google::bytestream::{ |
33 | | QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest, |
34 | | WriteResponse, |
35 | | }; |
36 | | use nativelink_store::grpc_store::GrpcStore; |
37 | | use nativelink_store::store_manager::StoreManager; |
38 | | use nativelink_util::buf_channel::{ |
39 | | make_buf_channel_pair, DropCloserReadHalf, DropCloserWriteHalf, |
40 | | }; |
41 | | use nativelink_util::common::DigestInfo; |
42 | | use nativelink_util::digest_hasher::{ |
43 | | default_digest_hasher_func, make_ctx_for_hash_func, DigestHasherFunc, |
44 | | }; |
45 | | use nativelink_util::origin_event::OriginEventContext; |
46 | | use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper; |
47 | | use nativelink_util::resource_info::ResourceInfo; |
48 | | use nativelink_util::spawn; |
49 | | use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo}; |
50 | | use nativelink_util::task::JoinHandleDropGuard; |
51 | | use parking_lot::Mutex; |
52 | | use tokio::time::sleep; |
53 | | use tonic::{Request, Response, Status, Streaming}; |
54 | | use tracing::{enabled, error_span, event, instrument, Instrument, Level}; |
55 | | |
56 | | /// If this value changes update the documentation in the config definition. |
57 | | const DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT: Duration = Duration::from_secs(60); |
58 | | |
59 | | /// If this value changes update the documentation in the config definition. |
60 | | const DEFAULT_MAX_BYTES_PER_STREAM: usize = 64 * 1024; |
61 | | |
62 | | /// If this value changes update the documentation in the config definition. |
63 | | const DEFAULT_MAX_DECODING_MESSAGE_SIZE: usize = 4 * 1024 * 1024; |
64 | | |
65 | | type ReadStream = Pin<Box<dyn Stream<Item = Result<ReadResponse, Status>> + Send + 'static>>; |
66 | | type StoreUpdateFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'static>>; |
67 | | |
68 | | struct StreamState { |
69 | | uuid: String, |
70 | | tx: DropCloserWriteHalf, |
71 | | store_update_fut: StoreUpdateFuture, |
72 | | } |
73 | | |
74 | | impl Debug for StreamState { |
75 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
76 | 0 | f.debug_struct("StreamState") |
77 | 0 | .field("uuid", &self.uuid) |
78 | 0 | .finish() |
79 | 0 | } |
80 | | } |
81 | | |
82 | | /// If a stream is in this state, it will automatically be put back into an `IdleStream` and |
83 | | /// placed back into the `active_uploads` map as an `IdleStream` after it is dropped. |
84 | | /// To prevent it from being put back into an `IdleStream` you must call `.graceful_finish()`. |
85 | | struct ActiveStreamGuard<'a> { |
86 | | stream_state: Option<StreamState>, |
87 | | bytes_received: Arc<AtomicU64>, |
88 | | bytestream_server: &'a ByteStreamServer, |
89 | | } |
90 | | |
91 | | impl ActiveStreamGuard<'_> { |
92 | | /// Consumes the guard. The stream will be considered "finished", will |
93 | | /// remove it from the `active_uploads`. |
94 | 8 | fn graceful_finish(mut self) { |
95 | 8 | let stream_state = self.stream_state.take().unwrap(); |
96 | 8 | self.bytestream_server |
97 | 8 | .active_uploads |
98 | 8 | .lock() |
99 | 8 | .remove(&stream_state.uuid); |
100 | 8 | } |
101 | | } |
102 | | |
103 | | impl Drop for ActiveStreamGuard<'_> { |
104 | 14 | fn drop(&mut self) { |
105 | 14 | let Some(stream_state6 ) = self.stream_state.take() else { Branch (105:13): [True: 6, False: 8]
Branch (105:13): [Folded - Ignored]
|
106 | 8 | return; // If None it means we don't want it put back into an IdleStream. |
107 | | }; |
108 | 6 | let weak_active_uploads = Arc::downgrade(&self.bytestream_server.active_uploads); |
109 | 6 | let mut active_uploads = self.bytestream_server.active_uploads.lock(); |
110 | 6 | let uuid = stream_state.uuid.clone(); |
111 | 6 | let Some(active_uploads_slot) = active_uploads.get_mut(&uuid) else { Branch (111:13): [True: 6, False: 0]
Branch (111:13): [Folded - Ignored]
|
112 | 0 | event!( |
113 | 0 | Level::ERROR, |
114 | | err = "Failed to find active upload. This should never happen.", |
115 | | uuid = ?uuid, |
116 | | ); |
117 | 0 | return; |
118 | | }; |
119 | 6 | let sleep_fn = self.bytestream_server.sleep_fn.clone(); |
120 | 6 | active_uploads_slot.1 = Some(IdleStream { |
121 | 6 | stream_state, |
122 | 6 | _timeout_streaam_drop_guard: spawn!("bytestream_idle_stream_timeout", async move { |
123 | 3 | (*sleep_fn)().await; |
124 | 0 | if let Some(active_uploads) = weak_active_uploads.upgrade() { Branch (124:24): [True: 0, False: 0]
Branch (124:24): [Folded - Ignored]
|
125 | 0 | let mut active_uploads = active_uploads.lock(); |
126 | 0 | event!(Level::INFO, msg = "Removing idle stream", uuid = ?uuid); |
127 | 0 | active_uploads.remove(&uuid); |
128 | 0 | } |
129 | 6 | }0 ), |
130 | | }); |
131 | 14 | } |
132 | | } |
133 | | |
134 | | /// Represents a stream that is in the "idle" state. this means it is not currently being used |
135 | | /// by a client. If it is not used within a certain amount of time it will be removed from the |
136 | | /// `active_uploads` map automatically. |
137 | | #[derive(Debug)] |
138 | | struct IdleStream { |
139 | | stream_state: StreamState, |
140 | | _timeout_streaam_drop_guard: JoinHandleDropGuard<()>, |
141 | | } |
142 | | |
143 | | impl IdleStream { |
144 | 3 | fn into_active_stream( |
145 | 3 | self, |
146 | 3 | bytes_received: Arc<AtomicU64>, |
147 | 3 | bytestream_server: &ByteStreamServer, |
148 | 3 | ) -> ActiveStreamGuard<'_> { |
149 | 3 | ActiveStreamGuard { |
150 | 3 | stream_state: Some(self.stream_state), |
151 | 3 | bytes_received, |
152 | 3 | bytestream_server, |
153 | 3 | } |
154 | 3 | } |
155 | | } |
156 | | |
157 | | type BytesWrittenAndIdleStream = (Arc<AtomicU64>, Option<IdleStream>); |
158 | | type SleepFn = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>; |
159 | | |
160 | | pub struct ByteStreamServer { |
161 | | stores: HashMap<String, Store>, |
162 | | // Max number of bytes to send on each grpc stream chunk. |
163 | | max_bytes_per_stream: usize, |
164 | | max_decoding_message_size: usize, |
165 | | active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>, |
166 | | sleep_fn: SleepFn, |
167 | | } |
168 | | |
169 | | impl ByteStreamServer { |
170 | 14 | pub fn new(config: &ByteStreamConfig, store_manager: &StoreManager) -> Result<Self, Error> { |
171 | 14 | let mut persist_stream_on_disconnect_timeout = |
172 | 14 | Duration::from_secs(config.persist_stream_on_disconnect_timeout as u64); |
173 | 14 | if config.persist_stream_on_disconnect_timeout == 0 { Branch (173:12): [True: 14, False: 0]
Branch (173:12): [Folded - Ignored]
|
174 | 14 | persist_stream_on_disconnect_timeout = DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT; |
175 | 14 | }0 |
176 | 14 | Self::new_with_sleep_fn( |
177 | 14 | config, |
178 | 14 | store_manager, |
179 | 14 | Arc::new(move || Box::pin(sleep(persist_stream_on_disconnect_timeout))3 ), |
180 | 14 | ) |
181 | 14 | } |
182 | | |
183 | 14 | pub fn new_with_sleep_fn( |
184 | 14 | config: &ByteStreamConfig, |
185 | 14 | store_manager: &StoreManager, |
186 | 14 | sleep_fn: SleepFn, |
187 | 14 | ) -> Result<Self, Error> { |
188 | 14 | let mut stores = HashMap::with_capacity(config.cas_stores.len()); |
189 | 28 | for (instance_name, store_name14 ) in &config.cas_stores { |
190 | 14 | let store = store_manager |
191 | 14 | .get_store(store_name) |
192 | 14 | .ok_or_else(|| make_input_err!("'cas_store': '{}' does not exist", store_name)0 )?0 ; |
193 | 14 | stores.insert(instance_name.to_string(), store); |
194 | | } |
195 | 14 | let max_bytes_per_stream = if config.max_bytes_per_stream == 0 { Branch (195:39): [True: 1, False: 13]
Branch (195:39): [Folded - Ignored]
|
196 | 1 | DEFAULT_MAX_BYTES_PER_STREAM |
197 | | } else { |
198 | 13 | config.max_bytes_per_stream |
199 | | }; |
200 | 14 | let max_decoding_message_size = if config.max_decoding_message_size == 0 { Branch (200:44): [True: 13, False: 1]
Branch (200:44): [Folded - Ignored]
|
201 | 13 | DEFAULT_MAX_DECODING_MESSAGE_SIZE |
202 | | } else { |
203 | 1 | config.max_decoding_message_size |
204 | | }; |
205 | 14 | Ok(ByteStreamServer { |
206 | 14 | stores, |
207 | 14 | max_bytes_per_stream, |
208 | 14 | max_decoding_message_size, |
209 | 14 | active_uploads: Arc::new(Mutex::new(HashMap::new())), |
210 | 14 | sleep_fn, |
211 | 14 | }) |
212 | 14 | } |
213 | | |
214 | 1 | pub fn into_service(self) -> Server<Self> { |
215 | 1 | let max_decoding_message_size = self.max_decoding_message_size; |
216 | 1 | Server::new(self).max_decoding_message_size(max_decoding_message_size) |
217 | 1 | } |
218 | | |
219 | 14 | fn create_or_join_upload_stream( |
220 | 14 | &self, |
221 | 14 | uuid: String, |
222 | 14 | store: Store, |
223 | 14 | digest: DigestInfo, |
224 | 14 | ) -> Result<ActiveStreamGuard<'_>, Error> { |
225 | 14 | let (uuid, bytes_received11 ) = match self.active_uploads.lock().entry(uuid) { |
226 | 3 | Entry::Occupied(mut entry) => { |
227 | 3 | let maybe_idle_stream = entry.get_mut(); |
228 | 3 | let Some(idle_stream) = maybe_idle_stream.1.take() else { Branch (228:21): [True: 3, False: 0]
Branch (228:21): [Folded - Ignored]
|
229 | 0 | return Err(make_input_err!("Cannot upload same UUID simultaneously")); |
230 | | }; |
231 | 3 | let bytes_received = maybe_idle_stream.0.clone(); |
232 | 3 | event!(Level::INFO, msg = "Joining existing stream", entry = ?entry.key()0 ); |
233 | 3 | return Ok(idle_stream.into_active_stream(bytes_received, self)); |
234 | | } |
235 | 11 | Entry::Vacant(entry) => { |
236 | 11 | let bytes_received = Arc::new(AtomicU64::new(0)); |
237 | 11 | let uuid = entry.key().clone(); |
238 | 11 | // Our stream is "in use" if the key is in the map, but the value is None. |
239 | 11 | entry.insert((bytes_received.clone(), None)); |
240 | 11 | (uuid, bytes_received) |
241 | 11 | } |
242 | 11 | }; |
243 | 11 | |
244 | 11 | // Important: Do not return an error from this point onwards without |
245 | 11 | // removing the entry from the map, otherwise that UUID becomes |
246 | 11 | // unusable. |
247 | 11 | |
248 | 11 | let (tx, rx) = make_buf_channel_pair(); |
249 | 11 | let store_update_fut = Box::pin(async move { |
250 | 8 | // We need to wrap `Store::update()` in a another future because we need to capture |
251 | 8 | // `store` to ensure its lifetime follows the future and not the caller. |
252 | 8 | store |
253 | 8 | // Bytestream always uses digest size as the actual byte size. |
254 | 8 | .update(digest, rx, UploadSizeInfo::ExactSize(digest.size_bytes())) |
255 | 8 | .await |
256 | 11 | }8 ); |
257 | 11 | Ok(ActiveStreamGuard { |
258 | 11 | stream_state: Some(StreamState { |
259 | 11 | uuid, |
260 | 11 | tx, |
261 | 11 | store_update_fut, |
262 | 11 | }), |
263 | 11 | bytes_received, |
264 | 11 | bytestream_server: self, |
265 | 11 | }) |
266 | 14 | } |
267 | | |
268 | 3 | async fn inner_read( |
269 | 3 | &self, |
270 | 3 | store: Store, |
271 | 3 | digest: DigestInfo, |
272 | 3 | read_request: ReadRequest, |
273 | 3 | ) -> Result<impl Stream<Item = Result<ReadResponse, Status>> + Send + 'static, Error> { |
274 | | struct ReaderState { |
275 | | max_bytes_per_stream: usize, |
276 | | rx: DropCloserReadHalf, |
277 | | maybe_get_part_result: Option<Result<(), Error>>, |
278 | | get_part_fut: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>, |
279 | | } |
280 | | |
281 | 3 | let read_limit = u64::try_from(read_request.read_limit) |
282 | 3 | .err_tip(|| "Could not convert read_limit to u64"0 )?0 ; |
283 | | |
284 | 3 | let (tx, rx) = make_buf_channel_pair(); |
285 | | |
286 | 3 | let read_limit = if read_limit != 0 { Branch (286:29): [True: 3, False: 0]
Branch (286:29): [Folded - Ignored]
|
287 | 3 | Some(read_limit) |
288 | | } else { |
289 | 0 | None |
290 | | }; |
291 | | |
292 | | // This allows us to call a destructor when the the object is dropped. |
293 | 3 | let state = Some(ReaderState { |
294 | 3 | rx, |
295 | 3 | max_bytes_per_stream: self.max_bytes_per_stream, |
296 | 3 | maybe_get_part_result: None, |
297 | 3 | get_part_fut: Box::pin(async move { |
298 | 3 | store |
299 | 3 | .get_part( |
300 | 3 | digest, |
301 | 3 | tx, |
302 | 3 | u64::try_from(read_request.read_offset) |
303 | 3 | .err_tip(|| "Could not convert read_offset to u64"0 )?0 , |
304 | 3 | read_limit, |
305 | 3 | ) |
306 | 3 | .await |
307 | 3 | }), |
308 | 3 | }); |
309 | | |
310 | 3 | let read_stream_span = error_span!("read_stream"); |
311 | | |
312 | 9.77k | Ok(Box::pin(unfold(state, 3 move |state| { |
313 | 9.77k | async { |
314 | 9.77k | let mut state = state?0 ; // If None our stream is done. |
315 | 9.77k | let mut response = ReadResponse::default(); |
316 | 9.77k | { |
317 | 9.77k | let consume_fut = state.rx.consume(Some(state.max_bytes_per_stream)); |
318 | 9.77k | tokio::pin!(consume_fut); |
319 | | loop { |
320 | 9.77k | tokio::select! { |
321 | 9.77k | read_result9.77k = &mut consume_fut => { |
322 | 9.77k | match read_result { |
323 | 9.76k | Ok(bytes) => { |
324 | 9.76k | if bytes.is_empty() { Branch (324:40): [True: 2, False: 9.76k]
Branch (324:40): [Folded - Ignored]
|
325 | | // EOF. |
326 | 2 | return None; |
327 | 9.76k | } |
328 | 9.76k | if bytes.len() > state.max_bytes_per_stream { Branch (328:40): [True: 0, False: 9.76k]
Branch (328:40): [Folded - Ignored]
|
329 | 0 | let err = make_err!(Code::Internal, "Returned store size was larger than read size"); |
330 | 0 | return Some((Err(err.into()), None)); |
331 | 9.76k | } |
332 | 9.76k | response.data = bytes; |
333 | 9.76k | if enabled!(Level::DEBUG) { |
334 | 0 | event!(Level::INFO, response = ?response); |
335 | | } else { |
336 | 9.76k | event!(Level::INFO, response.data = format!("<redacted len({})>", response.data.len())0 ); |
337 | | } |
338 | 9.76k | break; |
339 | | } |
340 | 1 | Err(mut e) => { |
341 | | // We may need to propagate the error from reading the data through first. |
342 | | // For example, the NotFound error will come through `get_part_fut`, and |
343 | | // will not be present in `e`, but we need to ensure we pass NotFound error |
344 | | // code or the client won't know why it failed. |
345 | 1 | let get_part_result = if let Some(result) = state.maybe_get_part_result { Branch (345:66): [True: 1, False: 0]
Branch (345:66): [Folded - Ignored]
|
346 | 1 | result |
347 | | } else { |
348 | | // This should never be `future::pending()` if maybe_get_part_result is |
349 | | // not set. |
350 | 0 | state.get_part_fut.await |
351 | | }; |
352 | 1 | if let Err(err) = get_part_result { Branch (352:44): [True: 1, False: 0]
Branch (352:44): [Folded - Ignored]
|
353 | 1 | e = err.merge(e); |
354 | 1 | }0 |
355 | 1 | if e.code == Code::NotFound { Branch (355:40): [True: 1, False: 0]
Branch (355:40): [Folded - Ignored]
|
356 | 1 | // Trim the error code. Not Found is quite common and we don't want to send a large |
357 | 1 | // error (debug) message for something that is common. We resize to just the last |
358 | 1 | // message as it will be the most relevant. |
359 | 1 | e.messages.truncate(1); |
360 | 1 | }0 |
361 | 1 | event!(Level::ERROR, response = ?e); |
362 | 1 | return Some((Err(e.into()), None)) |
363 | | } |
364 | | } |
365 | | }, |
366 | 9.77k | result3 = &mut state.get_part_fut => { |
367 | 3 | state.maybe_get_part_result = Some(result); |
368 | 3 | // It is non-deterministic on which future will finish in what order. |
369 | 3 | // It is also possible that the `state.rx.consume()` call above may not be able to |
370 | 3 | // respond even though the publishing future is done. |
371 | 3 | // Because of this we set the writing future to pending so it never finishes. |
372 | 3 | // The `state.rx.consume()` future will eventually finish and return either the |
373 | 3 | // data or an error. |
374 | 3 | // An EOF will terminate the `state.rx.consume()` future, but we are also protected |
375 | 3 | // because we are dropping the writing future, it will drop the `tx` channel |
376 | 3 | // which will eventually propagate an error to the `state.rx.consume()` future if |
377 | 3 | // the EOF was not sent due to some other error. |
378 | 3 | state.get_part_fut = Box::pin(pending()); |
379 | 3 | }, |
380 | | } |
381 | | } |
382 | | } |
383 | 9.76k | Some((Ok(response), Some(state))) |
384 | 9.77k | }.instrument(read_stream_span.clone()) |
385 | 9.77k | })))3 |
386 | 3 | } |
387 | | |
388 | | // We instrument tracing here as well as below because `stream` has a hash on it |
389 | | // that is extracted from the first stream message. If we only implemented it below |
390 | | // we would not have the hash available to us. |
391 | | #[instrument( |
392 | | ret(level = Level::INFO), |
393 | 14 | level = Level::ERROR, |
394 | | skip(self, store), |
395 | | fields(stream.first_msg = "<redacted>") |
396 | | )] |
397 | | async fn inner_write( |
398 | | &self, |
399 | | store: Store, |
400 | | digest: DigestInfo, |
401 | | stream: WriteRequestStreamWrapper<impl Stream<Item = Result<WriteRequest, Status>> + Unpin>, |
402 | | ) -> Result<Response<WriteResponse>, Error> { |
403 | 14 | async fn process_client_stream( |
404 | 14 | mut stream: WriteRequestStreamWrapper< |
405 | 14 | impl Stream<Item = Result<WriteRequest, Status>> + Unpin, |
406 | 14 | >, |
407 | 14 | tx: &mut DropCloserWriteHalf, |
408 | 14 | outer_bytes_received: &Arc<AtomicU64>, |
409 | 14 | expected_size: u64, |
410 | 14 | ) -> Result<(), Error> { |
411 | | loop { |
412 | 24 | let write_request20 = match stream.next().await { |
413 | | // Code path for when client tries to gracefully close the stream. |
414 | | // If this happens it means there's a problem with the data sent, |
415 | | // because we always close the stream from our end before this point |
416 | | // by counting the number of bytes sent from the client. If they send |
417 | | // less than the amount they said they were going to send and then |
418 | | // close the stream, we know there's a problem. |
419 | | None => { |
420 | 0 | return Err(make_input_err!( |
421 | 0 | "Client closed stream before sending all data" |
422 | 0 | )) |
423 | | } |
424 | | // Code path for client stream error. Probably client disconnect. |
425 | 4 | Some(Err(err)) => return Err(err), |
426 | | // Code path for received chunk of data. |
427 | 20 | Some(Ok(write_request)) => write_request, |
428 | 20 | }; |
429 | 20 | |
430 | 20 | if write_request.write_offset < 0 { Branch (430:20): [True: 1, False: 19]
Branch (430:20): [Folded - Ignored]
|
431 | 1 | return Err(make_input_err!( |
432 | 1 | "Invalid negative write offset in write request: {}", |
433 | 1 | write_request.write_offset |
434 | 1 | )); |
435 | 19 | } |
436 | 19 | let write_offset = write_request.write_offset as u64; |
437 | | |
438 | | // If we get duplicate data because a client didn't know where |
439 | | // it left off from, then we can simply skip it. |
440 | 19 | let data18 = if write_offset < tx.get_bytes_written() { Branch (440:31): [True: 2, False: 17]
Branch (440:31): [Folded - Ignored]
|
441 | 2 | if (write_offset + write_request.data.len() as u64) < tx.get_bytes_written() { Branch (441:24): [True: 0, False: 2]
Branch (441:24): [Folded - Ignored]
|
442 | 0 | if write_request.finish_write { Branch (442:28): [True: 0, False: 0]
Branch (442:28): [Folded - Ignored]
|
443 | 0 | return Err(make_input_err!( |
444 | 0 | "Resumed stream finished at {} bytes when we already received {} bytes.", |
445 | 0 | write_offset + write_request.data.len() as u64, |
446 | 0 | tx.get_bytes_written() |
447 | 0 | )); |
448 | 0 | } |
449 | 0 | continue; |
450 | 2 | } |
451 | 2 | write_request |
452 | 2 | .data |
453 | 2 | .slice((tx.get_bytes_written() - write_offset) as usize..) |
454 | | } else { |
455 | 17 | if write_offset != tx.get_bytes_written() { Branch (455:24): [True: 1, False: 16]
Branch (455:24): [Folded - Ignored]
|
456 | 1 | return Err(make_input_err!( |
457 | 1 | "Received out of order data. Got {}, expected {}", |
458 | 1 | write_offset, |
459 | 1 | tx.get_bytes_written() |
460 | 1 | )); |
461 | 16 | } |
462 | 16 | write_request.data |
463 | | }; |
464 | | |
465 | | // Do not process EOF or weird stuff will happen. |
466 | 18 | if !data.is_empty() { Branch (466:20): [True: 13, False: 5]
Branch (466:20): [Folded - Ignored]
|
467 | | // We also need to process the possible EOF branch, so we can't early return. |
468 | 13 | if let Err(mut err0 ) = tx.send(data).await { Branch (468:28): [True: 0, False: 13]
Branch (468:28): [Folded - Ignored]
|
469 | 0 | err.code = Code::Internal; |
470 | 0 | return Err(err); |
471 | 13 | } |
472 | 13 | outer_bytes_received.store(tx.get_bytes_written(), Ordering::Release); |
473 | 5 | } |
474 | | |
475 | 18 | if expected_size < tx.get_bytes_written() { Branch (475:20): [True: 0, False: 18]
Branch (475:20): [Folded - Ignored]
|
476 | 0 | return Err(make_input_err!("Received more bytes than expected")); |
477 | 18 | } |
478 | 18 | if write_request.finish_write { Branch (478:20): [True: 8, False: 10]
Branch (478:20): [Folded - Ignored]
|
479 | | // Gracefully close our stream. |
480 | 8 | tx.send_eof() |
481 | 8 | .err_tip(|| "Failed to send EOF in ByteStream::write"0 )?0 ; |
482 | 8 | return Ok(()); |
483 | 10 | } |
484 | | // Continue. |
485 | | } |
486 | | // Unreachable. |
487 | 14 | } |
488 | | |
489 | | let uuid = stream |
490 | | .resource_info |
491 | | .uuid |
492 | | .as_ref() |
493 | 0 | .ok_or_else(|| make_input_err!("UUID must be set if writing data"))? |
494 | | .to_string(); |
495 | | let mut active_stream_guard = self.create_or_join_upload_stream(uuid, store, digest)?; |
496 | | let expected_size = stream.resource_info.expected_size as u64; |
497 | | |
498 | | let active_stream = active_stream_guard.stream_state.as_mut().unwrap(); |
499 | | try_join!( |
500 | | process_client_stream( |
501 | | stream, |
502 | | &mut active_stream.tx, |
503 | | &active_stream_guard.bytes_received, |
504 | | expected_size |
505 | | ), |
506 | | (&mut active_stream.store_update_fut) |
507 | 0 | .map_err(|err| { err.append("Error updating inner store") }) |
508 | | )?; |
509 | | |
510 | | // Close our guard and consider the stream no longer active. |
511 | | active_stream_guard.graceful_finish(); |
512 | | |
513 | | Ok(Response::new(WriteResponse { |
514 | | committed_size: expected_size as i64, |
515 | | })) |
516 | | } |
517 | | |
518 | 3 | async fn inner_query_write_status( |
519 | 3 | &self, |
520 | 3 | query_request: &QueryWriteStatusRequest, |
521 | 3 | ) -> Result<Response<QueryWriteStatusResponse>, Error> { |
522 | 3 | let mut resource_info = ResourceInfo::new(&query_request.resource_name, true)?0 ; |
523 | | |
524 | 3 | let store_clone = self |
525 | 3 | .stores |
526 | 3 | .get(resource_info.instance_name.as_ref()) |
527 | 3 | .err_tip(|| { |
528 | 0 | format!( |
529 | 0 | "'instance_name' not configured for '{}'", |
530 | 0 | &resource_info.instance_name |
531 | 0 | ) |
532 | 3 | })?0 |
533 | 3 | .clone(); |
534 | | |
535 | 3 | let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?0 ; |
536 | | |
537 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
538 | 3 | if let Some(grpc_store0 ) = store_clone.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (538:16): [True: 0, False: 3]
Branch (538:16): [Folded - Ignored]
|
539 | 0 | return grpc_store |
540 | 0 | .query_write_status(Request::new(query_request.clone())) |
541 | 0 | .await; |
542 | 3 | } |
543 | | |
544 | 3 | let uuid = resource_info |
545 | 3 | .uuid |
546 | 3 | .take() |
547 | 3 | .ok_or_else(|| make_input_err!("UUID must be set if querying write status")0 )?0 ; |
548 | | |
549 | | { |
550 | 3 | let active_uploads = self.active_uploads.lock(); |
551 | 3 | if let Some((received_bytes, _maybe_idle_stream1 )) = active_uploads.get(uuid.as_ref()) { Branch (551:20): [True: 1, False: 2]
Branch (551:20): [Folded - Ignored]
|
552 | 1 | return Ok(Response::new(QueryWriteStatusResponse { |
553 | 1 | committed_size: received_bytes.load(Ordering::Acquire) as i64, |
554 | 1 | // If we are in the active_uploads map, but the value is None, |
555 | 1 | // it means the stream is not complete. |
556 | 1 | complete: false, |
557 | 1 | })); |
558 | 2 | } |
559 | 2 | } |
560 | 2 | |
561 | 2 | let has_fut = store_clone.has(digest); |
562 | 2 | let Some(item_size1 ) = has_fut.await.err_tip(|| "Failed to call .has() on store"0 )?0 else { Branch (562:13): [True: 1, False: 1]
Branch (562:13): [Folded - Ignored]
|
563 | | // We lie here and say that the stream needs to start over, even though |
564 | | // it was never started. This can happen when the client disconnects |
565 | | // before sending the first payload, but the client thinks it did send |
566 | | // the payload. |
567 | 1 | return Ok(Response::new(QueryWriteStatusResponse { |
568 | 1 | committed_size: 0, |
569 | 1 | complete: false, |
570 | 1 | })); |
571 | | }; |
572 | 1 | Ok(Response::new(QueryWriteStatusResponse { |
573 | 1 | committed_size: item_size as i64, |
574 | 1 | complete: true, |
575 | 1 | })) |
576 | 3 | } |
577 | | } |
578 | | |
579 | | #[tonic::async_trait] |
580 | | impl ByteStream for ByteStreamServer { |
581 | | type ReadStream = ReadStream; |
582 | | |
583 | | #[allow(clippy::blocks_in_conditions)] |
584 | | #[instrument( |
585 | | err, |
586 | | level = Level::ERROR, |
587 | | skip_all, |
588 | | fields(request = ?grpc_request.get_ref()) |
589 | | )] |
590 | | async fn read( |
591 | | &self, |
592 | | grpc_request: Request<ReadRequest>, |
593 | 3 | ) -> Result<Response<Self::ReadStream>, Status> { |
594 | 3 | let read_request = grpc_request.into_inner(); |
595 | 3 | let ctx = OriginEventContext::new(|| &read_request0 ).await; |
596 | | |
597 | 3 | let resource_info = ResourceInfo::new(&read_request.resource_name, false)?0 ; |
598 | 3 | let instance_name = resource_info.instance_name.as_ref(); |
599 | 3 | let store = self |
600 | 3 | .stores |
601 | 3 | .get(instance_name) |
602 | 3 | .err_tip(|| format!("'instance_name' not configured for '{instance_name}'")0 )?0 |
603 | 3 | .clone(); |
604 | | |
605 | 3 | let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?0 ; |
606 | | |
607 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
608 | 3 | if let Some(grpc_store0 ) = store.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (608:16): [True: 0, False: 3]
Branch (608:16): [Folded - Ignored]
|
609 | 0 | let stream = grpc_store.read(Request::new(read_request)).await?; |
610 | 0 | let resp = Ok(Response::new(ctx.wrap_stream(stream))); |
611 | 0 | ctx.emit(|| &resp).await; |
612 | 0 | return resp; |
613 | 3 | } |
614 | | |
615 | 3 | let digest_function = resource_info.digest_function.as_deref().map_or_else( |
616 | 3 | || Ok(default_digest_hasher_func()), |
617 | 3 | DigestHasherFunc::try_from, |
618 | 3 | )?0 ; |
619 | | |
620 | 3 | let resp = make_ctx_for_hash_func(digest_function) |
621 | 3 | .err_tip(|| "In BytestreamServer::read"0 )?0 |
622 | | .wrap_async( |
623 | 3 | error_span!("bytestream_read"), |
624 | 3 | self.inner_read(store, digest, read_request), |
625 | 3 | ) |
626 | 3 | .await |
627 | 3 | .err_tip(|| "In ByteStreamServer::read"0 ) |
628 | 3 | .map(|stream| -> Response<Self::ReadStream> { |
629 | 3 | Response::new(Box::pin(ctx.wrap_stream(stream))) |
630 | 3 | }) |
631 | 3 | .map_err(Into::into); |
632 | 3 | |
633 | 3 | if resp.is_ok() { Branch (633:12): [True: 3, False: 0]
Branch (633:12): [Folded - Ignored]
|
634 | 3 | event!(Level::DEBUG, return = "Ok(<stream>)"); |
635 | 0 | } |
636 | 3 | ctx.emit(|| &resp0 ).await; |
637 | 3 | resp |
638 | 6 | } |
639 | | |
640 | | #[allow(clippy::blocks_in_conditions)] |
641 | | #[instrument( |
642 | | err, |
643 | | level = Level::ERROR, |
644 | | skip_all, |
645 | | fields(request = ?grpc_request.get_ref()) |
646 | | )] |
647 | | async fn write( |
648 | | &self, |
649 | | grpc_request: Request<Streaming<WriteRequest>>, |
650 | 15 | ) -> Result<Response<WriteResponse>, Status> { |
651 | 15 | let request = grpc_request.into_inner(); |
652 | 15 | let ctx = OriginEventContext::new(|| &request0 ).await; |
653 | 15 | let stream14 = WriteRequestStreamWrapper::from(ctx.wrap_stream(request)) |
654 | 15 | .await |
655 | 15 | .err_tip(|| "Could not unwrap first stream message"1 ) |
656 | 15 | .map_err(Into::<Status>::into)?1 ; |
657 | | |
658 | 14 | let instance_name = stream.resource_info.instance_name.as_ref(); |
659 | 14 | let store = self |
660 | 14 | .stores |
661 | 14 | .get(instance_name) |
662 | 14 | .err_tip(|| format!("'instance_name' not configured for '{instance_name}'")0 )?0 |
663 | 14 | .clone(); |
664 | | |
665 | 14 | let digest = DigestInfo::try_new( |
666 | 14 | &stream.resource_info.hash, |
667 | 14 | stream.resource_info.expected_size, |
668 | 14 | ) |
669 | 14 | .err_tip(|| "Invalid digest input in ByteStream::write"0 )?0 ; |
670 | | |
671 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
672 | 14 | if let Some(grpc_store0 ) = store.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (672:16): [True: 0, False: 14]
Branch (672:16): [Folded - Ignored]
|
673 | 0 | let resp = grpc_store.write(stream).await.map_err(Into::into); |
674 | 0 | ctx.emit(|| &resp).await; |
675 | 0 | return resp; |
676 | 14 | } |
677 | | |
678 | 14 | let digest_function = stream |
679 | 14 | .resource_info |
680 | 14 | .digest_function |
681 | 14 | .as_deref() |
682 | 14 | .map_or_else( |
683 | 14 | || Ok(default_digest_hasher_func()), |
684 | 14 | DigestHasherFunc::try_from, |
685 | 14 | )?0 ; |
686 | | |
687 | 14 | let resp = make_ctx_for_hash_func(digest_function) |
688 | 14 | .err_tip(|| "In BytestreamServer::write"0 )?0 |
689 | | .wrap_async( |
690 | 14 | error_span!("bytestream_write"), |
691 | 14 | self.inner_write(store, digest, stream), |
692 | 14 | ) |
693 | 14 | .await |
694 | 14 | .err_tip(|| "In ByteStreamServer::write"6 ) |
695 | 14 | .map_err(Into::into); |
696 | 14 | ctx.emit(|| &resp0 ).await; |
697 | 14 | resp |
698 | 30 | } |
699 | | |
700 | | #[allow(clippy::blocks_in_conditions)] |
701 | | #[instrument( |
702 | | err, |
703 | | ret(level = Level::INFO), |
704 | | level = Level::ERROR, |
705 | | skip_all, |
706 | | fields(request = ?grpc_request.get_ref()) |
707 | | )] |
708 | | async fn query_write_status( |
709 | | &self, |
710 | | grpc_request: Request<QueryWriteStatusRequest>, |
711 | 3 | ) -> Result<Response<QueryWriteStatusResponse>, Status> { |
712 | 3 | let request = grpc_request.into_inner(); |
713 | 3 | let ctx = OriginEventContext::new(|| &request0 ).await; |
714 | 3 | let resp = self |
715 | 3 | .inner_query_write_status(&request) |
716 | 3 | .await |
717 | 3 | .err_tip(|| "Failed on query_write_status() command"0 ) |
718 | 3 | .map_err(Into::into); |
719 | 3 | ctx.emit(|| &resp0 ).await; |
720 | 3 | resp |
721 | 6 | } |
722 | | } |