// Copyright 2020 The Evcxr Authors.
//
// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be
// copied, modified, or distributed except according to those terms.
use crate::connection::Connection;
use crate::connection::ConnectionGroup;
use crate::connection::ConnectionShutdownRequester;
use crate::connection::RecvError;
use crate::control_file;
use crate::jupyter_message::JupyterMessage;
use anyhow::Result;
use anyhow::bail;
use ariadne::sources;
use colored::*;
use crossbeam_channel::RecvTimeoutError;
use crossbeam_channel::Select;
use evcxr::CommandContext;
use evcxr::Theme;
use json::JsonValue;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
// Note, to avoid potential deadlocks, each thread should lock at most one mutex at a time.
#[derive(Clone)]
pub(crate) struct Server {
iopub: Arc>>,
stdin: Arc>>,
latest_execution_request: Arc>>,
io_thread_shutdown_sender: Arc>>>,
tokio_handle: tokio::runtime::Handle,
}
impl Server {
pub(crate) fn run(config: &control_file::Control) -> Result<()> {
let runtime = tokio::runtime::Builder::new_multi_thread()
// We only technically need 1 thread. However we've observed that
// when using vscode's jupyter extension, we can get requests on the
// shell socket before we have any subscribers on iopub. The iopub
// subscription then completes, but the execution_state="idle"
// message(s) have already been sent to a channel that at the time
// had no subscriptions. The vscode extension then waits
// indefinitely for an execution_state="idle" message that will
// never come. Having multiple threads at least reduces the chances
// of this happening.
.worker_threads(4)
.enable_all()
.build()
.unwrap();
let handle = runtime.handle().clone();
runtime.block_on(Self::run_async(config, handle))?;
Ok(())
}
async fn run_async(
config: &control_file::Control,
tokio_handle: tokio::runtime::Handle,
) -> Result<()> {
let (connection_group, group_shutdown) = ConnectionGroup::new();
let mut heartbeat = bind_socket::<:repsocket>(
config,
config.hb_port,
Some(connection_group.clone()),
)
.await?;
let shell_socket = bind_socket::<:routersocket>(
config,
config.shell_port,
Some(connection_group.clone()),
)
.await?;
let control_socket = bind_socket::<:routersocket>(
config,
config.control_port,
Some(connection_group.clone()),
)
.await?;
let stdin_socket =
bind_socket::<:routersocket>(config, config.stdin_port, None).await?;
let iopub_socket =
bind_socket::<:pubsocket>(config, config.iopub_port, None).await?;
let iopub = Arc::new(Mutex::new(iopub_socket));
// Create a channel pair that's used to signal to the IO thread when to shut down. This
// needs to be a crossbeam channel because the IO thread uses select together with other
// crossbeam channels.
let (shutdown_sender, shutdown_receiver) = crossbeam_channel::unbounded();
let server = Server {
iopub,
latest_execution_request: Arc::new(Mutex::new(None)),
stdin: Arc::new(Mutex::new(stdin_socket)),
io_thread_shutdown_sender: Arc::new(Mutex::new(Some(shutdown_sender))),
tokio_handle,
};
let (execution_sender, mut execution_receiver) = tokio::sync::mpsc::unbounded_channel();
let (execution_response_sender, mut execution_response_receiver) =
tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
if let Err(error) = Self::handle_hb(&mut heartbeat).await {
eprintln!("hb error: {error:?}");
}
});
let (mut context, outputs) = CommandContext::new()?;
context.execute(":load_config")?;
let process_handle = context.process_handle();
let context = Arc::new(std::sync::Mutex::new(context));
{
let context = context.clone();
let server = server.clone();
tokio::spawn(async move {
let result = server
.handle_shell(
shell_socket,
&execution_sender,
&mut execution_response_receiver,
context,
)
.await;
if let Err(error) = result {
eprintln!("shell error: {error:?}");
}
});
}
{
let server = server.clone();
tokio::spawn(async move {
let result = server
.handle_execution_requests(
&context,
&mut execution_receiver,
&execution_response_sender,
)
.await;
if let Err(error) = result {
eprintln!("execution error: {error:?}");
}
});
}
server
.clone()
.start_output_pass_through_thread(
vec![("stdout", outputs.stdout), ("stderr", outputs.stderr)],
shutdown_receiver.clone(),
)
.await;
// Don't keep any outstanding instances of our connection group, otherwise things won't shut
// down properly.
drop(connection_group);
// Run the control channel on the main task. Once the control channel handler terminates,
// we're done.
server
.handle_control(control_socket, process_handle, group_shutdown)
.await?;
Ok(())
}
async fn shutdown_io_thread(&mut self) {
self.io_thread_shutdown_sender.lock().await.take();
}
async fn handle_hb(connection: &mut Connection<:repsocket>) -> Result<()> {
loop {
match connection.recv().await {
Ok(_) => {}
Err(RecvError::ShutdownRequested) => return Ok(()),
Err(RecvError::Other(e)) => return Err(e),
}
connection
.send(zeromq::ZmqMessage::from(b"ping".to_vec()))
.await?;
}
}
async fn handle_execution_requests(
self,
context: &Arc<:sync::mutex>>,
receiver: &mut tokio::sync::mpsc::UnboundedReceiver,
execution_reply_sender: &tokio::sync::mpsc::UnboundedSender,
) -> Result<()> {
let mut execution_count = 1;
loop {
let message = match receiver.recv().await {
Some(x) => x,
None => {
// Other end has closed. This is expected when we're shutting
// down.
return Ok(());
}
};
// If we want this clone to be cheaper, we probably only need the header, not the
// whole message.
*self.latest_execution_request.lock().await = Some(message.clone());
let src = message.code().to_owned();
execution_count += 1;
message
.new_message("execute_input")
.with_content(object! {
"execution_count" => execution_count,
"code" => src
})
.send(&mut *self.iopub.lock().await)
.await?;
let context = Arc::clone(context);
let server = self.clone();
let (eval_result, message) = tokio::task::spawn_blocking(move || {
let eval_result = context.lock().unwrap().execute_with_callbacks(
message.code(),
&mut evcxr::EvalCallbacks {
input_reader: &|input_request| {
server.tokio_handle.block_on(async {
server
.request_input(
&message,
&input_request.prompt,
input_request.is_password,
)
.await
.unwrap_or_default()
})
},
},
);
(eval_result, message)
})
.await?;
match eval_result {
Ok(output) => {
if !output.is_empty() {
// Increase the odds that stdout will have been finished being sent. A
// less hacky alternative would be to add a print statement, then block
// waiting for it.
tokio::time::sleep(Duration::from_millis(1)).await;
let mut data = HashMap::new();
// At the time of writing the json crate appears to have a generic From
// implementation for a Vec where T implements Into. It also
// has conversion from HashMap, but it doesn't have
// conversion from HashMap. Perhaps send a PR? For now, we
// convert the values manually.
for (k, v) in output.content_by_mime_type {
if k.contains("json") {
data.insert(k, json::parse(&v).unwrap_or_else(|_| json::from(v)));
} else {
data.insert(k, json::from(v));
}
}
message
.new_message("execute_result")
.with_content(object! {
"execution_count" => execution_count,
"data" => data,
"metadata" => object!(),
})
.send(&mut *self.iopub.lock().await)
.await?;
}
if let Some(duration) = output.timing {
// TODO replace by duration.as_millis() when stable
let ms = duration.as_secs() * 1000 + u64::from(duration.subsec_millis());
let mut data: HashMap = HashMap::new();
data.insert(
"text/html".into(),
json::from(format!(
"Took {}ms",
ms
)),
);
message
.new_message("execute_result")
.with_content(object! {
"execution_count" => execution_count,
"data" => data,
"metadata" => object!(),
})
.send(&mut *self.iopub.lock().await)
.await?;
}
execution_reply_sender.send(message.new_reply().with_content(object! {
"status" => "ok",
"execution_count" => execution_count,
}))?;
}
Err(errors) => {
self.emit_errors(&errors, &message, message.code(), execution_count)
.await?;
execution_reply_sender.send(message.new_reply().with_content(object! {
"status" => "error",
"execution_count" => execution_count
}))?;
}
};
}
}
async fn request_input(
&self,
current_request: &JupyterMessage,
prompt: &str,
password: bool,
) -> Option {
if current_request.get_content()["allow_stdin"].as_bool() != Some(true) {
return None;
}
let mut stdin = self.stdin.lock().await;
let stdin_request = current_request
.new_reply()
.with_message_type("input_request")
.with_content(object! {
"prompt" => prompt,
"password" => password,
});
stdin_request.send(&mut *stdin).await.ok()?;
let input_response = JupyterMessage::read(&mut *stdin).await.ok()?;
input_response.get_content()["value"]
.as_str()
.map(|value| value.to_owned())
}
async fn handle_shell(
self,
mut connection: Connection,
execution_channel: &tokio::sync::mpsc::UnboundedSender,
execution_reply_receiver: &mut tokio::sync::mpsc::UnboundedReceiver,
context: Arc<:sync::mutex>>,
) -> Result<()> {
loop {
let message = match JupyterMessage::read(&mut connection).await {
Ok(m) => m,
Err(RecvError::ShutdownRequested) => return Ok(()),
Err(RecvError::Other(error)) => return Err(error),
};
let message_type = message.message_type().to_owned();
let r = self
.handle_shell_message(
message,
&mut connection,
execution_channel,
execution_reply_receiver,
&context,
)
.await;
if let Err(error) = r {
// We see this often after issuing a restart-kernel from the Jupyter UI. Not sure
// why, but provided we continue to handle subsequent shell requests, things seem to
// work. So for now, we just print the error and continue.
eprintln!("Error handling shell message `{message_type}`: {error:#}");
}
}
}
async fn handle_shell_message(
&self,
message: JupyterMessage,
connection: &mut Connection,
execution_channel: &tokio::sync::mpsc::UnboundedSender,
execution_reply_receiver: &mut tokio::sync::mpsc::UnboundedReceiver,
context: &Arc<:sync::mutex>>,
) -> Result<()> {
// Processing of every message should be enclosed between "busy" and "idle"
// see https://jupyter-client.readthedocs.io/en/latest/messaging.html#messages-on-the-shell-router-dealer-channel
// Jupiter Lab doesn't use the kernel until it received "idle" for kernel_info_request
message
.new_message("status")
.with_content(object! {"execution_state" => "busy"})
.send(&mut *self.iopub.lock().await)
.await?;
let idle = message
.new_message("status")
.with_content(object! {"execution_state" => "idle"});
if message.message_type() == "kernel_info_request" {
message
.new_reply()
.with_content(kernel_info())
.send(connection)
.await?;
} else if message.message_type() == "is_complete_request" {
message
.new_reply()
.with_content(object! {"status" => "complete"})
.send(connection)
.await?;
} else if message.message_type() == "execute_request" {
execution_channel.send(message)?;
if let Some(reply) = execution_reply_receiver.recv().await {
reply.send(connection).await?;
}
} else if message.message_type() == "comm_open" {
comm_open(message, context, Arc::clone(&self.iopub)).await?;
} else if message.message_type() == "comm_msg"
|| message.message_type() == "comm_info_request"
{
// We don't handle this yet.
} else if message.message_type() == "complete_request" {
let reply = message.new_reply().with_content(
match handle_completion_request(context, message).await {
Ok(response_content) => response_content,
Err(error) => object! {
"status" => "error",
"ename" => error.to_string(),
"evalue" => "",
},
},
);
reply.send(connection).await?;
} else if message.message_type() == "history_request" {
// We don't yet support history requests, but we don't want to print
// a message in jupyter console.
} else {
eprintln!(
"Got unrecognized message type on shell channel: {}",
message.message_type()
);
}
idle.send(&mut *self.iopub.lock().await).await?;
Ok(())
}
async fn handle_control(
mut self,
mut connection: Connection<:routersocket>,
process_handle: Arc<:sync::mutex>>,
group_shutdown: ConnectionShutdownRequester,
) -> Result<()> {
loop {
let message = match JupyterMessage::read(&mut connection).await {
Ok(m) => m,
Err(RecvError::ShutdownRequested) => return Ok(()),
Err(RecvError::Other(error)) => return Err(error),
};
match message.message_type() {
"kernel_info_request" => {
message
.new_reply()
.with_content(kernel_info())
.send(&mut connection)
.await?
}
"shutdown_request" => {
let is_restart = message.get_content()["restart"].as_bool().unwrap_or(false);
let response = object! {
"status": "ok",
"restart": is_restart,
};
connection.shutdown_all_connections(group_shutdown).await;
self.shutdown_io_thread().await;
message
.new_reply()
.with_content(response)
.send(&mut connection)
.await?;
return Ok(());
}
"interrupt_request" => {
let process_handle = process_handle.clone();
tokio::task::spawn_blocking(move || {
if let Err(error) = process_handle.lock().unwrap().kill() {
eprintln!("Failed to restart subprocess: {}", error);
}
})
.await?;
message.new_reply().send(&mut connection).await?;
}
_ => {
eprintln!(
"Got unrecognized message type on control channel: {}",
message.message_type()
);
}
}
}
}
async fn start_output_pass_through_thread(
self,
channels: Vec<(&'static str, crossbeam_channel::Receiver)>,
shutdown_recv: crossbeam_channel::Receiver<()>,
) {
let handle = tokio::runtime::Handle::current();
tokio::task::spawn_blocking(move || {
let mut select = Select::new();
for (_, channel) in &channels {
select.recv(channel);
}
let shutdown_index = select.recv(&shutdown_recv);
loop {
let index = select.ready();
if index == shutdown_index {
return;
}
let (output_name, channel) = &channels[index];
// Needed in order to make the borrow checker happy.
let output_name: &'static str = output_name;
// Read from the channel that has output until it has been idle
// for 1ms before we return to checking other channels. This
// reduces the extent to which outputs interleave. e.g. a
// multi-line print is performed to stderr, then another to
// stdout - we can't guarantee the order in which they get sent,
// but we'd like to try to make sure that we don't interleave
// their lines if possible.
loop {
match channel.recv_timeout(Duration::from_millis(1)) {
Ok(line) => {
let server = self.clone();
handle.block_on(server.pass_output_line(output_name, line));
}
Err(RecvTimeoutError::Timeout) => break,
Err(RecvTimeoutError::Disconnected) => return,
}
}
}
});
}
async fn pass_output_line(&self, output_name: &'static str, line: String) {
let mut message = None;
if let Some(exec_request) = &*self.latest_execution_request.lock().await {
message = Some(exec_request.new_message("stream"));
}
if let Some(message) = message {
if let Err(error) = message
.with_content(object! {
"name" => output_name,
"text" => format!("{}\n", line),
})
.send(&mut *self.iopub.lock().await)
.await
{
eprintln!("output {output_name} error: {}", error);
}
}
}
async fn emit_errors(
&self,
errors: &evcxr::Error,
parent_message: &JupyterMessage,
source: &str,
execution_count: u32,
) -> Result<()> {
match errors {
evcxr::Error::CompilationErrors(errors) => {
for error in errors {
let message = format!("{}", error.message().bright_red());
if error.is_from_user_code() {
let file_name = format!("command_{}", execution_count);
let mut traceback = Vec::new();
if let Some(report) =
error.build_report(file_name.clone(), source.to_string(), Theme::Light)
{
let mut s = Vec::new();
report
.write(sources([(file_name, source.to_string())]), &mut s)
.unwrap();
let s = String::from_utf8_lossy(&s);
traceback = s.lines().map(|x| x.to_string()).collect::>();
} else {
for spanned_message in error.spanned_messages() {
for line in &spanned_message.lines {
traceback.push(line.clone());
}
if let Some(span) = &spanned_message.span {
let mut carrots = String::new();
for _ in 1..span.start_column {
carrots.push(' ');
}
for _ in span.start_column..span.end_column {
carrots.push('^');
}
traceback.push(format!(
"{} {}",
carrots.bright_red(),
spanned_message.label.bright_blue()
));
} else {
traceback.push(spanned_message.label.clone());
}
}
traceback.push(error.message());
for help in error.help() {
traceback.push(format!("{}: {}", "help".bold(), help));
}
}
parent_message
.new_message("error")
.with_content(object! {
"ename" => "Error",
"evalue" => error.message(),
"traceback" => traceback,
})
.send(&mut *self.iopub.lock().await)
.await?;
} else {
parent_message
.new_message("error")
.with_content(object! {
"ename" => "Error",
"evalue" => error.message(),
"traceback" => array![
message
],
})
.send(&mut *self.iopub.lock().await)
.await?;
}
}
}
error => {
let displayed_error = format!("{}", error);
parent_message
.new_message("error")
.with_content(object! {
"ename" => "Error",
"evalue" => displayed_error.clone(),
"traceback" => array![displayed_error],
})
.send(&mut *self.iopub.lock().await)
.await?;
}
}
Ok(())
}
}
async fn comm_open(
message: JupyterMessage,
context: &Arc<:sync::mutex>>,
iopub: Arc>>,
) -> Result<()> {
if message.target_name() == "evcxr-cargo-check" {
let context = Arc::clone(context);
tokio::spawn(async move {
if let Some(code) = message.data()["code"].as_str() {
let data = cargo_check(code.to_owned(), context).await;
let response_content = object! {
"comm_id" => message.comm_id(),
"data" => data,
};
message
.new_message("comm_msg")
.without_parent_header()
.with_content(response_content)
.send(&mut *iopub.lock().await)
.await
.unwrap();
}
message
.comm_close_message()
.send(&mut *iopub.lock().await)
.await
.unwrap();
});
Ok(())
} else {
// Unrecognised comm target, just close the comm.
message
.comm_close_message()
.send(&mut *iopub.lock().await)
.await
}
}
async fn cargo_check(code: String, context: Arc<:sync::mutex>>) -> JsonValue {
let problems = tokio::task::spawn_blocking(move || {
context.lock().unwrap().check(&code).unwrap_or_default()
})
.await
.unwrap_or_default();
let problems_json: Vec = problems
.iter()
.filter_map(|problem| {
if let Some(primary_spanned_message) = problem.primary_spanned_message() {
if let Some(span) = primary_spanned_message.span {
use std::fmt::Write;
let mut message = primary_spanned_message.label.clone();
if !message.is_empty() {
message.push('\n');
}
message.push_str(&problem.message());
for help in problem.help() {
write!(message, "\nhelp: {}", help).unwrap();
}
return Some(object! {
"message" => message,
"severity" => problem.level(),
"start_line" => span.start_line,
"start_column" => span.start_column,
"end_column" => span.end_column,
"end_line" => span.end_line,
});
}
}
None
})
.collect();
object! {
"problems" => problems_json,
}
}
async fn bind_socket(
config: &control_file::Control,
port: u16,
group: Option,
) -> Result> {
let endpoint = format!("{}://{}:{}", config.transport, config.ip, port);
let mut socket = S::new();
socket.bind(&endpoint).await?;
Connection::new(socket, &config.key, group)
}
/// See [Kernel info documentation](https://jupyter-client.readthedocs.io/en/stable/messaging.html#kernel-info)
fn kernel_info() -> JsonValue {
object! {
"protocol_version" => "5.3",
"implementation" => env!("CARGO_PKG_NAME"),
"implementation_version" => env!("CARGO_PKG_VERSION"),
"language_info" => object!{
"name" => "Rust",
"version" => "",
"mimetype" => "text/rust",
"file_extension" => ".rs",
// Pygments lexer, for highlighting Only needed if it differs from the 'name' field.
// see http://pygments.org/docs/lexers/#lexers-for-the-rust-language
"pygment_lexer" => "rust",
// Codemirror mode, for for highlighting in the notebook. Only needed if it differs from the 'name' field.
// codemirror use text/x-rustsrc as mimetypes
// see https://codemirror.net/mode/rust/
"codemirror_mode" => "rust",
},
"banner" => format!("EvCxR {} - Evaluation Context for Rust", env!("CARGO_PKG_VERSION")),
"help_links" => array![
object!{"text" => "Rust std docs",
"url" => "https://doc.rust-lang.org/stable/std/"}
],
"status" => "ok"
}
}
async fn handle_completion_request(
context: &Arc<:sync::mutex>>,
message: JupyterMessage,
) -> Result {
let context = Arc::clone(context);
tokio::task::spawn_blocking(move || {
let code = message.code();
let completions = context.lock().unwrap().completions(
code,
grapheme_offset_to_byte_offset(code, message.cursor_pos()),
)?;
let matches: Vec = completions
.completions
.into_iter()
.map(|completion| completion.code)
.collect();
Ok(object! {
"status" => "ok",
"matches" => matches,
"cursor_start" => byte_offset_to_grapheme_offset(code, completions.start_offset)?,
"cursor_end" => byte_offset_to_grapheme_offset(code, completions.end_offset)?,
"metadata" => object!{},
})
})
.await?
}
/// Returns the byte offset for the start of the specified grapheme. Any grapheme beyond the last
/// grapheme will return the end position of the input.
fn grapheme_offset_to_byte_offset(code: &str, grapheme_offset: usize) -> usize {
unicode_segmentation::UnicodeSegmentation::grapheme_indices(code, true)
.nth(grapheme_offset)
.map(|(byte_offset, _)| byte_offset)
.unwrap_or_else(|| code.len())
}
/// Returns the grapheme offset of the grapheme that starts at
fn byte_offset_to_grapheme_offset(code: &str, target_byte_offset: usize) -> Result {
let mut grapheme_offset = 0;
for (byte_offset, _) in unicode_segmentation::UnicodeSegmentation::grapheme_indices(code, true)
{
if byte_offset == target_byte_offset {
break;
}
if byte_offset > target_byte_offset {
bail!(
"Byte offset {} is not on a grapheme boundary in '{}'",
target_byte_offset,
code
);
}
grapheme_offset += 1;
}
Ok(grapheme_offset)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn grapheme_offsets() {
let src = "aÌeÌx";
assert_eq!(grapheme_offset_to_byte_offset(src, 0), 0);
assert_eq!(grapheme_offset_to_byte_offset(src, 1), 3);
assert_eq!(grapheme_offset_to_byte_offset(src, 2), 6);
assert_eq!(grapheme_offset_to_byte_offset(src, 3), 7);
assert_eq!(byte_offset_to_grapheme_offset(src, 0).unwrap(), 0);
assert_eq!(byte_offset_to_grapheme_offset(src, 3).unwrap(), 1);
assert_eq!(byte_offset_to_grapheme_offset(src, 6).unwrap(), 2);
assert_eq!(byte_offset_to_grapheme_offset(src, 7).unwrap(), 3);
}
}