vscode/cli/src/rpc.rs

756 lines
19 KiB
Rust
Raw Normal View History

2024-11-15 06:29:18 +00:00
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::{
collections::HashMap,
future,
sync::{
atomic::{AtomicU32, Ordering},
Arc, Mutex,
},
};
use crate::log;
use futures::{future::BoxFuture, Future, FutureExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, DuplexStream, WriteHalf},
sync::{mpsc, oneshot},
};
use crate::util::errors::AnyError;
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> Option<Vec<u8>>>;
pub type AsyncMethod =
Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> BoxFuture<'static, Option<Vec<u8>>>>;
pub type Duplex = Arc<
dyn Send
+ Sync
+ Fn(Option<u32>, &[u8]) -> (Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>),
>;
pub enum Method {
Sync(SyncMethod),
Async(AsyncMethod),
Duplex(Duplex),
}
/// Serialization is given to the RpcBuilder and defines how data gets serialized
/// when callinth methods.
pub trait Serialization: Send + Sync + 'static {
fn serialize(&self, value: impl Serialize) -> Vec<u8>;
fn deserialize<P: DeserializeOwned>(&self, b: &[u8]) -> Result<P, AnyError>;
}
/// RPC is a basic, transport-agnostic builder for RPC methods. You can
/// register methods to it, then call `.build()` to get a "dispatcher" type.
pub struct RpcBuilder<S> {
serializer: Arc<S>,
methods: HashMap<&'static str, Method>,
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
}
impl<S: Serialization> RpcBuilder<S> {
/// Creates a new empty RPC builder.
pub fn new(serializer: S) -> Self {
Self {
serializer: Arc::new(serializer),
methods: HashMap::new(),
calls: Arc::new(std::sync::Mutex::new(HashMap::new())),
}
}
/// Creates a caller that will be connected to any eventual dispatchers,
/// and that sends data to the "tx" channel.
pub fn get_caller(&mut self, sender: mpsc::UnboundedSender<Vec<u8>>) -> RpcCaller<S> {
RpcCaller {
serializer: self.serializer.clone(),
calls: self.calls.clone(),
sender,
}
}
/// Gets a method builder.
pub fn methods<C: Send + Sync + 'static>(self, context: C) -> RpcMethodBuilder<S, C> {
RpcMethodBuilder {
context: Arc::new(context),
serializer: self.serializer,
methods: self.methods,
calls: self.calls,
}
}
}
pub struct RpcMethodBuilder<S, C> {
context: Arc<C>,
serializer: Arc<S>,
methods: HashMap<&'static str, Method>,
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
}
#[derive(Serialize)]
struct DuplexStreamStarted {
pub for_request_id: u32,
pub stream_ids: Vec<u32>,
}
impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
/// Registers a synchronous rpc call that returns its result directly.
pub fn register_sync<P, R, F>(&mut self, method_name: &'static str, callback: F)
where
P: DeserializeOwned,
R: Serialize,
F: Fn(P, &C) -> Result<R, AnyError> + Send + Sync + 'static,
{
if self.methods.contains_key(method_name) {
panic!("Method already registered: {method_name}");
}
let serial = self.serializer.clone();
let context = self.context.clone();
self.methods.insert(
method_name,
Method::Sync(Arc::new(move |id, body| {
let param = match serial.deserialize::<RequestParams<P>>(body) {
Ok(p) => p,
Err(err) => {
return id.map(|id| {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: 0,
message: format!("{err:?}"),
},
})
})
}
};
match callback(param.params, &context) {
Ok(result) => id.map(|id| serial.serialize(&SuccessResponse { id, result })),
Err(err) => id.map(|id| {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: -1,
message: format!("{err:?}"),
},
})
}),
}
})),
);
}
/// Registers an async rpc call that returns a Future.
pub fn register_async<P, R, Fut, F>(&mut self, method_name: &'static str, callback: F)
where
P: DeserializeOwned + Send + 'static,
R: Serialize + Send + Sync + 'static,
Fut: Future<Output = Result<R, AnyError>> + Send,
F: (Fn(P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
{
let serial = self.serializer.clone();
let context = self.context.clone();
self.methods.insert(
method_name,
Method::Async(Arc::new(move |id, body| {
let param = match serial.deserialize::<RequestParams<P>>(body) {
Ok(p) => p,
Err(err) => {
return future::ready(id.map(|id| {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: 0,
message: format!("{err:?}"),
},
})
}))
.boxed();
}
};
let callback = callback.clone();
let serial = serial.clone();
let context = context.clone();
let fut = async move {
match callback(param.params, context).await {
Ok(result) => {
id.map(|id| serial.serialize(&SuccessResponse { id, result }))
}
Err(err) => id.map(|id| {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: -1,
message: format!("{err:?}"),
},
})
}),
}
};
fut.boxed()
})),
);
}
/// Registers an async rpc call that returns a Future containing a duplex
/// stream that should be handled by the client.
pub fn register_duplex<P, R, Fut, F>(
&mut self,
method_name: &'static str,
streams: usize,
callback: F,
) where
P: DeserializeOwned + Send + 'static,
R: Serialize + Send + Sync + 'static,
Fut: Future<Output = Result<R, AnyError>> + Send,
F: (Fn(Vec<DuplexStream>, P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
{
let serial = self.serializer.clone();
let context = self.context.clone();
self.methods.insert(
method_name,
Method::Duplex(Arc::new(move |id, body| {
let param = match serial.deserialize::<RequestParams<P>>(body) {
Ok(p) => p,
Err(err) => {
return (
None,
future::ready(id.map(|id| {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: 0,
message: format!("{err:?}"),
},
})
}))
.boxed(),
);
}
};
let callback = callback.clone();
let serial = serial.clone();
let context = context.clone();
let mut dto = StreamDto {
req_id: id.unwrap_or(0),
streams: Vec::with_capacity(streams),
};
let mut servers = Vec::with_capacity(streams);
for _ in 0..streams {
let (client, server) = tokio::io::duplex(8192);
servers.push(server);
dto.streams.push((next_message_id(), client));
}
let fut = async move {
match callback(servers, param.params, context).await {
Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })),
Err(err) => id.map(|id| {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: -1,
message: format!("{err:?}"),
},
})
}),
}
};
(Some(dto), fut.boxed())
})),
);
}
/// Builds into a usable, sync rpc dispatcher.
pub fn build(mut self, log: log::Logger) -> RpcDispatcher<S, C> {
let streams = Streams::default();
let s1 = streams.clone();
self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| {
let s1 = s1.clone();
async move {
s1.remove(m.stream).await;
Ok(())
}
});
let s2 = streams.clone();
self.register_sync(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
s2.write(m.stream, m.segment);
Ok(())
});
RpcDispatcher {
log,
context: self.context,
calls: self.calls,
serializer: self.serializer,
methods: Arc::new(self.methods),
streams,
}
}
}
type DispatchMethod = Box<dyn Send + Sync + FnOnce(Outcome)>;
/// Dispatcher returned from a Builder that provides a transport-agnostic way to
/// deserialize and dispatch RPC calls. This structure may get more advanced as
/// time goes on...
#[derive(Clone)]
pub struct RpcCaller<S: Serialization> {
serializer: Arc<S>,
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
sender: mpsc::UnboundedSender<Vec<u8>>,
}
impl<S: Serialization> RpcCaller<S> {
pub fn serialize_notify<M, A>(serializer: &S, method: M, params: A) -> Vec<u8>
where
S: Serialization,
M: AsRef<str> + serde::Serialize,
A: Serialize,
{
serializer.serialize(&FullRequest {
id: None,
method,
params,
})
}
/// Enqueues an outbound call. Returns whether the message was enqueued.
pub fn notify<M, A>(&self, method: M, params: A) -> bool
where
M: AsRef<str> + serde::Serialize,
A: Serialize,
{
self.sender
.send(Self::serialize_notify(&self.serializer, method, params))
.is_ok()
}
/// Enqueues an outbound call, returning its result.
pub fn call<M, A, R>(&self, method: M, params: A) -> oneshot::Receiver<Result<R, ResponseError>>
where
M: AsRef<str> + serde::Serialize,
A: Serialize,
R: DeserializeOwned + Send + 'static,
{
let (tx, rx) = oneshot::channel();
let id = next_message_id();
let body = self.serializer.serialize(&FullRequest {
id: Some(id),
method,
params,
});
if self.sender.send(body).is_err() {
drop(tx);
return rx;
}
let serializer = self.serializer.clone();
self.calls.lock().unwrap().insert(
id,
Box::new(move |body| {
match body {
Outcome::Error(e) => tx.send(Err(e)).ok(),
Outcome::Success(r) => match serializer.deserialize::<SuccessResponse<R>>(&r) {
Ok(r) => tx.send(Ok(r.result)).ok(),
Err(err) => tx
.send(Err(ResponseError {
code: 0,
message: err.to_string(),
}))
.ok(),
},
};
}),
);
rx
}
}
/// Dispatcher returned from a Builder that provides a transport-agnostic way to
/// deserialize and handle RPC calls. This structure may get more advanced as
/// time goes on...
#[derive(Clone)]
pub struct RpcDispatcher<S, C> {
log: log::Logger,
context: Arc<C>,
serializer: Arc<S>,
methods: Arc<HashMap<&'static str, Method>>,
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
streams: Streams,
}
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
fn next_message_id() -> u32 {
MESSAGE_ID_COUNTER.fetch_add(1, Ordering::SeqCst)
}
impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
/// Runs the incoming request, returning the result of the call synchronously
/// or in a future. (The caller can then decide whether to run the future
/// sequentially in its receive loop, or not.)
///
/// The future or return result will be optional bytes that should be sent
/// back to the socket.
pub fn dispatch(&self, body: &[u8]) -> MaybeSync {
match self.serializer.deserialize::<PartialIncoming>(body) {
Ok(partial) => self.dispatch_with_partial(body, partial),
Err(_err) => {
warning!(self.log, "Failed to deserialize request, hex: {:X?}", body);
MaybeSync::Sync(None)
}
}
}
/// Like dispatch, but allows passing an existing PartialIncoming.
pub fn dispatch_with_partial(&self, body: &[u8], partial: PartialIncoming) -> MaybeSync {
let id = partial.id;
if let Some(method_name) = partial.method {
let method = self.methods.get(method_name.as_str());
match method {
Some(Method::Sync(callback)) => MaybeSync::Sync(callback(id, body)),
Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)),
Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)),
None => MaybeSync::Sync(id.map(|id| {
self.serializer.serialize(ErrorResponse {
id,
error: ResponseError {
code: -1,
message: format!("Method not found: {method_name}"),
},
})
})),
}
} else if let Some(err) = partial.error {
if let Some(cb) = self.calls.lock().unwrap().remove(&id.unwrap()) {
cb(Outcome::Error(err));
}
MaybeSync::Sync(None)
} else {
if let Some(cb) = self.calls.lock().unwrap().remove(&id.unwrap()) {
cb(Outcome::Success(body.to_vec()));
}
MaybeSync::Sync(None)
}
}
/// Registers a stream call returned from dispatch().
pub async fn register_stream(
&self,
write_tx: mpsc::Sender<impl 'static + From<Vec<u8>> + Send>,
dto: StreamDto,
) {
let r = write_tx
.send(
self.serializer
.serialize(&FullRequest {
id: None,
method: METHOD_STREAMS_STARTED,
params: DuplexStreamStarted {
stream_ids: dto.streams.iter().map(|(id, _)| *id).collect(),
for_request_id: dto.req_id,
},
})
.into(),
)
.await;
if r.is_err() {
return;
}
for (stream_id, duplex) in dto.streams {
let (mut read, write) = tokio::io::split(duplex);
self.streams.insert(stream_id, write);
let write_tx = write_tx.clone();
let serial = self.serializer.clone();
tokio::spawn(async move {
let mut buf = vec![0; 4096];
loop {
match read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
let r = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_DATA,
params: StreamDataParams {
segment: &buf[..n],
stream: stream_id,
},
})
.into(),
)
.await;
if r.is_err() {
return;
}
}
}
}
let _ = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_ENDED,
params: StreamEndedParams { stream: stream_id },
})
.into(),
)
.await;
});
}
}
pub fn context(&self) -> Arc<C> {
self.context.clone()
}
}
struct StreamRec {
write: Option<WriteHalf<DuplexStream>>,
q: Vec<Vec<u8>>,
ended: bool,
}
#[derive(Clone, Default)]
struct Streams {
map: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
}
impl Streams {
pub async fn remove(&self, id: u32) {
let mut remove = None;
{
let mut map = self.map.lock().unwrap();
if let Some(s) = map.get_mut(&id) {
if let Some(w) = s.write.take() {
map.remove(&id);
remove = Some(w);
} else {
s.ended = true; // will shut down in write loop
}
}
}
// do this outside of the sync lock:
if let Some(mut w) = remove {
let _ = w.shutdown().await;
}
}
pub fn write(&self, id: u32, buf: Vec<u8>) {
let mut map = self.map.lock().unwrap();
if let Some(s) = map.get_mut(&id) {
s.q.push(buf);
if let Some(w) = s.write.take() {
tokio::spawn(write_loop(id, w, self.map.clone()));
}
}
}
pub fn insert(&self, id: u32, stream: WriteHalf<DuplexStream>) {
self.map.lock().unwrap().insert(
id,
StreamRec {
write: Some(stream),
q: Vec::new(),
ended: false,
},
);
}
}
/// Write loop started by `Streams.write`. It takes the WriteHalf, and
/// runs until there's no more items in the 'write queue'. At that point, if the
/// record still exists in the `streams` (i.e. we haven't shut down), it'll
/// return the WriteHalf so that the next `write` call starts
/// the loop again. Otherwise, it'll shut down the WriteHalf.
///
/// This is the equivalent of the same write_loop in the server_multiplexer.
/// I couldn't figure out a nice way to abstract it without introducing
/// performance overhead...
async fn write_loop(
id: u32,
mut w: WriteHalf<DuplexStream>,
streams: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
) {
let mut items_vec = vec![];
loop {
{
let mut lock = streams.lock().unwrap();
let stream_rec = match lock.get_mut(&id) {
Some(b) => b,
None => break,
};
if stream_rec.q.is_empty() {
if stream_rec.ended {
lock.remove(&id);
break;
} else {
stream_rec.write = Some(w);
return;
}
}
std::mem::swap(&mut stream_rec.q, &mut items_vec);
}
for item in items_vec.drain(..) {
if w.write_all(&item).await.is_err() {
break;
}
}
}
let _ = w.shutdown().await; // got here from `break` above, meaning our record got cleared. Close the bridge if so
}
const METHOD_STREAMS_STARTED: &str = "streams_started";
const METHOD_STREAM_DATA: &str = "stream_data";
const METHOD_STREAM_ENDED: &str = "stream_ended";
#[allow(dead_code)] // false positive
trait AssertIsSync: Sync {}
impl<S: Serialization, C: Send + Sync> AssertIsSync for RpcDispatcher<S, C> {}
/// Approximate shape that is used to determine what kind of data is incoming.
#[derive(Deserialize, Debug)]
pub struct PartialIncoming {
pub id: Option<u32>,
pub method: Option<String>,
pub error: Option<ResponseError>,
}
#[derive(Deserialize)]
struct StreamDataIncomingParams {
#[serde(with = "serde_bytes")]
pub segment: Vec<u8>,
pub stream: u32,
}
#[derive(Serialize, Deserialize)]
struct StreamDataParams<'a> {
#[serde(with = "serde_bytes")]
pub segment: &'a [u8],
pub stream: u32,
}
#[derive(Serialize, Deserialize)]
struct StreamEndedParams {
pub stream: u32,
}
#[derive(Serialize)]
pub struct FullRequest<M: AsRef<str>, P> {
pub id: Option<u32>,
pub method: M,
pub params: P,
}
#[derive(Deserialize)]
struct RequestParams<P> {
pub params: P,
}
#[derive(Serialize, Deserialize)]
struct SuccessResponse<T> {
pub id: u32,
pub result: T,
}
#[derive(Serialize, Deserialize)]
struct ErrorResponse {
pub id: u32,
pub error: ResponseError,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseError {
pub code: i32,
pub message: String,
}
enum Outcome {
Success(Vec<u8>),
Error(ResponseError),
}
pub struct StreamDto {
req_id: u32,
streams: Vec<(u32, DuplexStream)>,
}
pub enum MaybeSync {
Stream((Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>)),
Future(BoxFuture<'static, Option<Vec<u8>>>),
Sync(Option<Vec<u8>>),
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_remove() {
let streams = Streams::default();
let (writer, mut reader) = tokio::io::duplex(1024);
streams.insert(1, tokio::io::split(writer).1);
streams.remove(1).await;
assert!(streams.map.lock().unwrap().get(&1).is_none());
let mut buffer = Vec::new();
assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 0);
}
#[tokio::test]
async fn test_write() {
let streams = Streams::default();
let (writer, mut reader) = tokio::io::duplex(1024);
streams.insert(1, tokio::io::split(writer).1);
streams.write(1, vec![1, 2, 3]);
let mut buffer = [0; 3];
assert_eq!(reader.read_exact(&mut buffer).await.unwrap(), 3);
assert_eq!(buffer, [1, 2, 3]);
}
#[tokio::test]
async fn test_write_with_immediate_end() {
let streams = Streams::default();
let (writer, mut reader) = tokio::io::duplex(1);
streams.insert(1, tokio::io::split(writer).1);
streams.write(1, vec![1, 2, 3]); // spawn write loop
streams.write(1, vec![4, 5, 6]); // enqueued while writing
streams.remove(1).await; // end stream
let mut buffer = Vec::new();
assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 6);
assert_eq!(buffer, vec![1, 2, 3, 4, 5, 6]);
}
}