panorama/imap/src/client/inner.rs

264 lines
8.2 KiB
Rust
Raw Normal View History

2021-02-20 05:03:33 +00:00
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
2021-02-21 13:42:40 +00:00
use std::task::{Context, Poll, Waker};
2021-02-20 05:03:33 +00:00
2021-02-21 13:42:40 +00:00
use anyhow::{Context as AnyhowContext, Result};
use futures::future::{self, Either, Future, FutureExt};
2021-02-20 05:03:33 +00:00
use panorama_strings::{StringEntry, StringStore};
2021-02-20 07:30:58 +00:00
use parking_lot::{Mutex, RwLock};
2021-02-20 05:03:33 +00:00
use tokio::{
io::{
self, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader,
2021-02-21 13:42:40 +00:00
ReadHalf, WriteHalf,
2021-02-20 05:03:33 +00:00
},
2021-02-21 13:42:40 +00:00
sync::{mpsc, oneshot},
2021-02-20 05:03:33 +00:00
task::JoinHandle,
};
2021-02-21 13:54:46 +00:00
use tokio_rustls::{
client::TlsStream, rustls::ClientConfig as RustlsConfig, webpki::DNSNameRef, TlsConnector,
};
2021-02-20 05:03:33 +00:00
use crate::command::Command;
2021-02-21 13:42:40 +00:00
use crate::types::Response;
2021-02-21 13:54:46 +00:00
use super::ClientConfig;
2021-02-20 05:03:33 +00:00
pub type BoxedFunc = Box<dyn Fn()>;
2021-02-21 13:42:40 +00:00
pub type ResultMap = Arc<RwLock<HashMap<usize, (Option<String>, Option<Waker>)>>>;
2021-02-21 13:54:46 +00:00
pub type GreetingState = Arc<RwLock<(bool, Option<Waker>)>>;
2021-02-21 01:13:10 +00:00
pub const TAG_PREFIX: &str = "panorama";
2021-02-20 05:03:33 +00:00
2021-02-21 13:42:40 +00:00
/// The lower-level Client struct, that is shared by all of the exported structs in the state machine.
2021-02-20 05:03:33 +00:00
pub struct Client<C> {
2021-02-21 13:54:46 +00:00
config: ClientConfig,
2021-02-20 05:03:33 +00:00
conn: WriteHalf<C>,
symbols: StringStore,
id: usize,
2021-02-20 07:30:58 +00:00
results: ResultMap,
2021-02-20 05:03:33 +00:00
2021-02-21 13:42:40 +00:00
/// cached set of capabilities
2021-02-20 05:03:33 +00:00
caps: Vec<StringEntry>,
2021-02-21 13:42:40 +00:00
/// join handle for the listener thread
listener_handle: JoinHandle<Result<ReadHalf<C>>>,
/// used for telling the listener thread to stop and return the read half
exit_tx: mpsc::Sender<()>,
/// used for receiving the greeting
2021-02-21 13:54:46 +00:00
greeting: GreetingState,
2021-02-20 05:03:33 +00:00
}
impl<C> Client<C>
where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
/// Creates a new client that wraps a connection
2021-02-21 13:54:46 +00:00
pub fn new(conn: C, config: ClientConfig) -> Self {
2021-02-20 05:03:33 +00:00
let (read_half, write_half) = io::split(conn);
2021-02-20 07:30:58 +00:00
let results = Arc::new(RwLock::new(HashMap::new()));
2021-02-21 13:42:40 +00:00
let (exit_tx, exit_rx) = mpsc::channel(1);
let greeting = Arc::new(RwLock::new((false, None)));
let listen_fut = tokio::spawn(listen(
read_half,
results.clone(),
exit_rx,
greeting.clone(),
));
2021-02-20 05:03:33 +00:00
Client {
2021-02-21 13:42:40 +00:00
config,
2021-02-20 05:03:33 +00:00
conn: write_half,
symbols: StringStore::new(256),
id: 0,
2021-02-20 07:30:58 +00:00
results,
2021-02-20 05:03:33 +00:00
caps: Vec::new(),
2021-02-21 13:42:40 +00:00
listener_handle: listen_fut,
exit_tx,
greeting,
2021-02-20 05:03:33 +00:00
}
}
2021-02-21 13:54:46 +00:00
/// Returns a future that doesn't resolve until we receive a greeting from the server.
pub fn wait_for_greeting(&self) -> GreetingWaiter {
2021-02-21 13:42:40 +00:00
debug!("waiting for greeting");
2021-02-21 13:54:46 +00:00
GreetingWaiter(self.greeting.clone())
2021-02-21 13:42:40 +00:00
}
2021-02-20 05:03:33 +00:00
/// Sends a command to the server and returns a handle to retrieve the result
2021-02-21 13:42:40 +00:00
pub async fn execute(&mut self, cmd: Command) -> Result<String> {
2021-02-20 07:30:58 +00:00
debug!("executing command {:?}", cmd);
2021-02-20 05:03:33 +00:00
let id = self.id;
self.id += 1;
{
2021-02-20 07:30:58 +00:00
let mut handlers = self.results.write();
handlers.insert(id, (None, None));
2021-02-20 05:03:33 +00:00
}
2021-02-21 01:13:10 +00:00
let cmd_str = format!("{}{} {}\r\n", TAG_PREFIX, id, cmd);
2021-02-20 07:30:58 +00:00
debug!("[{}] writing to socket: {:?}", id, cmd_str);
2021-02-20 05:03:33 +00:00
self.conn.write_all(cmd_str.as_bytes()).await?;
self.conn.flush().await?;
2021-02-20 07:30:58 +00:00
debug!("[{}] written.", id);
2021-02-20 05:03:33 +00:00
2021-02-21 13:54:46 +00:00
ExecWaiter(self, id).await;
2021-02-20 07:30:58 +00:00
let resp = {
let mut handlers = self.results.write();
handlers.remove(&id).unwrap().0.unwrap()
};
2021-02-21 13:42:40 +00:00
Ok(resp)
2021-02-20 05:03:33 +00:00
}
/// Executes the CAPABILITY command
2021-02-21 13:42:40 +00:00
pub async fn capabilities(&mut self) -> Result<()> {
2021-02-20 05:03:33 +00:00
let cmd = Command::Capability;
2021-02-20 07:30:58 +00:00
debug!("sending: {:?} {:?}", cmd, cmd.to_string());
2021-02-21 13:42:40 +00:00
let result = self
.execute(cmd)
.await
.context("error executing CAPABILITY command")?;
let (_, resp) = Response::from_bytes(result.as_bytes())
.map_err(|err| anyhow!(""))
.context("error parsing response from CAPABILITY")?;
debug!("cap resp: {:?}", resp);
if let Response::Capabilities(caps) = resp {
debug!("capabilities: {:?}", caps);
}
2021-02-20 07:30:58 +00:00
Ok(())
2021-02-20 05:03:33 +00:00
}
2021-02-21 13:42:40 +00:00
/// Attempts to upgrade this connection using STARTTLS
pub async fn upgrade(mut self) -> Result<Client<TlsStream<C>>> {
// TODO: make sure STARTTLS is in the capability list
// first, send the STARTTLS command
let resp = self.execute(Command::Starttls).await?;
debug!("server response to starttls: {:?}", resp);
debug!("sending exit ()");
self.exit_tx.send(()).await?;
let reader = self.listener_handle.await??;
let writer = self.conn;
let conn = reader.unsplit(writer);
let server_name = &self.config.hostname;
2021-02-21 13:54:46 +00:00
let mut tls_config = RustlsConfig::new();
2021-02-21 13:42:40 +00:00
tls_config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
let tls_config = TlsConnector::from(Arc::new(tls_config));
let dnsname = DNSNameRef::try_from_ascii_str(server_name).unwrap();
let stream = tls_config.connect(dnsname, conn).await?;
Ok(Client::new(stream, self.config))
}
}
2021-02-21 13:54:46 +00:00
pub struct GreetingWaiter(GreetingState);
2021-02-21 13:42:40 +00:00
2021-02-21 13:54:46 +00:00
impl Future for GreetingWaiter {
2021-02-21 13:42:40 +00:00
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let (state, waker) = &mut *self.0.write();
if waker.is_none() {
*waker = Some(cx.waker().clone());
}
match state {
true => Poll::Ready(()),
false => Poll::Pending,
}
}
2021-02-20 05:03:33 +00:00
}
2021-02-21 13:54:46 +00:00
pub struct ExecWaiter<'a, C>(&'a Client<C>, usize);
2021-02-20 05:03:33 +00:00
2021-02-21 13:54:46 +00:00
impl<'a, C> Future for ExecWaiter<'a, C> {
2021-02-20 05:03:33 +00:00
type Output = ();
2021-02-20 07:30:58 +00:00
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut handlers = self.0.results.write();
2021-02-21 01:13:10 +00:00
let state = handlers.get_mut(&self.1);
2021-02-20 05:03:33 +00:00
// TODO: handle the None case here
2021-02-20 07:30:58 +00:00
debug!("f[{}] {:?}", self.1, state);
let (result, waker) = state.unwrap();
match result {
Some(_) => Poll::Ready(()),
None => {
*waker = Some(cx.waker().clone());
Poll::Pending
}
2021-02-20 05:03:33 +00:00
}
}
}
2021-02-21 13:42:40 +00:00
/// Main listen loop for the application
async fn listen<C>(
conn: C,
results: ResultMap,
mut exit: mpsc::Receiver<()>,
2021-02-21 13:54:46 +00:00
greeting: GreetingState,
2021-02-21 13:42:40 +00:00
) -> Result<C>
where
C: AsyncRead + Unpin,
{
2021-02-20 05:24:46 +00:00
debug!("amogus");
2021-02-20 05:03:33 +00:00
let mut reader = BufReader::new(conn);
2021-02-21 13:42:40 +00:00
let mut greeting = Some(greeting);
2021-02-20 05:03:33 +00:00
loop {
let mut next_line = String::new();
2021-02-21 13:42:40 +00:00
let fut = reader.read_line(&mut next_line).fuse();
pin_mut!(fut);
let fut2 = exit.recv().fuse();
pin_mut!(fut2);
match future::select(fut, fut2).await {
Either::Left((_, _)) => {
debug!("got a new line");
2021-02-21 13:42:40 +00:00
let next_line = next_line.trim_end_matches('\n').trim_end_matches('\r');
let mut parts = next_line.split(" ");
let tag = parts.next().unwrap();
let rest = parts.collect::<Vec<_>>().join(" ");
if tag == "*" {
debug!("UNTAGGED {:?}", rest);
2021-02-21 13:54:46 +00:00
// TODO: verify that the greeting is actually an OK
2021-02-21 13:42:40 +00:00
if let Some(greeting) = greeting.take() {
let (greeting, waker) = &mut *greeting.write();
debug!("got greeting");
*greeting = true;
if let Some(waker) = waker.take() {
waker.wake();
}
}
} else if tag.starts_with(TAG_PREFIX) {
let id = tag.trim_start_matches(TAG_PREFIX).parse::<usize>()?;
debug!("set {} to {:?}", id, rest);
let mut results = results.write();
if let Some((c, w)) = results.get_mut(&id) {
// *c = Some(rest.to_string());
*c = Some(next_line.to_owned());
if let Some(waker) = w.take() {
waker.wake();
}
}
}
}
Either::Right((_, _)) => {
debug!("exiting read loop");
break;
2021-02-20 07:30:58 +00:00
}
}
2021-02-20 05:03:33 +00:00
}
2021-02-21 13:42:40 +00:00
let conn = reader.into_inner();
Ok(conn)
2021-02-20 05:03:33 +00:00
}