/*--------------------------------------------------------------------------------------------- * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ use std::{ collections::HashMap, future, sync::{ atomic::{AtomicU32, Ordering}, Arc, Mutex, }, }; use crate::log; use futures::{future::BoxFuture, Future, FutureExt}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt, DuplexStream, WriteHalf}, sync::{mpsc, oneshot}, }; use crate::util::errors::AnyError; pub type SyncMethod = Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> Option<Vec<u8>>>; pub type AsyncMethod = Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> BoxFuture<'static, Option<Vec<u8>>>>; pub type Duplex = Arc< dyn Send + Sync + Fn(Option<u32>, &[u8]) -> (Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>), >; pub enum Method { Sync(SyncMethod), Async(AsyncMethod), Duplex(Duplex), } /// Serialization is given to the RpcBuilder and defines how data gets serialized /// when callinth methods. pub trait Serialization: Send + Sync + 'static { fn serialize(&self, value: impl Serialize) -> Vec<u8>; fn deserialize<P: DeserializeOwned>(&self, b: &[u8]) -> Result<P, AnyError>; } /// RPC is a basic, transport-agnostic builder for RPC methods. You can /// register methods to it, then call `.build()` to get a "dispatcher" type. pub struct RpcBuilder<S> { serializer: Arc<S>, methods: HashMap<&'static str, Method>, calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>, } impl<S: Serialization> RpcBuilder<S> { /// Creates a new empty RPC builder. pub fn new(serializer: S) -> Self { Self { serializer: Arc::new(serializer), methods: HashMap::new(), calls: Arc::new(std::sync::Mutex::new(HashMap::new())), } } /// Creates a caller that will be connected to any eventual dispatchers, /// and that sends data to the "tx" channel. pub fn get_caller(&mut self, sender: mpsc::UnboundedSender<Vec<u8>>) -> RpcCaller<S> { RpcCaller { serializer: self.serializer.clone(), calls: self.calls.clone(), sender, } } /// Gets a method builder. pub fn methods<C: Send + Sync + 'static>(self, context: C) -> RpcMethodBuilder<S, C> { RpcMethodBuilder { context: Arc::new(context), serializer: self.serializer, methods: self.methods, calls: self.calls, } } } pub struct RpcMethodBuilder<S, C> { context: Arc<C>, serializer: Arc<S>, methods: HashMap<&'static str, Method>, calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>, } #[derive(Serialize)] struct DuplexStreamStarted { pub for_request_id: u32, pub stream_ids: Vec<u32>, } impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> { /// Registers a synchronous rpc call that returns its result directly. pub fn register_sync<P, R, F>(&mut self, method_name: &'static str, callback: F) where P: DeserializeOwned, R: Serialize, F: Fn(P, &C) -> Result<R, AnyError> + Send + Sync + 'static, { if self.methods.contains_key(method_name) { panic!("Method already registered: {method_name}"); } let serial = self.serializer.clone(); let context = self.context.clone(); self.methods.insert( method_name, Method::Sync(Arc::new(move |id, body| { let param = match serial.deserialize::<RequestParams<P>>(body) { Ok(p) => p, Err(err) => { return id.map(|id| { serial.serialize(ErrorResponse { id, error: ResponseError { code: 0, message: format!("{err:?}"), }, }) }) } }; match callback(param.params, &context) { Ok(result) => id.map(|id| serial.serialize(&SuccessResponse { id, result })), Err(err) => id.map(|id| { serial.serialize(ErrorResponse { id, error: ResponseError { code: -1, message: format!("{err:?}"), }, }) }), } })), ); } /// Registers an async rpc call that returns a Future. pub fn register_async<P, R, Fut, F>(&mut self, method_name: &'static str, callback: F) where P: DeserializeOwned + Send + 'static, R: Serialize + Send + Sync + 'static, Fut: Future<Output = Result<R, AnyError>> + Send, F: (Fn(P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static, { let serial = self.serializer.clone(); let context = self.context.clone(); self.methods.insert( method_name, Method::Async(Arc::new(move |id, body| { let param = match serial.deserialize::<RequestParams<P>>(body) { Ok(p) => p, Err(err) => { return future::ready(id.map(|id| { serial.serialize(ErrorResponse { id, error: ResponseError { code: 0, message: format!("{err:?}"), }, }) })) .boxed(); } }; let callback = callback.clone(); let serial = serial.clone(); let context = context.clone(); let fut = async move { match callback(param.params, context).await { Ok(result) => { id.map(|id| serial.serialize(&SuccessResponse { id, result })) } Err(err) => id.map(|id| { serial.serialize(ErrorResponse { id, error: ResponseError { code: -1, message: format!("{err:?}"), }, }) }), } }; fut.boxed() })), ); } /// Registers an async rpc call that returns a Future containing a duplex /// stream that should be handled by the client. pub fn register_duplex<P, R, Fut, F>( &mut self, method_name: &'static str, streams: usize, callback: F, ) where P: DeserializeOwned + Send + 'static, R: Serialize + Send + Sync + 'static, Fut: Future<Output = Result<R, AnyError>> + Send, F: (Fn(Vec<DuplexStream>, P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static, { let serial = self.serializer.clone(); let context = self.context.clone(); self.methods.insert( method_name, Method::Duplex(Arc::new(move |id, body| { let param = match serial.deserialize::<RequestParams<P>>(body) { Ok(p) => p, Err(err) => { return ( None, future::ready(id.map(|id| { serial.serialize(ErrorResponse { id, error: ResponseError { code: 0, message: format!("{err:?}"), }, }) })) .boxed(), ); } }; let callback = callback.clone(); let serial = serial.clone(); let context = context.clone(); let mut dto = StreamDto { req_id: id.unwrap_or(0), streams: Vec::with_capacity(streams), }; let mut servers = Vec::with_capacity(streams); for _ in 0..streams { let (client, server) = tokio::io::duplex(8192); servers.push(server); dto.streams.push((next_message_id(), client)); } let fut = async move { match callback(servers, param.params, context).await { Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })), Err(err) => id.map(|id| { serial.serialize(ErrorResponse { id, error: ResponseError { code: -1, message: format!("{err:?}"), }, }) }), } }; (Some(dto), fut.boxed()) })), ); } /// Builds into a usable, sync rpc dispatcher. pub fn build(mut self, log: log::Logger) -> RpcDispatcher<S, C> { let streams = Streams::default(); let s1 = streams.clone(); self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| { let s1 = s1.clone(); async move { s1.remove(m.stream).await; Ok(()) } }); let s2 = streams.clone(); self.register_sync(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| { s2.write(m.stream, m.segment); Ok(()) }); RpcDispatcher { log, context: self.context, calls: self.calls, serializer: self.serializer, methods: Arc::new(self.methods), streams, } } } type DispatchMethod = Box<dyn Send + Sync + FnOnce(Outcome)>; /// Dispatcher returned from a Builder that provides a transport-agnostic way to /// deserialize and dispatch RPC calls. This structure may get more advanced as /// time goes on... #[derive(Clone)] pub struct RpcCaller<S: Serialization> { serializer: Arc<S>, calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>, sender: mpsc::UnboundedSender<Vec<u8>>, } impl<S: Serialization> RpcCaller<S> { pub fn serialize_notify<M, A>(serializer: &S, method: M, params: A) -> Vec<u8> where S: Serialization, M: AsRef<str> + serde::Serialize, A: Serialize, { serializer.serialize(&FullRequest { id: None, method, params, }) } /// Enqueues an outbound call. Returns whether the message was enqueued. pub fn notify<M, A>(&self, method: M, params: A) -> bool where M: AsRef<str> + serde::Serialize, A: Serialize, { self.sender .send(Self::serialize_notify(&self.serializer, method, params)) .is_ok() } /// Enqueues an outbound call, returning its result. pub fn call<M, A, R>(&self, method: M, params: A) -> oneshot::Receiver<Result<R, ResponseError>> where M: AsRef<str> + serde::Serialize, A: Serialize, R: DeserializeOwned + Send + 'static, { let (tx, rx) = oneshot::channel(); let id = next_message_id(); let body = self.serializer.serialize(&FullRequest { id: Some(id), method, params, }); if self.sender.send(body).is_err() { drop(tx); return rx; } let serializer = self.serializer.clone(); self.calls.lock().unwrap().insert( id, Box::new(move |body| { match body { Outcome::Error(e) => tx.send(Err(e)).ok(), Outcome::Success(r) => match serializer.deserialize::<SuccessResponse<R>>(&r) { Ok(r) => tx.send(Ok(r.result)).ok(), Err(err) => tx .send(Err(ResponseError { code: 0, message: err.to_string(), })) .ok(), }, }; }), ); rx } } /// Dispatcher returned from a Builder that provides a transport-agnostic way to /// deserialize and handle RPC calls. This structure may get more advanced as /// time goes on... #[derive(Clone)] pub struct RpcDispatcher<S, C> { log: log::Logger, context: Arc<C>, serializer: Arc<S>, methods: Arc<HashMap<&'static str, Method>>, calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>, streams: Streams, } static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0); fn next_message_id() -> u32 { MESSAGE_ID_COUNTER.fetch_add(1, Ordering::SeqCst) } impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> { /// Runs the incoming request, returning the result of the call synchronously /// or in a future. (The caller can then decide whether to run the future /// sequentially in its receive loop, or not.) /// /// The future or return result will be optional bytes that should be sent /// back to the socket. pub fn dispatch(&self, body: &[u8]) -> MaybeSync { match self.serializer.deserialize::<PartialIncoming>(body) { Ok(partial) => self.dispatch_with_partial(body, partial), Err(_err) => { warning!(self.log, "Failed to deserialize request, hex: {:X?}", body); MaybeSync::Sync(None) } } } /// Like dispatch, but allows passing an existing PartialIncoming. pub fn dispatch_with_partial(&self, body: &[u8], partial: PartialIncoming) -> MaybeSync { let id = partial.id; if let Some(method_name) = partial.method { let method = self.methods.get(method_name.as_str()); match method { Some(Method::Sync(callback)) => MaybeSync::Sync(callback(id, body)), Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)), Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)), None => MaybeSync::Sync(id.map(|id| { self.serializer.serialize(ErrorResponse { id, error: ResponseError { code: -1, message: format!("Method not found: {method_name}"), }, }) })), } } else if let Some(err) = partial.error { if let Some(cb) = self.calls.lock().unwrap().remove(&id.unwrap()) { cb(Outcome::Error(err)); } MaybeSync::Sync(None) } else { if let Some(cb) = self.calls.lock().unwrap().remove(&id.unwrap()) { cb(Outcome::Success(body.to_vec())); } MaybeSync::Sync(None) } } /// Registers a stream call returned from dispatch(). pub async fn register_stream( &self, write_tx: mpsc::Sender<impl 'static + From<Vec<u8>> + Send>, dto: StreamDto, ) { let r = write_tx .send( self.serializer .serialize(&FullRequest { id: None, method: METHOD_STREAMS_STARTED, params: DuplexStreamStarted { stream_ids: dto.streams.iter().map(|(id, _)| *id).collect(), for_request_id: dto.req_id, }, }) .into(), ) .await; if r.is_err() { return; } for (stream_id, duplex) in dto.streams { let (mut read, write) = tokio::io::split(duplex); self.streams.insert(stream_id, write); let write_tx = write_tx.clone(); let serial = self.serializer.clone(); tokio::spawn(async move { let mut buf = vec![0; 4096]; loop { match read.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => { let r = write_tx .send( serial .serialize(&FullRequest { id: None, method: METHOD_STREAM_DATA, params: StreamDataParams { segment: &buf[..n], stream: stream_id, }, }) .into(), ) .await; if r.is_err() { return; } } } } let _ = write_tx .send( serial .serialize(&FullRequest { id: None, method: METHOD_STREAM_ENDED, params: StreamEndedParams { stream: stream_id }, }) .into(), ) .await; }); } } pub fn context(&self) -> Arc<C> { self.context.clone() } } struct StreamRec { write: Option<WriteHalf<DuplexStream>>, q: Vec<Vec<u8>>, ended: bool, } #[derive(Clone, Default)] struct Streams { map: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>, } impl Streams { pub async fn remove(&self, id: u32) { let mut remove = None; { let mut map = self.map.lock().unwrap(); if let Some(s) = map.get_mut(&id) { if let Some(w) = s.write.take() { map.remove(&id); remove = Some(w); } else { s.ended = true; // will shut down in write loop } } } // do this outside of the sync lock: if let Some(mut w) = remove { let _ = w.shutdown().await; } } pub fn write(&self, id: u32, buf: Vec<u8>) { let mut map = self.map.lock().unwrap(); if let Some(s) = map.get_mut(&id) { s.q.push(buf); if let Some(w) = s.write.take() { tokio::spawn(write_loop(id, w, self.map.clone())); } } } pub fn insert(&self, id: u32, stream: WriteHalf<DuplexStream>) { self.map.lock().unwrap().insert( id, StreamRec { write: Some(stream), q: Vec::new(), ended: false, }, ); } } /// Write loop started by `Streams.write`. It takes the WriteHalf, and /// runs until there's no more items in the 'write queue'. At that point, if the /// record still exists in the `streams` (i.e. we haven't shut down), it'll /// return the WriteHalf so that the next `write` call starts /// the loop again. Otherwise, it'll shut down the WriteHalf. /// /// This is the equivalent of the same write_loop in the server_multiplexer. /// I couldn't figure out a nice way to abstract it without introducing /// performance overhead... async fn write_loop( id: u32, mut w: WriteHalf<DuplexStream>, streams: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>, ) { let mut items_vec = vec![]; loop { { let mut lock = streams.lock().unwrap(); let stream_rec = match lock.get_mut(&id) { Some(b) => b, None => break, }; if stream_rec.q.is_empty() { if stream_rec.ended { lock.remove(&id); break; } else { stream_rec.write = Some(w); return; } } std::mem::swap(&mut stream_rec.q, &mut items_vec); } for item in items_vec.drain(..) { if w.write_all(&item).await.is_err() { break; } } } let _ = w.shutdown().await; // got here from `break` above, meaning our record got cleared. Close the bridge if so } const METHOD_STREAMS_STARTED: &str = "streams_started"; const METHOD_STREAM_DATA: &str = "stream_data"; const METHOD_STREAM_ENDED: &str = "stream_ended"; #[allow(dead_code)] // false positive trait AssertIsSync: Sync {} impl<S: Serialization, C: Send + Sync> AssertIsSync for RpcDispatcher<S, C> {} /// Approximate shape that is used to determine what kind of data is incoming. #[derive(Deserialize, Debug)] pub struct PartialIncoming { pub id: Option<u32>, pub method: Option<String>, pub error: Option<ResponseError>, } #[derive(Deserialize)] struct StreamDataIncomingParams { #[serde(with = "serde_bytes")] pub segment: Vec<u8>, pub stream: u32, } #[derive(Serialize, Deserialize)] struct StreamDataParams<'a> { #[serde(with = "serde_bytes")] pub segment: &'a [u8], pub stream: u32, } #[derive(Serialize, Deserialize)] struct StreamEndedParams { pub stream: u32, } #[derive(Serialize)] pub struct FullRequest<M: AsRef<str>, P> { pub id: Option<u32>, pub method: M, pub params: P, } #[derive(Deserialize)] struct RequestParams<P> { pub params: P, } #[derive(Serialize, Deserialize)] struct SuccessResponse<T> { pub id: u32, pub result: T, } #[derive(Serialize, Deserialize)] struct ErrorResponse { pub id: u32, pub error: ResponseError, } #[derive(Serialize, Deserialize, Debug)] pub struct ResponseError { pub code: i32, pub message: String, } enum Outcome { Success(Vec<u8>), Error(ResponseError), } pub struct StreamDto { req_id: u32, streams: Vec<(u32, DuplexStream)>, } pub enum MaybeSync { Stream((Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>)), Future(BoxFuture<'static, Option<Vec<u8>>>), Sync(Option<Vec<u8>>), } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_remove() { let streams = Streams::default(); let (writer, mut reader) = tokio::io::duplex(1024); streams.insert(1, tokio::io::split(writer).1); streams.remove(1).await; assert!(streams.map.lock().unwrap().get(&1).is_none()); let mut buffer = Vec::new(); assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 0); } #[tokio::test] async fn test_write() { let streams = Streams::default(); let (writer, mut reader) = tokio::io::duplex(1024); streams.insert(1, tokio::io::split(writer).1); streams.write(1, vec![1, 2, 3]); let mut buffer = [0; 3]; assert_eq!(reader.read_exact(&mut buffer).await.unwrap(), 3); assert_eq!(buffer, [1, 2, 3]); } #[tokio::test] async fn test_write_with_immediate_end() { let streams = Streams::default(); let (writer, mut reader) = tokio::io::duplex(1); streams.insert(1, tokio::io::split(writer).1); streams.write(1, vec![1, 2, 3]); // spawn write loop streams.write(1, vec![4, 5, 6]); // enqueued while writing streams.remove(1).await; // end stream let mut buffer = Vec::new(); assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 6); assert_eq!(buffer, vec![1, 2, 3, 4, 5, 6]); } }