diff --git a/Cargo.lock b/Cargo.lock index 16954ed..d5a69d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -177,6 +177,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7f0778972c64420fdedc63f09919c8a88bda7b25135357fd25a5d9f3257e832" dependencies = [ "memchr", + "once_cell", + "regex-automata", "serde", ] @@ -925,6 +927,7 @@ version = "0.1.0" dependencies = [ "anyhow", "axum", + "bstr", "chrono", "clap", "dashmap", @@ -1351,6 +1354,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" + [[package]] name = "regex-syntax" version = "0.6.28" @@ -1475,9 +1484,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.91" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c235533714907a8c2464236f5c4b2a17262ef1bd71f38f35ea592c8da6883" +checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76" dependencies = [ "itoa", "ryu", diff --git a/Cargo.toml b/Cargo.toml index c6e7c1c..b207c8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] anyhow = { version = "1.0.68", features = ["backtrace"] } axum = { version = "0.6.4", features = ["ws", "http2", "macros", "headers"] } +bstr = "1.2.0" chrono = "0.4.23" clap = { version = "4.0.32", features = ["derive", "env"] } dashmap = "5.4.0" @@ -21,7 +22,7 @@ rand = "0.8.5" serde = "1.0.152" serde_bytes = "0.11.9" serde_derive = "1.0.152" -serde_json = "1.0.91" +serde_json = "1.0.93" signal-hook = "0.3.14" tera = "1.17.1" termios = "0.3.3" diff --git a/src/asciicast.rs b/src/asciicast.rs index 926b65a..060096d 100644 --- a/src/asciicast.rs +++ b/src/asciicast.rs @@ -2,7 +2,7 @@ //! //! [asciicast v2]: https://github.com/asciinema/asciinema/blob/develop/doc/asciicast-v2.md -use std::{collections::HashMap, fmt}; +use std::{collections::HashMap, fmt, io::Cursor}; use serde::{ de::{ @@ -10,7 +10,10 @@ use serde::{ }, ser::{Serialize, SerializeSeq, Serializer}, }; -use serde_bytes::ByteBuf; +use serde_bytes::{ByteBuf, Bytes}; +use serde_json::ser::{CharEscape, CompactFormatter, Formatter}; + +use crate::char_escaper::CharEscaper; #[derive(Clone, Debug, Builder, Serialize, Deserialize)] pub struct Header { @@ -74,11 +77,15 @@ impl Serialize for Event { match &self.1 { EventKind::Output(s) => { seq.serialize_element("o")?; - seq.serialize_element(std::str::from_utf8(s).unwrap())?; + let ce = CharEscaper::new(s); + let s = ce.to_string(); + seq.serialize_element(&s)?; } EventKind::Input(s) => { seq.serialize_element("i")?; - seq.serialize_element(std::str::from_utf8(s).unwrap())?; + let ce = CharEscaper::new(s); + let s = ce.to_string(); + seq.serialize_element(&s)?; } } seq.end() @@ -116,17 +123,13 @@ impl<'d> Deserialize<'d> for Event { } }; - // Must first go through &[u8] then to Vec because serde_json treats - // &[u8] specially when it comes to deserializing from binary strings - let data = { - let data: ByteBuf = match seq.next_element()? { - Some(v) => v, - None => { - return Err(A::Error::invalid_length(2, &"an array of length 3")) - } - }; - data.into_vec() + let data: &'de [u8] = match seq.next_element()? { + Some(v) => v, + None => { + return Err(A::Error::invalid_length(2, &"an array of length 3")) + } }; + let data = data.to_vec(); let event_kind = match io { 'i' => EventKind::Input(data), @@ -149,9 +152,10 @@ impl<'d> Deserialize<'d> for Event { #[test] fn test() { - let evt = Event(1.5, EventKind::Output(vec![69, 42])); + let evt = Event(1.5, EventKind::Output(vec![69, 42, 255, 97, 240, 159, 135, 122].into())); let evt_ser = serde_json::to_string(&evt).unwrap(); eprintln!("ser: {evt_ser}"); let evt_de: Event = serde_json::from_str(&evt_ser).unwrap(); assert_eq!(evt, evt_de); + panic!() } diff --git a/src/char_escaper.rs b/src/char_escaper.rs new file mode 100644 index 0000000..099d084 --- /dev/null +++ b/src/char_escaper.rs @@ -0,0 +1,78 @@ +// I hate unicode!!! :WAAH: + +use std::io::{Cursor, Write}; + +use bstr::{BStr, ByteSlice, CharIndices}; +use serde_json::ser::{CharEscape, CompactFormatter, Formatter}; + +pub struct CharEscaper<'a>(CharIndices<'a>, &'a [u8]); + +impl<'a> CharEscaper<'a> { + pub fn new(src: &'a [u8]) -> Self { + let bstr = BStr::new(src); + Self(bstr.char_indices(), src) + } + + pub fn to_string(self) -> String { + let s = Vec::new(); + let mut c = Cursor::new(s); + let mut f = CompactFormatter; + for ch in self { + match ch { + Ok(ch) => { + write!(c, "{ch}").unwrap(); + } + Err((c0, c1, c2)) => { + f.write_char_escape(&mut c, c0).unwrap(); + if let Some(c1) = c1 { + f.write_char_escape(&mut c, c1).unwrap(); + } + if let Some(c2) = c2 { + f.write_char_escape(&mut c, c2).unwrap(); + } + } + } + } + let s = c.into_inner(); + String::from_utf8(s).unwrap() + } +} + +impl<'a> Iterator for CharEscaper<'a> { + type Item = + Result, Option)>; + + fn next(&mut self) -> Option { + let (start, end, ch) = match self.0.next() { + Some(v) => v, + None => return None, + }; + + let mut c1 = None; + let mut c2 = None; + + let c0 = match ch { + '"' => CharEscape::Quote, + '\\' => CharEscape::ReverseSolidus, + '/' => CharEscape::Solidus, + '\u{0008}' => CharEscape::Backspace, + '\u{000c}' => CharEscape::FormFeed, + '\u{000a}' => CharEscape::LineFeed, + '\u{000d}' => CharEscape::CarriageReturn, + '\u{0009}' => CharEscape::Tab, + '\u{fffd}' => { + let c0 = CharEscape::AsciiControl(self.1[start]); + if end - start > 1 { + c1 = Some(CharEscape::AsciiControl(self.1[start + 1])); + } + if end - start > 2 { + c2 = Some(CharEscape::AsciiControl(self.1[start + 2])); + } + c0 + } + _ => return Some(Ok(ch)), + }; + + Some(Err((c0, c1, c2))) + } +} diff --git a/src/client/terminal.rs b/src/client/terminal.rs index e7ec25d..a900880 100644 --- a/src/client/terminal.rs +++ b/src/client/terminal.rs @@ -120,7 +120,7 @@ impl Terminal { let event_kind = (match output { true => EventKind::Output, false => EventKind::Input, - })(data.to_vec()); + })(data.to_vec().into()); let event = Event(elapsed, event_kind); self.event_tx.send(event)?; diff --git a/src/lib.rs b/src/lib.rs index cc6ee2e..e7b1dd7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,3 +13,4 @@ pub mod asciicast; pub mod client; pub mod message; pub mod server; +pub mod char_escaper; diff --git a/src/message.rs b/src/message.rs index bdbd64d..c23c03c 100644 --- a/src/message.rs +++ b/src/message.rs @@ -20,7 +20,7 @@ fn test() { const msg: &'static str = "{\"AsciicastEvent\":[0.058985141,\"o\",\"\\u001b[1m\\u001b[3m%\\u001b[23m\\u001b[1m\\u001b[0m \\r \\r\"]}"; let msg2: Message = serde_json::from_str(&msg).unwrap(); if let Message::AsciicastEvent(ref evt) = msg2 { - if let EventKind::Output(ref d) = evt.1 { + if let EventKind::Output(ref _d) = evt.1 { // println!("==> {} <==", String::from_utf8(d.to_vec()).unwrap()); } } diff --git a/src/server/broadcast.rs b/src/server/broadcast.rs index cdd03cf..381c9f1 100644 --- a/src/server/broadcast.rs +++ b/src/server/broadcast.rs @@ -15,7 +15,7 @@ use serde::Deserialize; use tokio::sync::broadcast::{self, Sender}; use crate::{ - asciicast::{Header}, + asciicast::Header, message::{Message, ServerHello}, }; @@ -46,7 +46,8 @@ pub async fn broadcast(ws: WebSocketUpgrade) -> Response { // Distribution guarantees UTF-8-safety String::from_utf8_unchecked(room_id) }; - let _entry = match BROADCASTS.entry(room_id.clone()) { + + match BROADCASTS.entry(room_id.clone()) { Entry::Occupied(_) => { ct += 1; diff --git a/src/server/mod.rs b/src/server/mod.rs index a19d446..0995c95 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -6,11 +6,7 @@ use std::net::SocketAddr; use anyhow::Result; use axum::{ - extract::{Path}, - http::{StatusCode}, - response::Html, - routing::{get}, - Router, + extract::Path, http::StatusCode, response::Html, routing::get, Router, }; use tera::{Context, Tera}; use tokio::runtime::Runtime; diff --git a/src/static/index.html b/src/static/index.html index 9a4e1ff..48fa5b6 100644 --- a/src/static/index.html +++ b/src/static/index.html @@ -38,8 +38,10 @@ switch (location.protocol) { case "https:": protocol = "wss:"; + break; case "http:": protocol = "ws:"; + break; default: throw new Error("L"); }