Coverage Report

Created: 2024-12-20 00:05

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-util/src/origin_event_middleware.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::sync::Arc;
16
17
use base64::prelude::BASE64_STANDARD_NO_PAD;
18
use base64::Engine;
19
use futures::future::BoxFuture;
20
use futures::task::{Context, Poll};
21
use hyper::http::{self, StatusCode};
22
use nativelink_config::cas_server::IdentityHeaderSpec;
23
use nativelink_proto::build::bazel::remote::execution::v2::RequestMetadata;
24
use nativelink_proto::com::github::trace_machina::nativelink::events::OriginEvent;
25
use prost::Message;
26
use tokio::sync::mpsc;
27
use tower::layer::Layer;
28
use tower::Service;
29
use tracing::trace_span;
30
31
use crate::origin_context::{ActiveOriginContext, ORIGIN_IDENTITY};
32
use crate::origin_event::{OriginEventCollector, ORIGIN_EVENT_COLLECTOR};
33
34
/// Default identity header name.
35
/// Note: If this is changed, the default value in the [`IdentityHeaderSpec`]
36
// TODO(allada) This has a mirror in bep_server.rs.
37
// We should consolidate these.
38
const DEFAULT_IDENTITY_HEADER: &str = "x-identity";
39
40
#[derive(Default, Clone)]
41
pub struct OriginRequestMetadata {
42
    pub identity: String,
43
    pub bazel_metadata: Option<RequestMetadata>,
44
}
45
46
#[derive(Clone)]
47
pub struct OriginEventMiddlewareLayer {
48
    maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>,
49
    idenity_header_config: Arc<IdentityHeaderSpec>,
50
}
51
52
impl OriginEventMiddlewareLayer {
53
0
    pub fn new(
54
0
        maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>,
55
0
        idenity_header_config: IdentityHeaderSpec,
56
0
    ) -> Self {
57
0
        Self {
58
0
            maybe_origin_event_tx,
59
0
            idenity_header_config: Arc::new(idenity_header_config),
60
0
        }
61
0
    }
62
}
63
64
impl<S> Layer<S> for OriginEventMiddlewareLayer {
65
    type Service = OriginEventMiddleware<S>;
66
67
0
    fn layer(&self, service: S) -> Self::Service {
68
0
        OriginEventMiddleware {
69
0
            inner: service,
70
0
            maybe_origin_event_tx: self.maybe_origin_event_tx.clone(),
71
0
            idenity_header_config: self.idenity_header_config.clone(),
72
0
        }
73
0
    }
74
}
75
76
#[derive(Clone)]
77
pub struct OriginEventMiddleware<S> {
78
    inner: S,
79
    maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>,
80
    idenity_header_config: Arc<IdentityHeaderSpec>,
81
}
82
83
impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for OriginEventMiddleware<S>
84
where
85
    S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone + Send + 'static,
86
    S::Future: Send + 'static,
87
    ReqBody: std::fmt::Debug + Send + 'static,
88
    ResBody: From<String> + Send + 'static,
89
{
90
    type Response = S::Response;
91
    type Error = S::Error;
92
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
93
94
0
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
95
0
        self.inner.poll_ready(cx)
96
0
    }
97
98
0
    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
99
0
        // We must take the current `inner` and not the clone.
100
0
        // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
101
0
        let clone = self.inner.clone();
102
0
        let mut inner = std::mem::replace(&mut self.inner, clone);
103
0
104
0
        let mut context = ActiveOriginContext::fork().unwrap_or_default();
105
0
        let identity = {
106
0
            let identity_header = self
107
0
                .idenity_header_config
108
0
                .header_name
109
0
                .as_deref()
110
0
                .unwrap_or(DEFAULT_IDENTITY_HEADER);
111
0
            let identity = if !identity_header.is_empty() {
  Branch (111:31): [Folded - Ignored]
  Branch (111:31): [Folded - Ignored]
112
0
                req.headers()
113
0
                    .get(identity_header)
114
0
                    .and_then(|header| header.to_str().ok().map(str::to_string))
115
0
                    .unwrap_or_default()
116
            } else {
117
0
                String::new()
118
            };
119
120
0
            if identity.is_empty() && self.idenity_header_config.required {
  Branch (120:16): [Folded - Ignored]
  Branch (120:39): [Folded - Ignored]
  Branch (120:16): [Folded - Ignored]
  Branch (120:39): [Folded - Ignored]
121
0
                return Box::pin(async move {
122
0
                    Ok(http::Response::builder()
123
0
                        .status(StatusCode::UNAUTHORIZED)
124
0
                        .body("'identity_header' header is required".to_string().into())
125
0
                        .unwrap())
126
0
                });
127
0
            }
128
0
            context.set_value(&ORIGIN_IDENTITY, Arc::new(identity.clone()));
129
0
            identity
130
        };
131
0
        if let Some(origin_event_tx) = &self.maybe_origin_event_tx {
  Branch (131:16): [Folded - Ignored]
  Branch (131:16): [Folded - Ignored]
132
0
            let bazel_metadata = req
133
0
                .headers()
134
0
                .get("build.bazel.remote.execution.v2.requestmetadata-bin")
135
0
                .and_then(|header| BASE64_STANDARD_NO_PAD.decode(header.as_bytes()).ok())
136
0
                .and_then(|data| RequestMetadata::decode(data.as_slice()).ok());
137
0
            context.set_value(
138
0
                &ORIGIN_EVENT_COLLECTOR,
139
0
                Arc::new(OriginEventCollector::new(
140
0
                    origin_event_tx.clone(),
141
0
                    identity,
142
0
                    bazel_metadata,
143
0
                )),
144
0
            );
145
0
        }
146
147
0
        Box::pin(async move {
148
0
            Arc::new(context)
149
0
                .wrap_async(trace_span!("OriginEventMiddleware"), inner.call(req))
150
0
                .await
151
0
        })
152
0
    }
153
}