/*---------------------------------------------------------------------------------------------
 *  Copyright (c) Microsoft Corporation. All rights reserved.
 *  Licensed under the MIT License. See License.txt in the project root for license information.
 *--------------------------------------------------------------------------------------------*/

use std::{
	collections::HashMap,
	future,
	sync::{
		atomic::{AtomicU32, Ordering},
		Arc, Mutex,
	},
};

use crate::log;
use futures::{future::BoxFuture, Future, FutureExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::{
	io::{AsyncReadExt, AsyncWriteExt, DuplexStream, WriteHalf},
	sync::{mpsc, oneshot},
};

use crate::util::errors::AnyError;

pub type SyncMethod = Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> Option<Vec<u8>>>;
pub type AsyncMethod =
	Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> BoxFuture<'static, Option<Vec<u8>>>>;
pub type Duplex = Arc<
	dyn Send
		+ Sync
		+ Fn(Option<u32>, &[u8]) -> (Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>),
>;

pub enum Method {
	Sync(SyncMethod),
	Async(AsyncMethod),
	Duplex(Duplex),
}

/// Serialization is given to the RpcBuilder and defines how data gets serialized
/// when callinth methods.
pub trait Serialization: Send + Sync + 'static {
	fn serialize(&self, value: impl Serialize) -> Vec<u8>;
	fn deserialize<P: DeserializeOwned>(&self, b: &[u8]) -> Result<P, AnyError>;
}

/// RPC is a basic, transport-agnostic builder for RPC methods. You can
/// register methods to it, then call `.build()` to get a "dispatcher" type.
pub struct RpcBuilder<S> {
	serializer: Arc<S>,
	methods: HashMap<&'static str, Method>,
	calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
}

impl<S: Serialization> RpcBuilder<S> {
	/// Creates a new empty RPC builder.
	pub fn new(serializer: S) -> Self {
		Self {
			serializer: Arc::new(serializer),
			methods: HashMap::new(),
			calls: Arc::new(std::sync::Mutex::new(HashMap::new())),
		}
	}

	/// Creates a caller that will be connected to any eventual dispatchers,
	/// and that sends data to the "tx" channel.
	pub fn get_caller(&mut self, sender: mpsc::UnboundedSender<Vec<u8>>) -> RpcCaller<S> {
		RpcCaller {
			serializer: self.serializer.clone(),
			calls: self.calls.clone(),
			sender,
		}
	}

	/// Gets a method builder.
	pub fn methods<C: Send + Sync + 'static>(self, context: C) -> RpcMethodBuilder<S, C> {
		RpcMethodBuilder {
			context: Arc::new(context),
			serializer: self.serializer,
			methods: self.methods,
			calls: self.calls,
		}
	}
}

pub struct RpcMethodBuilder<S, C> {
	context: Arc<C>,
	serializer: Arc<S>,
	methods: HashMap<&'static str, Method>,
	calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
}

#[derive(Serialize)]
struct DuplexStreamStarted {
	pub for_request_id: u32,
	pub stream_ids: Vec<u32>,
}

impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
	/// Registers a synchronous rpc call that returns its result directly.
	pub fn register_sync<P, R, F>(&mut self, method_name: &'static str, callback: F)
	where
		P: DeserializeOwned,
		R: Serialize,
		F: Fn(P, &C) -> Result<R, AnyError> + Send + Sync + 'static,
	{
		if self.methods.contains_key(method_name) {
			panic!("Method already registered: {method_name}");
		}

		let serial = self.serializer.clone();
		let context = self.context.clone();
		self.methods.insert(
			method_name,
			Method::Sync(Arc::new(move |id, body| {
				let param = match serial.deserialize::<RequestParams<P>>(body) {
					Ok(p) => p,
					Err(err) => {
						return id.map(|id| {
							serial.serialize(ErrorResponse {
								id,
								error: ResponseError {
									code: 0,
									message: format!("{err:?}"),
								},
							})
						})
					}
				};

				match callback(param.params, &context) {
					Ok(result) => id.map(|id| serial.serialize(&SuccessResponse { id, result })),
					Err(err) => id.map(|id| {
						serial.serialize(ErrorResponse {
							id,
							error: ResponseError {
								code: -1,
								message: format!("{err:?}"),
							},
						})
					}),
				}
			})),
		);
	}

	/// Registers an async rpc call that returns a Future.
	pub fn register_async<P, R, Fut, F>(&mut self, method_name: &'static str, callback: F)
	where
		P: DeserializeOwned + Send + 'static,
		R: Serialize + Send + Sync + 'static,
		Fut: Future<Output = Result<R, AnyError>> + Send,
		F: (Fn(P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
	{
		let serial = self.serializer.clone();
		let context = self.context.clone();
		self.methods.insert(
			method_name,
			Method::Async(Arc::new(move |id, body| {
				let param = match serial.deserialize::<RequestParams<P>>(body) {
					Ok(p) => p,
					Err(err) => {
						return future::ready(id.map(|id| {
							serial.serialize(ErrorResponse {
								id,
								error: ResponseError {
									code: 0,
									message: format!("{err:?}"),
								},
							})
						}))
						.boxed();
					}
				};

				let callback = callback.clone();
				let serial = serial.clone();
				let context = context.clone();
				let fut = async move {
					match callback(param.params, context).await {
						Ok(result) => {
							id.map(|id| serial.serialize(&SuccessResponse { id, result }))
						}
						Err(err) => id.map(|id| {
							serial.serialize(ErrorResponse {
								id,
								error: ResponseError {
									code: -1,
									message: format!("{err:?}"),
								},
							})
						}),
					}
				};

				fut.boxed()
			})),
		);
	}

	/// Registers an async rpc call that returns a Future containing a duplex
	/// stream that should be handled by the client.
	pub fn register_duplex<P, R, Fut, F>(
		&mut self,
		method_name: &'static str,
		streams: usize,
		callback: F,
	) where
		P: DeserializeOwned + Send + 'static,
		R: Serialize + Send + Sync + 'static,
		Fut: Future<Output = Result<R, AnyError>> + Send,
		F: (Fn(Vec<DuplexStream>, P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
	{
		let serial = self.serializer.clone();
		let context = self.context.clone();
		self.methods.insert(
			method_name,
			Method::Duplex(Arc::new(move |id, body| {
				let param = match serial.deserialize::<RequestParams<P>>(body) {
					Ok(p) => p,
					Err(err) => {
						return (
							None,
							future::ready(id.map(|id| {
								serial.serialize(ErrorResponse {
									id,
									error: ResponseError {
										code: 0,
										message: format!("{err:?}"),
									},
								})
							}))
							.boxed(),
						);
					}
				};

				let callback = callback.clone();
				let serial = serial.clone();
				let context = context.clone();

				let mut dto = StreamDto {
					req_id: id.unwrap_or(0),
					streams: Vec::with_capacity(streams),
				};
				let mut servers = Vec::with_capacity(streams);

				for _ in 0..streams {
					let (client, server) = tokio::io::duplex(8192);
					servers.push(server);
					dto.streams.push((next_message_id(), client));
				}

				let fut = async move {
					match callback(servers, param.params, context).await {
						Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })),
						Err(err) => id.map(|id| {
							serial.serialize(ErrorResponse {
								id,
								error: ResponseError {
									code: -1,
									message: format!("{err:?}"),
								},
							})
						}),
					}
				};

				(Some(dto), fut.boxed())
			})),
		);
	}

	/// Builds into a usable, sync rpc dispatcher.
	pub fn build(mut self, log: log::Logger) -> RpcDispatcher<S, C> {
		let streams = Streams::default();

		let s1 = streams.clone();
		self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| {
			let s1 = s1.clone();
			async move {
				s1.remove(m.stream).await;
				Ok(())
			}
		});

		let s2 = streams.clone();
		self.register_sync(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
			s2.write(m.stream, m.segment);
			Ok(())
		});

		RpcDispatcher {
			log,
			context: self.context,
			calls: self.calls,
			serializer: self.serializer,
			methods: Arc::new(self.methods),
			streams,
		}
	}
}

type DispatchMethod = Box<dyn Send + Sync + FnOnce(Outcome)>;

/// Dispatcher returned from a Builder that provides a transport-agnostic way to
/// deserialize and dispatch RPC calls. This structure may get more advanced as
/// time goes on...
#[derive(Clone)]
pub struct RpcCaller<S: Serialization> {
	serializer: Arc<S>,
	calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
	sender: mpsc::UnboundedSender<Vec<u8>>,
}

impl<S: Serialization> RpcCaller<S> {
	pub fn serialize_notify<M, A>(serializer: &S, method: M, params: A) -> Vec<u8>
	where
		S: Serialization,
		M: AsRef<str> + serde::Serialize,
		A: Serialize,
	{
		serializer.serialize(&FullRequest {
			id: None,
			method,
			params,
		})
	}

	/// Enqueues an outbound call. Returns whether the message was enqueued.
	pub fn notify<M, A>(&self, method: M, params: A) -> bool
	where
		M: AsRef<str> + serde::Serialize,
		A: Serialize,
	{
		self.sender
			.send(Self::serialize_notify(&self.serializer, method, params))
			.is_ok()
	}

	/// Enqueues an outbound call, returning its result.
	pub fn call<M, A, R>(&self, method: M, params: A) -> oneshot::Receiver<Result<R, ResponseError>>
	where
		M: AsRef<str> + serde::Serialize,
		A: Serialize,
		R: DeserializeOwned + Send + 'static,
	{
		let (tx, rx) = oneshot::channel();
		let id = next_message_id();
		let body = self.serializer.serialize(&FullRequest {
			id: Some(id),
			method,
			params,
		});

		if self.sender.send(body).is_err() {
			drop(tx);
			return rx;
		}

		let serializer = self.serializer.clone();
		self.calls.lock().unwrap().insert(
			id,
			Box::new(move |body| {
				match body {
					Outcome::Error(e) => tx.send(Err(e)).ok(),
					Outcome::Success(r) => match serializer.deserialize::<SuccessResponse<R>>(&r) {
						Ok(r) => tx.send(Ok(r.result)).ok(),
						Err(err) => tx
							.send(Err(ResponseError {
								code: 0,
								message: err.to_string(),
							}))
							.ok(),
					},
				};
			}),
		);

		rx
	}
}

/// Dispatcher returned from a Builder that provides a transport-agnostic way to
/// deserialize and handle RPC calls. This structure may get more advanced as
/// time goes on...
#[derive(Clone)]
pub struct RpcDispatcher<S, C> {
	log: log::Logger,
	context: Arc<C>,
	serializer: Arc<S>,
	methods: Arc<HashMap<&'static str, Method>>,
	calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
	streams: Streams,
}

static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
fn next_message_id() -> u32 {
	MESSAGE_ID_COUNTER.fetch_add(1, Ordering::SeqCst)
}

impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
	/// Runs the incoming request, returning the result of the call synchronously
	/// or in a future. (The caller can then decide whether to run the future
	/// sequentially in its receive loop, or not.)
	///
	/// The future or return result will be optional bytes that should be sent
	/// back to the socket.
	pub fn dispatch(&self, body: &[u8]) -> MaybeSync {
		match self.serializer.deserialize::<PartialIncoming>(body) {
			Ok(partial) => self.dispatch_with_partial(body, partial),
			Err(_err) => {
				warning!(self.log, "Failed to deserialize request, hex: {:X?}", body);
				MaybeSync::Sync(None)
			}
		}
	}

	/// Like dispatch, but allows passing an existing PartialIncoming.
	pub fn dispatch_with_partial(&self, body: &[u8], partial: PartialIncoming) -> MaybeSync {
		let id = partial.id;

		if let Some(method_name) = partial.method {
			let method = self.methods.get(method_name.as_str());
			match method {
				Some(Method::Sync(callback)) => MaybeSync::Sync(callback(id, body)),
				Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)),
				Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)),
				None => MaybeSync::Sync(id.map(|id| {
					self.serializer.serialize(ErrorResponse {
						id,
						error: ResponseError {
							code: -1,
							message: format!("Method not found: {method_name}"),
						},
					})
				})),
			}
		} else if let Some(err) = partial.error {
			if let Some(cb) = self.calls.lock().unwrap().remove(&id.unwrap()) {
				cb(Outcome::Error(err));
			}
			MaybeSync::Sync(None)
		} else {
			if let Some(cb) = self.calls.lock().unwrap().remove(&id.unwrap()) {
				cb(Outcome::Success(body.to_vec()));
			}
			MaybeSync::Sync(None)
		}
	}

	/// Registers a stream call returned from dispatch().
	pub async fn register_stream(
		&self,
		write_tx: mpsc::Sender<impl 'static + From<Vec<u8>> + Send>,
		dto: StreamDto,
	) {
		let r = write_tx
			.send(
				self.serializer
					.serialize(&FullRequest {
						id: None,
						method: METHOD_STREAMS_STARTED,
						params: DuplexStreamStarted {
							stream_ids: dto.streams.iter().map(|(id, _)| *id).collect(),
							for_request_id: dto.req_id,
						},
					})
					.into(),
			)
			.await;

		if r.is_err() {
			return;
		}

		for (stream_id, duplex) in dto.streams {
			let (mut read, write) = tokio::io::split(duplex);
			self.streams.insert(stream_id, write);

			let write_tx = write_tx.clone();
			let serial = self.serializer.clone();
			tokio::spawn(async move {
				let mut buf = vec![0; 4096];
				loop {
					match read.read(&mut buf).await {
						Ok(0) | Err(_) => break,
						Ok(n) => {
							let r = write_tx
								.send(
									serial
										.serialize(&FullRequest {
											id: None,
											method: METHOD_STREAM_DATA,
											params: StreamDataParams {
												segment: &buf[..n],
												stream: stream_id,
											},
										})
										.into(),
								)
								.await;

							if r.is_err() {
								return;
							}
						}
					}
				}

				let _ = write_tx
					.send(
						serial
							.serialize(&FullRequest {
								id: None,
								method: METHOD_STREAM_ENDED,
								params: StreamEndedParams { stream: stream_id },
							})
							.into(),
					)
					.await;
			});
		}
	}

	pub fn context(&self) -> Arc<C> {
		self.context.clone()
	}
}

struct StreamRec {
	write: Option<WriteHalf<DuplexStream>>,
	q: Vec<Vec<u8>>,
	ended: bool,
}

#[derive(Clone, Default)]
struct Streams {
	map: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
}

impl Streams {
	pub async fn remove(&self, id: u32) {
		let mut remove = None;

		{
			let mut map = self.map.lock().unwrap();
			if let Some(s) = map.get_mut(&id) {
				if let Some(w) = s.write.take() {
					map.remove(&id);
					remove = Some(w);
				} else {
					s.ended = true; // will shut down in write loop
				}
			}
		}

		// do this outside of the sync lock:
		if let Some(mut w) = remove {
			let _ = w.shutdown().await;
		}
	}

	pub fn write(&self, id: u32, buf: Vec<u8>) {
		let mut map = self.map.lock().unwrap();
		if let Some(s) = map.get_mut(&id) {
			s.q.push(buf);

			if let Some(w) = s.write.take() {
				tokio::spawn(write_loop(id, w, self.map.clone()));
			}
		}
	}

	pub fn insert(&self, id: u32, stream: WriteHalf<DuplexStream>) {
		self.map.lock().unwrap().insert(
			id,
			StreamRec {
				write: Some(stream),
				q: Vec::new(),
				ended: false,
			},
		);
	}
}

/// Write loop started by `Streams.write`. It takes the WriteHalf, and
/// runs until there's no more items in the 'write queue'. At that point, if the
/// record still exists in the `streams` (i.e. we haven't shut down), it'll
/// return the WriteHalf so that the next `write` call starts
/// the loop again. Otherwise, it'll shut down the WriteHalf.
///
/// This is the equivalent of the same write_loop in the server_multiplexer.
/// I couldn't figure out a nice way to abstract it without introducing
/// performance overhead...
async fn write_loop(
	id: u32,
	mut w: WriteHalf<DuplexStream>,
	streams: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
) {
	let mut items_vec = vec![];
	loop {
		{
			let mut lock = streams.lock().unwrap();
			let stream_rec = match lock.get_mut(&id) {
				Some(b) => b,
				None => break,
			};

			if stream_rec.q.is_empty() {
				if stream_rec.ended {
					lock.remove(&id);
					break;
				} else {
					stream_rec.write = Some(w);
					return;
				}
			}

			std::mem::swap(&mut stream_rec.q, &mut items_vec);
		}

		for item in items_vec.drain(..) {
			if w.write_all(&item).await.is_err() {
				break;
			}
		}
	}

	let _ = w.shutdown().await; // got here from `break` above, meaning our record got cleared. Close the bridge if so
}

const METHOD_STREAMS_STARTED: &str = "streams_started";
const METHOD_STREAM_DATA: &str = "stream_data";
const METHOD_STREAM_ENDED: &str = "stream_ended";

#[allow(dead_code)] // false positive
trait AssertIsSync: Sync {}
impl<S: Serialization, C: Send + Sync> AssertIsSync for RpcDispatcher<S, C> {}

/// Approximate shape that is used to determine what kind of data is incoming.
#[derive(Deserialize, Debug)]
pub struct PartialIncoming {
	pub id: Option<u32>,
	pub method: Option<String>,
	pub error: Option<ResponseError>,
}

#[derive(Deserialize)]
struct StreamDataIncomingParams {
	#[serde(with = "serde_bytes")]
	pub segment: Vec<u8>,
	pub stream: u32,
}

#[derive(Serialize, Deserialize)]
struct StreamDataParams<'a> {
	#[serde(with = "serde_bytes")]
	pub segment: &'a [u8],
	pub stream: u32,
}

#[derive(Serialize, Deserialize)]
struct StreamEndedParams {
	pub stream: u32,
}

#[derive(Serialize)]
pub struct FullRequest<M: AsRef<str>, P> {
	pub id: Option<u32>,
	pub method: M,
	pub params: P,
}

#[derive(Deserialize)]
struct RequestParams<P> {
	pub params: P,
}

#[derive(Serialize, Deserialize)]
struct SuccessResponse<T> {
	pub id: u32,
	pub result: T,
}

#[derive(Serialize, Deserialize)]
struct ErrorResponse {
	pub id: u32,
	pub error: ResponseError,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseError {
	pub code: i32,
	pub message: String,
}

enum Outcome {
	Success(Vec<u8>),
	Error(ResponseError),
}

pub struct StreamDto {
	req_id: u32,
	streams: Vec<(u32, DuplexStream)>,
}

pub enum MaybeSync {
	Stream((Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>)),
	Future(BoxFuture<'static, Option<Vec<u8>>>),
	Sync(Option<Vec<u8>>),
}

#[cfg(test)]
mod tests {
	use super::*;

	#[tokio::test]
	async fn test_remove() {
		let streams = Streams::default();
		let (writer, mut reader) = tokio::io::duplex(1024);
		streams.insert(1, tokio::io::split(writer).1);
		streams.remove(1).await;

		assert!(streams.map.lock().unwrap().get(&1).is_none());
		let mut buffer = Vec::new();
		assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 0);
	}

	#[tokio::test]
	async fn test_write() {
		let streams = Streams::default();
		let (writer, mut reader) = tokio::io::duplex(1024);
		streams.insert(1, tokio::io::split(writer).1);
		streams.write(1, vec![1, 2, 3]);

		let mut buffer = [0; 3];
		assert_eq!(reader.read_exact(&mut buffer).await.unwrap(), 3);
		assert_eq!(buffer, [1, 2, 3]);
	}

	#[tokio::test]
	async fn test_write_with_immediate_end() {
		let streams = Streams::default();
		let (writer, mut reader) = tokio::io::duplex(1);
		streams.insert(1, tokio::io::split(writer).1);
		streams.write(1, vec![1, 2, 3]); // spawn write loop
		streams.write(1, vec![4, 5, 6]); // enqueued while writing
		streams.remove(1).await; // end stream

		let mut buffer = Vec::new();
		assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 6);
		assert_eq!(buffer, vec![1, 2, 3, 4, 5, 6]);
	}
}