// Copyright 2020 The Evcxr Authors. // // Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be // copied, modified, or distributed except according to those terms. use crate::connection::Connection; use crate::connection::ConnectionGroup; use crate::connection::ConnectionShutdownRequester; use crate::connection::RecvError; use crate::control_file; use crate::jupyter_message::JupyterMessage; use anyhow::Result; use anyhow::bail; use ariadne::sources; use colored::*; use crossbeam_channel::RecvTimeoutError; use crossbeam_channel::Select; use evcxr::CommandContext; use evcxr::Theme; use json::JsonValue; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex; // Note, to avoid potential deadlocks, each thread should lock at most one mutex at a time. #[derive(Clone)] pub(crate) struct Server { iopub: Arc>>, stdin: Arc>>, latest_execution_request: Arc>>, io_thread_shutdown_sender: Arc>>>, tokio_handle: tokio::runtime::Handle, } impl Server { pub(crate) fn run(config: &control_file::Control) -> Result<()> { let runtime = tokio::runtime::Builder::new_multi_thread() // We only technically need 1 thread. However we've observed that // when using vscode's jupyter extension, we can get requests on the // shell socket before we have any subscribers on iopub. The iopub // subscription then completes, but the execution_state="idle" // message(s) have already been sent to a channel that at the time // had no subscriptions. The vscode extension then waits // indefinitely for an execution_state="idle" message that will // never come. Having multiple threads at least reduces the chances // of this happening. .worker_threads(4) .enable_all() .build() .unwrap(); let handle = runtime.handle().clone(); runtime.block_on(Self::run_async(config, handle))?; Ok(()) } async fn run_async( config: &control_file::Control, tokio_handle: tokio::runtime::Handle, ) -> Result<()> { let (connection_group, group_shutdown) = ConnectionGroup::new(); let mut heartbeat = bind_socket::<:repsocket>( config, config.hb_port, Some(connection_group.clone()), ) .await?; let shell_socket = bind_socket::<:routersocket>( config, config.shell_port, Some(connection_group.clone()), ) .await?; let control_socket = bind_socket::<:routersocket>( config, config.control_port, Some(connection_group.clone()), ) .await?; let stdin_socket = bind_socket::<:routersocket>(config, config.stdin_port, None).await?; let iopub_socket = bind_socket::<:pubsocket>(config, config.iopub_port, None).await?; let iopub = Arc::new(Mutex::new(iopub_socket)); // Create a channel pair that's used to signal to the IO thread when to shut down. This // needs to be a crossbeam channel because the IO thread uses select together with other // crossbeam channels. let (shutdown_sender, shutdown_receiver) = crossbeam_channel::unbounded(); let server = Server { iopub, latest_execution_request: Arc::new(Mutex::new(None)), stdin: Arc::new(Mutex::new(stdin_socket)), io_thread_shutdown_sender: Arc::new(Mutex::new(Some(shutdown_sender))), tokio_handle, }; let (execution_sender, mut execution_receiver) = tokio::sync::mpsc::unbounded_channel(); let (execution_response_sender, mut execution_response_receiver) = tokio::sync::mpsc::unbounded_channel(); tokio::spawn(async move { if let Err(error) = Self::handle_hb(&mut heartbeat).await { eprintln!("hb error: {error:?}"); } }); let (mut context, outputs) = CommandContext::new()?; context.execute(":load_config")?; let process_handle = context.process_handle(); let context = Arc::new(std::sync::Mutex::new(context)); { let context = context.clone(); let server = server.clone(); tokio::spawn(async move { let result = server .handle_shell( shell_socket, &execution_sender, &mut execution_response_receiver, context, ) .await; if let Err(error) = result { eprintln!("shell error: {error:?}"); } }); } { let server = server.clone(); tokio::spawn(async move { let result = server .handle_execution_requests( &context, &mut execution_receiver, &execution_response_sender, ) .await; if let Err(error) = result { eprintln!("execution error: {error:?}"); } }); } server .clone() .start_output_pass_through_thread( vec![("stdout", outputs.stdout), ("stderr", outputs.stderr)], shutdown_receiver.clone(), ) .await; // Don't keep any outstanding instances of our connection group, otherwise things won't shut // down properly. drop(connection_group); // Run the control channel on the main task. Once the control channel handler terminates, // we're done. server .handle_control(control_socket, process_handle, group_shutdown) .await?; Ok(()) } async fn shutdown_io_thread(&mut self) { self.io_thread_shutdown_sender.lock().await.take(); } async fn handle_hb(connection: &mut Connection<:repsocket>) -> Result<()> { loop { match connection.recv().await { Ok(_) => {} Err(RecvError::ShutdownRequested) => return Ok(()), Err(RecvError::Other(e)) => return Err(e), } connection .send(zeromq::ZmqMessage::from(b"ping".to_vec())) .await?; } } async fn handle_execution_requests( self, context: &Arc<:sync::mutex>>, receiver: &mut tokio::sync::mpsc::UnboundedReceiver, execution_reply_sender: &tokio::sync::mpsc::UnboundedSender, ) -> Result<()> { let mut execution_count = 1; loop { let message = match receiver.recv().await { Some(x) => x, None => { // Other end has closed. This is expected when we're shutting // down. return Ok(()); } }; // If we want this clone to be cheaper, we probably only need the header, not the // whole message. *self.latest_execution_request.lock().await = Some(message.clone()); let src = message.code().to_owned(); execution_count += 1; message .new_message("execute_input") .with_content(object! { "execution_count" => execution_count, "code" => src }) .send(&mut *self.iopub.lock().await) .await?; let context = Arc::clone(context); let server = self.clone(); let (eval_result, message) = tokio::task::spawn_blocking(move || { let eval_result = context.lock().unwrap().execute_with_callbacks( message.code(), &mut evcxr::EvalCallbacks { input_reader: &|input_request| { server.tokio_handle.block_on(async { server .request_input( &message, &input_request.prompt, input_request.is_password, ) .await .unwrap_or_default() }) }, }, ); (eval_result, message) }) .await?; match eval_result { Ok(output) => { if !output.is_empty() { // Increase the odds that stdout will have been finished being sent. A // less hacky alternative would be to add a print statement, then block // waiting for it. tokio::time::sleep(Duration::from_millis(1)).await; let mut data = HashMap::new(); // At the time of writing the json crate appears to have a generic From // implementation for a Vec where T implements Into. It also // has conversion from HashMap, but it doesn't have // conversion from HashMap. Perhaps send a PR? For now, we // convert the values manually. for (k, v) in output.content_by_mime_type { if k.contains("json") { data.insert(k, json::parse(&v).unwrap_or_else(|_| json::from(v))); } else { data.insert(k, json::from(v)); } } message .new_message("execute_result") .with_content(object! { "execution_count" => execution_count, "data" => data, "metadata" => object!(), }) .send(&mut *self.iopub.lock().await) .await?; } if let Some(duration) = output.timing { // TODO replace by duration.as_millis() when stable let ms = duration.as_secs() * 1000 + u64::from(duration.subsec_millis()); let mut data: HashMap = HashMap::new(); data.insert( "text/html".into(), json::from(format!( "Took {}ms", ms )), ); message .new_message("execute_result") .with_content(object! { "execution_count" => execution_count, "data" => data, "metadata" => object!(), }) .send(&mut *self.iopub.lock().await) .await?; } execution_reply_sender.send(message.new_reply().with_content(object! { "status" => "ok", "execution_count" => execution_count, }))?; } Err(errors) => { self.emit_errors(&errors, &message, message.code(), execution_count) .await?; execution_reply_sender.send(message.new_reply().with_content(object! { "status" => "error", "execution_count" => execution_count }))?; } }; } } async fn request_input( &self, current_request: &JupyterMessage, prompt: &str, password: bool, ) -> Option { if current_request.get_content()["allow_stdin"].as_bool() != Some(true) { return None; } let mut stdin = self.stdin.lock().await; let stdin_request = current_request .new_reply() .with_message_type("input_request") .with_content(object! { "prompt" => prompt, "password" => password, }); stdin_request.send(&mut *stdin).await.ok()?; let input_response = JupyterMessage::read(&mut *stdin).await.ok()?; input_response.get_content()["value"] .as_str() .map(|value| value.to_owned()) } async fn handle_shell( self, mut connection: Connection, execution_channel: &tokio::sync::mpsc::UnboundedSender, execution_reply_receiver: &mut tokio::sync::mpsc::UnboundedReceiver, context: Arc<:sync::mutex>>, ) -> Result<()> { loop { let message = match JupyterMessage::read(&mut connection).await { Ok(m) => m, Err(RecvError::ShutdownRequested) => return Ok(()), Err(RecvError::Other(error)) => return Err(error), }; let message_type = message.message_type().to_owned(); let r = self .handle_shell_message( message, &mut connection, execution_channel, execution_reply_receiver, &context, ) .await; if let Err(error) = r { // We see this often after issuing a restart-kernel from the Jupyter UI. Not sure // why, but provided we continue to handle subsequent shell requests, things seem to // work. So for now, we just print the error and continue. eprintln!("Error handling shell message `{message_type}`: {error:#}"); } } } async fn handle_shell_message( &self, message: JupyterMessage, connection: &mut Connection, execution_channel: &tokio::sync::mpsc::UnboundedSender, execution_reply_receiver: &mut tokio::sync::mpsc::UnboundedReceiver, context: &Arc<:sync::mutex>>, ) -> Result<()> { // Processing of every message should be enclosed between "busy" and "idle" // see https://jupyter-client.readthedocs.io/en/latest/messaging.html#messages-on-the-shell-router-dealer-channel // Jupiter Lab doesn't use the kernel until it received "idle" for kernel_info_request message .new_message("status") .with_content(object! {"execution_state" => "busy"}) .send(&mut *self.iopub.lock().await) .await?; let idle = message .new_message("status") .with_content(object! {"execution_state" => "idle"}); if message.message_type() == "kernel_info_request" { message .new_reply() .with_content(kernel_info()) .send(connection) .await?; } else if message.message_type() == "is_complete_request" { message .new_reply() .with_content(object! {"status" => "complete"}) .send(connection) .await?; } else if message.message_type() == "execute_request" { execution_channel.send(message)?; if let Some(reply) = execution_reply_receiver.recv().await { reply.send(connection).await?; } } else if message.message_type() == "comm_open" { comm_open(message, context, Arc::clone(&self.iopub)).await?; } else if message.message_type() == "comm_msg" || message.message_type() == "comm_info_request" { // We don't handle this yet. } else if message.message_type() == "complete_request" { let reply = message.new_reply().with_content( match handle_completion_request(context, message).await { Ok(response_content) => response_content, Err(error) => object! { "status" => "error", "ename" => error.to_string(), "evalue" => "", }, }, ); reply.send(connection).await?; } else if message.message_type() == "history_request" { // We don't yet support history requests, but we don't want to print // a message in jupyter console. } else { eprintln!( "Got unrecognized message type on shell channel: {}", message.message_type() ); } idle.send(&mut *self.iopub.lock().await).await?; Ok(()) } async fn handle_control( mut self, mut connection: Connection<:routersocket>, process_handle: Arc<:sync::mutex>>, group_shutdown: ConnectionShutdownRequester, ) -> Result<()> { loop { let message = match JupyterMessage::read(&mut connection).await { Ok(m) => m, Err(RecvError::ShutdownRequested) => return Ok(()), Err(RecvError::Other(error)) => return Err(error), }; match message.message_type() { "kernel_info_request" => { message .new_reply() .with_content(kernel_info()) .send(&mut connection) .await? } "shutdown_request" => { let is_restart = message.get_content()["restart"].as_bool().unwrap_or(false); let response = object! { "status": "ok", "restart": is_restart, }; connection.shutdown_all_connections(group_shutdown).await; self.shutdown_io_thread().await; message .new_reply() .with_content(response) .send(&mut connection) .await?; return Ok(()); } "interrupt_request" => { let process_handle = process_handle.clone(); tokio::task::spawn_blocking(move || { if let Err(error) = process_handle.lock().unwrap().kill() { eprintln!("Failed to restart subprocess: {}", error); } }) .await?; message.new_reply().send(&mut connection).await?; } _ => { eprintln!( "Got unrecognized message type on control channel: {}", message.message_type() ); } } } } async fn start_output_pass_through_thread( self, channels: Vec<(&'static str, crossbeam_channel::Receiver)>, shutdown_recv: crossbeam_channel::Receiver<()>, ) { let handle = tokio::runtime::Handle::current(); tokio::task::spawn_blocking(move || { let mut select = Select::new(); for (_, channel) in &channels { select.recv(channel); } let shutdown_index = select.recv(&shutdown_recv); loop { let index = select.ready(); if index == shutdown_index { return; } let (output_name, channel) = &channels[index]; // Needed in order to make the borrow checker happy. let output_name: &'static str = output_name; // Read from the channel that has output until it has been idle // for 1ms before we return to checking other channels. This // reduces the extent to which outputs interleave. e.g. a // multi-line print is performed to stderr, then another to // stdout - we can't guarantee the order in which they get sent, // but we'd like to try to make sure that we don't interleave // their lines if possible. loop { match channel.recv_timeout(Duration::from_millis(1)) { Ok(line) => { let server = self.clone(); handle.block_on(server.pass_output_line(output_name, line)); } Err(RecvTimeoutError::Timeout) => break, Err(RecvTimeoutError::Disconnected) => return, } } } }); } async fn pass_output_line(&self, output_name: &'static str, line: String) { let mut message = None; if let Some(exec_request) = &*self.latest_execution_request.lock().await { message = Some(exec_request.new_message("stream")); } if let Some(message) = message { if let Err(error) = message .with_content(object! { "name" => output_name, "text" => format!("{}\n", line), }) .send(&mut *self.iopub.lock().await) .await { eprintln!("output {output_name} error: {}", error); } } } async fn emit_errors( &self, errors: &evcxr::Error, parent_message: &JupyterMessage, source: &str, execution_count: u32, ) -> Result<()> { match errors { evcxr::Error::CompilationErrors(errors) => { for error in errors { let message = format!("{}", error.message().bright_red()); if error.is_from_user_code() { let file_name = format!("command_{}", execution_count); let mut traceback = Vec::new(); if let Some(report) = error.build_report(file_name.clone(), source.to_string(), Theme::Light) { let mut s = Vec::new(); report .write(sources([(file_name, source.to_string())]), &mut s) .unwrap(); let s = String::from_utf8_lossy(&s); traceback = s.lines().map(|x| x.to_string()).collect::>(); } else { for spanned_message in error.spanned_messages() { for line in &spanned_message.lines { traceback.push(line.clone()); } if let Some(span) = &spanned_message.span { let mut carrots = String::new(); for _ in 1..span.start_column { carrots.push(' '); } for _ in span.start_column..span.end_column { carrots.push('^'); } traceback.push(format!( "{} {}", carrots.bright_red(), spanned_message.label.bright_blue() )); } else { traceback.push(spanned_message.label.clone()); } } traceback.push(error.message()); for help in error.help() { traceback.push(format!("{}: {}", "help".bold(), help)); } } parent_message .new_message("error") .with_content(object! { "ename" => "Error", "evalue" => error.message(), "traceback" => traceback, }) .send(&mut *self.iopub.lock().await) .await?; } else { parent_message .new_message("error") .with_content(object! { "ename" => "Error", "evalue" => error.message(), "traceback" => array![ message ], }) .send(&mut *self.iopub.lock().await) .await?; } } } error => { let displayed_error = format!("{}", error); parent_message .new_message("error") .with_content(object! { "ename" => "Error", "evalue" => displayed_error.clone(), "traceback" => array![displayed_error], }) .send(&mut *self.iopub.lock().await) .await?; } } Ok(()) } } async fn comm_open( message: JupyterMessage, context: &Arc<:sync::mutex>>, iopub: Arc>>, ) -> Result<()> { if message.target_name() == "evcxr-cargo-check" { let context = Arc::clone(context); tokio::spawn(async move { if let Some(code) = message.data()["code"].as_str() { let data = cargo_check(code.to_owned(), context).await; let response_content = object! { "comm_id" => message.comm_id(), "data" => data, }; message .new_message("comm_msg") .without_parent_header() .with_content(response_content) .send(&mut *iopub.lock().await) .await .unwrap(); } message .comm_close_message() .send(&mut *iopub.lock().await) .await .unwrap(); }); Ok(()) } else { // Unrecognised comm target, just close the comm. message .comm_close_message() .send(&mut *iopub.lock().await) .await } } async fn cargo_check(code: String, context: Arc<:sync::mutex>>) -> JsonValue { let problems = tokio::task::spawn_blocking(move || { context.lock().unwrap().check(&code).unwrap_or_default() }) .await .unwrap_or_default(); let problems_json: Vec = problems .iter() .filter_map(|problem| { if let Some(primary_spanned_message) = problem.primary_spanned_message() { if let Some(span) = primary_spanned_message.span { use std::fmt::Write; let mut message = primary_spanned_message.label.clone(); if !message.is_empty() { message.push('\n'); } message.push_str(&problem.message()); for help in problem.help() { write!(message, "\nhelp: {}", help).unwrap(); } return Some(object! { "message" => message, "severity" => problem.level(), "start_line" => span.start_line, "start_column" => span.start_column, "end_column" => span.end_column, "end_line" => span.end_line, }); } } None }) .collect(); object! { "problems" => problems_json, } } async fn bind_socket( config: &control_file::Control, port: u16, group: Option, ) -> Result> { let endpoint = format!("{}://{}:{}", config.transport, config.ip, port); let mut socket = S::new(); socket.bind(&endpoint).await?; Connection::new(socket, &config.key, group) } /// See [Kernel info documentation](https://jupyter-client.readthedocs.io/en/stable/messaging.html#kernel-info) fn kernel_info() -> JsonValue { object! { "protocol_version" => "5.3", "implementation" => env!("CARGO_PKG_NAME"), "implementation_version" => env!("CARGO_PKG_VERSION"), "language_info" => object!{ "name" => "Rust", "version" => "", "mimetype" => "text/rust", "file_extension" => ".rs", // Pygments lexer, for highlighting Only needed if it differs from the 'name' field. // see http://pygments.org/docs/lexers/#lexers-for-the-rust-language "pygment_lexer" => "rust", // Codemirror mode, for for highlighting in the notebook. Only needed if it differs from the 'name' field. // codemirror use text/x-rustsrc as mimetypes // see https://codemirror.net/mode/rust/ "codemirror_mode" => "rust", }, "banner" => format!("EvCxR {} - Evaluation Context for Rust", env!("CARGO_PKG_VERSION")), "help_links" => array![ object!{"text" => "Rust std docs", "url" => "https://doc.rust-lang.org/stable/std/"} ], "status" => "ok" } } async fn handle_completion_request( context: &Arc<:sync::mutex>>, message: JupyterMessage, ) -> Result { let context = Arc::clone(context); tokio::task::spawn_blocking(move || { let code = message.code(); let completions = context.lock().unwrap().completions( code, grapheme_offset_to_byte_offset(code, message.cursor_pos()), )?; let matches: Vec = completions .completions .into_iter() .map(|completion| completion.code) .collect(); Ok(object! { "status" => "ok", "matches" => matches, "cursor_start" => byte_offset_to_grapheme_offset(code, completions.start_offset)?, "cursor_end" => byte_offset_to_grapheme_offset(code, completions.end_offset)?, "metadata" => object!{}, }) }) .await? } /// Returns the byte offset for the start of the specified grapheme. Any grapheme beyond the last /// grapheme will return the end position of the input. fn grapheme_offset_to_byte_offset(code: &str, grapheme_offset: usize) -> usize { unicode_segmentation::UnicodeSegmentation::grapheme_indices(code, true) .nth(grapheme_offset) .map(|(byte_offset, _)| byte_offset) .unwrap_or_else(|| code.len()) } /// Returns the grapheme offset of the grapheme that starts at fn byte_offset_to_grapheme_offset(code: &str, target_byte_offset: usize) -> Result { let mut grapheme_offset = 0; for (byte_offset, _) in unicode_segmentation::UnicodeSegmentation::grapheme_indices(code, true) { if byte_offset == target_byte_offset { break; } if byte_offset > target_byte_offset { bail!( "Byte offset {} is not on a grapheme boundary in '{}'", target_byte_offset, code ); } grapheme_offset += 1; } Ok(grapheme_offset) } #[cfg(test)] mod tests { use super::*; #[test] fn grapheme_offsets() { let src = "a̐éx"; assert_eq!(grapheme_offset_to_byte_offset(src, 0), 0); assert_eq!(grapheme_offset_to_byte_offset(src, 1), 3); assert_eq!(grapheme_offset_to_byte_offset(src, 2), 6); assert_eq!(grapheme_offset_to_byte_offset(src, 3), 7); assert_eq!(byte_offset_to_grapheme_offset(src, 0).unwrap(), 0); assert_eq!(byte_offset_to_grapheme_offset(src, 3).unwrap(), 1); assert_eq!(byte_offset_to_grapheme_offset(src, 6).unwrap(), 2); assert_eq!(byte_offset_to_grapheme_offset(src, 7).unwrap(), 3); } }