).
@@ -19,13 +24,19 @@ def get(self):
class RedirectHandler(ExtensionHandlerMixin, JupyterHandler):
+ """A redirect handler."""
+
def get(self):
+ """Handle a redirect."""
self.redirect(f"/static/{self.name}/favicon.ico")
class ParameterHandler(ExtensionHandlerMixin, JupyterHandler):
+ """A parameterized handler."""
+
def get(self, matched_part=None, *args, **kwargs):
- var1 = self.get_argument("var1", default=None)
+ """Handle a get with parameters."""
+ var1 = self.get_argument("var1", default="")
components = [x for x in self.request.path.split("/") if x]
self.write("Hello Simple App 1 from Handler.
")
self.write(f"matched_part: {url_escape(matched_part)}
")
@@ -34,21 +45,28 @@ def get(self, matched_part=None, *args, **kwargs):
class BaseTemplateHandler(ExtensionHandlerJinjaMixin, ExtensionHandlerMixin, JupyterHandler):
- pass
+ """The base template handler."""
class TypescriptHandler(BaseTemplateHandler):
+ """A typescript handler."""
+
def get(self):
+ """Get the typescript template."""
self.write(self.render_template("typescript.html"))
class TemplateHandler(BaseTemplateHandler):
+ """A template handler."""
+
def get(self, path):
"""Optionally, you can print(self.get_template('simple1.html'))"""
self.write(self.render_template("simple1.html", path=path))
class ErrorHandler(BaseTemplateHandler):
+ """An error handler."""
+
def get(self, path):
- # write_error renders template from error.html file.
+ """Write_error renders template from error.html file."""
self.write_error(400)
diff --git a/examples/simple/simple_ext1/static/bundle.js b/examples/simple/simple_ext1/static/bundle.js
deleted file mode 100644
index 9017d3883d..0000000000
--- a/examples/simple/simple_ext1/static/bundle.js
+++ /dev/null
@@ -1,144 +0,0 @@
-/******/ (function (modules) {
- // webpackBootstrap
- /******/ // The module cache
- /******/ var installedModules = {}; // The require function
- /******/
- /******/ /******/ function __webpack_require__(moduleId) {
- /******/
- /******/ // Check if module is in cache
- /******/ if (installedModules[moduleId]) {
- /******/ return installedModules[moduleId].exports;
- /******/
- } // Create a new module (and put it into the cache)
- /******/ /******/ var module = (installedModules[moduleId] = {
- /******/ i: moduleId,
- /******/ l: false,
- /******/ exports: {},
- /******/
- }); // Execute the module function
- /******/
- /******/ /******/ modules[moduleId].call(
- module.exports,
- module,
- module.exports,
- __webpack_require__
- ); // Flag the module as loaded
- /******/
- /******/ /******/ module.l = true; // Return the exports of the module
- /******/
- /******/ /******/ return module.exports;
- /******/
- } // expose the modules object (__webpack_modules__)
- /******/
- /******/
- /******/ /******/ __webpack_require__.m = modules; // expose the module cache
- /******/
- /******/ /******/ __webpack_require__.c = installedModules; // define getter function for harmony exports
- /******/
- /******/ /******/ __webpack_require__.d = function (exports, name, getter) {
- /******/ if (!__webpack_require__.o(exports, name)) {
- /******/ Object.defineProperty(exports, name, {
- enumerable: true,
- get: getter,
- });
- /******/
- }
- /******/
- }; // define __esModule on exports
- /******/
- /******/ /******/ __webpack_require__.r = function (exports) {
- /******/ if (typeof Symbol !== "undefined" && Symbol.toStringTag) {
- /******/ Object.defineProperty(exports, Symbol.toStringTag, {
- value: "Module",
- });
- /******/
- }
- /******/ Object.defineProperty(exports, "__esModule", { value: true });
- /******/
- }; // create a fake namespace object // mode & 1: value is a module id, require it // mode & 2: merge all properties of value into the ns // mode & 4: return value when already ns object // mode & 8|1: behave like require
- /******/
- /******/ /******/ /******/ /******/ /******/ /******/ __webpack_require__.t =
- function (value, mode) {
- /******/ if (mode & 1) value = __webpack_require__(value);
- /******/ if (mode & 8) return value;
- /******/ if (
- mode & 4 &&
- typeof value === "object" &&
- value &&
- value.__esModule
- )
- return value;
- /******/ var ns = Object.create(null);
- /******/ __webpack_require__.r(ns);
- /******/ Object.defineProperty(ns, "default", {
- enumerable: true,
- value: value,
- });
- /******/ if (mode & 2 && typeof value != "string")
- for (var key in value)
- __webpack_require__.d(
- ns,
- key,
- function (key) {
- return value[key];
- }.bind(null, key)
- );
- /******/ return ns;
- /******/
- }; // getDefaultExport function for compatibility with non-harmony modules
- /******/
- /******/ /******/ __webpack_require__.n = function (module) {
- /******/ var getter =
- module && module.__esModule
- ? /******/ function getDefault() {
- return module["default"];
- }
- : /******/ function getModuleExports() {
- return module;
- };
- /******/ __webpack_require__.d(getter, "a", getter);
- /******/ return getter;
- /******/
- }; // Object.prototype.hasOwnProperty.call
- /******/
- /******/ /******/ __webpack_require__.o = function (object, property) {
- return Object.prototype.hasOwnProperty.call(object, property);
- }; // __webpack_public_path__
- /******/
- /******/ /******/ __webpack_require__.p = ""; // Load entry module and return exports
- /******/
- /******/
- /******/ /******/ return __webpack_require__((__webpack_require__.s = 0));
- /******/
-})(
- /************************************************************************/
- /******/ {
- /***/ "./simple_ext1/static/index.js":
- /*!*************************************!*\
- !*** ./simple_ext1/static/index.js ***!
- \*************************************/
- /*! no static exports found */
- /***/ function (module, exports) {
- eval(
- 'function main() {\n let div = document.getElementById("mydiv");\n div.innerText = "Hello from Typescript";\n}\nwindow.addEventListener(\'load\', main);\n\n\n//# sourceURL=webpack:///./simple_ext1/static/index.js?'
- );
-
- /***/
- },
-
- /***/ 0:
- /*!*******************************************!*\
- !*** multi ./simple_ext1/static/index.js ***!
- \*******************************************/
- /*! no static exports found */
- /***/ function (module, exports, __webpack_require__) {
- eval(
- 'module.exports = __webpack_require__(/*! ./simple_ext1/static/index.js */"./simple_ext1/static/index.js");\n\n\n//# sourceURL=webpack:///multi_./simple_ext1/static/index.js?'
- );
-
- /***/
- },
-
- /******/
- }
-);
diff --git a/examples/simple/simple_ext1/static/index.js b/examples/simple/simple_ext1/static/index.js
index 4cc84b9bc3..a6c59f1086 100644
--- a/examples/simple/simple_ext1/static/index.js
+++ b/examples/simple/simple_ext1/static/index.js
@@ -1,5 +1,5 @@
function main() {
- let div = document.getElementById("mydiv");
- div.innerText = "Hello from Typescript";
+ let div = document.getElementById("mydiv");
+ div.innerText = "Hello from Typescript";
}
window.addEventListener("load", main);
diff --git a/examples/simple/simple_ext1/static/tsconfig.tsbuildinfo b/examples/simple/simple_ext1/static/tsconfig.tsbuildinfo
index 8167ef00a2..27452c1246 100644
--- a/examples/simple/simple_ext1/static/tsconfig.tsbuildinfo
+++ b/examples/simple/simple_ext1/static/tsconfig.tsbuildinfo
@@ -50,7 +50,7 @@
"signature": "3e0a459888f32b42138d5a39f706ff2d55d500ab1031e0988b5568b0f67c2303"
},
"../../src/index.ts": {
- "version": "fd4f62325debd29128c1990caa4d546f2c48c21ea133fbcbb3e29f9fbef55e49",
+ "version": "a5398b1577287a9a5a7e190a9a7283ee67b12fcc0dbc6d2cac55ef25ed166bb2",
"signature": "ed4b087ea2a2e4a58647864cf512c7534210bfc2f9d236a2f9ed5245cf7a0896"
}
},
diff --git a/examples/simple/simple_ext11/__init__.py b/examples/simple/simple_ext11/__init__.py
index abe0f73a2a..d7c3e4341b 100644
--- a/examples/simple/simple_ext11/__init__.py
+++ b/examples/simple/simple_ext11/__init__.py
@@ -1,3 +1,4 @@
+"""Extension entry point."""
from .application import SimpleApp11
diff --git a/examples/simple/simple_ext11/__main__.py b/examples/simple/simple_ext11/__main__.py
index 317a0bd1f5..90b15cbc92 100644
--- a/examples/simple/simple_ext11/__main__.py
+++ b/examples/simple/simple_ext11/__main__.py
@@ -1,3 +1,4 @@
+"""Application cli main."""
from .application import main
if __name__ == "__main__":
diff --git a/examples/simple/simple_ext11/application.py b/examples/simple/simple_ext11/application.py
index fb4e6f846f..398716f213 100644
--- a/examples/simple/simple_ext11/application.py
+++ b/examples/simple/simple_ext11/application.py
@@ -1,6 +1,7 @@
+"""A Jupyter Server example application."""
import os
-from simple_ext1.application import SimpleApp1
+from simple_ext1.application import SimpleApp1 # type:ignore[import-not-found]
from traitlets import Bool, Unicode, observe
from jupyter_server.serverapp import aliases, flags
@@ -10,6 +11,8 @@
class SimpleApp11(SimpleApp1):
+ """A simple application."""
+
flags["hello"] = ({"SimpleApp11": {"hello": True}}, "Say hello on startup.")
aliases.update(
{
@@ -20,7 +23,7 @@ class SimpleApp11(SimpleApp1):
# The name of the extension.
name = "simple_ext11"
- # Te url that your extension will serve its homepage.
+ # The url that your extension will serve its homepage.
extension_url = "/simple_ext11/default"
# Local path to static files directory.
@@ -53,6 +56,7 @@ def simple11_dir_formatted(self):
return "/" + self.simple11_dir
def initialize_settings(self):
+ """Initialize settings."""
self.log.info(f"hello: {self.hello}")
if self.hello is True:
self.log.info(
@@ -62,6 +66,7 @@ def initialize_settings(self):
super().initialize_settings()
def initialize_handlers(self):
+ """Initialize handlers."""
super().initialize_handlers()
diff --git a/examples/simple/simple_ext2/__init__.py b/examples/simple/simple_ext2/__init__.py
index ffe7bc43c3..3059dbda49 100644
--- a/examples/simple/simple_ext2/__init__.py
+++ b/examples/simple/simple_ext2/__init__.py
@@ -1,3 +1,4 @@
+"""The extension entry point."""
from .application import SimpleApp2
diff --git a/examples/simple/simple_ext2/__main__.py b/examples/simple/simple_ext2/__main__.py
index 317a0bd1f5..465db9c1c2 100644
--- a/examples/simple/simple_ext2/__main__.py
+++ b/examples/simple/simple_ext2/__main__.py
@@ -1,3 +1,4 @@
+"""The application cli main."""
from .application import main
if __name__ == "__main__":
diff --git a/examples/simple/simple_ext2/application.py b/examples/simple/simple_ext2/application.py
index 5ca3fac882..b9da358131 100644
--- a/examples/simple/simple_ext2/application.py
+++ b/examples/simple/simple_ext2/application.py
@@ -1,3 +1,4 @@
+"""A simple Jupyter Server extension example."""
import os
from traitlets import Unicode
@@ -11,25 +12,27 @@
class SimpleApp2(ExtensionAppJinjaMixin, ExtensionApp):
+ """A simple application."""
# The name of the extension.
name = "simple_ext2"
- # Te url that your extension will serve its homepage.
+ # The url that your extension will serve its homepage.
extension_url = "/simple_ext2"
# Should your extension expose other server extensions when launched directly?
load_other_extensions = True
# Local path to static files directory.
- static_paths = [DEFAULT_STATIC_FILES_PATH]
+ static_paths = [DEFAULT_STATIC_FILES_PATH] # type:ignore[assignment]
# Local path to templates directory.
- template_paths = [DEFAULT_TEMPLATE_FILES_PATH]
+ template_paths = [DEFAULT_TEMPLATE_FILES_PATH] # type:ignore[assignment]
configD = Unicode("", config=True, help="Config D example.")
def initialize_handlers(self):
+ """Initialize handlers."""
self.handlers.extend(
[
(r"/simple_ext2/params/(.+)$", ParameterHandler),
@@ -40,6 +43,7 @@ def initialize_handlers(self):
)
def initialize_settings(self):
+ """Initialize settings."""
self.log.info(f"Config {self.config}")
diff --git a/examples/simple/simple_ext2/handlers.py b/examples/simple/simple_ext2/handlers.py
index acd908cfb5..4f52e6f061 100644
--- a/examples/simple/simple_ext2/handlers.py
+++ b/examples/simple/simple_ext2/handlers.py
@@ -1,14 +1,15 @@
+"""API handlers for the Jupyter Server example."""
from jupyter_server.base.handlers import JupyterHandler
-from jupyter_server.extension.handler import (
- ExtensionHandlerJinjaMixin,
- ExtensionHandlerMixin,
-)
+from jupyter_server.extension.handler import ExtensionHandlerJinjaMixin, ExtensionHandlerMixin
from jupyter_server.utils import url_escape
class ParameterHandler(ExtensionHandlerMixin, JupyterHandler):
+ """A parameterized handler."""
+
def get(self, matched_part=None, *args, **kwargs):
- var1 = self.get_argument("var1", default=None)
+ """Get a parameterized response."""
+ var1 = self.get_argument("var1", default="")
components = [x for x in self.request.path.split("/") if x]
self.write("Hello Simple App 2 from Handler.
")
self.write(f"matched_part: {url_escape(matched_part)}
")
@@ -17,20 +18,28 @@ def get(self, matched_part=None, *args, **kwargs):
class BaseTemplateHandler(ExtensionHandlerJinjaMixin, ExtensionHandlerMixin, JupyterHandler):
- pass
+ """A base template handler."""
class IndexHandler(BaseTemplateHandler):
+ """The root API handler."""
+
def get(self):
+ """Get the root response."""
self.write(self.render_template("index.html"))
class TemplateHandler(BaseTemplateHandler):
+ """A template handler."""
+
def get(self, path):
- print(self.get_template("simple_ext2.html"))
+ """Get the template for the path."""
self.write(self.render_template("simple_ext2.html", path=path))
class ErrorHandler(BaseTemplateHandler):
+ """An error handler."""
+
def get(self, path):
+ """Handle the error."""
self.write_error(400)
diff --git a/examples/simple/tests/test_handlers.py b/examples/simple/tests/test_handlers.py
index a46bb2b868..59b9d045ae 100644
--- a/examples/simple/tests/test_handlers.py
+++ b/examples/simple/tests/test_handlers.py
@@ -1,20 +1,43 @@
+"""Tests for the simple handler."""
import pytest
-@pytest.fixture
-def jp_server_config(jp_template_dir):
+@pytest.fixture()
+def jp_server_auth_resources(jp_server_auth_core_resources):
+ """The server auth resources."""
+ for url_regex in [
+ "/simple_ext1/default",
+ ]:
+ jp_server_auth_core_resources[url_regex] = "simple_ext1:default"
+ return jp_server_auth_core_resources
+
+
+@pytest.fixture()
+def jp_server_config(jp_template_dir, jp_server_authorizer):
+ """The server config."""
return {
- "ServerApp": {"jpserver_extensions": {"simple_ext1": True}},
+ "ServerApp": {
+ "jpserver_extensions": {"simple_ext1": True},
+ "authorizer_class": jp_server_authorizer,
+ },
}
-async def test_handler_default(jp_fetch):
+async def test_handler_default(jp_fetch, jp_serverapp):
+ """Test the default handler."""
+ jp_serverapp.authorizer.permissions = {
+ "actions": ["read"],
+ "resources": [
+ "simple_ext1:default",
+ ],
+ }
r = await jp_fetch("simple_ext1/default", method="GET")
assert r.code == 200
assert r.body.decode().index("Hello Simple 1 - I am the default...") > -1
async def test_handler_template(jp_fetch):
+ """Test the template handler."""
path = "/custom/path"
r = await jp_fetch(f"simple_ext1/template1/{path}", method="GET")
assert r.code == 200
@@ -22,10 +45,12 @@ async def test_handler_template(jp_fetch):
async def test_handler_typescript(jp_fetch):
+ """Test the typescript handler."""
r = await jp_fetch("simple_ext1/typescript", method="GET")
assert r.code == 200
async def test_handler_error(jp_fetch):
+ """Test the error handler."""
r = await jp_fetch("simple_ext1/nope", method="GET")
assert r.body.decode().index("400 : Bad Request") > -1
diff --git a/examples/simple/webpack.config.js b/examples/simple/webpack.config.js
index c0f4735649..5acc57fa89 100644
--- a/examples/simple/webpack.config.js
+++ b/examples/simple/webpack.config.js
@@ -3,6 +3,7 @@ module.exports = {
output: {
path: require("path").join(__dirname, "simple_ext1", "static"),
filename: "bundle.js",
+ hashFunction: 'sha256'
},
mode: "development",
};
diff --git a/jupyter_server/__init__.py b/jupyter_server/__init__.py
index d5b97f0c90..3d85bbd2c8 100644
--- a/jupyter_server/__init__.py
+++ b/jupyter_server/__init__.py
@@ -1,7 +1,6 @@
"""The Jupyter Server"""
import os
-import subprocess
-import sys
+import pathlib
DEFAULT_STATIC_FILES_PATH = os.path.join(os.path.dirname(__file__), "static")
DEFAULT_TEMPLATE_PATH_LIST = [
@@ -10,20 +9,19 @@
]
DEFAULT_JUPYTER_SERVER_PORT = 8888
-
-del os
-
-from ._version import __version__, version_info # noqa
-
-
-def _cleanup():
- pass
-
-
-# patch subprocess on Windows for python<3.7
-# see https://bugs.python.org/issue37380
-# the fix for python3.7: https://github.com/python/cpython/pull/15706/files
-if sys.platform == "win32":
- if sys.version_info < (3, 7):
- subprocess._cleanup = _cleanup
- subprocess._active = None
+JUPYTER_SERVER_EVENTS_URI = "https://events.jupyter.org/jupyter_server"
+DEFAULT_EVENTS_SCHEMA_PATH = pathlib.Path(__file__).parent / "event_schemas"
+
+from ._version import __version__, version_info
+from .base.call_context import CallContext
+
+__all__ = [
+ "DEFAULT_STATIC_FILES_PATH",
+ "DEFAULT_TEMPLATE_PATH_LIST",
+ "DEFAULT_JUPYTER_SERVER_PORT",
+ "JUPYTER_SERVER_EVENTS_URI",
+ "DEFAULT_EVENTS_SCHEMA_PATH",
+ "__version__",
+ "version_info",
+ "CallContext",
+]
diff --git a/jupyter_server/__main__.py b/jupyter_server/__main__.py
index 6ada4be7ea..70a8ef34bd 100644
--- a/jupyter_server/__main__.py
+++ b/jupyter_server/__main__.py
@@ -1,3 +1,5 @@
+"""The main entry point for Jupyter Server."""
+
if __name__ == "__main__":
from jupyter_server import serverapp as app
diff --git a/jupyter_server/_sysinfo.py b/jupyter_server/_sysinfo.py
index a0a430bb2c..f167c4e92a 100644
--- a/jupyter_server/_sysinfo.py
+++ b/jupyter_server/_sysinfo.py
@@ -73,17 +73,17 @@ def pkg_info(pkg_path):
with named parameters of interest
"""
src, hsh = pkg_commit_hash(pkg_path)
- return dict(
- jupyter_server_version=jupyter_server.__version__,
- jupyter_server_path=pkg_path,
- commit_source=src,
- commit_hash=hsh,
- sys_version=sys.version,
- sys_executable=sys.executable,
- sys_platform=sys.platform,
- platform=platform.platform(),
- os_name=os.name,
- )
+ return {
+ "jupyter_server_version": jupyter_server.__version__,
+ "jupyter_server_path": pkg_path,
+ "commit_source": src,
+ "commit_hash": hsh,
+ "sys_version": sys.version,
+ "sys_executable": sys.executable,
+ "sys_platform": sys.platform,
+ "platform": platform.platform(),
+ "os_name": os.name,
+ }
def get_sys_info():
diff --git a/jupyter_server/_tz.py b/jupyter_server/_tz.py
index 41d8222889..a7a495de85 100644
--- a/jupyter_server/_tz.py
+++ b/jupyter_server/_tz.py
@@ -5,7 +5,9 @@
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
-from datetime import datetime, timedelta, tzinfo
+from __future__ import annotations
+
+from datetime import datetime, timedelta, timezone, tzinfo
# constant for zero offset
ZERO = timedelta(0)
@@ -14,31 +16,28 @@
class tzUTC(tzinfo):
"""tzinfo object for UTC (zero offset)"""
- def utcoffset(self, d):
+ def utcoffset(self, d: datetime | None) -> timedelta:
+ """Compute utcoffset."""
return ZERO
- def dst(self, d):
+ def dst(self, d: datetime | None) -> timedelta:
+ """Compute dst."""
return ZERO
-UTC = tzUTC()
-
-
-def utc_aware(unaware):
- """decorator for adding UTC tzinfo to datetime's utcfoo methods"""
+def utcnow() -> datetime:
+ """Return timezone-aware UTC timestamp"""
+ return datetime.now(timezone.utc)
- def utc_method(*args, **kwargs):
- dt = unaware(*args, **kwargs)
- return dt.replace(tzinfo=UTC)
- return utc_method
+def utcfromtimestamp(timestamp: float) -> datetime:
+ return datetime.fromtimestamp(timestamp, timezone.utc)
-utcfromtimestamp = utc_aware(datetime.utcfromtimestamp)
-utcnow = utc_aware(datetime.utcnow)
+UTC = tzUTC() # type:ignore[abstract]
-def isoformat(dt):
+def isoformat(dt: datetime) -> str:
"""Return iso-formatted timestamp
Like .isoformat(), but uses Z for UTC instead of +00:00
diff --git a/jupyter_server/_version.py b/jupyter_server/_version.py
index 77a1118f9b..bef2f8e281 100644
--- a/jupyter_server/_version.py
+++ b/jupyter_server/_version.py
@@ -2,5 +2,17 @@
store the current version info of the server.
"""
-version_info = (1, 16, 1, ".dev", "0")
-__version__ = ".".join(map(str, version_info[:3])) + "".join(version_info[3:])
+import re
+from typing import List
+
+# Version string must appear intact for automatic versioning
+__version__ = "2.11.2"
+
+# Build up version_info tuple for backwards compatibility
+pattern = r"(?P\d+).(?P\d+).(?P\d+)(?P.*)"
+match = re.match(pattern, __version__)
+assert match is not None
+parts: List[object] = [int(match[part]) for part in ["major", "minor", "patch"]]
+if match["rest"]:
+ parts.append(match["rest"])
+version_info = tuple(parts)
diff --git a/jupyter_server/auth/__init__.py b/jupyter_server/auth/__init__.py
index 54477ffd1b..36418f214b 100644
--- a/jupyter_server/auth/__init__.py
+++ b/jupyter_server/auth/__init__.py
@@ -1,3 +1,4 @@
-from .authorizer import * # noqa
-from .decorator import authorized # noqa
-from .security import passwd # noqa
+from .authorizer import *
+from .decorator import authorized
+from .identity import *
+from .security import passwd
diff --git a/jupyter_server/auth/__main__.py b/jupyter_server/auth/__main__.py
index b34a3189c1..d1573f11a1 100644
--- a/jupyter_server/auth/__main__.py
+++ b/jupyter_server/auth/__main__.py
@@ -1,22 +1,27 @@
+"""The cli for auth."""
import argparse
import sys
+import warnings
from getpass import getpass
from jupyter_core.paths import jupyter_config_dir
+from traitlets.log import get_logger
-from jupyter_server.auth import passwd
+from jupyter_server.auth import passwd # type:ignore[attr-defined]
from jupyter_server.config_manager import BaseJSONConfigManager
def set_password(args):
+ """Set a password."""
password = args.password
+
while not password:
password1 = getpass("" if args.quiet else "Provide password: ")
password_repeat = getpass("" if args.quiet else "Repeat password: ")
if password1 != password_repeat:
- print("Passwords do not match, try again")
+ warnings.warn("Passwords do not match, try again", stacklevel=2)
elif len(password1) < 4:
- print("Please provide at least 4 characters")
+ warnings.warn("Please provide at least 4 characters", stacklevel=2)
else:
password = password1
@@ -31,10 +36,12 @@ def set_password(args):
},
)
if not args.quiet:
- print("password stored in config dir: %s" % jupyter_config_dir())
+ log = get_logger()
+ log.info("password stored in config dir: %s" % jupyter_config_dir())
def main(argv):
+ """The main cli handler."""
parser = argparse.ArgumentParser(argv[0])
subparsers = parser.add_subparsers()
parser_password = subparsers.add_parser(
diff --git a/jupyter_server/auth/authorizer.py b/jupyter_server/auth/authorizer.py
index 952cb0278d..f22dbe5463 100644
--- a/jupyter_server/auth/authorizer.py
+++ b/jupyter_server/auth/authorizer.py
@@ -7,9 +7,17 @@
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from traitlets import Instance
from traitlets.config import LoggingConfigurable
-from jupyter_server.base.handlers import JupyterHandler
+from .identity import IdentityProvider, User
+
+if TYPE_CHECKING:
+ from jupyter_server.base.handlers import JupyterHandler
class Authorizer(LoggingConfigurable):
@@ -18,10 +26,10 @@ class Authorizer(LoggingConfigurable):
All authorizers used in Jupyter Server
should inherit from this base class and, at the very minimum,
- implement an `is_authorized` method with the
+ implement an ``is_authorized`` method with the
same signature as in this base class.
- The `is_authorized` method is called by the `@authorized` decorator
+ The ``is_authorized`` method is called by the ``@authorized`` decorator
in JupyterHandler. If it returns True, the incoming request
to the server is accepted; if it returns False, the server
returns a 403 (Forbidden) error code.
@@ -32,23 +40,30 @@ class Authorizer(LoggingConfigurable):
.. versionadded:: 2.0
"""
- def is_authorized(self, handler: JupyterHandler, user: str, action: str, resource: str) -> bool:
- """A method to determine if `user` is authorized to perform `action`
- (read, write, or execute) on the `resource` type.
+ identity_provider = Instance(IdentityProvider)
+
+ def is_authorized(
+ self, handler: JupyterHandler, user: User, action: str, resource: str
+ ) -> bool:
+ """A method to determine if ``user`` is authorized to perform ``action``
+ (read, write, or execute) on the ``resource`` type.
Parameters
----------
- user : usually a dict or string
- A truthy model representing the authenticated user.
- A username string by default,
- but usually a dict when integrating with an auth provider.
+ user : jupyter_server.auth.User
+ An object representing the authenticated user,
+ as returned by :meth:`jupyter_server.auth.IdentityProvider.get_user`.
+
action : str
the category of action for the current request: read, write, or execute.
resource : str
the type of resource (i.e. contents, kernels, files, etc.) the user is requesting.
- Returns True if user authorized to make request; otherwise, returns False.
+ Returns
+ -------
+ bool
+ True if user authorized to make request; False, otherwise
"""
raise NotImplementedError()
@@ -61,7 +76,9 @@ class AllowAllAuthorizer(Authorizer):
.. versionadded:: 2.0
"""
- def is_authorized(self, handler: JupyterHandler, user: str, action: str, resource: str) -> bool:
+ def is_authorized(
+ self, handler: JupyterHandler, user: User, action: str, resource: str
+ ) -> bool:
"""This method always returns True.
All authenticated users are allowed to do anything in the Jupyter Server.
diff --git a/jupyter_server/auth/decorator.py b/jupyter_server/auth/decorator.py
index 72a489dbe9..fd38cda1e7 100644
--- a/jupyter_server/auth/decorator.py
+++ b/jupyter_server/auth/decorator.py
@@ -3,19 +3,21 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from functools import wraps
-from typing import Callable, Optional, Union
+from typing import Any, Callable, Optional, TypeVar, Union, cast
from tornado.log import app_log
from tornado.web import HTTPError
-from .utils import HTTP_METHOD_TO_AUTH_ACTION, warn_disabled_authorization
+from .utils import HTTP_METHOD_TO_AUTH_ACTION
+
+FuncT = TypeVar("FuncT", bound=Callable[..., Any])
def authorized(
- action: Optional[Union[str, Callable]] = None,
+ action: Optional[Union[str, FuncT]] = None,
resource: Optional[str] = None,
message: Optional[str] = None,
-) -> Callable:
+) -> FuncT:
"""A decorator for tornado.web.RequestHandler methods
that verifies whether the current user is authorized
to make the following request.
@@ -57,18 +59,13 @@ def inner(self, *args, **kwargs):
if not user:
app_log.warning("Attempting to authorize request without authentication!")
raise HTTPError(status_code=403, log_message=message)
-
- # Handle the case where an authorizer wasn't attached to the handler.
- if not self.authorizer:
- warn_disabled_authorization()
- return method(self, *args, **kwargs)
-
- # Only return the method if the action is authorized.
+ # If the user is allowed to do this action,
+ # call the method.
if self.authorizer.is_authorized(self, user, action, resource):
return method(self, *args, **kwargs)
-
- # Raise an exception if the method wasn't returned (i.e. not authorized)
- raise HTTPError(status_code=403, log_message=message)
+ # else raise an exception.
+ else:
+ raise HTTPError(status_code=403, log_message=message)
return inner
@@ -76,6 +73,6 @@ def inner(self, *args, **kwargs):
method = action
action = None
# no-arguments `@authorized` decorator called
- return wrapper(method)
+ return cast(FuncT, wrapper(method))
- return wrapper
+ return cast(FuncT, wrapper)
diff --git a/jupyter_server/auth/identity.py b/jupyter_server/auth/identity.py
new file mode 100644
index 0000000000..adeb567b5b
--- /dev/null
+++ b/jupyter_server/auth/identity.py
@@ -0,0 +1,738 @@
+"""Identity Provider interface
+
+This defines the _authentication_ layer of Jupyter Server,
+to be used in combination with Authorizer for _authorization_.
+
+.. versionadded:: 2.0
+"""
+from __future__ import annotations
+
+import binascii
+import datetime
+import json
+import os
+import re
+import sys
+import typing as t
+import uuid
+from dataclasses import asdict, dataclass
+from http.cookies import Morsel
+
+from tornado import escape, httputil, web
+from traitlets import Bool, Dict, Type, Unicode, default
+from traitlets.config import LoggingConfigurable
+
+from jupyter_server.transutils import _i18n
+
+from .security import passwd_check, set_password
+from .utils import get_anonymous_username
+
+_non_alphanum = re.compile(r"[^A-Za-z0-9]")
+
+
+@dataclass
+class User:
+ """Object representing a User
+
+ This or a subclass should be returned from IdentityProvider.get_user
+ """
+
+ username: str # the only truly required field
+
+ # these fields are filled from username if not specified
+ # name is the 'real' name of the user
+ name: str = ""
+ # display_name is a shorter name for us in UI,
+ # if different from name. e.g. a nickname
+ display_name: str = ""
+
+ # these fields are left as None if undefined
+ initials: str | None = None
+ avatar_url: str | None = None
+ color: str | None = None
+
+ # TODO: extension fields?
+ # ext: Dict[str, Dict[str, Any]] = field(default_factory=dict)
+
+ def __post_init__(self):
+ self.fill_defaults()
+
+ def fill_defaults(self):
+ """Fill out default fields in the identity model
+
+ - Ensures all values are defined
+ - Fills out derivative values for name fields fields
+ - Fills out null values for optional fields
+ """
+
+ # username is the only truly required field
+ if not self.username:
+ msg = f"user.username must not be empty: {self}"
+ raise ValueError(msg)
+
+ # derive name fields from username -> name -> display name
+ if not self.name:
+ self.name = self.username
+ if not self.display_name:
+ self.display_name = self.name
+
+
+def _backward_compat_user(got_user: t.Any) -> User:
+ """Backward-compatibility for LoginHandler.get_user
+
+ Prior to 2.0, LoginHandler.get_user could return anything truthy.
+
+ Typically, this was either a simple string username,
+ or a simple dict.
+
+ Make some effort to allow common patterns to keep working.
+ """
+ if isinstance(got_user, str):
+ return User(username=got_user)
+ elif isinstance(got_user, dict):
+ kwargs = {}
+ if "username" not in got_user and "name" in got_user:
+ kwargs["username"] = got_user["name"]
+ for field in User.__dataclass_fields__:
+ if field in got_user:
+ kwargs[field] = got_user[field]
+ try:
+ return User(**kwargs)
+ except TypeError:
+ msg = f"Unrecognized user: {got_user}"
+ raise ValueError(msg) from None
+ else:
+ msg = f"Unrecognized user: {got_user}"
+ raise ValueError(msg)
+
+
+class IdentityProvider(LoggingConfigurable):
+ """
+ Interface for providing identity management and authentication.
+
+ Two principle methods:
+
+ - :meth:`~jupyter_server.auth.IdentityProvider.get_user` returns a :class:`~.User` object
+ for successful authentication, or None for no-identity-found.
+ - :meth:`~jupyter_server.auth.IdentityProvider.identity_model` turns a :class:`~jupyter_server.auth.User` into a JSONable dict.
+ The default is to use :py:meth:`dataclasses.asdict`,
+ and usually shouldn't need override.
+
+ Additional methods can customize authentication.
+
+ .. versionadded:: 2.0
+ """
+
+ cookie_name: str | Unicode[str, str | bytes] = Unicode(
+ "",
+ config=True,
+ help=_i18n("Name of the cookie to set for persisting login. Default: username-${Host}."),
+ )
+
+ cookie_options = Dict(
+ config=True,
+ help=_i18n(
+ "Extra keyword arguments to pass to `set_secure_cookie`."
+ " See tornado's set_secure_cookie docs for details."
+ ),
+ )
+
+ secure_cookie: bool | Bool[bool | None, bool | int | None] = Bool(
+ None,
+ allow_none=True,
+ config=True,
+ help=_i18n(
+ "Specify whether login cookie should have the `secure` property (HTTPS-only)."
+ "Only needed when protocol-detection gives the wrong answer due to proxies."
+ ),
+ )
+
+ get_secure_cookie_kwargs = Dict(
+ config=True,
+ help=_i18n(
+ "Extra keyword arguments to pass to `get_secure_cookie`."
+ " See tornado's get_secure_cookie docs for details."
+ ),
+ )
+
+ token: str | Unicode[str, str | bytes] = Unicode(
+ "",
+ help=_i18n(
+ """Token used for authenticating first-time connections to the server.
+
+ The token can be read from the file referenced by JUPYTER_TOKEN_FILE or set directly
+ with the JUPYTER_TOKEN environment variable.
+
+ When no password is enabled,
+ the default is to generate a new, random token.
+
+ Setting to an empty string disables authentication altogether, which is NOT RECOMMENDED.
+
+ Prior to 2.0: configured as ServerApp.token
+ """
+ ),
+ ).tag(config=True)
+
+ login_handler_class = Type(
+ default_value="jupyter_server.auth.login.LoginFormHandler",
+ klass=web.RequestHandler,
+ config=True,
+ help=_i18n("The login handler class to use, if any."),
+ )
+
+ logout_handler_class = Type(
+ default_value="jupyter_server.auth.logout.LogoutHandler",
+ klass=web.RequestHandler,
+ config=True,
+ help=_i18n("The logout handler class to use."),
+ )
+
+ token_generated = False
+
+ @default("token")
+ def _token_default(self):
+ if os.getenv("JUPYTER_TOKEN"):
+ self.token_generated = False
+ return os.environ["JUPYTER_TOKEN"]
+ if os.getenv("JUPYTER_TOKEN_FILE"):
+ self.token_generated = False
+ with open(os.environ["JUPYTER_TOKEN_FILE"]) as token_file:
+ return token_file.read()
+ if not self.need_token:
+ # no token if password is enabled
+ self.token_generated = False
+ return ""
+ else:
+ self.token_generated = True
+ return binascii.hexlify(os.urandom(24)).decode("ascii")
+
+ need_token: bool | Bool[bool, t.Union[bool, int]] = Bool(True)
+
+ def get_user(self, handler: web.RequestHandler) -> User | None | t.Awaitable[User | None]:
+ """Get the authenticated user for a request
+
+ Must return a :class:`jupyter_server.auth.User`,
+ though it may be a subclass.
+
+ Return None if the request is not authenticated.
+
+ _may_ be a coroutine
+ """
+ return self._get_user(handler)
+
+ # not sure how to have optional-async type signature
+ # on base class with `async def` without splitting it into two methods
+
+ async def _get_user(self, handler: web.RequestHandler) -> User | None:
+ """Get the user."""
+ if getattr(handler, "_jupyter_current_user", None):
+ # already authenticated
+ return t.cast(User, handler._jupyter_current_user) # type:ignore[attr-defined]
+ _token_user: User | None | t.Awaitable[User | None] = self.get_user_token(handler)
+ if isinstance(_token_user, t.Awaitable):
+ _token_user = await _token_user
+ token_user: User | None = _token_user # need second variable name to collapse type
+ _cookie_user = self.get_user_cookie(handler)
+ if isinstance(_cookie_user, t.Awaitable):
+ _cookie_user = await _cookie_user
+ cookie_user: User | None = _cookie_user
+ # prefer token to cookie if both given,
+ # because token is always explicit
+ user = token_user or cookie_user
+
+ if user is not None and token_user is not None:
+ # if token-authenticated, persist user_id in cookie
+ # if it hasn't already been stored there
+ if user != cookie_user:
+ self.set_login_cookie(handler, user)
+ # Record that the current request has been authenticated with a token.
+ # Used in is_token_authenticated above.
+ handler._token_authenticated = True # type:ignore[attr-defined]
+
+ if user is None:
+ # If an invalid cookie was sent, clear it to prevent unnecessary
+ # extra warnings. But don't do this on a request with *no* cookie,
+ # because that can erroneously log you out (see gh-3365)
+ cookie_name = self.get_cookie_name(handler)
+ cookie = handler.get_cookie(cookie_name)
+ if cookie is not None:
+ self.log.warning(f"Clearing invalid/expired login cookie {cookie_name}")
+ self.clear_login_cookie(handler)
+ if not self.auth_enabled:
+ # Completely insecure! No authentication at all.
+ # No need to warn here, though; validate_security will have already done that.
+ user = self.generate_anonymous_user(handler)
+ # persist user on first request
+ # so the user data is stable for a given browser session
+ self.set_login_cookie(handler, user)
+
+ return user
+
+ def identity_model(self, user: User) -> dict[str, t.Any]:
+ """Return a User as an Identity model"""
+ # TODO: validate?
+ return asdict(user)
+
+ def get_handlers(self) -> list[tuple[str, object]]:
+ """Return list of additional handlers for this identity provider
+
+ For example, an OAuth callback handler.
+ """
+ handlers = []
+ if self.login_available:
+ handlers.append((r"/login", self.login_handler_class))
+ if self.logout_available:
+ handlers.append((r"/logout", self.logout_handler_class))
+ return handlers
+
+ def user_to_cookie(self, user: User) -> str:
+ """Serialize a user to a string for storage in a cookie
+
+ If overriding in a subclass, make sure to define user_from_cookie as well.
+
+ Default is just the user's username.
+ """
+ # default: username is enough
+ cookie = json.dumps(
+ {
+ "username": user.username,
+ "name": user.name,
+ "display_name": user.display_name,
+ "initials": user.initials,
+ "color": user.color,
+ }
+ )
+ return cookie
+
+ def user_from_cookie(self, cookie_value: str) -> User | None:
+ """Inverse of user_to_cookie"""
+ user = json.loads(cookie_value)
+ return User(
+ user["username"],
+ user["name"],
+ user["display_name"],
+ user["initials"],
+ None,
+ user["color"],
+ )
+
+ def get_cookie_name(self, handler: web.RequestHandler) -> str:
+ """Return the login cookie name
+
+ Uses IdentityProvider.cookie_name, if defined.
+ Default is to generate a string taking host into account to avoid
+ collisions for multiple servers on one hostname with different ports.
+ """
+ if self.cookie_name:
+ return self.cookie_name
+ else:
+ return _non_alphanum.sub("-", f"username-{handler.request.host}")
+
+ def set_login_cookie(self, handler: web.RequestHandler, user: User) -> None:
+ """Call this on handlers to set the login cookie for success"""
+ cookie_options = {}
+ cookie_options.update(self.cookie_options)
+ cookie_options.setdefault("httponly", True)
+ # tornado <4.2 has a bug that considers secure==True as soon as
+ # 'secure' kwarg is passed to set_secure_cookie
+ secure_cookie = self.secure_cookie
+ if secure_cookie is None:
+ secure_cookie = handler.request.protocol == "https"
+ if secure_cookie:
+ cookie_options.setdefault("secure", True)
+ cookie_options.setdefault("path", handler.base_url) # type:ignore[attr-defined]
+ cookie_name = self.get_cookie_name(handler)
+ handler.set_secure_cookie(cookie_name, self.user_to_cookie(user), **cookie_options)
+
+ def _force_clear_cookie(
+ self, handler: web.RequestHandler, name: str, path: str = "/", domain: str | None = None
+ ) -> None:
+ """Deletes the cookie with the given name.
+
+ Tornado's cookie handling currently (Jan 2018) stores cookies in a dict
+ keyed by name, so it can only modify one cookie with a given name per
+ response. The browser can store multiple cookies with the same name
+ but different domains and/or paths. This method lets us clear multiple
+ cookies with the same name.
+
+ Due to limitations of the cookie protocol, you must pass the same
+ path and domain to clear a cookie as were used when that cookie
+ was set (but there is no way to find out on the server side
+ which values were used for a given cookie).
+ """
+ name = escape.native_str(name)
+ expires = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=365)
+
+ morsel: Morsel[t.Any] = Morsel()
+ morsel.set(name, "", '""')
+ morsel["expires"] = httputil.format_timestamp(expires)
+ morsel["path"] = path
+ if domain:
+ morsel["domain"] = domain
+ handler.add_header("Set-Cookie", morsel.OutputString())
+
+ def clear_login_cookie(self, handler: web.RequestHandler) -> None:
+ """Clear the login cookie, effectively logging out the session."""
+ cookie_options = {}
+ cookie_options.update(self.cookie_options)
+ path = cookie_options.setdefault("path", handler.base_url) # type:ignore[attr-defined]
+ cookie_name = self.get_cookie_name(handler)
+ handler.clear_cookie(cookie_name, path=path)
+ if path and path != "/":
+ # also clear cookie on / to ensure old cookies are cleared
+ # after the change in path behavior.
+ # N.B. This bypasses the normal cookie handling, which can't update
+ # two cookies with the same name. See the method above.
+ self._force_clear_cookie(handler, cookie_name)
+
+ def get_user_cookie(
+ self, handler: web.RequestHandler
+ ) -> User | None | t.Awaitable[User | None]:
+ """Get user from a cookie
+
+ Calls user_from_cookie to deserialize cookie value
+ """
+ _user_cookie = handler.get_secure_cookie(
+ self.get_cookie_name(handler),
+ **self.get_secure_cookie_kwargs,
+ )
+ if not _user_cookie:
+ return None
+ user_cookie = _user_cookie.decode()
+ # TODO: try/catch in case of change in config?
+ try:
+ return self.user_from_cookie(user_cookie)
+ except Exception as e:
+ # log bad cookie itself, only at debug-level
+ self.log.debug(f"Error unpacking user from cookie: cookie={user_cookie}", exc_info=True)
+ self.log.error(f"Error unpacking user from cookie: {e}")
+ return None
+
+ auth_header_pat = re.compile(r"(token|bearer)\s+(.+)", re.IGNORECASE)
+
+ def get_token(self, handler: web.RequestHandler) -> str | None:
+ """Get the user token from a request
+
+ Default:
+
+ - in URL parameters: ?token=
+ - in header: Authorization: token
+ """
+ user_token = handler.get_argument("token", "")
+ if not user_token:
+ # get it from Authorization header
+ m = self.auth_header_pat.match(handler.request.headers.get("Authorization", ""))
+ if m:
+ user_token = m.group(2)
+ return user_token
+
+ async def get_user_token(self, handler: web.RequestHandler) -> User | None:
+ """Identify the user based on a token in the URL or Authorization header
+
+ Returns:
+ - uuid if authenticated
+ - None if not
+ """
+ token = t.cast("str | None", handler.token) # type:ignore[attr-defined]
+ if not token:
+ return None
+ # check login token from URL argument or Authorization header
+ user_token = self.get_token(handler)
+ authenticated = False
+ if user_token == token:
+ # token-authenticated, set the login cookie
+ self.log.debug(
+ "Accepting token-authenticated request from %s",
+ handler.request.remote_ip,
+ )
+ authenticated = True
+
+ if authenticated:
+ # token does not correspond to user-id,
+ # which is stored in a cookie.
+ # still check the cookie for the user id
+ _user = self.get_user_cookie(handler)
+ if isinstance(_user, t.Awaitable):
+ _user = await _user
+ user: User | None = _user
+ if user is None:
+ user = self.generate_anonymous_user(handler)
+ return user
+ else:
+ return None
+
+ def generate_anonymous_user(self, handler: web.RequestHandler) -> User:
+ """Generate a random anonymous user.
+
+ For use when a single shared token is used,
+ but does not identify a user.
+ """
+ user_id = uuid.uuid4().hex
+ moon = get_anonymous_username()
+ name = display_name = f"Anonymous {moon}"
+ initials = f"A{moon[0]}"
+ color = None
+ handler.log.debug(f"Generating new user for token-authenticated request: {user_id}") # type:ignore[attr-defined]
+ return User(user_id, name, display_name, initials, None, color)
+
+ def should_check_origin(self, handler: web.RequestHandler) -> bool:
+ """Should the Handler check for CORS origin validation?
+
+ Origin check should be skipped for token-authenticated requests.
+
+ Returns:
+ - True, if Handler must check for valid CORS origin.
+ - False, if Handler should skip origin check since requests are token-authenticated.
+ """
+ return not self.is_token_authenticated(handler)
+
+ def is_token_authenticated(self, handler: web.RequestHandler) -> bool:
+ """Returns True if handler has been token authenticated. Otherwise, False.
+
+ Login with a token is used to signal certain things, such as:
+
+ - permit access to REST API
+ - xsrf protection
+ - skip origin-checks for scripts
+ """
+ # ensure get_user has been called, so we know if we're token-authenticated
+ handler.current_user # noqa: B018
+ return getattr(handler, "_token_authenticated", False)
+
+ def validate_security(
+ self,
+ app: t.Any,
+ ssl_options: dict[str, t.Any] | None = None,
+ ) -> None:
+ """Check the application's security.
+
+ Show messages, or abort if necessary, based on the security configuration.
+ """
+ if not app.ip:
+ warning = "WARNING: The Jupyter server is listening on all IP addresses"
+ if ssl_options is None:
+ app.log.warning(f"{warning} and not using encryption. This is not recommended.")
+ if not self.auth_enabled:
+ app.log.warning(
+ f"{warning} and not using authentication. "
+ "This is highly insecure and not recommended."
+ )
+ elif not self.auth_enabled:
+ app.log.warning(
+ "All authentication is disabled."
+ " Anyone who can connect to this server will be able to run code."
+ )
+
+ def process_login_form(self, handler: web.RequestHandler) -> User | None:
+ """Process login form data
+
+ Return authenticated User if successful, None if not.
+ """
+ typed_password = handler.get_argument("password", default="")
+ user = None
+ if not self.auth_enabled:
+ self.log.warning("Accepting anonymous login because auth fully disabled!")
+ return self.generate_anonymous_user(handler)
+
+ if self.token and self.token == typed_password:
+ return t.cast(User, self.user_for_token(typed_password)) # type:ignore[attr-defined]
+
+ return user
+
+ @property
+ def auth_enabled(self):
+ """Is authentication enabled?
+
+ Should always be True, but may be False in rare, insecure cases
+ where requests with no auth are allowed.
+
+ Previously: LoginHandler.get_login_available
+ """
+ return True
+
+ @property
+ def login_available(self):
+ """Whether a LoginHandler is needed - and therefore whether the login page should be displayed."""
+ return self.auth_enabled
+
+ @property
+ def logout_available(self):
+ """Whether a LogoutHandler is needed."""
+ return True
+
+
+class PasswordIdentityProvider(IdentityProvider):
+ """A password identity provider."""
+
+ hashed_password = Unicode(
+ "",
+ config=True,
+ help=_i18n(
+ """
+ Hashed password to use for web authentication.
+
+ To generate, type in a python/IPython shell:
+
+ from jupyter_server.auth import passwd; passwd()
+
+ The string should be of the form type:salt:hashed-password.
+ """
+ ),
+ )
+
+ password_required = Bool(
+ False,
+ config=True,
+ help=_i18n(
+ """
+ Forces users to use a password for the Jupyter server.
+ This is useful in a multi user environment, for instance when
+ everybody in the LAN can access each other's machine through ssh.
+
+ In such a case, serving on localhost is not secure since
+ any user can connect to the Jupyter server via ssh.
+
+ """
+ ),
+ )
+
+ allow_password_change = Bool(
+ True,
+ config=True,
+ help=_i18n(
+ """
+ Allow password to be changed at login for the Jupyter server.
+
+ While logging in with a token, the Jupyter server UI will give the opportunity to
+ the user to enter a new password at the same time that will replace
+ the token login mechanism.
+
+ This can be set to False to prevent changing password from the UI/API.
+ """
+ ),
+ )
+
+ @default("need_token")
+ def _need_token_default(self):
+ return not bool(self.hashed_password)
+
+ @property
+ def login_available(self) -> bool:
+ """Whether a LoginHandler is needed - and therefore whether the login page should be displayed."""
+ return self.auth_enabled
+
+ @property
+ def auth_enabled(self) -> bool:
+ """Return whether any auth is enabled"""
+ return bool(self.hashed_password or self.token)
+
+ def passwd_check(self, password):
+ """Check password against our stored hashed password"""
+ return passwd_check(self.hashed_password, password)
+
+ def process_login_form(self, handler: web.RequestHandler) -> User | None:
+ """Process login form data
+
+ Return authenticated User if successful, None if not.
+ """
+ typed_password = handler.get_argument("password", default="")
+ new_password = handler.get_argument("new_password", default="")
+ user = None
+ if not self.auth_enabled:
+ self.log.warning("Accepting anonymous login because auth fully disabled!")
+ return self.generate_anonymous_user(handler)
+
+ if self.passwd_check(typed_password) and not new_password:
+ return self.generate_anonymous_user(handler)
+ elif self.token and self.token == typed_password:
+ user = self.generate_anonymous_user(handler)
+ if new_password and self.allow_password_change:
+ config_dir = handler.settings.get("config_dir", "")
+ config_file = os.path.join(config_dir, "jupyter_server_config.json")
+ self.hashed_password = set_password(new_password, config_file=config_file)
+ self.log.info(_i18n(f"Wrote hashed password to {config_file}"))
+
+ return user
+
+ def validate_security(
+ self,
+ app: t.Any,
+ ssl_options: dict[str, t.Any] | None = None,
+ ) -> None:
+ """Handle security validation."""
+ super().validate_security(app, ssl_options)
+ if self.password_required and (not self.hashed_password):
+ self.log.critical(
+ _i18n("Jupyter servers are configured to only be run with a password.")
+ )
+ self.log.critical(_i18n("Hint: run the following command to set a password"))
+ self.log.critical(_i18n("\t$ python -m jupyter_server.auth password"))
+ sys.exit(1)
+
+
+class LegacyIdentityProvider(PasswordIdentityProvider):
+ """Legacy IdentityProvider for use with custom LoginHandlers
+
+ Login configuration has moved from LoginHandler to IdentityProvider
+ in Jupyter Server 2.0.
+ """
+
+ # settings must be passed for
+ settings = Dict()
+
+ @default("settings")
+ def _default_settings(self):
+ return {
+ "token": self.token,
+ "password": self.hashed_password,
+ }
+
+ @default("login_handler_class")
+ def _default_login_handler_class(self):
+ from .login import LegacyLoginHandler
+
+ return LegacyLoginHandler
+
+ @property
+ def auth_enabled(self):
+ return self.login_available
+
+ def get_user(self, handler: web.RequestHandler) -> User | None:
+ """Get the user."""
+ user = self.login_handler_class.get_user(handler) # type:ignore[attr-defined]
+ if user is None:
+ return None
+ return _backward_compat_user(user)
+
+ @property
+ def login_available(self) -> bool:
+ return bool(
+ self.login_handler_class.get_login_available( # type:ignore[attr-defined]
+ self.settings
+ )
+ )
+
+ def should_check_origin(self, handler: web.RequestHandler) -> bool:
+ """Whether we should check origin."""
+ return bool(self.login_handler_class.should_check_origin(handler)) # type:ignore[attr-defined]
+
+ def is_token_authenticated(self, handler: web.RequestHandler) -> bool:
+ """Whether we are token authenticated."""
+ return bool(self.login_handler_class.is_token_authenticated(handler)) # type:ignore[attr-defined]
+
+ def validate_security(
+ self,
+ app: t.Any,
+ ssl_options: dict[str, t.Any] | None = None,
+ ) -> None:
+ """Validate security."""
+ if self.password_required and (not self.hashed_password):
+ self.log.critical(
+ _i18n("Jupyter servers are configured to only be run with a password.")
+ )
+ self.log.critical(_i18n("Hint: run the following command to set a password"))
+ self.log.critical(_i18n("\t$ python -m jupyter_server.auth password"))
+ sys.exit(1)
+ self.login_handler_class.validate_security( # type:ignore[attr-defined]
+ app, ssl_options
+ )
diff --git a/jupyter_server/auth/login.py b/jupyter_server/auth/login.py
index 382077d9e0..22832df341 100644
--- a/jupyter_server/auth/login.py
+++ b/jupyter_server/auth/login.py
@@ -12,13 +12,14 @@
from .security import passwd_check, set_password
-class LoginHandler(JupyterHandler):
+class LoginFormHandler(JupyterHandler):
"""The basic tornado login handler
- authenticates with a hashed password from the configuration.
+ accepts login form, passed to IdentityProvider.process_login_form.
"""
def _render(self, message=None):
+ """Render the login form."""
self.write(
self.render_template(
"login.html",
@@ -40,12 +41,25 @@ def _redirect_safe(self, url, default=None):
# \ is not valid in urls, but some browsers treat it as /
# instead of %5C, causing `\\` to behave as `//`
url = url.replace("\\", "%5C")
+ # urllib and browsers interpret extra '/' in the scheme separator (`scheme:///host/path`)
+ # differently.
+ # urllib gives scheme=scheme, netloc='', path='/host/path', while
+ # browsers get scheme=scheme, netloc='host', path='/path'
+ # so make sure ':///*' collapses to '://' by splitting and stripping any additional leading slash
+ # don't allow any kind of `:/` shenanigans by splitting on ':' only
+ # and replacing `:/*` with exactly `://`
+ if ":" in url:
+ scheme, _, rest = url.partition(":")
+ url = f"{scheme}://{rest.lstrip('/')}"
parsed = urlparse(url)
- if parsed.netloc or not (parsed.path + "/").startswith(self.base_url):
+ # full url may be `//host/path` (empty scheme == same scheme as request)
+ # or `https://host/path`
+ # or even `https:///host/path` (invalid, but accepted and ambiguously interpreted)
+ if (parsed.scheme or parsed.netloc) or not (parsed.path + "/").startswith(self.base_url):
# require that next_url be absolute path within our path
allow = False
# OR pass our cross-origin check
- if parsed.netloc:
+ if parsed.scheme or parsed.netloc:
# if full URL, run our cross-origin check:
origin = f"{parsed.scheme}://{parsed.netloc}"
origin = origin.lower()
@@ -60,20 +74,44 @@ def _redirect_safe(self, url, default=None):
self.redirect(url)
def get(self):
+ """Get the login form."""
if self.current_user:
next_url = self.get_argument("next", default=self.base_url)
self._redirect_safe(next_url)
else:
self._render()
+ def post(self):
+ """Post a login."""
+ user = self.current_user = self.identity_provider.process_login_form(self)
+ if user is None:
+ self.set_status(401)
+ self._render(message={"error": "Invalid credentials"})
+ return
+
+ self.log.info(f"User {user.username} logged in.")
+ self.identity_provider.set_login_cookie(self, user)
+ next_url = self.get_argument("next", default=self.base_url)
+ self._redirect_safe(next_url)
+
+
+class LegacyLoginHandler(LoginFormHandler):
+ """Legacy LoginHandler, implementing most custom auth configuration.
+
+ Deprecated in jupyter-server 2.0.
+ Login configuration has moved to IdentityProvider.
+ """
+
@property
def hashed_password(self):
return self.password_from_settings(self.settings)
def passwd_check(self, a, b):
+ """Check a passwd."""
return passwd_check(a, b)
def post(self):
+ """Post a login form."""
typed_password = self.get_argument("password", default="")
new_password = self.get_argument("new_password", default="")
@@ -82,10 +120,13 @@ def post(self):
self.set_login_cookie(self, uuid.uuid4().hex)
elif self.token and self.token == typed_password:
self.set_login_cookie(self, uuid.uuid4().hex)
- if new_password and self.settings.get("allow_password_change"):
- config_dir = self.settings.get("config_dir")
+ if new_password and getattr(self.identity_provider, "allow_password_change", False):
+ config_dir = self.settings.get("config_dir", "")
config_file = os.path.join(config_dir, "jupyter_server_config.json")
- set_password(new_password, config_file=config_file)
+ if hasattr(self.identity_provider, "hashed_password"):
+ self.identity_provider.hashed_password = self.settings[
+ "password"
+ ] = set_password(new_password, config_file=config_file)
self.log.info("Wrote hashed password to %s" % config_file)
else:
self.set_status(401)
@@ -130,52 +171,38 @@ def get_token(cls, handler):
@classmethod
def should_check_origin(cls, handler):
- """Should the Handler check for CORS origin validation?
-
- Origin check should be skipped for token-authenticated requests.
-
- Returns:
- - True, if Handler must check for valid CORS origin.
- - False, if Handler should skip origin check since requests are token-authenticated.
- """
+ """DEPRECATED in 2.0, use IdentityProvider API"""
return not cls.is_token_authenticated(handler)
@classmethod
def is_token_authenticated(cls, handler):
- """Returns True if handler has been token authenticated. Otherwise, False.
-
- Login with a token is used to signal certain things, such as:
-
- - permit access to REST API
- - xsrf protection
- - skip origin-checks for scripts
- """
+ """DEPRECATED in 2.0, use IdentityProvider API"""
if getattr(handler, "_user_id", None) is None:
# ensure get_user has been called, so we know if we're token-authenticated
- handler.get_current_user()
+ handler.current_user # noqa: B018
return getattr(handler, "_token_authenticated", False)
@classmethod
def get_user(cls, handler):
- """Called by handlers.get_current_user for identifying the current user.
-
- See tornado.web.RequestHandler.get_current_user for details.
- """
+ """DEPRECATED in 2.0, use IdentityProvider API"""
# Can't call this get_current_user because it will collide when
# called on LoginHandler itself.
if getattr(handler, "_user_id", None):
return handler._user_id
- user_id = cls.get_user_token(handler)
- if user_id is None:
- get_secure_cookie_kwargs = handler.settings.get("get_secure_cookie_kwargs", {})
- user_id = handler.get_secure_cookie(handler.cookie_name, **get_secure_cookie_kwargs)
- if user_id:
- user_id = user_id.decode()
- else:
- cls.set_login_cookie(handler, user_id)
+ token_user_id = cls.get_user_token(handler)
+ cookie_user_id = cls.get_user_cookie(handler)
+ # prefer token to cookie if both given,
+ # because token is always explicit
+ user_id = token_user_id or cookie_user_id
+ if token_user_id:
+ # if token-authenticated, persist user_id in cookie
+ # if it hasn't already been stored there
+ if user_id != cookie_user_id:
+ cls.set_login_cookie(handler, user_id)
# Record that the current request has been authenticated with a token.
# Used in is_token_authenticated above.
handler._token_authenticated = True
+
if user_id is None:
# If an invalid cookie was sent, clear it to prevent unnecessary
# extra warnings. But don't do this on a request with *no* cookie,
@@ -193,16 +220,20 @@ def get_user(cls, handler):
return user_id
@classmethod
- def get_user_token(cls, handler):
- """Identify the user based on a token in the URL or Authorization header
+ def get_user_cookie(cls, handler):
+ """DEPRECATED in 2.0, use IdentityProvider API"""
+ get_secure_cookie_kwargs = handler.settings.get("get_secure_cookie_kwargs", {})
+ user_id = handler.get_secure_cookie(handler.cookie_name, **get_secure_cookie_kwargs)
+ if user_id:
+ user_id = user_id.decode()
+ return user_id
- Returns:
- - uuid if authenticated
- - None if not
- """
+ @classmethod
+ def get_user_token(cls, handler):
+ """DEPRECATED in 2.0, use IdentityProvider API"""
token = handler.token
if not token:
- return
+ return None
# check login token from URL argument or Authorization header
user_token = cls.get_token(handler)
authenticated = False
@@ -215,16 +246,23 @@ def get_user_token(cls, handler):
authenticated = True
if authenticated:
- return uuid.uuid4().hex
+ # token does not correspond to user-id,
+ # which is stored in a cookie.
+ # still check the cookie for the user id
+ user_id = cls.get_user_cookie(handler)
+ if user_id is None:
+ # no cookie, generate new random user_id
+ user_id = uuid.uuid4().hex
+ handler.log.info(
+ f"Generating new user_id for token-authenticated request: {user_id}"
+ )
+ return user_id
else:
return None
@classmethod
def validate_security(cls, app, ssl_options=None):
- """Check the application's security.
-
- Show messages, or abort if necessary, based on the security configuration.
- """
+ """DEPRECATED in 2.0, use IdentityProvider API"""
if not app.ip:
warning = "WARNING: The Jupyter server is listening on all IP addresses"
if ssl_options is None:
@@ -234,22 +272,23 @@ def validate_security(cls, app, ssl_options=None):
f"{warning} and not using authentication. "
"This is highly insecure and not recommended."
)
- else:
- if not app.password and not app.token:
- app.log.warning(
- "All authentication is disabled."
- " Anyone who can connect to this server will be able to run code."
- )
+ elif not app.password and not app.token:
+ app.log.warning(
+ "All authentication is disabled."
+ " Anyone who can connect to this server will be able to run code."
+ )
@classmethod
def password_from_settings(cls, settings):
- """Return the hashed password from the tornado settings.
-
- If there is no configured password, an empty string will be returned.
- """
+ """DEPRECATED in 2.0, use IdentityProvider API"""
return settings.get("password", "")
@classmethod
def get_login_available(cls, settings):
- """Whether this LoginHandler is needed - and therefore whether the login page should be displayed."""
+ """DEPRECATED in 2.0, use IdentityProvider API"""
+
return bool(cls.password_from_settings(settings) or settings.get("token"))
+
+
+# deprecated import, so deprecated implementations get the Legacy class instead
+LoginHandler = LegacyLoginHandler
diff --git a/jupyter_server/auth/logout.py b/jupyter_server/auth/logout.py
index abe23425c9..3db7f796ba 100644
--- a/jupyter_server/auth/logout.py
+++ b/jupyter_server/auth/logout.py
@@ -6,8 +6,11 @@
class LogoutHandler(JupyterHandler):
+ """An auth logout handler."""
+
def get(self):
- self.clear_login_cookie()
+ """Handle a logout."""
+ self.identity_provider.clear_login_cookie(self)
if self.login_available:
message = {"info": "Successfully logged out."}
else:
diff --git a/jupyter_server/auth/security.py b/jupyter_server/auth/security.py
index fa7dded7fb..a5ae185f1e 100644
--- a/jupyter_server/auth/security.py
+++ b/jupyter_server/auth/security.py
@@ -11,7 +11,8 @@
from contextlib import contextmanager
from jupyter_core.paths import jupyter_config_dir
-from traitlets.config import Config, ConfigFileNotFound, JSONFileConfigLoader
+from traitlets.config import Config
+from traitlets.config.loader import ConfigFileNotFound, JSONFileConfigLoader
# Length of the salt in nr of hex chars, which implies salt_len * 4
# bits of randomness.
@@ -51,10 +52,10 @@ def passwd(passphrase=None, algorithm="argon2"):
if p0 == p1:
passphrase = p0
break
- else:
- print("Passwords do not match.")
+ warnings.warn("Passwords do not match.", stacklevel=2)
else:
- raise ValueError("No matching passwords found. Giving up.")
+ msg = "No matching passwords found. Giving up."
+ raise ValueError(msg)
if algorithm == "argon2":
import argon2
@@ -64,9 +65,9 @@ def passwd(passphrase=None, algorithm="argon2"):
time_cost=10,
parallelism=8,
)
- h = ph.hash(passphrase)
+ h_ph = ph.hash(passphrase)
- return ":".join((algorithm, h))
+ return ":".join((algorithm, h_ph))
h = hashlib.new(algorithm)
salt = ("%0" + str(salt_len) + "x") % random.getrandbits(4 * salt_len)
@@ -160,7 +161,9 @@ def persist_config(config_file=None, mode=0o600):
os.chmod(config_file, mode)
except Exception:
tb = traceback.format_exc()
- warnings.warn(f"Failed to set permissions on {config_file}:\n{tb}", RuntimeWarning)
+ warnings.warn(
+ f"Failed to set permissions on {config_file}:\n{tb}", RuntimeWarning, stacklevel=2
+ )
def set_password(password=None, config_file=None):
@@ -169,4 +172,5 @@ def set_password(password=None, config_file=None):
hashed_password = passwd(password)
with persist_config(config_file) as config:
- config.ServerApp.password = hashed_password
+ config.IdentityProvider.hashed_password = hashed_password
+ return hashed_password
diff --git a/jupyter_server/auth/utils.py b/jupyter_server/auth/utils.py
index b939b87ae0..b0f790be1f 100644
--- a/jupyter_server/auth/utils.py
+++ b/jupyter_server/auth/utils.py
@@ -3,21 +3,17 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import importlib
+import random
import re
import warnings
def warn_disabled_authorization():
+ """DEPRECATED, does nothing"""
warnings.warn(
- "The Tornado web application does not have an 'authorizer' defined "
- "in its settings. In future releases of jupyter_server, this will "
- "be a required key for all subclasses of `JupyterHandler`. For an "
- "example, see the jupyter_server source code for how to "
- "add an authorizer to the tornado settings: "
- "https://github.com/jupyter-server/jupyter_server/blob/"
- "653740cbad7ce0c8a8752ce83e4d3c2c754b13cb/jupyter_server/serverapp.py"
- "#L234-L256",
- FutureWarning,
+ "jupyter_server.auth.utils.warn_disabled_authorization is deprecated",
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -44,9 +40,9 @@ def get_regex_to_resource_map():
from jupyter_server.serverapp import JUPYTER_SERVICE_HANDLERS
modules = []
- for mod in JUPYTER_SERVICE_HANDLERS.values():
- if mod:
- modules.extend(mod)
+ for mod_name in JUPYTER_SERVICE_HANDLERS.values():
+ if mod_name:
+ modules.extend(mod_name)
resource_map = {}
for handler_module in modules:
mod = importlib.import_module(handler_module)
@@ -79,3 +75,95 @@ def match_url_to_resource(url, regex_mapping=None):
pattern = re.compile(regex)
if pattern.fullmatch(url):
return auth_resource
+
+
+# From https://en.wikipedia.org/wiki/Moons_of_Jupiter
+moons_of_jupyter = [
+ "Metis",
+ "Adrastea",
+ "Amalthea",
+ "Thebe",
+ "Io",
+ "Europa",
+ "Ganymede",
+ "Callisto",
+ "Themisto",
+ "Leda",
+ "Ersa",
+ "Pandia",
+ "Himalia",
+ "Lysithea",
+ "Elara",
+ "Dia",
+ "Carpo",
+ "Valetudo",
+ "Euporie",
+ "Eupheme",
+ # 'S/2003 J 18',
+ # 'S/2010 J 2',
+ "Helike",
+ # 'S/2003 J 16',
+ # 'S/2003 J 2',
+ "Euanthe",
+ # 'S/2017 J 7',
+ "Hermippe",
+ "Praxidike",
+ "Thyone",
+ "Thelxinoe",
+ # 'S/2017 J 3',
+ "Ananke",
+ "Mneme",
+ # 'S/2016 J 1',
+ "Orthosie",
+ "Harpalyke",
+ "Iocaste",
+ # 'S/2017 J 9',
+ # 'S/2003 J 12',
+ # 'S/2003 J 4',
+ "Erinome",
+ "Aitne",
+ "Herse",
+ "Taygete",
+ # 'S/2017 J 2',
+ # 'S/2017 J 6',
+ "Eukelade",
+ "Carme",
+ # 'S/2003 J 19',
+ "Isonoe",
+ # 'S/2003 J 10',
+ "Autonoe",
+ "Philophrosyne",
+ "Cyllene",
+ "Pasithee",
+ # 'S/2010 J 1',
+ "Pasiphae",
+ "Sponde",
+ # 'S/2017 J 8',
+ "Eurydome",
+ # 'S/2017 J 5',
+ "Kalyke",
+ "Hegemone",
+ "Kale",
+ "Kallichore",
+ # 'S/2011 J 1',
+ # 'S/2017 J 1',
+ "Chaldene",
+ "Arche",
+ "Eirene",
+ "Kore",
+ # 'S/2011 J 2',
+ # 'S/2003 J 9',
+ "Megaclite",
+ "Aoede",
+ # 'S/2003 J 23',
+ "Callirrhoe",
+ "Sinope",
+]
+
+
+def get_anonymous_username() -> str:
+ """
+ Get a random user-name based on the moons of Jupyter.
+ This function returns names like "Anonymous Io" or "Anonymous Metis".
+ """
+ return moons_of_jupyter[random.randint(0, len(moons_of_jupyter) - 1)]
diff --git a/jupyter_server/base/call_context.py b/jupyter_server/base/call_context.py
new file mode 100644
index 0000000000..3d989121c2
--- /dev/null
+++ b/jupyter_server/base/call_context.py
@@ -0,0 +1,88 @@
+"""Provides access to variables pertaining to specific call contexts."""
+# Copyright (c) Jupyter Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+from contextvars import Context, ContextVar, copy_context
+from typing import Any, Dict, List
+
+
+class CallContext:
+ """CallContext essentially acts as a namespace for managing context variables.
+
+ Although not required, it is recommended that any "file-spanning" context variable
+ names (i.e., variables that will be set or retrieved from multiple files or services) be
+ added as constants to this class definition.
+ """
+
+ # Add well-known (file-spanning) names here.
+ #: Provides access to the current request handler once set.
+ JUPYTER_HANDLER: str = "JUPYTER_HANDLER"
+
+ # A map of variable name to value is maintained as the single ContextVar. This also enables
+ # easier management over maintaining a set of ContextVar instances, since the Context is a
+ # map of ContextVar instances to their values, and the "name" is no longer a lookup key.
+ _NAME_VALUE_MAP = "_name_value_map"
+ _name_value_map: ContextVar[Dict[str, Any]] = ContextVar(_NAME_VALUE_MAP)
+
+ @classmethod
+ def get(cls, name: str) -> Any:
+ """Returns the value corresponding the named variable relative to this context.
+
+ If the named variable doesn't exist, None will be returned.
+
+ Parameters
+ ----------
+ name : str
+ The name of the variable to get from the call context
+
+ Returns
+ -------
+ value: Any
+ The value associated with the named variable for this call context
+ """
+ name_value_map = CallContext._get_map()
+
+ if name in name_value_map:
+ return name_value_map[name]
+ return None # TODO - should this raise `LookupError` (or a custom error derived from said)
+
+ @classmethod
+ def set(cls, name: str, value: Any) -> None:
+ """Sets the named variable to the specified value in the current call context.
+
+ Parameters
+ ----------
+ name : str
+ The name of the variable to store into the call context
+ value : Any
+ The value of the variable to store into the call context
+
+ Returns
+ -------
+ None
+ """
+ name_value_map = CallContext._get_map()
+ name_value_map[name] = value
+
+ @classmethod
+ def context_variable_names(cls) -> List[str]:
+ """Returns a list of variable names set for this call context.
+
+ Returns
+ -------
+ names: List[str]
+ A list of variable names set for this call context.
+ """
+ name_value_map = CallContext._get_map()
+ return list(name_value_map.keys())
+
+ @classmethod
+ def _get_map(cls) -> Dict[str, Any]:
+ """Get the map of names to their values from the _NAME_VALUE_MAP context var.
+
+ If the map does not exist in the current context, an empty map is created and returned.
+ """
+ ctx: Context = copy_context()
+ if CallContext._name_value_map not in ctx:
+ CallContext._name_value_map.set({})
+ return CallContext._name_value_map.get()
diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py
index 42f7fb3d5e..b1b783cca9 100644
--- a/jupyter_server/base/handlers.py
+++ b/jupyter_server/base/handlers.py
@@ -1,30 +1,35 @@
"""Base Tornado handlers for the Jupyter server."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
-import datetime
+from __future__ import annotations
+
import functools
+import inspect
import ipaddress
import json
import mimetypes
import os
import re
-import traceback
import types
import warnings
from http.client import responses
-from http.cookies import Morsel
+from logging import Logger
+from typing import TYPE_CHECKING, Any, Awaitable, Sequence, cast
from urllib.parse import urlparse
import prometheus_client
from jinja2 import TemplateNotFound
from jupyter_core.paths import is_hidden
-from tornado import escape, httputil, web
+from jupyter_events import EventLogger
+from tornado import web
from tornado.log import app_log
from traitlets.config import Application
import jupyter_server
+from jupyter_server import CallContext
from jupyter_server._sysinfo import get_sys_info
from jupyter_server._tz import utcnow
+from jupyter_server.auth.decorator import authorized
from jupyter_server.i18n import combine_translations
from jupyter_server.services.security import csp_report_uri
from jupyter_server.utils import (
@@ -36,24 +41,38 @@
urldecode_unix_socket_path,
)
+if TYPE_CHECKING:
+ from jupyter_client.kernelspec import KernelSpecManager
+ from jupyter_server_terminals.terminalmanager import TerminalManager
+ from tornado.concurrent import Future
+
+ from jupyter_server.auth.authorizer import Authorizer
+ from jupyter_server.auth.identity import IdentityProvider, User
+ from jupyter_server.serverapp import ServerApp
+ from jupyter_server.services.config.manager import ConfigManager
+ from jupyter_server.services.contents.manager import ContentsManager
+ from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager
+ from jupyter_server.services.sessions.sessionmanager import SessionManager
+
# -----------------------------------------------------------------------------
# Top-level handlers
# -----------------------------------------------------------------------------
-non_alphanum = re.compile(r"[^A-Za-z0-9]")
_sys_info_cache = None
def json_sys_info():
- global _sys_info_cache
+ """Get sys info as json."""
+ global _sys_info_cache # noqa: PLW0603
if _sys_info_cache is None:
_sys_info_cache = json.dumps(get_sys_info())
return _sys_info_cache
-def log():
+def log() -> Logger:
+ """Get the application log."""
if Application.initialized():
- return Application.instance().log
+ return cast(Logger, Application.instance().log)
else:
return app_log
@@ -62,14 +81,18 @@ class AuthenticatedHandler(web.RequestHandler):
"""A RequestHandler with an authenticated user."""
@property
- def content_security_policy(self):
+ def base_url(self) -> str:
+ return cast(str, self.settings.get("base_url", "/"))
+
+ @property
+ def content_security_policy(self) -> str:
"""The default Content-Security-Policy header
Can be overridden by defining Content-Security-Policy in settings['headers']
"""
if "Content-Security-Policy" in self.settings.get("headers", {}):
# user-specified, don't override
- return self.settings["headers"]["Content-Security-Policy"]
+ return cast(str, self.settings["headers"]["Content-Security-Policy"])
return "; ".join(
[
@@ -80,7 +103,8 @@ def content_security_policy(self):
]
)
- def set_default_headers(self):
+ def set_default_headers(self) -> None:
+ """Set the default headers."""
headers = {}
headers["X-Content-Type-Options"] = "nosniff"
headers.update(self.settings.get("headers", {}))
@@ -95,50 +119,62 @@ def set_default_headers(self):
# tornado raise Exception (not a subclass)
# if method is unsupported (websocket and Access-Control-Allow-Origin
# for example, so just ignore)
- self.log.debug(e)
-
- def force_clear_cookie(self, name, path="/", domain=None):
- """Deletes the cookie with the given name.
+ self.log.exception( # type:ignore[attr-defined]
+ "Could not set default headers: %s", e
+ )
- Tornado's cookie handling currently (Jan 2018) stores cookies in a dict
- keyed by name, so it can only modify one cookie with a given name per
- response. The browser can store multiple cookies with the same name
- but different domains and/or paths. This method lets us clear multiple
- cookies with the same name.
+ @property
+ def cookie_name(self) -> str:
+ warnings.warn(
+ """JupyterHandler.login_handler is deprecated in 2.0,
+ use JupyterHandler.identity_provider.
+ """,
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.identity_provider.get_cookie_name(self)
+
+ def force_clear_cookie(self, name: str, path: str = "/", domain: str | None = None) -> None:
+ """Force a cookie clear."""
+ warnings.warn(
+ """JupyterHandler.login_handler is deprecated in 2.0,
+ use JupyterHandler.identity_provider.
+ """,
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ self.identity_provider._force_clear_cookie(self, name, path=path, domain=domain)
+
+ def clear_login_cookie(self) -> None:
+ """Clear a login cookie."""
+ warnings.warn(
+ """JupyterHandler.login_handler is deprecated in 2.0,
+ use JupyterHandler.identity_provider.
+ """,
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ self.identity_provider.clear_login_cookie(self)
+
+ def get_current_user(self) -> str:
+ """Get the current user."""
+ clsname = self.__class__.__name__
+ msg = (
+ f"Calling `{clsname}.get_current_user()` directly is deprecated in jupyter-server 2.0."
+ " Use `self.current_user` instead (works in all versions)."
+ )
+ if hasattr(self, "_jupyter_current_user"):
+ # backward-compat: return _jupyter_current_user
+ warnings.warn(
+ msg,
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return cast(str, self._jupyter_current_user)
+ # haven't called get_user in prepare, raise
+ raise RuntimeError(msg)
- Due to limitations of the cookie protocol, you must pass the same
- path and domain to clear a cookie as were used when that cookie
- was set (but there is no way to find out on the server side
- which values were used for a given cookie).
- """
- name = escape.native_str(name)
- expires = datetime.datetime.utcnow() - datetime.timedelta(days=365)
-
- morsel = Morsel()
- morsel.set(name, "", '""')
- morsel["expires"] = httputil.format_timestamp(expires)
- morsel["path"] = path
- if domain:
- morsel["domain"] = domain
- self.add_header("Set-Cookie", morsel.OutputString())
-
- def clear_login_cookie(self):
- cookie_options = self.settings.get("cookie_options", {})
- path = cookie_options.setdefault("path", self.base_url)
- self.clear_cookie(self.cookie_name, path=path)
- if path and path != "/":
- # also clear cookie on / to ensure old cookies are cleared
- # after the change in path behavior.
- # N.B. This bypasses the normal cookie handling, which can't update
- # two cookies with the same name. See the method above.
- self.force_clear_cookie(self.cookie_name)
-
- def get_current_user(self):
- if self.login_handler is None:
- return "anonymous"
- return self.login_handler.get_user(self)
-
- def skip_check_origin(self):
+ def skip_check_origin(self) -> bool:
"""Ask my login_handler if I should skip the origin_check
For example: in the default LoginHandler, if a request is token-authenticated,
@@ -147,53 +183,89 @@ def skip_check_origin(self):
if self.request.method == "OPTIONS":
# no origin-check on options requests, which are used to check origins!
return True
- if self.login_handler is None or not hasattr(self.login_handler, "should_check_origin"):
- return False
- return not self.login_handler.should_check_origin(self)
+ return not self.identity_provider.should_check_origin(self)
@property
- def token_authenticated(self):
+ def token_authenticated(self) -> bool:
"""Have I been authenticated with a token?"""
- if self.login_handler is None or not hasattr(self.login_handler, "is_token_authenticated"):
- return False
- return self.login_handler.is_token_authenticated(self)
+ return self.identity_provider.is_token_authenticated(self)
@property
- def cookie_name(self):
- default_cookie_name = non_alphanum.sub("-", f"username-{self.request.host}")
- return self.settings.get("cookie_name", default_cookie_name)
-
- @property
- def logged_in(self):
+ def logged_in(self) -> bool:
"""Is a user currently logged in?"""
- user = self.get_current_user()
- return user and not user == "anonymous"
+ user = self.current_user
+ return bool(user and user != "anonymous")
@property
- def login_handler(self):
+ def login_handler(self) -> Any:
"""Return the login handler for this application, if any."""
- return self.settings.get("login_handler_class", None)
+ warnings.warn(
+ """JupyterHandler.login_handler is deprecated in 2.0,
+ use JupyterHandler.identity_provider.
+ """,
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.identity_provider.login_handler_class
@property
- def token(self):
+ def token(self) -> str | None:
"""Return the login token for this application, if any."""
- return self.settings.get("token", None)
+ return self.identity_provider.token
@property
- def login_available(self):
+ def login_available(self) -> bool:
"""May a user proceed to log in?
This returns True if login capability is available, irrespective of
whether the user is already logged in or not.
"""
- if self.login_handler is None:
- return False
- return bool(self.login_handler.get_login_available(self.settings))
+ return cast(bool, self.identity_provider.login_available)
@property
- def authorizer(self):
- return self.settings.get("authorizer")
+ def authorizer(self) -> Authorizer:
+ if "authorizer" not in self.settings:
+ warnings.warn(
+ "The Tornado web application does not have an 'authorizer' defined "
+ "in its settings. In future releases of jupyter_server, this will "
+ "be a required key for all subclasses of `JupyterHandler`. For an "
+ "example, see the jupyter_server source code for how to "
+ "add an authorizer to the tornado settings: "
+ "https://github.com/jupyter-server/jupyter_server/blob/"
+ "653740cbad7ce0c8a8752ce83e4d3c2c754b13cb/jupyter_server/serverapp.py"
+ "#L234-L256",
+ stacklevel=2,
+ )
+ from jupyter_server.auth import AllowAllAuthorizer
+
+ self.settings["authorizer"] = AllowAllAuthorizer(
+ config=self.settings.get("config", None),
+ identity_provider=self.identity_provider,
+ )
+
+ return cast("Authorizer", self.settings.get("authorizer"))
+
+ @property
+ def identity_provider(self) -> IdentityProvider:
+ if "identity_provider" not in self.settings:
+ warnings.warn(
+ "The Tornado web application does not have an 'identity_provider' defined "
+ "in its settings. In future releases of jupyter_server, this will "
+ "be a required key for all subclasses of `JupyterHandler`. For an "
+ "example, see the jupyter_server source code for how to "
+ "add an identity provider to the tornado settings: "
+ "https://github.com/jupyter-server/jupyter_server/blob/v2.0.0/"
+ "jupyter_server/serverapp.py#L242",
+ stacklevel=2,
+ )
+ from jupyter_server.auth import IdentityProvider
+
+ # no identity provider set, load default
+ self.settings["identity_provider"] = IdentityProvider(
+ config=self.settings.get("config", None)
+ )
+ return cast("IdentityProvider", self.settings["identity_provider"])
class JupyterHandler(AuthenticatedHandler):
@@ -203,113 +275,120 @@ class JupyterHandler(AuthenticatedHandler):
"""
@property
- def config(self):
- return self.settings.get("config", None)
+ def config(self) -> dict[str, Any] | None:
+ return cast("dict[str, Any] | None", self.settings.get("config", None))
@property
- def log(self):
+ def log(self) -> Logger:
"""use the Jupyter log by default, falling back on tornado's logger"""
return log()
@property
- def jinja_template_vars(self):
+ def jinja_template_vars(self) -> dict[str, Any]:
"""User-supplied values to supply to jinja templates."""
- return self.settings.get("jinja_template_vars", {})
+ return cast("dict[str, Any]", self.settings.get("jinja_template_vars", {}))
@property
- def serverapp(self):
- return self.settings["serverapp"]
+ def serverapp(self) -> ServerApp | None:
+ return cast("ServerApp | None", self.settings["serverapp"])
# ---------------------------------------------------------------
# URLs
# ---------------------------------------------------------------
@property
- def version_hash(self):
+ def version_hash(self) -> str:
"""The version hash to use for cache hints for static files"""
- return self.settings.get("version_hash", "")
+ return cast(str, self.settings.get("version_hash", ""))
@property
- def mathjax_url(self):
- url = self.settings.get("mathjax_url", "")
+ def mathjax_url(self) -> str:
+ url = cast(str, self.settings.get("mathjax_url", ""))
if not url or url_is_absolute(url):
return url
return url_path_join(self.base_url, url)
@property
- def mathjax_config(self):
- return self.settings.get("mathjax_config", "TeX-AMS-MML_HTMLorMML-full,Safe")
-
- @property
- def base_url(self):
- return self.settings.get("base_url", "/")
+ def mathjax_config(self) -> str:
+ return cast(str, self.settings.get("mathjax_config", "TeX-AMS-MML_HTMLorMML-full,Safe"))
@property
- def default_url(self):
- return self.settings.get("default_url", "")
+ def default_url(self) -> str:
+ return cast(str, self.settings.get("default_url", ""))
@property
- def ws_url(self):
- return self.settings.get("websocket_url", "")
+ def ws_url(self) -> str:
+ return cast(str, self.settings.get("websocket_url", ""))
@property
- def contents_js_source(self):
+ def contents_js_source(self) -> str:
self.log.debug(
"Using contents: %s",
self.settings.get("contents_js_source", "services/contents"),
)
- return self.settings.get("contents_js_source", "services/contents")
+ return cast(str, self.settings.get("contents_js_source", "services/contents"))
# ---------------------------------------------------------------
# Manager objects
# ---------------------------------------------------------------
@property
- def kernel_manager(self):
- return self.settings["kernel_manager"]
+ def kernel_manager(self) -> AsyncMappingKernelManager:
+ return cast("AsyncMappingKernelManager", self.settings["kernel_manager"])
@property
- def contents_manager(self):
- return self.settings["contents_manager"]
+ def contents_manager(self) -> ContentsManager:
+ return cast("ContentsManager", self.settings["contents_manager"])
@property
- def session_manager(self):
- return self.settings["session_manager"]
+ def session_manager(self) -> SessionManager:
+ return cast("SessionManager", self.settings["session_manager"])
@property
- def terminal_manager(self):
- return self.settings["terminal_manager"]
+ def terminal_manager(self) -> TerminalManager:
+ return cast("TerminalManager", self.settings["terminal_manager"])
@property
- def kernel_spec_manager(self):
- return self.settings["kernel_spec_manager"]
+ def kernel_spec_manager(self) -> KernelSpecManager:
+ return cast("KernelSpecManager", self.settings["kernel_spec_manager"])
@property
- def config_manager(self):
- return self.settings["config_manager"]
+ def config_manager(self) -> ConfigManager:
+ return cast("ConfigManager", self.settings["config_manager"])
+
+ @property
+ def event_logger(self) -> EventLogger:
+ return cast("EventLogger", self.settings["event_logger"])
# ---------------------------------------------------------------
# CORS
# ---------------------------------------------------------------
@property
- def allow_origin(self):
+ def allow_origin(self) -> str:
"""Normal Access-Control-Allow-Origin"""
- return self.settings.get("allow_origin", "")
+ return cast(str, self.settings.get("allow_origin", ""))
@property
- def allow_origin_pat(self):
+ def allow_origin_pat(self) -> str | None:
"""Regular expression version of allow_origin"""
- return self.settings.get("allow_origin_pat", None)
+ return cast("str | None", self.settings.get("allow_origin_pat", None))
@property
- def allow_credentials(self):
+ def allow_credentials(self) -> bool:
"""Whether to set Access-Control-Allow-Credentials"""
- return self.settings.get("allow_credentials", False)
+ return cast(bool, self.settings.get("allow_credentials", False))
- def set_default_headers(self):
+ def set_default_headers(self) -> None:
"""Add CORS headers, if defined"""
super().set_default_headers()
+
+ def set_cors_headers(self) -> None:
+ """Add CORS headers, if defined
+
+ Now that current_user is async (jupyter-server 2.0),
+ must be called at the end of prepare(), instead of in set_default_headers.
+ """
if self.allow_origin:
self.set_header("Access-Control-Allow-Origin", self.allow_origin)
elif self.allow_origin_pat:
@@ -326,7 +405,7 @@ def set_default_headers(self):
if self.allow_credentials:
self.set_header("Access-Control-Allow-Credentials", "true")
- def set_attachment_header(self, filename):
+ def set_attachment_header(self, filename: str) -> None:
"""Set Content-Disposition: attachment header
As a method to ensure handling of filename encoding
@@ -334,13 +413,10 @@ def set_attachment_header(self, filename):
escaped_filename = url_escape(filename)
self.set_header(
"Content-Disposition",
- "attachment;"
- " filename*=utf-8''{utf8}".format(
- utf8=escaped_filename,
- ),
+ f"attachment; filename*=utf-8''{escaped_filename}",
)
- def get_origin(self):
+ def get_origin(self) -> str | None:
# Handle WebSocket Origin naming convention differences
# The difference between version 8 and 13 is that in 8 the
# client sends a "Sec-Websocket-Origin" header and in 13 it's
@@ -353,7 +429,7 @@ def get_origin(self):
# origin_to_satisfy_tornado is present because tornado requires
# check_origin to take an origin argument, but we don't use it
- def check_origin(self, origin_to_satisfy_tornado=""):
+ def check_origin(self, origin_to_satisfy_tornado: str = "") -> bool:
"""Check Origin for cross-site API requests, including websockets
Copied from WebSocket with changes:
@@ -385,7 +461,7 @@ def check_origin(self, origin_to_satisfy_tornado=""):
# Check CORS headers
if self.allow_origin:
- allow = self.allow_origin == origin
+ allow = bool(self.allow_origin == origin)
elif self.allow_origin_pat:
allow = bool(re.match(self.allow_origin_pat, origin))
else:
@@ -400,7 +476,7 @@ def check_origin(self, origin_to_satisfy_tornado=""):
)
return allow
- def check_referer(self):
+ def check_referer(self) -> bool:
"""Check Referer for cross-site requests.
Disables requests to certain endpoints with
external or missing Referer.
@@ -446,15 +522,18 @@ def check_referer(self):
)
return allow
- def check_xsrf_cookie(self):
+ def check_xsrf_cookie(self) -> None:
"""Bypass xsrf cookie checks when token-authenticated"""
+ if not hasattr(self, "_jupyter_current_user"):
+ # Called too early, will be checked later
+ return None
if self.token_authenticated or self.settings.get("disable_check_xsrf", False):
# Token-authenticated requests do not need additional XSRF-check
# Servers without authentication are vulnerable to XSRF
- return
+ return None
try:
return super().check_xsrf_cookie()
- except web.HTTPError:
+ except web.HTTPError as e:
if self.request.method in {"GET", "HEAD"}:
# Consider Referer a sufficient cross-origin check for GET requests
if not self.check_referer():
@@ -463,11 +542,11 @@ def check_xsrf_cookie(self):
msg = f"Blocking Cross Origin request from {referer}."
else:
msg = "Blocking request from unknown origin"
- raise web.HTTPError(403, msg)
+ raise web.HTTPError(403, msg) from e
else:
raise
- def check_host(self):
+ def check_host(self) -> bool:
"""Check the host header if remote access disallowed.
Returns True if the request should continue, False otherwise.
@@ -476,7 +555,9 @@ def check_host(self):
return True
# Remove port (e.g. ':8888') from host
- host = re.match(r"^(.*?)(:\d+)?$", self.request.host).group(1)
+ match = re.match(r"^(.*?)(:\d+)?$", self.request.host)
+ assert match is not None
+ host = match.group(1)
# Browsers format IPv6 addresses like [::1]; we need to remove the []
if host.startswith("[") and host.endswith("]"):
@@ -507,9 +588,47 @@ def check_host(self):
)
return allow
- def prepare(self):
+ async def prepare(self) -> Awaitable[None] | None: # type:ignore[override]
+ """Prepare a response."""
+ # Set the current Jupyter Handler context variable.
+ CallContext.set(CallContext.JUPYTER_HANDLER, self)
+
if not self.check_host():
+ self.current_user = self._jupyter_current_user = None
raise web.HTTPError(403)
+
+ from jupyter_server.auth import IdentityProvider
+
+ mod_obj = inspect.getmodule(self.get_current_user)
+ assert mod_obj is not None
+ user: User | None = None
+
+ if type(self.identity_provider) is IdentityProvider and mod_obj.__name__ != __name__:
+ # check for overridden get_current_user + default IdentityProvider
+ # deprecated way to override auth (e.g. JupyterHub < 3.0)
+ # allow deprecated, overridden get_current_user
+ warnings.warn(
+ "Overriding JupyterHandler.get_current_user is deprecated in jupyter-server 2.0."
+ " Use an IdentityProvider class.",
+ DeprecationWarning,
+ stacklevel=1,
+ )
+ user = User(self.get_current_user())
+ else:
+ _user = self.identity_provider.get_user(self)
+ if isinstance(_user, Awaitable):
+ # IdentityProvider.get_user _may_ be async
+ _user = await _user
+ user = _user
+
+ # self.current_user for tornado's @web.authenticated
+ # self._jupyter_current_user for backward-compat in deprecated get_current_user calls
+ # and our own private checks for whether .current_user has been set
+ self.current_user = self._jupyter_current_user = user
+ # complete initial steps which require auth to resolve first:
+ self.set_cors_headers()
+ if self.request.method not in {"GET", "HEAD", "OPTIONS"}:
+ self.check_xsrf_cookie()
return super().prepare()
# ---------------------------------------------------------------
@@ -521,19 +640,21 @@ def get_template(self, name):
return self.settings["jinja2_env"].get_template(name)
def render_template(self, name, **ns):
+ """Render a template by name."""
ns.update(self.template_namespace)
template = self.get_template(name)
return template.render(**ns)
@property
- def template_namespace(self):
+ def template_namespace(self) -> dict[str, Any]:
return dict(
base_url=self.base_url,
default_url=self.default_url,
ws_url=self.ws_url,
logged_in=self.logged_in,
- allow_password_change=self.settings.get("allow_password_change"),
- login_available=self.login_available,
+ allow_password_change=getattr(self.identity_provider, "allow_password_change", False),
+ auth_enabled=self.identity_provider.auth_enabled,
+ login_available=self.identity_provider.login_available,
token_available=bool(self.token),
static_url=self.static_url,
sys_info=json_sys_info(),
@@ -548,7 +669,7 @@ def template_namespace(self):
**self.jinja_template_vars,
)
- def get_json_body(self):
+ def get_json_body(self) -> dict[str, Any] | None:
"""Return the body of the request as JSON data."""
if not self.request.body:
return None
@@ -560,14 +681,14 @@ def get_json_body(self):
self.log.debug("Bad JSON: %r", body)
self.log.error("Couldn't parse JSON", exc_info=True)
raise web.HTTPError(400, "Invalid JSON in body of request") from e
- return model
+ return cast("dict[str, Any]", model)
- def write_error(self, status_code, **kwargs):
+ def write_error(self, status_code: int, **kwargs: Any) -> None:
"""render custom error pages"""
exc_info = kwargs.get("exc_info")
message = ""
status_message = responses.get(status_code, "Unknown HTTP Error")
- exception = "(unknown)"
+
if exc_info:
exception = exc_info[1]
# get the custom message, if defined
@@ -580,14 +701,16 @@ def write_error(self, status_code, **kwargs):
reason = getattr(exception, "reason", "")
if reason:
status_message = reason
+ else:
+ exception = "(unknown)"
# build template namespace
- ns = dict(
- status_code=status_code,
- status_message=status_message,
- message=message,
- exception=exception,
- )
+ ns = {
+ "status_code": status_code,
+ "status_message": status_message,
+ "message": message,
+ "exception": exception,
+ }
self.set_header("Content-Type", "text/html")
# render the template
@@ -602,16 +725,17 @@ def write_error(self, status_code, **kwargs):
class APIHandler(JupyterHandler):
"""Base class for API handlers"""
- def prepare(self):
+ async def prepare(self) -> None:
+ """Prepare an API response."""
+ await super().prepare()
if not self.check_origin():
raise web.HTTPError(404)
- return super().prepare()
- def write_error(self, status_code, **kwargs):
+ def write_error(self, status_code: int, **kwargs: Any) -> None:
"""APIHandler errors are JSON, not human pages"""
self.set_header("Content-Type", "application/json")
message = responses.get(status_code, "Unknown HTTP Error")
- reply = {
+ reply: dict[str, Any] = {
"message": message,
}
exc_info = kwargs.get("exc_info")
@@ -623,19 +747,14 @@ def write_error(self, status_code, **kwargs):
else:
reply["message"] = "Unhandled error"
reply["reason"] = None
- reply["traceback"] = "".join(traceback.format_exception(*exc_info))
- self.log.warning(reply["message"])
+ # backward-compatibility: traceback field is present,
+ # but always empty
+ reply["traceback"] = ""
+ self.log.warning("wrote error: %r", reply["message"], exc_info=True)
self.finish(json.dumps(reply))
- def get_current_user(self):
- """Raise 403 on API handlers instead of redirecting to human login page"""
- # preserve _user_cache so we don't raise more than once
- if hasattr(self, "_user_cache"):
- return self._user_cache
- self._user_cache = user = super().get_current_user()
- return user
-
- def get_login_url(self):
+ def get_login_url(self) -> str:
+ """Get the login url."""
# if get_login_url is invoked in an API handler,
# that means @web.authenticated is trying to trigger a redirect.
# instead of redirecting, raise 403 instead.
@@ -644,7 +763,7 @@ def get_login_url(self):
return super().get_login_url()
@property
- def content_security_policy(self):
+ def content_security_policy(self) -> str:
csp = "; ".join(
[
super().content_security_policy,
@@ -656,22 +775,26 @@ def content_security_policy(self):
# set _track_activity = False on API handlers that shouldn't track activity
_track_activity = True
- def update_api_activity(self):
+ def update_api_activity(self) -> None:
"""Update last_activity of API requests"""
# record activity of authenticated requests
if (
self._track_activity
- and getattr(self, "_user_cache", None)
+ and getattr(self, "_jupyter_current_user", None)
and self.get_argument("no_track_activity", None) is None
):
self.settings["api_last_activity"] = utcnow()
- def finish(self, *args, **kwargs):
+ def finish(self, *args: Any, **kwargs: Any) -> Future[Any]:
+ """Finish an API response."""
self.update_api_activity()
- self.set_header("Content-Type", "application/json")
+ # Allow caller to indicate content-type...
+ set_content_type = kwargs.pop("set_content_type", "application/json")
+ self.set_header("Content-Type", set_content_type)
return super().finish(*args, **kwargs)
- def options(self, *args, **kwargs):
+ def options(self, *args: Any, **kwargs: Any) -> None:
+ """Get the options."""
if "Access-Control-Allow-Headers" in self.settings.get("headers", {}):
self.set_header(
"Access-Control-Allow-Headers",
@@ -713,33 +836,46 @@ def options(self, *args, **kwargs):
class Template404(JupyterHandler):
"""Render our 404 template"""
- def prepare(self):
+ async def prepare(self) -> None:
+ """Prepare a 404 response."""
+ await super().prepare()
raise web.HTTPError(404)
class AuthenticatedFileHandler(JupyterHandler, web.StaticFileHandler):
"""static files should only be accessible when logged in"""
+ auth_resource = "contents"
+
@property
- def content_security_policy(self):
+ def content_security_policy(self) -> str:
# In case we're serving HTML/SVG, confine any Javascript to a unique
# origin so it can't interact with the Jupyter server.
return super().content_security_policy + "; sandbox allow-scripts"
@web.authenticated
- def head(self, path):
+ @authorized
+ def head(self, path: str) -> Awaitable[None]: # type:ignore[override]
+ """Get the head response for a path."""
self.check_xsrf_cookie()
return super().head(path)
@web.authenticated
- def get(self, path):
- if os.path.splitext(path)[1] == ".ipynb" or self.get_argument("download", False):
+ @authorized
+ def get( # type:ignore[override]
+ self, path: str, **kwargs: Any
+ ) -> Awaitable[None]:
+ """Get a file by path."""
+ self.check_xsrf_cookie()
+ if os.path.splitext(path)[1] == ".ipynb" or self.get_argument("download", None):
name = path.rsplit("/", 1)[-1]
self.set_attachment_header(name)
- return web.StaticFileHandler.get(self, path)
+ return web.StaticFileHandler.get(self, path, **kwargs)
- def get_content_type(self):
+ def get_content_type(self) -> str:
+ """Get the content type."""
+ assert self.absolute_path is not None
path = self.absolute_path.strip("/")
if "/" in path:
_, name = path.rsplit("/", 1)
@@ -754,16 +890,18 @@ def get_content_type(self):
else:
return super().get_content_type()
- def set_headers(self):
+ def set_headers(self) -> None:
+ """Set the headers."""
super().set_headers()
# disable browser caching, rely on 304 replies for savings
if "v" not in self.request.arguments:
self.add_header("Cache-Control", "no-cache")
- def compute_etag(self):
+ def compute_etag(self) -> str | None:
+ """Compute the etag."""
return None
- def validate_absolute_path(self, root, absolute_path):
+ def validate_absolute_path(self, root: str, absolute_path: str) -> str:
"""Validate and return the absolute path.
Requires tornado 3.1
@@ -772,7 +910,8 @@ def validate_absolute_path(self, root, absolute_path):
"""
abs_path = super().validate_absolute_path(root, absolute_path)
abs_root = os.path.abspath(root)
- if is_hidden(abs_path, abs_root) and not self.contents_manager.allow_hidden:
+ assert abs_path is not None
+ if not self.contents_manager.allow_hidden and is_hidden(abs_path, abs_root):
self.log.info(
"Refusing to serve hidden file, via 404 Error, use flag 'ContentsManager.allow_hidden' to enable"
)
@@ -780,7 +919,7 @@ def validate_absolute_path(self, root, absolute_path):
return abs_path
-def json_errors(method):
+def json_errors(method: Any) -> Any: # pragma: no cover
"""Decorate methods with this to return GitHub style JSON errors.
This should be used on any JSON API on any handler method that can raise HTTPErrors.
@@ -815,33 +954,55 @@ def wrapper(self, *args, **kwargs):
class FileFindHandler(JupyterHandler, web.StaticFileHandler):
- """subclass of StaticFileHandler for serving files from a search path"""
+ """subclass of StaticFileHandler for serving files from a search path
+
+ The setting "static_immutable_cache" can be set up to serve some static
+ file as immutable (e.g. file name containing a hash). The setting is a
+ list of base URL, every static file URL starting with one of those will
+ be immutable.
+ """
# cache search results, don't search for files more than once
- _static_paths = {}
+ _static_paths: dict[str, str] = {}
+ root: tuple[str] # type:ignore[assignment]
- def set_headers(self):
+ def set_headers(self) -> None:
+ """Set the headers."""
super().set_headers()
+
+ immutable_paths = self.settings.get("static_immutable_cache", [])
+
+ # allow immutable cache for files
+ if any(self.request.path.startswith(path) for path in immutable_paths):
+ self.set_header("Cache-Control", "public, max-age=31536000, immutable")
+
# disable browser caching, rely on 304 replies for savings
- if "v" not in self.request.arguments or any(
+ elif "v" not in self.request.arguments or any(
self.request.path.startswith(path) for path in self.no_cache_paths
):
self.set_header("Cache-Control", "no-cache")
- def initialize(self, path, default_filename=None, no_cache_paths=None):
+ def initialize(
+ self,
+ path: str | list[str],
+ default_filename: str | None = None,
+ no_cache_paths: list[str] | None = None,
+ ) -> None:
+ """Initialize the file find handler."""
self.no_cache_paths = no_cache_paths or []
if isinstance(path, str):
path = [path]
- self.root = tuple(os.path.abspath(os.path.expanduser(p)) + os.sep for p in path)
+ self.root = tuple(os.path.abspath(os.path.expanduser(p)) + os.sep for p in path) # type:ignore[assignment]
self.default_filename = default_filename
- def compute_etag(self):
+ def compute_etag(self) -> str | None:
+ """Compute the etag."""
return None
@classmethod
- def get_absolute_path(cls, roots, path):
+ def get_absolute_path(cls, roots: Sequence[str], path: str) -> str:
"""locate a file to serve on our static file search path"""
with cls._lock:
if path in cls._static_paths:
@@ -857,9 +1018,9 @@ def get_absolute_path(cls, roots, path):
log().debug(f"Path {path} served from {abspath}")
return abspath
- def validate_absolute_path(self, root, absolute_path):
+ def validate_absolute_path(self, root: str, absolute_path: str) -> str | None:
"""check if the file should be served (raises 404, 403, etc.)"""
- if absolute_path == "":
+ if not absolute_path:
raise web.HTTPError(404)
for root in self.root:
@@ -870,7 +1031,12 @@ def validate_absolute_path(self, root, absolute_path):
class APIVersionHandler(APIHandler):
- def get(self):
+ """An API handler for the server version."""
+
+ _track_activity = False
+
+ def get(self) -> None:
+ """Get the server version info."""
# not authenticated, so give as few info as possible
self.finish(json.dumps({"version": jupyter_server.__version__}))
@@ -881,7 +1047,9 @@ class TrailingSlashHandler(web.RequestHandler):
This should be the first, highest priority handler.
"""
- def get(self):
+ def get(self) -> None:
+ """Handle trailing slashes in a get."""
+ assert self.request.uri is not None
path, *rest = self.request.uri.partition("?")
# trim trailing *and* leading /
# to avoid misinterpreting repeated '//'
@@ -895,7 +1063,8 @@ def get(self):
class MainHandler(JupyterHandler):
"""Simple handler for base_url."""
- def get(self):
+ def get(self) -> None:
+ """Get the main template."""
html = self.render_template("main.html")
self.write(html)
@@ -906,7 +1075,7 @@ class FilesRedirectHandler(JupyterHandler):
"""Handler for redirecting relative URLs to the /files/ handler"""
@staticmethod
- async def redirect_to_files(self, path):
+ async def redirect_to_files(self: Any, path: str) -> None:
"""make redirect logic a reusable static method
so it can be called from other handlers.
@@ -934,18 +1103,20 @@ async def redirect_to_files(self, path):
self.log.debug("Redirecting %s to %s", self.request.path, url)
self.redirect(url)
- def get(self, path=""):
- return self.redirect_to_files(self, path)
+ async def get(self, path: str = "") -> None:
+ return await self.redirect_to_files(self, path)
class RedirectWithParams(web.RequestHandler):
"""Sam as web.RedirectHandler, but preserves URL parameters"""
- def initialize(self, url, permanent=True):
+ def initialize(self, url: str, permanent: bool = True) -> None:
+ """Initialize a redirect handler."""
self._url = url
self._permanent = permanent
- def get(self):
+ def get(self) -> None:
+ """Get a redirect."""
sep = "&" if "?" in self._url else "?"
url = sep.join([self._url, self.request.query])
self.redirect(url, permanent=self._permanent)
@@ -953,10 +1124,11 @@ def get(self):
class PrometheusMetricsHandler(JupyterHandler):
"""
- Return prometheus metrics for this notebook server
+ Return prometheus metrics for this server
"""
- def get(self):
+ def get(self) -> None:
+ """Get prometheus metrics."""
if self.settings["authenticate_prometheus"] and not self.logged_in:
raise web.HTTPError(403)
@@ -965,7 +1137,7 @@ def get(self):
# -----------------------------------------------------------------------------
-# URL pattern fragments for re-use
+# URL pattern fragments for reuse
# -----------------------------------------------------------------------------
# path matches any number of `/foo[/bar...]` or just `/` or ''
diff --git a/jupyter_server/base/websocket.py b/jupyter_server/base/websocket.py
new file mode 100644
index 0000000000..a27b7a72a7
--- /dev/null
+++ b/jupyter_server/base/websocket.py
@@ -0,0 +1,128 @@
+"""Base websocket classes."""
+import re
+from typing import Optional, no_type_check
+from urllib.parse import urlparse
+
+from tornado import ioloop
+from tornado.iostream import IOStream
+
+# ping interval for keeping websockets alive (30 seconds)
+WS_PING_INTERVAL = 30000
+
+
+class WebSocketMixin:
+ """Mixin for common websocket options"""
+
+ ping_callback = None
+ last_ping = 0.0
+ last_pong = 0.0
+ stream: Optional[IOStream] = None
+
+ @property
+ def ping_interval(self):
+ """The interval for websocket keep-alive pings.
+
+ Set ws_ping_interval = 0 to disable pings.
+ """
+ return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined]
+
+ @property
+ def ping_timeout(self):
+ """If no ping is received in this many milliseconds,
+ close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
+ Default is max of 3 pings or 30 seconds.
+ """
+ return self.settings.get( # type:ignore[attr-defined]
+ "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL)
+ )
+
+ @no_type_check
+ def check_origin(self, origin: Optional[str] = None) -> bool:
+ """Check Origin == Host or Access-Control-Allow-Origin.
+
+ Tornado >= 4 calls this method automatically, raising 403 if it returns False.
+ """
+
+ if self.allow_origin == "*" or (
+ hasattr(self, "skip_check_origin") and self.skip_check_origin()
+ ):
+ return True
+
+ host = self.request.headers.get("Host")
+ if origin is None:
+ origin = self.get_origin()
+
+ # If no origin or host header is provided, assume from script
+ if origin is None or host is None:
+ return True
+
+ origin = origin.lower()
+ origin_host = urlparse(origin).netloc
+
+ # OK if origin matches host
+ if origin_host == host:
+ return True
+
+ # Check CORS headers
+ if self.allow_origin:
+ allow = self.allow_origin == origin
+ elif self.allow_origin_pat:
+ allow = bool(re.match(self.allow_origin_pat, origin))
+ else:
+ # No CORS headers deny the request
+ allow = False
+ if not allow:
+ self.log.warning(
+ "Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
+ origin,
+ host,
+ )
+ return allow
+
+ def clear_cookie(self, *args, **kwargs):
+ """meaningless for websockets"""
+
+ @no_type_check
+ def open(self, *args, **kwargs):
+ """Open the websocket."""
+ self.log.debug("Opening websocket %s", self.request.path)
+
+ # start the pinging
+ if self.ping_interval > 0:
+ loop = ioloop.IOLoop.current()
+ self.last_ping = loop.time() # Remember time of last ping
+ self.last_pong = self.last_ping
+ self.ping_callback = ioloop.PeriodicCallback(
+ self.send_ping,
+ self.ping_interval,
+ )
+ self.ping_callback.start()
+ return super().open(*args, **kwargs)
+
+ @no_type_check
+ def send_ping(self):
+ """send a ping to keep the websocket alive"""
+ if self.ws_connection is None and self.ping_callback is not None:
+ self.ping_callback.stop()
+ return
+
+ if self.ws_connection.client_terminated:
+ self.close()
+ return
+
+ # check for timeout on pong. Make sure that we really have sent a recent ping in
+ # case the machine with both server and client has been suspended since the last ping.
+ now = ioloop.IOLoop.current().time()
+ since_last_pong = 1e3 * (now - self.last_pong)
+ since_last_ping = 1e3 * (now - self.last_ping)
+ if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout:
+ self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong)
+ self.close()
+ return
+
+ self.ping(b"")
+ self.last_ping = now
+
+ def on_pong(self, data):
+ """Handle a pong message."""
+ self.last_pong = ioloop.IOLoop.current().time()
diff --git a/jupyter_server/base/zmqhandlers.py b/jupyter_server/base/zmqhandlers.py
index 28e296c722..4490380a34 100644
--- a/jupyter_server/base/zmqhandlers.py
+++ b/jupyter_server/base/zmqhandlers.py
@@ -1,348 +1,19 @@
-"""Tornado handlers for WebSocket <-> ZMQ sockets."""
-# Copyright (c) Jupyter Development Team.
-# Distributed under the terms of the Modified BSD License.
-import json
-import re
-import struct
-import sys
-from urllib.parse import urlparse
+"""This module is deprecated in Jupyter Server 2.0"""
+# Raise a warning that this module is deprecated.
+import warnings
-import tornado
-
-try:
- from jupyter_client.jsonutil import json_default
-except ImportError:
- from jupyter_client.jsonutil import date_default as json_default
-
-from jupyter_client.jsonutil import extract_dates
-from jupyter_client.session import Session
-from tornado import ioloop, web
from tornado.websocket import WebSocketHandler
-from jupyter_server.auth.utils import warn_disabled_authorization
-
-from .handlers import JupyterHandler
-
-
-def serialize_binary_message(msg):
- """serialize a message as a binary blob
-
- Header:
-
- 4 bytes: number of msg parts (nbufs) as 32b int
- 4 * nbufs bytes: offset for each buffer as integer as 32b int
-
- Offsets are from the start of the buffer, including the header.
-
- Returns
- -------
- The message serialized to bytes.
-
- """
- # don't modify msg or buffer list in-place
- msg = msg.copy()
- buffers = list(msg.pop("buffers"))
- if sys.version_info < (3, 4):
- buffers = [x.tobytes() for x in buffers]
- bmsg = json.dumps(msg, default=json_default).encode("utf8")
- buffers.insert(0, bmsg)
- nbufs = len(buffers)
- offsets = [4 * (nbufs + 1)]
- for buf in buffers[:-1]:
- offsets.append(offsets[-1] + len(buf))
- offsets_buf = struct.pack("!" + "I" * (nbufs + 1), nbufs, *offsets)
- buffers.insert(0, offsets_buf)
- return b"".join(buffers)
-
-
-def deserialize_binary_message(bmsg):
- """deserialize a message from a binary blog
-
- Header:
-
- 4 bytes: number of msg parts (nbufs) as 32b int
- 4 * nbufs bytes: offset for each buffer as integer as 32b int
-
- Offsets are from the start of the buffer, including the header.
-
- Returns
- -------
- message dictionary
- """
- nbufs = struct.unpack("!i", bmsg[:4])[0]
- offsets = list(struct.unpack("!" + "I" * nbufs, bmsg[4 : 4 * (nbufs + 1)]))
- offsets.append(None)
- bufs = []
- for start, stop in zip(offsets[:-1], offsets[1:]):
- bufs.append(bmsg[start:stop])
- msg = json.loads(bufs[0].decode("utf8"))
- msg["header"] = extract_dates(msg["header"])
- msg["parent_header"] = extract_dates(msg["parent_header"])
- msg["buffers"] = bufs[1:]
- return msg
-
-
-def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None):
- if pack:
- msg_list = [
- pack(msg_or_list["header"]),
- pack(msg_or_list["parent_header"]),
- pack(msg_or_list["metadata"]),
- pack(msg_or_list["content"]),
- ]
- else:
- msg_list = msg_or_list
- channel = channel.encode("utf-8")
- offsets = []
- offsets.append(8 * (1 + 1 + len(msg_list) + 1))
- offsets.append(len(channel) + offsets[-1])
- for msg in msg_list:
- offsets.append(len(msg) + offsets[-1])
- offset_number = len(offsets).to_bytes(8, byteorder="little")
- offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets]
- bin_msg = b"".join([offset_number] + offsets + [channel] + msg_list)
- return bin_msg
-
-
-def deserialize_msg_from_ws_v1(ws_msg):
- offset_number = int.from_bytes(ws_msg[:8], "little")
- offsets = [
- int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number)
- ]
- channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8")
- msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)]
- return channel, msg_list
-
-
-# ping interval for keeping websockets alive (30 seconds)
-WS_PING_INTERVAL = 30000
-
-
-class WebSocketMixin:
- """Mixin for common websocket options"""
-
- ping_callback = None
- last_ping = 0
- last_pong = 0
- stream = None
-
- @property
- def ping_interval(self):
- """The interval for websocket keep-alive pings.
-
- Set ws_ping_interval = 0 to disable pings.
- """
- return self.settings.get("ws_ping_interval", WS_PING_INTERVAL)
-
- @property
- def ping_timeout(self):
- """If no ping is received in this many milliseconds,
- close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
- Default is max of 3 pings or 30 seconds.
- """
- return self.settings.get("ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL))
-
- def check_origin(self, origin=None):
- """Check Origin == Host or Access-Control-Allow-Origin.
-
- Tornado >= 4 calls this method automatically, raising 403 if it returns False.
- """
-
- if self.allow_origin == "*" or (
- hasattr(self, "skip_check_origin") and self.skip_check_origin()
- ):
- return True
-
- host = self.request.headers.get("Host")
- if origin is None:
- origin = self.get_origin()
-
- # If no origin or host header is provided, assume from script
- if origin is None or host is None:
- return True
-
- origin = origin.lower()
- origin_host = urlparse(origin).netloc
-
- # OK if origin matches host
- if origin_host == host:
- return True
-
- # Check CORS headers
- if self.allow_origin:
- allow = self.allow_origin == origin
- elif self.allow_origin_pat:
- allow = bool(re.match(self.allow_origin_pat, origin))
- else:
- # No CORS headers deny the request
- allow = False
- if not allow:
- self.log.warning(
- "Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
- origin,
- host,
- )
- return allow
-
- def clear_cookie(self, *args, **kwargs):
- """meaningless for websockets"""
- pass
-
- def open(self, *args, **kwargs):
- self.log.debug("Opening websocket %s", self.request.path)
-
- # start the pinging
- if self.ping_interval > 0:
- loop = ioloop.IOLoop.current()
- self.last_ping = loop.time() # Remember time of last ping
- self.last_pong = self.last_ping
- self.ping_callback = ioloop.PeriodicCallback(
- self.send_ping,
- self.ping_interval,
- )
- self.ping_callback.start()
- return super().open(*args, **kwargs)
-
- def send_ping(self):
- """send a ping to keep the websocket alive"""
- if self.ws_connection is None and self.ping_callback is not None:
- self.ping_callback.stop()
- return
-
- if self.ws_connection.client_terminated:
- self.close()
- return
-
- # check for timeout on pong. Make sure that we really have sent a recent ping in
- # case the machine with both server and client has been suspended since the last ping.
- now = ioloop.IOLoop.current().time()
- since_last_pong = 1e3 * (now - self.last_pong)
- since_last_ping = 1e3 * (now - self.last_ping)
- if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout:
- self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong)
- self.close()
- return
-
- self.ping(b"")
- self.last_ping = now
-
- def on_pong(self, data):
- self.last_pong = ioloop.IOLoop.current().time()
-
-
-class ZMQStreamHandler(WebSocketMixin, WebSocketHandler):
-
- if tornado.version_info < (4, 1):
- """Backport send_error from tornado 4.1 to 4.0"""
-
- def send_error(self, *args, **kwargs):
- if self.stream is None:
- super(WebSocketHandler, self).send_error(*args, **kwargs)
- else:
- # If we get an uncaught exception during the handshake,
- # we have no choice but to abruptly close the connection.
- # TODO: for uncaught exceptions after the handshake,
- # we can close the connection more gracefully.
- self.stream.close()
-
- def _reserialize_reply(self, msg_or_list, channel=None):
- """Reserialize a reply message using JSON.
-
- msg_or_list can be an already-deserialized msg dict or the zmq buffer list.
- If it is the zmq list, it will be deserialized with self.session.
-
- This takes the msg list from the ZMQ socket and serializes the result for the websocket.
- This method should be used by self._on_zmq_reply to build messages that can
- be sent back to the browser.
-
- """
- if isinstance(msg_or_list, dict):
- # already unpacked
- msg = msg_or_list
- else:
- idents, msg_list = self.session.feed_identities(msg_or_list)
- msg = self.session.deserialize(msg_list)
- if channel:
- msg["channel"] = channel
- if msg["buffers"]:
- buf = serialize_binary_message(msg)
- return buf
- else:
- return json.dumps(msg, default=json_default)
-
- def select_subprotocol(self, subprotocols):
- preferred_protocol = self.settings.get("kernel_ws_protocol")
- if preferred_protocol is None:
- preferred_protocol = "v1.kernel.websocket.jupyter.org"
- elif preferred_protocol == "":
- preferred_protocol = None
- selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None
- # None is the default, "legacy" protocol
- return selected_subprotocol
-
- def _on_zmq_reply(self, stream, msg_list):
- # Sometimes this gets triggered when the on_close method is scheduled in the
- # eventloop but hasn't been called.
- if self.ws_connection is None or stream.closed():
- self.log.warning("zmq message arrived on closed channel")
- self.close()
- return
- channel = getattr(stream, "channel", None)
- if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
- bin_msg = serialize_msg_to_ws_v1(msg_list, channel)
- self.write_message(bin_msg, binary=True)
- else:
- try:
- msg = self._reserialize_reply(msg_list, channel=channel)
- except Exception:
- self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
- else:
- self.write_message(msg, binary=isinstance(msg, bytes))
-
-
-class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler):
- def set_default_headers(self):
- """Undo the set_default_headers in JupyterHandler
-
- which doesn't make sense for websockets
- """
- pass
-
- def pre_get(self):
- """Run before finishing the GET request
-
- Extend this method to add logic that should fire before
- the websocket finishes completing.
- """
- # authenticate the request before opening the websocket
- user = self.get_current_user()
- if user is None:
- self.log.warning("Couldn't authenticate WebSocket connection")
- raise web.HTTPError(403)
-
- # authorize the user.
- if not self.authorizer:
- # Warn if there is not authorizer.
- warn_disabled_authorization()
- elif not self.authorizer.is_authorized(self, user, "execute", "kernels"):
- raise web.HTTPError(403)
-
- if self.get_argument("session_id", False):
- self.session.session = self.get_argument("session_id")
- else:
- self.log.warning("No session ID specified")
-
- async def get(self, *args, **kwargs):
- # pre_get can be a coroutine in subclasses
- # assign and yield in two step to avoid tornado 3 issues
- res = self.pre_get()
- await res
- res = super().get(*args, **kwargs)
- await res
-
- def initialize(self):
- self.log.debug("Initializing websocket connection %s", self.request.path)
- self.session = Session(config=self.config)
-
- def get_compression_options(self):
- return self.settings.get("websocket_compression_options", None)
+from jupyter_server.base.websocket import WebSocketMixin
+from jupyter_server.services.kernels.connection.base import (
+ deserialize_binary_message,
+ deserialize_msg_from_ws_v1,
+ serialize_binary_message,
+ serialize_msg_to_ws_v1,
+)
+
+warnings.warn(
+ "jupyter_server.base.zmqhandlers module is deprecated in Jupyter Server 2.0",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/jupyter_server/config_manager.py b/jupyter_server/config_manager.py
index 25c5efd28f..87480d7609 100644
--- a/jupyter_server/config_manager.py
+++ b/jupyter_server/config_manager.py
@@ -1,17 +1,22 @@
"""Manager to read and modify config data in JSON files."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
+from __future__ import annotations
+
import copy
import errno
import glob
import json
import os
+import typing as t
from traitlets.config import LoggingConfigurable
from traitlets.traitlets import Bool, Unicode
+StrDict = t.Dict[str, t.Any]
+
-def recursive_update(target, new):
+def recursive_update(target: StrDict, new: StrDict) -> None:
"""Recursively update one dictionary using another.
None values will delete their keys.
@@ -32,7 +37,7 @@ def recursive_update(target, new):
target[k] = v
-def remove_defaults(data, defaults):
+def remove_defaults(data: StrDict, defaults: StrDict) -> None:
"""Recursively remove items from dict that are already in defaults"""
# copy the iterator, since data will be modified
for key, value in list(data.items()):
@@ -41,9 +46,8 @@ def remove_defaults(data, defaults):
remove_defaults(data[key], defaults[key])
if not data[key]: # prune empty subdicts
del data[key]
- else:
- if value == defaults[key]:
- del data[key]
+ elif value == defaults[key]:
+ del data[key]
class BaseJSONConfigManager(LoggingConfigurable):
@@ -56,7 +60,7 @@ class BaseJSONConfigManager(LoggingConfigurable):
config_dir = Unicode(".")
read_directory = Bool(True)
- def ensure_config_dir_exists(self):
+ def ensure_config_dir_exists(self) -> None:
"""Will try to create the config_dir directory."""
try:
os.makedirs(self.config_dir, 0o755)
@@ -64,15 +68,15 @@ def ensure_config_dir_exists(self):
if e.errno != errno.EEXIST:
raise
- def file_name(self, section_name):
+ def file_name(self, section_name: str) -> str:
"""Returns the json filename for the section_name: {config_dir}/{section_name}.json"""
return os.path.join(self.config_dir, section_name + ".json")
- def directory(self, section_name):
+ def directory(self, section_name: str) -> str:
"""Returns the directory name for the section name: {config_dir}/{section_name}.d"""
return os.path.join(self.config_dir, section_name + ".d")
- def get(self, section_name, include_root=True):
+ def get(self, section_name: str, include_root: bool = True) -> dict[str, t.Any]:
"""Retrieve the config data for the specified section.
Returns the data as a dictionary, or an empty dictionary if the file
@@ -95,14 +99,14 @@ def get(self, section_name, include_root=True):
section_name,
"\n\t".join(paths),
)
- data = {}
+ data: dict[str, t.Any] = {}
for path in paths:
if os.path.isfile(path):
with open(path, encoding="utf-8") as f:
recursive_update(data, json.load(f))
return data
- def set(self, section_name, data):
+ def set(self, section_name: str, data: t.Any) -> None:
"""Store the given config data."""
filename = self.file_name(section_name)
self.ensure_config_dir_exists()
@@ -116,11 +120,10 @@ def set(self, section_name, data):
# Generate the JSON up front, since it could raise an exception,
# in order to avoid writing half-finished corrupted data to disk.
json_content = json.dumps(data, indent=2)
- f = open(filename, "w", encoding="utf-8")
- with f:
+ with open(filename, "w", encoding="utf-8") as f:
f.write(json_content)
- def update(self, section_name, new_data):
+ def update(self, section_name: str, new_data: t.Any) -> dict[str, t.Any]:
"""Modify the config section by recursively updating it with new_data.
Returns the modified config data as a dictionary.
diff --git a/jupyter_server/event_schemas/contents_service/v1.yaml b/jupyter_server/event_schemas/contents_service/v1.yaml
new file mode 100644
index 0000000000..a787f9b2b0
--- /dev/null
+++ b/jupyter_server/event_schemas/contents_service/v1.yaml
@@ -0,0 +1,73 @@
+"$id": https://events.jupyter.org/jupyter_server/contents_service/v1
+version: 1
+title: Contents Manager activities
+personal-data: true
+description: |
+ Record actions on files via the ContentsManager.
+
+ The notebook ContentsManager REST API is used by all frontends to retrieve,
+ save, list, delete and perform other actions on notebooks, directories,
+ and other files through the UI. This is pluggable - the default acts on
+ the file system, but can be replaced with a different ContentsManager
+ implementation - to work on S3, Postgres, other object stores, etc.
+ The events get recorded regardless of the ContentsManager implementation
+ being used.
+
+ Limitations:
+
+ 1. This does not record all filesystem access, just the ones that happen
+ explicitly via the notebook server's REST API. Users can (and often do)
+ trivially access the filesystem in many other ways (such as `open()` calls
+ in their code), so this is usually never a complete record.
+ 2. As with all events recorded by the notebook server, users most likely
+ have the ability to modify the code of the notebook server. Unless other
+ security measures are in place, these events should be treated as user
+ controlled and not used in high security areas.
+ 3. Events are only recorded when an action succeeds.
+type: object
+required:
+ - action
+ - path
+properties:
+ action:
+ enum:
+ - get
+ - create
+ - save
+ - upload
+ - rename
+ - copy
+ - delete
+ description: |
+ Action performed by the ContentsManager API.
+
+ This is a required field.
+
+ Possible values:
+
+ 1. get
+ Get contents of a particular file, or list contents of a directory.
+
+ 2. save
+ Save a file at path with contents from the client
+
+ 3. rename
+ Rename a file or directory from value in source_path to
+ value in path.
+
+ 4. copy
+ Copy a file or directory from value in source_path to
+ value in path.
+
+ 5. delete
+ Delete a file or empty directory at given path
+ path:
+ type: string
+ description: |
+ Logical path on which the operation was performed.
+
+ This is a required field.
+ source_path:
+ type: string
+ description: |
+ Source path of an operation when action is 'copy' or 'rename'
diff --git a/jupyter_server/event_schemas/gateway_client/v1.yaml b/jupyter_server/event_schemas/gateway_client/v1.yaml
new file mode 100644
index 0000000000..0a35d2464d
--- /dev/null
+++ b/jupyter_server/event_schemas/gateway_client/v1.yaml
@@ -0,0 +1,40 @@
+"$id": https://events.jupyter.org/jupyter_server/gateway_client/v1
+version: 1
+title: Gateway Client activities.
+personal-data: true
+description: |
+ Record events of a gateway client.
+type: object
+required:
+ - status
+ - msg
+properties:
+ status:
+ enum:
+ - error
+ - success
+ description: |
+ Status received by Gateway client based on the rest api operation to gateway kernel.
+
+ This is a required field.
+
+ Possible values:
+
+ 1. error
+ Error response from a rest api operation to gateway kernel.
+
+ 2. success
+ Success response from a rest api operation to gateway kernel.
+ status_code:
+ type: number
+ description: |
+ Http response codes from a rest api operation to gateway kernel.
+ Examples: 200, 400, 502, 503, 599 etc.
+ msg:
+ type: string
+ description: |
+ Description of the event being emitted.
+ gateway_url:
+ type: string
+ description: |
+ Gateway url where the remote server exist.
diff --git a/jupyter_server/event_schemas/kernel_actions/v1.yaml b/jupyter_server/event_schemas/kernel_actions/v1.yaml
new file mode 100644
index 0000000000..e0375e5aaa
--- /dev/null
+++ b/jupyter_server/event_schemas/kernel_actions/v1.yaml
@@ -0,0 +1,80 @@
+"$id": https://events.jupyter.org/jupyter_server/kernel_actions/v1
+version: 1
+title: Kernel Manager activities
+personal-data: true
+description: |
+ Record events of a kernel manager.
+type: object
+required:
+ - action
+ - msg
+properties:
+ action:
+ enum:
+ - start
+ - interrupt
+ - shutdown
+ - restart
+ description: |
+ Action performed by the Kernel Manager.
+
+ This is a required field.
+
+ Possible values:
+
+ 1. start
+ A kernel has been started with the given kernel id.
+
+ 2. interrupt
+ A kernel has been interrupted for the given kernel id.
+
+ 3. shutdown
+ A kernel has been shut down for the given kernel id.
+
+ 4. restart
+ A kernel has been restarted for the given kernel id.
+ kernel_id:
+ type: string
+ description: |
+ Kernel id.
+
+ This is a required field for all actions and statuses except action start with status error.
+ kernel_name:
+ type: string
+ description: |
+ Name of the kernel.
+ status:
+ enum:
+ - error
+ - success
+ description: |
+ Status received from a rest api operation to kernel server.
+
+ This is a required field.
+
+ Possible values:
+
+ 1. error
+ Error response from a rest api operation to kernel server.
+
+ 2. success
+ Success response from a rest api operation to kernel server.
+ status_code:
+ type: number
+ description: |
+ Http response codes from a rest api operation to kernel server.
+ Examples: 200, 400, 502, 503, 599 etc
+ msg:
+ type: string
+ description: |
+ Description of the event specified in action.
+if:
+ not:
+ properties:
+ status:
+ const: error
+ action:
+ const: start
+then:
+ required:
+ - kernel_id
diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py
index 167f6dd94e..aeeab5a94d 100644
--- a/jupyter_server/extension/application.py
+++ b/jupyter_server/extension/application.py
@@ -1,6 +1,10 @@
+"""An extension application."""
+from __future__ import annotations
+
import logging
import re
import sys
+import typing as t
from jinja2 import Environment, FileSystemLoader
from jupyter_core.application import JupyterApp, NoStart
@@ -20,24 +24,24 @@
# -----------------------------------------------------------------------------
-def _preparse_for_subcommand(Application, argv):
+def _preparse_for_subcommand(application_klass, argv):
"""Preparse command line to look for subcommands."""
# Read in arguments from command line.
if len(argv) == 0:
- return
+ return None
# Find any subcommands.
- if Application.subcommands and len(argv) > 0:
+ if application_klass.subcommands and len(argv) > 0:
# we have subcommands, and one may have been specified
subc, subargv = argv[0], argv[1:]
- if re.match(r"^\w(\-?\w)*$", subc) and subc in Application.subcommands:
+ if re.match(r"^\w(\-?\w)*$", subc) and subc in application_klass.subcommands:
# it's a subcommand, and *not* a flag or class parameter
- app = Application()
+ app = application_klass()
app.initialize_subcommand(subc, subargv)
return app.subapp
-def _preparse_for_stopping_flags(Application, argv):
+def _preparse_for_stopping_flags(application_klass, argv):
"""Looks for 'help', 'version', and 'generate-config; commands
in command line. If found, raises the help and version of
current Application.
@@ -57,19 +61,19 @@ def _preparse_for_stopping_flags(Application, argv):
# Catch any help calls.
if any(x in interpreted_argv for x in ("-h", "--help-all", "--help")):
- app = Application()
+ app = application_klass()
app.print_help("--help-all" in interpreted_argv)
app.exit(0)
# Catch version commands
if "--version" in interpreted_argv or "-V" in interpreted_argv:
- app = Application()
+ app = application_klass()
app.print_version()
app.exit(0)
# Catch generate-config commands.
if "--generate-config" in interpreted_argv:
- app = Application()
+ app = application_klass()
app.write_default_config()
app.exit(0)
@@ -84,8 +88,9 @@ class ExtensionAppJinjaMixin(HasTraits):
)
).tag(config=True)
+ @t.no_type_check
def _prepare_templates(self):
- # Get templates defined in a subclass.
+ """Get templates defined in a subclass."""
self.initialize_templates()
# Add templates to web app settings if extension has templates.
if len(self.template_paths) > 0:
@@ -121,12 +126,14 @@ class ExtensionApp(JupyterApp):
"""Base class for configurable Jupyter Server Extension Applications.
ExtensionApp subclasses can be initialized two ways:
- 1. Extension is listed as a jpserver_extension, and ServerApp calls
- its load_jupyter_server_extension classmethod. This is the
- classic way of loading a server extension.
- 2. Extension is launched directly by calling its `launch_instance`
- class method. This method can be set as a entry_point in
- the extensions setup.py
+
+ - Extension is listed as a jpserver_extension, and ServerApp calls
+ its load_jupyter_server_extension classmethod. This is the
+ classic way of loading a server extension.
+
+ - Extension is launched directly by calling its `launch_instance`
+ class method. This method can be set as a entry_point in
+ the extensions setup.py.
"""
# Subclasses should override this trait. Tells the server if
@@ -137,7 +144,7 @@ class method. This method can be set as a entry_point in
# A useful class property that subclasses can override to
# configure the underlying Jupyter Server when this extension
# is launched directly (using its `launch_instance` method).
- serverapp_config = {}
+ serverapp_config: dict[str, t.Any] = {}
# Some subclasses will likely override this trait to flip
# the default value to False if they don't offer a browser
@@ -153,22 +160,25 @@ class method. This method can be set as a entry_point in
@default("open_browser")
def _default_open_browser(self):
+ assert self.serverapp is not None
return self.serverapp.config["ServerApp"].get("open_browser", True)
@property
def config_file_paths(self):
"""Look on the same path as our parent for config files"""
# rely on parent serverapp, which should control all config loading
+ assert self.serverapp is not None
return self.serverapp.config_file_paths
# The extension name used to name the jupyter config
# file, jupyter_{name}_config.
# This should also match the jupyter subcommand used to launch
# this extension from the CLI, e.g. `jupyter {name}`.
- name = None
+ name: str | Unicode[str, str] = "ExtensionApp" # type:ignore[assignment]
@classmethod
def get_extension_package(cls):
+ """Get an extension package."""
parts = cls.__module__.split(".")
if is_namespace_package(parts[0]):
# in this case the package name is `.`.
@@ -177,6 +187,7 @@ def get_extension_package(cls):
@classmethod
def get_extension_point(cls):
+ """Get an extension point."""
return cls.__module__
# Extension URL sets the default landing page for this extension.
@@ -199,7 +210,7 @@ def _default_url(self):
]
# A ServerApp is not defined yet, but will be initialized below.
- serverapp = Any()
+ serverapp: ServerApp | None = Any() # type:ignore[assignment]
@default("serverapp")
def _default_serverapp(self):
@@ -215,7 +226,7 @@ def _default_serverapp(self):
# declare an empty one
return ServerApp()
- _log_formatter_cls = LogFormatter
+ _log_formatter_cls = LogFormatter # type:ignore[assignment]
@default("log_level")
def _default_log_level(self):
@@ -235,6 +246,7 @@ def _default_log_format(self):
@default("static_url_prefix")
def _default_static_url_prefix(self):
static_url = f"static/{self.name}/"
+ assert self.serverapp is not None
return url_path_join(self.serverapp.base_url, static_url)
static_paths = List(
@@ -257,7 +269,9 @@ def _default_static_url_prefix(self):
settings = Dict(help=_i18n("""Settings that will passed to the server.""")).tag(config=True)
- handlers = List(help=_i18n("""Handlers appended to the server.""")).tag(config=True)
+ handlers: List[tuple[t.Any, ...]] = List(
+ help=_i18n("""Handlers appended to the server.""")
+ ).tag(config=True)
def _config_file_name_default(self):
"""The default config file name."""
@@ -267,15 +281,12 @@ def _config_file_name_default(self):
def initialize_settings(self):
"""Override this method to add handling of settings."""
- pass
def initialize_handlers(self):
"""Override this method to append handlers to a Jupyter Server."""
- pass
def initialize_templates(self):
"""Override this method to add handling of template files."""
- pass
def _prepare_config(self):
"""Builds a Config object from the extension's traits and passes
@@ -286,7 +297,9 @@ def _prepare_config(self):
self.settings[f"{self.name}_config"] = self.extension_config
def _prepare_settings(self):
+ """Prepare the settings."""
# Make webapp settings accessible to initialize_settings method
+ assert self.serverapp is not None
webapp = self.serverapp.web_app
self.settings.update(**webapp.settings)
@@ -305,6 +318,8 @@ def _prepare_settings(self):
webapp.settings.update(**self.settings)
def _prepare_handlers(self):
+ """Prepare the handlers."""
+ assert self.serverapp is not None
webapp = self.serverapp.web_app
# Get handlers defined by extension subclass.
@@ -318,7 +333,7 @@ def _prepare_handlers(self):
handler = handler_items[1]
# Get handler kwargs, if given
- kwargs = {}
+ kwargs: dict[str, t.Any] = {}
if issubclass(handler, ExtensionHandlerMixin):
kwargs["name"] = self.name
@@ -343,15 +358,16 @@ def _prepare_handlers(self):
)
new_handlers.append(handler)
- webapp.add_handlers(".*$", new_handlers)
+ webapp.add_handlers(".*$", new_handlers) # type:ignore[arg-type]
def _prepare_templates(self):
- # Add templates to web app settings if extension has templates.
+ """Add templates to web app settings if extension has templates."""
if len(self.template_paths) > 0:
self.settings.update({f"{self.name}_template_paths": self.template_paths})
self.initialize_templates()
def _jupyter_server_config(self):
+ """The jupyter server config."""
base_config = {
"ServerApp": {
"default_url": self.default_url,
@@ -362,7 +378,7 @@ def _jupyter_server_config(self):
base_config["ServerApp"].update(self.serverapp_config)
return base_config
- def _link_jupyter_server_extension(self, serverapp):
+ def _link_jupyter_server_extension(self, serverapp: ServerApp) -> None:
"""Link the ExtensionApp to an initialized ServerApp.
The ServerApp is stored as an attribute and config
@@ -401,11 +417,10 @@ def initialize(self):
corresponding server app and webapp should already
be initialized by this step.
- 1) Appends Handlers to the ServerApp,
- 2) Passes config and settings from ExtensionApp
- to the Tornado web application
- 3) Points Tornado Webapp to templates and
- static assets.
+ - Appends Handlers to the ServerApp,
+ - Passes config and settings from ExtensionApp
+ to the Tornado web application
+ - Points Tornado Webapp to templates and static assets.
"""
if not self.serverapp:
msg = (
@@ -427,13 +442,19 @@ def start(self):
"""
super().start()
# Start the server.
+ assert self.serverapp is not None
self.serverapp.start()
+ def current_activity(self):
+ """Return a list of activity happening in this extension."""
+ return
+
async def stop_extension(self):
"""Cleanup any resources managed by this extension."""
def stop(self):
"""Stop the underlying Jupyter server."""
+ assert self.serverapp is not None
self.serverapp.stop()
self.serverapp.clear_instance()
@@ -534,10 +555,20 @@ def load_classic_server_extension(cls, serverapp):
)
extension.initialize()
+ serverapp_class = ServerApp
+
+ @classmethod
+ def make_serverapp(cls, **kwargs: t.Any) -> ServerApp:
+ """Instantiate the ServerApp
+
+ Override to customize the ServerApp before it loads any configuration
+ """
+ return cls.serverapp_class.instance(**kwargs)
+
@classmethod
def initialize_server(cls, argv=None, load_other_extensions=True, **kwargs):
"""Creates an instance of ServerApp and explicitly sets
- this extension to enabled=True (i.e. superceding disabling
+ this extension to enabled=True (i.e. superseding disabling
found in other config from files).
The `launch_instance` method uses this method to initialize
@@ -549,8 +580,8 @@ def initialize_server(cls, argv=None, load_other_extensions=True, **kwargs):
jpserver_extensions.update(cls.serverapp_config["jpserver_extensions"])
cls.serverapp_config["jpserver_extensions"] = jpserver_extensions
find_extensions = False
- serverapp = ServerApp.instance(jpserver_extensions=jpserver_extensions, **kwargs)
- serverapp.aliases.update(cls.aliases)
+ serverapp = cls.make_serverapp(jpserver_extensions=jpserver_extensions, **kwargs)
+ serverapp.aliases.update(cls.aliases) # type:ignore[has-type]
serverapp.initialize(
argv=argv or [],
starter_extension=cls.name,
@@ -565,7 +596,7 @@ def launch_instance(cls, argv=None, **kwargs):
extension's landing page.
"""
# Handle arguments.
- if argv is None:
+ if argv is None: # noqa: SIM108
args = sys.argv[1:] # slice out extension config.
else:
args = argv
@@ -585,10 +616,7 @@ def launch_instance(cls, argv=None, **kwargs):
# Log if extension is blocking other extensions from loading.
if not cls.load_other_extensions:
- serverapp.log.info(
- "{ext_name} is running without loading "
- "other extensions.".format(ext_name=cls.name)
- )
+ serverapp.log.info(f"{cls.name} is running without loading other extensions.")
# Start the server.
try:
serverapp.start()
diff --git a/jupyter_server/extension/config.py b/jupyter_server/extension/config.py
index 15a3cfbd0c..47b4f6cce1 100644
--- a/jupyter_server/extension/config.py
+++ b/jupyter_server/extension/config.py
@@ -1,3 +1,4 @@
+"""Extension config."""
from jupyter_server.services.config.manager import ConfigManager
DEFAULT_SECTION_NAME = "jupyter_server_config"
@@ -24,9 +25,11 @@ def enabled(self, name, section_name=DEFAULT_SECTION_NAME, include_root=True):
return False
def enable(self, name):
+ """Enable an extension by name."""
data = {"ServerApp": {"jpserver_extensions": {name: True}}}
self.update(name, data)
def disable(self, name):
+ """Disable an extension by name."""
data = {"ServerApp": {"jpserver_extensions": {name: False}}}
self.update(name, data)
diff --git a/jupyter_server/extension/handler.py b/jupyter_server/extension/handler.py
index 164d74bb15..55f5aff2c3 100644
--- a/jupyter_server/extension/handler.py
+++ b/jupyter_server/extension/handler.py
@@ -1,20 +1,32 @@
+"""An extension handler."""
+from __future__ import annotations
+
+from logging import Logger
+from typing import TYPE_CHECKING, Any, cast
+
from jinja2.exceptions import TemplateNotFound
from jupyter_server.base.handlers import FileFindHandler
+if TYPE_CHECKING:
+ from traitlets.config import Config
+
+ from jupyter_server.extension.application import ExtensionApp
+ from jupyter_server.serverapp import ServerApp
+
class ExtensionHandlerJinjaMixin:
"""Mixin class for ExtensionApp handlers that use jinja templating for
template rendering.
"""
- def get_template(self, name):
+ def get_template(self, name: str) -> str:
"""Return the jinja template object for a given name"""
try:
- env = f"{self.name}_jinja2_env"
- return self.settings[env].get_template(name)
+ env = f"{self.name}_jinja2_env" # type:ignore[attr-defined]
+ return cast(str, self.settings[env].get_template(name)) # type:ignore[attr-defined]
except TemplateNotFound:
- return super().get_template(name)
+ return cast(str, super().get_template(name)) # type:ignore[misc]
class ExtensionHandlerMixin:
@@ -28,49 +40,55 @@ class ExtensionHandlerMixin:
other extensions.
"""
- def initialize(self, name):
+ settings: dict[str, Any]
+
+ def initialize(self, name: str, *args: Any, **kwargs: Any) -> None:
self.name = name
+ try:
+ super().initialize(*args, **kwargs) # type:ignore[misc]
+ except TypeError:
+ pass
@property
- def extensionapp(self):
- return self.settings[self.name]
+ def extensionapp(self) -> ExtensionApp:
+ return cast("ExtensionApp", self.settings[self.name])
@property
- def serverapp(self):
+ def serverapp(self) -> ServerApp:
key = "serverapp"
- return self.settings[key]
+ return cast("ServerApp", self.settings[key])
@property
- def log(self):
+ def log(self) -> Logger:
if not hasattr(self, "name"):
- return super().log
+ return cast(Logger, super().log) # type:ignore[misc]
# Attempt to pull the ExtensionApp's log, otherwise fall back to ServerApp.
try:
- return self.extensionapp.log
+ return cast(Logger, self.extensionapp.log)
except AttributeError:
- return self.serverapp.log
+ return cast(Logger, self.serverapp.log)
@property
- def config(self):
- return self.settings[f"{self.name}_config"]
+ def config(self) -> Config:
+ return cast("Config", self.settings[f"{self.name}_config"])
@property
- def server_config(self):
- return self.settings["config"]
+ def server_config(self) -> Config:
+ return cast("Config", self.settings["config"])
@property
- def base_url(self):
- return self.settings.get("base_url", "/")
+ def base_url(self) -> str:
+ return cast(str, self.settings.get("base_url", "/"))
@property
- def static_url_prefix(self):
+ def static_url_prefix(self) -> str:
return self.extensionapp.static_url_prefix
@property
- def static_path(self):
- return self.settings[f"{self.name}_static_paths"]
+ def static_path(self) -> str:
+ return cast(str, self.settings[f"{self.name}_static_paths"])
- def static_url(self, path, include_host=None, **kwargs):
+ def static_url(self, path: str, include_host: bool | None = None, **kwargs: Any) -> str:
"""Returns a static URL for the given relative static file path.
This method requires you set the ``{name}_static_path``
setting in your extension (which specifies the root directory
@@ -89,13 +107,14 @@ def static_url(self, path, include_host=None, **kwargs):
"""
key = f"{self.name}_static_paths"
try:
- self.require_setting(key, "static_url")
+ self.require_setting(key, "static_url") # type:ignore[attr-defined]
except Exception as e:
if key in self.settings:
- raise Exception(
+ msg = (
"This extension doesn't have any static paths listed. Check that the "
"extension's `static_paths` trait is set."
- ) from e
+ )
+ raise Exception(msg) from None
else:
raise e
@@ -104,10 +123,9 @@ def static_url(self, path, include_host=None, **kwargs):
if include_host is None:
include_host = getattr(self, "include_host", False)
+ base = ""
if include_host:
- base = self.request.protocol + "://" + self.request.host
- else:
- base = ""
+ base = self.request.protocol + "://" + self.request.host # type:ignore[attr-defined]
# Hijack settings dict to send extension templates to extension
# static directory.
@@ -116,4 +134,4 @@ def static_url(self, path, include_host=None, **kwargs):
"static_url_prefix": self.static_url_prefix,
}
- return base + get_url(settings, path, **kwargs)
+ return base + cast(str, get_url(settings, path, **kwargs))
diff --git a/jupyter_server/extension/manager.py b/jupyter_server/extension/manager.py
index 1efb2cadd0..3509e2e9f6 100644
--- a/jupyter_server/extension/manager.py
+++ b/jupyter_server/extension/manager.py
@@ -1,19 +1,16 @@
+"""The extension manager."""
+from __future__ import annotations
+
import importlib
-import sys
-import traceback
+from itertools import starmap
from tornado.gen import multi
-from traitlets import Any, Bool, Dict, HasTraits, Instance, Unicode, default, observe
+from traitlets import Any, Bool, Dict, HasTraits, Instance, List, Unicode, default, observe
from traitlets import validate as validate_trait
from traitlets.config import LoggingConfigurable
from .config import ExtensionConfigManager
-from .utils import (
- ExtensionMetadataError,
- ExtensionModuleNotFound,
- get_loader,
- get_metadata,
-)
+from .utils import ExtensionMetadataError, ExtensionModuleNotFound, get_loader, get_metadata
class ExtensionPoint(HasTraits):
@@ -28,22 +25,23 @@ class ExtensionPoint(HasTraits):
@validate_trait("metadata")
def _valid_metadata(self, proposed):
+ """Validate metadata."""
metadata = proposed["value"]
# Verify that the metadata has a "name" key.
try:
self._module_name = metadata["module"]
except KeyError:
- raise ExtensionMetadataError(
- "There is no 'module' key in the extension's metadata packet."
- )
+ msg = "There is no 'module' key in the extension's metadata packet."
+ raise ExtensionMetadataError(msg) from None
try:
self._module = importlib.import_module(self._module_name)
except ImportError:
- raise ExtensionModuleNotFound(
- "The submodule '{}' could not be found. Are you "
- "sure the extension is installed?".format(self._module_name)
+ msg = (
+ f"The submodule '{self._module_name}' could not be found. Are you "
+ "sure the extension is installed?"
)
+ raise ExtensionModuleNotFound(msg) from None
# If the metadata includes an ExtensionApp, create an instance.
if "app" in metadata:
self._app = metadata["app"]()
@@ -99,6 +97,7 @@ def module(self):
return self._module
def _get_linker(self):
+ """Get a linker."""
if self.app:
linker = self.app._link_jupyter_server_extension
else:
@@ -112,6 +111,7 @@ def _get_linker(self):
return linker
def _get_loader(self):
+ """Get a loader."""
loc = self.app
if not loc:
loc = self.module
@@ -150,7 +150,7 @@ def load(self, serverapp):
return loader(serverapp)
-class ExtensionPackage(HasTraits):
+class ExtensionPackage(LoggingConfigurable):
"""An API for interfacing with a Jupyter Server extension package.
Usage:
@@ -160,74 +160,74 @@ class ExtensionPackage(HasTraits):
"""
name = Unicode(help="Name of the an importable Python package.")
- enabled = Bool(False).tag(config=True)
+ enabled = Bool(False, help="Whether the extension package is enabled.")
+
+ _linked_points = Dict()
+ extension_points = Dict()
+ module = Any(allow_none=True, help="The module for this extension package. None if not enabled")
+ metadata = List(Dict(), help="Extension metadata loaded from the extension package.")
+ version = Unicode(
+ help="""
+ The version of this extension package, if it can be found.
+ Otherwise, an empty string.
+ """,
+ )
- def __init__(self, *args, **kwargs):
- # Store extension points that have been linked.
- self._linked_points = {}
- super().__init__(*args, **kwargs)
+ @default("version")
+ def _load_version(self):
+ if not self.enabled:
+ return ""
+ return getattr(self.module, "__version__", "")
- _linked_points = {}
+ def __init__(self, **kwargs):
+ """Initialize an extension package."""
+ super().__init__(**kwargs)
+ if self.enabled:
+ self._load_metadata()
- @validate_trait("name")
- def _validate_name(self, proposed):
- name = proposed["value"]
- self._extension_points = {}
+ def _load_metadata(self):
+ """Import package and load metadata
+
+ Only used if extension package is enabled
+ """
+ name = self.name
try:
- self._module, self._metadata = get_metadata(name)
- except ImportError:
- raise ExtensionModuleNotFound(
- "The module '{name}' could not be found. Are you "
- "sure the extension is installed?".format(name=name)
+ self.module, self.metadata = get_metadata(name, logger=self.log)
+ except ImportError as e:
+ msg = (
+ f"The module '{name}' could not be found ({e}). Are you "
+ "sure the extension is installed?"
)
+ raise ExtensionModuleNotFound(msg) from None
# Create extension point interfaces for each extension path.
- for m in self._metadata:
+ for m in self.metadata:
point = ExtensionPoint(metadata=m)
- self._extension_points[point.name] = point
+ self.extension_points[point.name] = point
return name
- @property
- def module(self):
- """Extension metadata loaded from the extension package."""
- return self._module
-
- @property
- def version(self):
- """Get the version of this package, if it's given. Otherwise, return an empty string"""
- return getattr(self._module, "__version__", "")
-
- @property
- def metadata(self):
- """Extension metadata loaded from the extension package."""
- return self._metadata
-
- @property
- def extension_points(self):
- """A dictionary of extension points."""
- return self._extension_points
-
def validate(self):
"""Validate all extension points in this package."""
- for extension in self.extension_points.values():
- if not extension.validate():
- return False
- return True
+ return all(extension.validate() for extension in self.extension_points.values())
def link_point(self, point_name, serverapp):
+ """Link an extension point."""
linked = self._linked_points.get(point_name, False)
if not linked:
point = self.extension_points[point_name]
point.link(serverapp)
def load_point(self, point_name, serverapp):
+ """Load an extension point."""
point = self.extension_points[point_name]
return point.load(serverapp)
def link_all_points(self, serverapp):
+ """Link all extension points."""
for point_name in self.extension_points:
self.link_point(point_name, serverapp)
def load_all_points(self, serverapp):
+ """Load all extension points."""
return [self.load_point(point_name, serverapp) for point_name in self.extension_points]
@@ -324,12 +324,19 @@ def add_extension(self, extension_name, enabled=False):
return True
# Raise a warning if the extension cannot be loaded.
except Exception as e:
- if self.serverapp.reraise_server_extension_failures:
+ if self.serverapp and self.serverapp.reraise_server_extension_failures:
raise
- self.log.warning(e)
+ self.log.warning(
+ "%s | error adding extension (enabled: %s): %s",
+ extension_name,
+ enabled,
+ e,
+ exc_info=True,
+ )
return False
def link_extension(self, name):
+ """Link an extension by name."""
linked = self.linked_extensions.get(name, False)
extension = self.extensions[name]
if not linked and extension.enabled:
@@ -337,36 +344,34 @@ def link_extension(self, name):
# Link extension and store links
extension.link_all_points(self.serverapp)
self.linked_extensions[name] = True
- self.log.info(f"{name} | extension was successfully linked.")
+ self.log.info("%s | extension was successfully linked.", name)
except Exception as e:
- if self.serverapp.reraise_server_extension_failures:
+ if self.serverapp and self.serverapp.reraise_server_extension_failures:
raise
- self.log.warning(e)
+ self.log.warning("%s | error linking extension: %s", name, e, exc_info=True)
def load_extension(self, name):
+ """Load an extension by name."""
extension = self.extensions.get(name)
- if extension.enabled:
+ if extension and extension.enabled:
try:
extension.load_all_points(self.serverapp)
except Exception as e:
- if self.serverapp.reraise_server_extension_failures:
+ if self.serverapp and self.serverapp.reraise_server_extension_failures:
raise
- self.log.debug("".join(traceback.format_exception(*sys.exc_info())))
self.log.warning(
- "{name} | extension failed loading with message: {error}".format(
- name=name, error=str(e)
- )
+ "%s | extension failed loading with message: %r", name, e, exc_info=True
)
else:
- self.log.info(f"{name} | extension was successfully loaded.")
+ self.log.info("%s | extension was successfully loaded.", name)
async def stop_extension(self, name, apps):
"""Call the shutdown hooks in the specified apps."""
for app in apps:
- self.log.debug(f'{name} | extension app "{app.name}" stopping')
+ self.log.debug("%s | extension app %r stopping", name, app.name)
await app.stop_extension()
- self.log.debug(f'{name} | extension app "{app.name}" stopped')
+ self.log.debug("%s | extension app %r stopped", name, app.name)
def link_all_extensions(self):
"""Link all enabled extensions
@@ -374,7 +379,7 @@ def link_all_extensions(self):
"""
# Sort the extension names to enforce deterministic linking
# order.
- for name in self.sorted_extensions.keys():
+ for name in self.sorted_extensions:
self.link_extension(name)
def load_all_extensions(self):
@@ -383,14 +388,16 @@ def load_all_extensions(self):
"""
# Sort the extension names to enforce deterministic loading
# order.
- for name in self.sorted_extensions.keys():
+ for name in self.sorted_extensions:
self.load_extension(name)
async def stop_all_extensions(self):
"""Call the shutdown hooks in all extensions."""
- await multi(
- [
- self.stop_extension(name, apps)
- for name, apps in sorted(dict(self.extension_apps).items())
- ]
- )
+ await multi(list(starmap(self.stop_extension, sorted(dict(self.extension_apps).items()))))
+
+ def any_activity(self):
+ """Check for any activity currently happening across all extension applications."""
+ for _, apps in sorted(dict(self.extension_apps).items()):
+ for app in apps:
+ if app.current_activity():
+ return True
diff --git a/jupyter_server/extension/serverextension.py b/jupyter_server/extension/serverextension.py
index 23c1bde231..19f3a30709 100644
--- a/jupyter_server/extension/serverextension.py
+++ b/jupyter_server/extension/serverextension.py
@@ -1,8 +1,12 @@
"""Utilities for installing extensions"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
+from __future__ import annotations
+
+import logging
import os
import sys
+import typing as t
from jupyter_core.application import JupyterApp
from jupyter_core.paths import ENV_CONFIG_PATH, SYSTEM_CONFIG_PATH, jupyter_config_dir
@@ -14,7 +18,7 @@
from jupyter_server.extension.manager import ExtensionManager, ExtensionPackage
-def _get_config_dir(user=False, sys_prefix=False):
+def _get_config_dir(user: bool = False, sys_prefix: bool = False) -> str:
"""Get the location of config files for the current context
Returns the string to the environment
@@ -37,7 +41,9 @@ def _get_config_dir(user=False, sys_prefix=False):
return extdir
-def _get_extmanager_for_context(write_dir="jupyter_server_config.d", user=False, sys_prefix=False):
+def _get_extmanager_for_context(
+ write_dir: str = "jupyter_server_config.d", user: bool = False, sys_prefix: bool = False
+) -> tuple[str, ExtensionManager]:
"""Get an extension manager pointing at the current context
Returns the path to the current context and an ExtensionManager object.
@@ -66,7 +72,7 @@ class ArgumentConflict(ValueError):
pass
-_base_flags = {}
+_base_flags: dict[str, t.Any] = {}
_base_flags.update(JupyterApp.flags)
_base_flags.pop("y", None)
_base_flags.pop("generate-config", None)
@@ -109,14 +115,14 @@ class ArgumentConflict(ValueError):
)
_base_flags["python"] = _base_flags["py"]
-_base_aliases = {}
+_base_aliases: dict[str, t.Any] = {}
_base_aliases.update(JupyterApp.aliases)
class BaseExtensionApp(JupyterApp):
"""Base extension installer app"""
- _log_formatter_cls = LogFormatter
+ _log_formatter_cls = LogFormatter # type:ignore[assignment]
flags = _base_flags
aliases = _base_aliases
version = __version__
@@ -125,12 +131,12 @@ class BaseExtensionApp(JupyterApp):
sys_prefix = Bool(True, config=True, help="Use the sys.prefix as the prefix")
python = Bool(False, config=True, help="Install from a Python package")
- def _log_format_default(self):
+ def _log_format_default(self) -> str:
"""A default format for messages"""
return "%(message)s"
@property
- def config_dir(self):
+ def config_dir(self) -> str: # type:ignore[override]
return _get_config_dir(user=self.user, sys_prefix=self.sys_prefix)
@@ -147,8 +153,12 @@ def config_dir(self):
def toggle_server_extension_python(
- import_name, enabled=None, parent=None, user=False, sys_prefix=True
-):
+ import_name: str,
+ enabled: bool | None = None,
+ parent: t.Any = None,
+ user: bool = False,
+ sys_prefix: bool = True,
+) -> None:
"""Toggle the boolean setting for a given server extension
in a Jupyter config file.
"""
@@ -212,11 +222,14 @@ def toggle_server_extension_python(
flags["python"] = flags["py"]
+_desc = "Enable/disable a server extension using frontend configuration files."
+
+
class ToggleServerExtensionApp(BaseExtensionApp):
"""A base class for enabling/disabling extensions"""
name = "jupyter server extension enable/disable"
- description = "Enable/disable a server extension using frontend configuration files."
+ description = _desc
flags = flags
@@ -224,7 +237,7 @@ class ToggleServerExtensionApp(BaseExtensionApp):
_toggle_pre_message = ""
_toggle_post_message = ""
- def toggle_server_extension(self, import_name):
+ def toggle_server_extension(self, import_name: str) -> None:
"""Change the status of a named server extension.
Uses the value of `self._toggle_value`.
@@ -253,17 +266,18 @@ def toggle_server_extension(self, import_name):
# Toggle extension config.
config = extension_manager.config_manager
- if self._toggle_value is True:
- config.enable(import_name)
- else:
- config.disable(import_name)
+ if config:
+ if self._toggle_value is True:
+ config.enable(import_name)
+ else:
+ config.disable(import_name)
# If successful, let's log.
self.log.info(f" - Extension successfully {self._toggle_post_message}.")
except Exception as err:
self.log.info(f" {RED_X} Validation failed: {err}")
- def start(self):
+ def start(self) -> None:
"""Perform the App's actions as configured"""
if not self.extra_args:
sys.exit("Please specify a server extension/package to enable or disable")
@@ -281,7 +295,7 @@ class EnableServerExtensionApp(ToggleServerExtensionApp):
Usage
jupyter server extension enable [--system|--sys-prefix]
"""
- _toggle_value = True
+ _toggle_value = True # type:ignore[assignment]
_toggle_pre_message = "enabling"
_toggle_post_message = "enabled"
@@ -296,7 +310,7 @@ class DisableServerExtensionApp(ToggleServerExtensionApp):
Usage
jupyter server extension disable [--system|--sys-prefix]
"""
- _toggle_value = False
+ _toggle_value = False # type:ignore[assignment]
_toggle_pre_message = "disabling"
_toggle_post_message = "disabled"
@@ -308,7 +322,7 @@ class ListServerExtensionsApp(BaseExtensionApp):
version = __version__
description = "List all server extensions known by the configuration system"
- def list_server_extensions(self):
+ def list_server_extensions(self) -> None:
"""List all enabled and disabled server extensions, by config path
Enabled extensions are validated, potentially generating warnings.
@@ -320,24 +334,34 @@ def list_server_extensions(self):
)
for option in configurations:
- config_dir, ext_manager = _get_extmanager_for_context(**option)
+ config_dir = _get_config_dir(**option)
self.log.info(f"Config dir: {config_dir}")
- for name, extension in ext_manager.extensions.items():
- enabled = extension.enabled
+ write_dir = "jupyter_server_config.d"
+ config_manager = ExtensionConfigManager(
+ read_config_path=[config_dir],
+ write_config_dir=os.path.join(config_dir, write_dir),
+ )
+ jpserver_extensions = config_manager.get_jpserver_extensions()
+ for name, enabled in jpserver_extensions.items():
# Attempt to get extension metadata
self.log.info(f" {name} {GREEN_ENABLED if enabled else RED_DISABLED}")
try:
self.log.info(f" - Validating {name}...")
+ extension = ExtensionPackage(name=name, enabled=enabled)
if not extension.validate():
- raise ValueError("validation failed")
+ msg = "validation failed"
+ raise ValueError(msg)
version = extension.version
self.log.info(f" {name} {version} {GREEN_OK}")
except Exception as err:
- self.log.warning(f" {RED_X} {err}")
+ exc_info = False
+ if int(self.log_level) <= logging.DEBUG: # type:ignore[call-overload]
+ exc_info = True
+ self.log.warning(f" {RED_X} {err}", exc_info=exc_info)
# Add a blank line between paths.
self.log.info("")
- def start(self):
+ def start(self) -> None:
"""Perform the App's actions as configured"""
self.list_server_extensions()
@@ -354,16 +378,16 @@ class ServerExtensionApp(BaseExtensionApp):
name = "jupyter server extension"
version = __version__
- description = "Work with Jupyter server extensions"
+ description: str = "Work with Jupyter server extensions"
examples = _examples
- subcommands = dict(
- enable=(EnableServerExtensionApp, "Enable a server extension"),
- disable=(DisableServerExtensionApp, "Disable a server extension"),
- list=(ListServerExtensionsApp, "List server extensions"),
- )
+ subcommands: dict[str, t.Any] = {
+ "enable": (EnableServerExtensionApp, "Enable a server extension"),
+ "disable": (DisableServerExtensionApp, "Disable a server extension"),
+ "list": (ListServerExtensionsApp, "List server extensions"),
+ }
- def start(self):
+ def start(self) -> None:
"""Perform the App's actions as configured"""
super().start()
diff --git a/jupyter_server/extension/utils.py b/jupyter_server/extension/utils.py
index a8c93a0580..5d18939ab2 100644
--- a/jupyter_server/extension/utils.py
+++ b/jupyter_server/extension/utils.py
@@ -1,21 +1,23 @@
+"""Extension utilities."""
import importlib
+import time
import warnings
class ExtensionLoadingError(Exception):
- pass
+ """An extension loading error."""
class ExtensionMetadataError(Exception):
- pass
+ """An extension metadata error."""
class ExtensionModuleNotFound(Exception):
- pass
+ """An extension module not found error."""
class NotAnExtensionApp(Exception):
- pass
+ """An error raised when a module is not an extension."""
def get_loader(obj, logger=None):
@@ -26,19 +28,25 @@ def get_loader(obj, logger=None):
underscore prefix.
"""
try:
- func = getattr(obj, "_load_jupyter_server_extension") # noqa B009
+ return obj._load_jupyter_server_extension
except AttributeError:
- func = getattr(obj, "load_jupyter_server_extension", None)
- warnings.warn(
- "A `_load_jupyter_server_extension` function was not "
- "found in {name!s}. Instead, a `load_jupyter_server_extension` "
- "function was found and will be used for now. This function "
- "name will be deprecated in future releases "
- "of Jupyter Server.".format(name=obj),
- DeprecationWarning,
- )
- except Exception:
- raise ExtensionLoadingError("_load_jupyter_server_extension function was not found.")
+ pass
+
+ try:
+ func = obj.load_jupyter_server_extension
+ except AttributeError:
+ msg = "_load_jupyter_server_extension function was not found."
+ raise ExtensionLoadingError(msg) from None
+
+ warnings.warn(
+ "A `_load_jupyter_server_extension` function was not "
+ f"found in {obj!s}. Instead, a `load_jupyter_server_extension` "
+ "function was found and will be used for now. This function "
+ "name will be deprecated in future releases "
+ "of Jupyter Server.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
return func
@@ -47,12 +55,20 @@ def get_metadata(package_name, logger=None):
This looks for a `_jupyter_server_extension_points` function
that returns metadata about all extension points within a Jupyter
- Server Extension pacakge.
+ Server Extension package.
If it doesn't exist, return a basic metadata packet given
the module name.
"""
+ start_time = time.perf_counter()
module = importlib.import_module(package_name)
+ end_time = time.perf_counter()
+ duration = end_time - start_time
+ # Sometimes packages can take a *while* to import, so we report how long
+ # each module took to import. This makes it much easier for users to report
+ # slow loading modules upstream, as slow loading modules will block server startup
+ if logger:
+ logger.info(f"Package {package_name} took {duration:.4f}s to import")
try:
return module, module._jupyter_server_extension_points()
@@ -67,10 +83,10 @@ def get_metadata(package_name, logger=None):
if logger:
logger.warning(
"A `_jupyter_server_extension_points` function was not "
- "found in {name}. Instead, a `_jupyter_server_extension_paths` "
+ f"found in {package_name}. Instead, a `_jupyter_server_extension_paths` "
"function was found and will be used for now. This function "
"name will be deprecated in future releases "
- "of Jupyter Server.".format(name=package_name)
+ "of Jupyter Server."
)
return module, extension_points
except AttributeError:
@@ -81,9 +97,9 @@ def get_metadata(package_name, logger=None):
if logger:
logger.debug(
"A `_jupyter_server_extension_points` function was "
- "not found in {name}, so Jupyter Server will look "
+ f"not found in {package_name}, so Jupyter Server will look "
"for extension points in the extension pacakge's "
- "root.".format(name=package_name)
+ "root."
)
return module, [{"module": package_name, "name": package_name}]
diff --git a/jupyter_server/files/handlers.py b/jupyter_server/files/handlers.py
index c76fdc28d3..043c581034 100644
--- a/jupyter_server/files/handlers.py
+++ b/jupyter_server/files/handlers.py
@@ -1,20 +1,22 @@
"""Serve files directly from the ContentsManager."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
-import json
+from __future__ import annotations
+
import mimetypes
from base64 import decodebytes
+from typing import Awaitable
+from jupyter_core.utils import ensure_async
from tornado import web
-from jupyter_server.auth import authorized
+from jupyter_server.auth.decorator import authorized
from jupyter_server.base.handlers import JupyterHandler
-from jupyter_server.utils import ensure_async
AUTH_RESOURCE = "contents"
-class FilesHandler(JupyterHandler):
+class FilesHandler(JupyterHandler, web.StaticFileHandler):
"""serve files via ContentsManager
Normally used when ContentsManager is not a FileContentsManager.
@@ -27,13 +29,15 @@ class FilesHandler(JupyterHandler):
@property
def content_security_policy(self):
+ """The content security policy."""
# In case we're serving HTML/SVG, confine any Javascript to a unique
# origin so it can't interact with the notebook server.
return super().content_security_policy + "; sandbox allow-scripts"
@web.authenticated
@authorized
- def head(self, path):
+ def head(self, path: str) -> Awaitable[None] | None: # type:ignore[override]
+ """The head response."""
self.get(path, include_body=False)
self.check_xsrf_cookie()
return self.get(path, include_body=False)
@@ -41,11 +45,12 @@ def head(self, path):
@web.authenticated
@authorized
async def get(self, path, include_body=True):
+ """Get a file by path."""
# /files/ requests must originate from the same site
self.check_xsrf_cookie()
cm = self.contents_manager
- if await ensure_async(cm.is_hidden(path)) and not cm.allow_hidden:
+ if not cm.allow_hidden and await ensure_async(cm.is_hidden(path)):
self.log.info("Refusing to serve hidden file, via 404 Error")
raise web.HTTPError(404)
@@ -57,7 +62,7 @@ async def get(self, path, include_body=True):
model = await ensure_async(cm.get(path, type="file", content=include_body))
- if self.get_argument("download", False):
+ if self.get_argument("download", None):
self.set_attachment_header(name)
# get mimetype from filename
@@ -74,21 +79,18 @@ async def get(self, path, include_body=True):
self.set_header("Content-Type", "application/octet-stream")
elif cur_mime is not None:
self.set_header("Content-Type", cur_mime)
+ elif model["format"] == "base64":
+ self.set_header("Content-Type", "application/octet-stream")
else:
- if model["format"] == "base64":
- self.set_header("Content-Type", "application/octet-stream")
- else:
- self.set_header("Content-Type", "text/plain; charset=UTF-8")
+ self.set_header("Content-Type", "text/plain; charset=UTF-8")
if include_body:
if model["format"] == "base64":
b64_bytes = model["content"].encode("ascii")
self.write(decodebytes(b64_bytes))
- elif model["format"] == "json":
- self.write(json.dumps(model["content"]))
else:
self.write(model["content"])
self.flush()
-default_handlers = []
+default_handlers: list[JupyterHandler] = []
diff --git a/jupyter_server/gateway/connections.py b/jupyter_server/gateway/connections.py
new file mode 100644
index 0000000000..028a0f8f4e
--- /dev/null
+++ b/jupyter_server/gateway/connections.py
@@ -0,0 +1,176 @@
+"""Gateway connection classes."""
+# Copyright (c) Jupyter Development Team.
+# Distributed under the terms of the Modified BSD License.
+from __future__ import annotations
+
+import asyncio
+import logging
+import random
+from typing import Any, cast
+
+import tornado.websocket as tornado_websocket
+from tornado.concurrent import Future
+from tornado.escape import json_decode, url_escape, utf8
+from tornado.httpclient import HTTPRequest
+from tornado.ioloop import IOLoop
+from traitlets import Bool, Instance, Int
+
+from ..services.kernels.connection.base import BaseKernelWebsocketConnection
+from ..utils import url_path_join
+from .gateway_client import GatewayClient
+
+
+class GatewayWebSocketConnection(BaseKernelWebsocketConnection):
+ """Web socket connection that proxies to a kernel/enterprise gateway."""
+
+ ws = Instance(klass=tornado_websocket.WebSocketClientConnection, allow_none=True)
+
+ ws_future = Instance(klass=Future, allow_none=True)
+
+ disconnected = Bool(False)
+
+ retry = Int(0)
+
+ async def connect(self):
+ """Connect to the socket."""
+ # websocket is initialized before connection
+ self.ws = None
+ ws_url = url_path_join(
+ GatewayClient.instance().ws_url or "",
+ GatewayClient.instance().kernels_endpoint,
+ url_escape(self.kernel_id),
+ "channels",
+ )
+ self.log.info(f"Connecting to {ws_url}")
+ kwargs: dict[str, Any] = {}
+ kwargs = GatewayClient.instance().load_connection_args(**kwargs)
+
+ request = HTTPRequest(ws_url, **kwargs)
+ self.ws_future = cast("Future[Any]", tornado_websocket.websocket_connect(request))
+ self.ws_future.add_done_callback(self._connection_done)
+
+ loop = IOLoop.current()
+ loop.add_future(self.ws_future, lambda future: self._read_messages())
+
+ def _connection_done(self, fut):
+ """Handle a finished connection."""
+ if (
+ not self.disconnected and fut.exception() is None
+ ): # prevent concurrent.futures._base.CancelledError
+ self.ws = fut.result()
+ self.retry = 0
+ self.log.debug(f"Connection is ready: ws: {self.ws}")
+ else:
+ self.log.warning(
+ "Websocket connection has been closed via client disconnect or due to error. "
+ "Kernel with ID '{}' may not be terminated on GatewayClient: {}".format(
+ self.kernel_id, GatewayClient.instance().url
+ )
+ )
+
+ def disconnect(self):
+ """Handle a disconnect."""
+ self.disconnected = True
+ if self.ws is not None:
+ # Close connection
+ self.ws.close()
+ elif self.ws_future and not self.ws_future.done():
+ # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
+ self.ws_future.cancel()
+ self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}")
+
+ async def _read_messages(self):
+ """Read messages from gateway server."""
+ while self.ws is not None:
+ message = None
+ if not self.disconnected:
+ try:
+ message = await self.ws.read_message()
+ except Exception as e:
+ self.log.error(
+ f"Exception reading message from websocket: {e}"
+ ) # , exc_info=True)
+ if message is None:
+ if not self.disconnected:
+ self.log.warning(f"Lost connection to Gateway: {self.kernel_id}")
+ break
+ if isinstance(message, bytes):
+ message = message.decode("utf8")
+ self.handle_outgoing_message(
+ message
+ ) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
+ else: # ws cancelled - stop reading
+ break
+
+ # NOTE(esevan): if websocket is not disconnected by client, try to reconnect.
+ if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max:
+ jitter = random.randint(10, 100) * 0.01
+ retry_interval = (
+ min(
+ GatewayClient.instance().gateway_retry_interval * (2**self.retry),
+ GatewayClient.instance().gateway_retry_interval_max,
+ )
+ + jitter
+ )
+ self.retry += 1
+ self.log.info(
+ "Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s",
+ retry_interval,
+ self.retry,
+ GatewayClient.instance().gateway_retry_max,
+ self.kernel_id,
+ )
+ await asyncio.sleep(retry_interval)
+ loop = IOLoop.current()
+ loop.spawn_callback(self.connect)
+
+ def handle_outgoing_message(self, incoming_msg: str, *args: Any) -> None:
+ """Send message to the notebook client."""
+ try:
+ self.websocket_handler.write_message(incoming_msg)
+ except tornado_websocket.WebSocketClosedError:
+ if self.log.isEnabledFor(logging.DEBUG):
+ msg_summary = GatewayWebSocketConnection._get_message_summary(
+ json_decode(utf8(incoming_msg))
+ )
+ self.log.debug(
+ f"Notebook client closed websocket connection - message dropped: {msg_summary}"
+ )
+
+ def handle_incoming_message(self, message: str) -> None:
+ """Send message to gateway server."""
+ if self.ws is None and self.ws_future is not None:
+ loop = IOLoop.current()
+ loop.add_future(self.ws_future, lambda future: self.handle_incoming_message(message))
+ else:
+ self._write_message(message)
+
+ def _write_message(self, message):
+ """Send message to gateway server."""
+ try:
+ if not self.disconnected and self.ws is not None:
+ self.ws.write_message(message)
+ except Exception as e:
+ self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True)
+
+ @staticmethod
+ def _get_message_summary(message):
+ """Get a summary of a message."""
+ summary = []
+ message_type = message["msg_type"]
+ summary.append(f"type: {message_type}")
+
+ if message_type == "status":
+ summary.append(", state: {}".format(message["content"]["execution_state"]))
+ elif message_type == "error":
+ summary.append(
+ ", {}:{}:{}".format(
+ message["content"]["ename"],
+ message["content"]["evalue"],
+ message["content"]["traceback"],
+ )
+ )
+ else:
+ summary.append(", ...") # don't display potentially sensitive data
+
+ return "".join(summary)
diff --git a/jupyter_server/gateway/gateway_client.py b/jupyter_server/gateway/gateway_client.py
index 396a9a1abc..437d54d227 100644
--- a/jupyter_server/gateway/gateway_client.py
+++ b/jupyter_server/gateway/gateway_client.py
@@ -1,30 +1,127 @@
+"""A kernel gateway client."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
+from __future__ import annotations
+
+import asyncio
import json
+import logging
import os
+import typing as ty
+from abc import ABC, ABCMeta, abstractmethod
+from datetime import datetime, timezone
+from email.utils import parsedate_to_datetime
+from http.cookies import SimpleCookie
from socket import gaierror
+from jupyter_events import EventLogger
from tornado import web
-from tornado.httpclient import AsyncHTTPClient, HTTPError
-from traitlets import Bool, Float, Int, TraitError, Unicode, default, validate
-from traitlets.config import SingletonConfigurable
+from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPResponse
+from traitlets import (
+ Bool,
+ Float,
+ Instance,
+ Int,
+ TraitError,
+ Type,
+ Unicode,
+ default,
+ observe,
+ validate,
+)
+from traitlets.config import LoggingConfigurable, SingletonConfigurable
+
+from jupyter_server import DEFAULT_EVENTS_SCHEMA_PATH, JUPYTER_SERVER_EVENTS_URI
+
+ERROR_STATUS = "error"
+SUCCESS_STATUS = "success"
+STATUS_KEY = "status"
+STATUS_CODE_KEY = "status_code"
+MESSAGE_KEY = "msg"
+
+if ty.TYPE_CHECKING:
+ from http.cookies import Morsel
+
+
+class GatewayTokenRenewerMeta(ABCMeta, type(LoggingConfigurable)): # type: ignore[misc]
+ """The metaclass necessary for proper ABC behavior in a Configurable."""
+
+
+class GatewayTokenRenewerBase( # type:ignore[misc]
+ ABC, LoggingConfigurable, metaclass=GatewayTokenRenewerMeta
+):
+ """
+ Abstract base class for refreshing tokens used between this server and a Gateway
+ server. Implementations requiring additional configuration can extend their class
+ with appropriate configuration values or convey those values via appropriate
+ environment variables relative to the implementation.
+ """
+ @abstractmethod
+ def get_token(
+ self,
+ auth_header_key: str,
+ auth_scheme: ty.Union[str, None],
+ auth_token: str,
+ **kwargs: ty.Any,
+ ) -> str:
+ """
+ Given the current authorization header key, scheme, and token, this method returns
+ a (potentially renewed) token for use against the Gateway server.
+ """
-class GatewayClient(SingletonConfigurable):
- """This class manages the configuration. It's its own singleton class so that we
- can share these values across all objects. It also contains some helper methods
- to build request arguments out of the various config options.
+class NoOpTokenRenewer(GatewayTokenRenewerBase): # type:ignore[misc]
+ """NoOpTokenRenewer is the default value to the GatewayClient trait
+ `gateway_token_renewer` and merely returns the provided token.
"""
+ def get_token(
+ self,
+ auth_header_key: str,
+ auth_scheme: ty.Union[str, None],
+ auth_token: str,
+ **kwargs: ty.Any,
+ ) -> str:
+ """This implementation simply returns the current authorization token."""
+ return auth_token
+
+
+class GatewayClient(SingletonConfigurable):
+ """This class manages the configuration. It's its own singleton class so
+ that we can share these values across all objects. It also contains some
+ options.
+ helper methods to build request arguments out of the various config
+ """
+
+ event_schema_id = JUPYTER_SERVER_EVENTS_URI + "/gateway_client/v1"
+ event_logger = Instance(EventLogger).tag(config=True)
+
+ @default("event_logger")
+ def _default_event_logger(self):
+ if self.parent and hasattr(self.parent, "event_logger"):
+ # Event logger is attached from serverapp.
+ return self.parent.event_logger
+ else:
+ # If parent does not have an event logger, create one.
+ logger = EventLogger()
+ schema_path = DEFAULT_EVENTS_SCHEMA_PATH / "gateway_client" / "v1.yaml"
+ logger.register_event_schema(schema_path)
+ self.log.info("Event is registered in GatewayClient.")
+ return logger
+
+ def emit(self, data):
+ """Emit event using the core event schema from Jupyter Server's Gateway Client."""
+ self.event_logger.emit(schema_id=self.event_schema_id, data=data)
+
url = Unicode(
default_value=None,
allow_none=True,
config=True,
help="""The url of the Kernel or Enterprise Gateway server where
- kernel specifications are defined and kernel management takes place.
- If defined, this Notebook server acts as a proxy for all kernel
- management and kernel specification retrieval. (JUPYTER_GATEWAY_URL env var)
+kernel specifications are defined and kernel management takes place.
+If defined, this Notebook server acts as a proxy for all kernel
+management and kernel specification retrieval. (JUPYTER_GATEWAY_URL env var)
""",
)
@@ -38,9 +135,10 @@ def _url_default(self):
def _url_validate(self, proposal):
value = proposal["value"]
# Ensure value, if present, starts with 'http'
- if value is not None and len(value) > 0:
- if not str(value).lower().startswith("http"):
- raise TraitError("GatewayClient url must start with 'http': '%r'" % value)
+ if value is not None and len(value) > 0 and not str(value).lower().startswith("http"):
+ message = "GatewayClient url must start with 'http': '%r'" % value
+ self.emit(data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 400, MESSAGE_KEY: message})
+ raise TraitError(message)
return value
ws_url = Unicode(
@@ -48,7 +146,7 @@ def _url_validate(self, proposal):
allow_none=True,
config=True,
help="""The websocket url of the Kernel or Enterprise Gateway server. If not provided, this value
- will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var)
+will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var)
""",
)
@@ -57,18 +155,18 @@ def _url_validate(self, proposal):
@default("ws_url")
def _ws_url_default(self):
default_value = os.environ.get(self.ws_url_env)
- if default_value is None:
- if self.gateway_enabled:
- default_value = self.url.lower().replace("http", "ws")
+ if self.url is not None and default_value is None and self.gateway_enabled:
+ default_value = self.url.lower().replace("http", "ws")
return default_value
@validate("ws_url")
def _ws_url_validate(self, proposal):
value = proposal["value"]
# Ensure value, if present, starts with 'ws'
- if value is not None and len(value) > 0:
- if not str(value).lower().startswith("ws"):
- raise TraitError("GatewayClient ws_url must start with 'ws': '%r'" % value)
+ if value is not None and len(value) > 0 and not str(value).lower().startswith("ws"):
+ message = "GatewayClient ws_url must start with 'ws': '%r'" % value
+ self.emit(data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 400, MESSAGE_KEY: message})
+ raise TraitError(message)
return value
kernels_endpoint_default_value = "/api/kernels"
@@ -103,7 +201,7 @@ def _kernelspecs_endpoint_default(self):
default_value=kernelspecs_resource_endpoint_default_value,
config=True,
help="""The gateway endpoint for accessing kernelspecs resources
- (JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT env var)""",
+(JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT env var)""",
)
@default("kernelspecs_resource_endpoint")
@@ -119,16 +217,14 @@ def _kernelspecs_resource_endpoint_default(self):
default_value=connect_timeout_default_value,
config=True,
help="""The time allowed for HTTP connection establishment with the Gateway server.
- (JUPYTER_GATEWAY_CONNECT_TIMEOUT env var)""",
+(JUPYTER_GATEWAY_CONNECT_TIMEOUT env var)""",
)
@default("connect_timeout")
- def connect_timeout_default(self):
- return float(
- os.environ.get("JUPYTER_GATEWAY_CONNECT_TIMEOUT", self.connect_timeout_default_value)
- )
+ def _connect_timeout_default(self):
+ return float(os.environ.get(self.connect_timeout_env, self.connect_timeout_default_value))
- request_timeout_default_value = 40.0
+ request_timeout_default_value = 42.0
request_timeout_env = "JUPYTER_GATEWAY_REQUEST_TIMEOUT"
request_timeout = Float(
default_value=request_timeout_default_value,
@@ -137,10 +233,8 @@ def connect_timeout_default(self):
)
@default("request_timeout")
- def request_timeout_default(self):
- return float(
- os.environ.get("JUPYTER_GATEWAY_REQUEST_TIMEOUT", self.request_timeout_default_value)
- )
+ def _request_timeout_default(self):
+ return float(os.environ.get(self.request_timeout_env, self.request_timeout_default_value))
client_key = Unicode(
default_value=None,
@@ -222,36 +316,54 @@ def _http_pwd_default(self):
def _headers_default(self):
return os.environ.get(self.headers_env, self.headers_default_value)
+ auth_header_key_default_value = "Authorization"
+ auth_header_key = Unicode(
+ config=True,
+ help="""The authorization header's key name (typically 'Authorization') used in the HTTP headers. The
+header will be formatted as::
+
+{'{auth_header_key}': '{auth_scheme} {auth_token}'}
+
+If the authorization header key takes a single value, `auth_scheme` should be set to None and
+'auth_token' should be configured to use the appropriate value.
+
+(JUPYTER_GATEWAY_AUTH_HEADER_KEY env var)""",
+ )
+ auth_header_key_env = "JUPYTER_GATEWAY_AUTH_HEADER_KEY"
+
+ @default("auth_header_key")
+ def _auth_header_key_default(self):
+ return os.environ.get(self.auth_header_key_env, self.auth_header_key_default_value)
+
+ auth_token_default_value = ""
auth_token = Unicode(
default_value=None,
allow_none=True,
config=True,
help="""The authorization token used in the HTTP headers. The header will be formatted as::
- {
- 'Authorization': '{auth_scheme} {auth_token}'
- }
+{'{auth_header_key}': '{auth_scheme} {auth_token}'}
- (JUPYTER_GATEWAY_AUTH_TOKEN env var)""",
+(JUPYTER_GATEWAY_AUTH_TOKEN env var)""",
)
auth_token_env = "JUPYTER_GATEWAY_AUTH_TOKEN"
@default("auth_token")
def _auth_token_default(self):
- return os.environ.get(self.auth_token_env, "")
+ return os.environ.get(self.auth_token_env, self.auth_token_default_value)
+ auth_scheme_default_value = "token" # This value is purely for backwards compatibility
auth_scheme = Unicode(
- default_value=None,
allow_none=True,
config=True,
help="""The auth scheme, added as a prefix to the authorization token used in the HTTP headers.
- (JUPYTER_GATEWAY_AUTH_SCHEME env var)""",
+(JUPYTER_GATEWAY_AUTH_SCHEME env var)""",
)
auth_scheme_env = "JUPYTER_GATEWAY_AUTH_SCHEME"
@default("auth_scheme")
def _auth_scheme_default(self):
- return os.environ.get(self.auth_scheme_env, "token")
+ return os.environ.get(self.auth_scheme_env, self.auth_scheme_default_value)
validate_cert_default_value = True
validate_cert_env = "JUPYTER_GATEWAY_VALIDATE_CERT"
@@ -259,34 +371,39 @@ def _auth_scheme_default(self):
default_value=validate_cert_default_value,
config=True,
help="""For HTTPS requests, determines if server's certificate should be validated or not.
- (JUPYTER_GATEWAY_VALIDATE_CERT env var)""",
+(JUPYTER_GATEWAY_VALIDATE_CERT env var)""",
)
@default("validate_cert")
- def validate_cert_default(self):
+ def _validate_cert_default(self):
return bool(
os.environ.get(self.validate_cert_env, str(self.validate_cert_default_value))
not in ["no", "false"]
)
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self._static_args = {} # initialized on first use
-
- env_whitelist_default_value = ""
- env_whitelist_env = "JUPYTER_GATEWAY_ENV_WHITELIST"
- env_whitelist = Unicode(
- default_value=env_whitelist_default_value,
+ allowed_envs_default_value = ""
+ allowed_envs_env = "JUPYTER_GATEWAY_ALLOWED_ENVS"
+ allowed_envs = Unicode(
+ default_value=allowed_envs_default_value,
config=True,
help="""A comma-separated list of environment variable names that will be included, along with
- their values, in the kernel startup request. The corresponding `env_whitelist` configuration
- value must also be set on the Gateway server - since that configuration value indicates which
- environmental values to make available to the kernel. (JUPYTER_GATEWAY_ENV_WHITELIST env var)""",
+their values, in the kernel startup request. The corresponding `client_envs` configuration
+value must also be set on the Gateway server - since that configuration value indicates which
+environmental values to make available to the kernel. (JUPYTER_GATEWAY_ALLOWED_ENVS env var)""",
)
- @default("env_whitelist")
- def _env_whitelist_default(self):
- return os.environ.get(self.env_whitelist_env, self.env_whitelist_default_value)
+ @default("allowed_envs")
+ def _allowed_envs_default(self):
+ return os.environ.get(
+ self.allowed_envs_env,
+ os.environ.get("JUPYTER_GATEWAY_ENV_WHITELIST", self.allowed_envs_default_value),
+ )
+
+ env_whitelist = Unicode(
+ default_value=allowed_envs_default_value,
+ config=True,
+ help="""Deprecated, use `GatewayClient.allowed_envs`""",
+ )
gateway_retry_interval_default_value = 1.0
gateway_retry_interval_env = "JUPYTER_GATEWAY_RETRY_INTERVAL"
@@ -294,16 +411,16 @@ def _env_whitelist_default(self):
default_value=gateway_retry_interval_default_value,
config=True,
help="""The time allowed for HTTP reconnection with the Gateway server for the first time.
- Next will be JUPYTER_GATEWAY_RETRY_INTERVAL multiplied by two in factor of numbers of retries
- but less than JUPYTER_GATEWAY_RETRY_INTERVAL_MAX.
- (JUPYTER_GATEWAY_RETRY_INTERVAL env var)""",
+Next will be JUPYTER_GATEWAY_RETRY_INTERVAL multiplied by two in factor of numbers of retries
+but less than JUPYTER_GATEWAY_RETRY_INTERVAL_MAX.
+(JUPYTER_GATEWAY_RETRY_INTERVAL env var)""",
)
@default("gateway_retry_interval")
- def gateway_retry_interval_default(self):
+ def _gateway_retry_interval_default(self):
return float(
os.environ.get(
- "JUPYTER_GATEWAY_RETRY_INTERVAL",
+ self.gateway_retry_interval_env,
self.gateway_retry_interval_default_value,
)
)
@@ -314,14 +431,14 @@ def gateway_retry_interval_default(self):
default_value=gateway_retry_interval_max_default_value,
config=True,
help="""The maximum time allowed for HTTP reconnection retry with the Gateway server.
- (JUPYTER_GATEWAY_RETRY_INTERVAL_MAX env var)""",
+(JUPYTER_GATEWAY_RETRY_INTERVAL_MAX env var)""",
)
@default("gateway_retry_interval_max")
- def gateway_retry_interval_max_default(self):
+ def _gateway_retry_interval_max_default(self):
return float(
os.environ.get(
- "JUPYTER_GATEWAY_RETRY_INTERVAL_MAX",
+ self.gateway_retry_interval_max_env,
self.gateway_retry_interval_max_default_value,
)
)
@@ -332,15 +449,87 @@ def gateway_retry_interval_max_default(self):
default_value=gateway_retry_max_default_value,
config=True,
help="""The maximum retries allowed for HTTP reconnection with the Gateway server.
- (JUPYTER_GATEWAY_RETRY_MAX env var)""",
+(JUPYTER_GATEWAY_RETRY_MAX env var)""",
)
@default("gateway_retry_max")
- def gateway_retry_max_default(self):
- return int(
- os.environ.get("JUPYTER_GATEWAY_RETRY_MAX", self.gateway_retry_max_default_value)
+ def _gateway_retry_max_default(self):
+ return int(os.environ.get(self.gateway_retry_max_env, self.gateway_retry_max_default_value))
+
+ gateway_token_renewer_class_default_value = (
+ "jupyter_server.gateway.gateway_client.NoOpTokenRenewer"
+ )
+ gateway_token_renewer_class_env = "JUPYTER_GATEWAY_TOKEN_RENEWER_CLASS"
+ gateway_token_renewer_class = Type(
+ klass=GatewayTokenRenewerBase,
+ config=True,
+ help="""The class to use for Gateway token renewal. (JUPYTER_GATEWAY_TOKEN_RENEWER_CLASS env var)""",
+ )
+
+ @default("gateway_token_renewer_class")
+ def _gateway_token_renewer_class_default(self):
+ return os.environ.get(
+ self.gateway_token_renewer_class_env, self.gateway_token_renewer_class_default_value
+ )
+
+ launch_timeout_pad_default_value = 2.0
+ launch_timeout_pad_env = "JUPYTER_GATEWAY_LAUNCH_TIMEOUT_PAD"
+ launch_timeout_pad = Float(
+ default_value=launch_timeout_pad_default_value,
+ config=True,
+ help="""Timeout pad to be ensured between KERNEL_LAUNCH_TIMEOUT and request_timeout
+such that request_timeout >= KERNEL_LAUNCH_TIMEOUT + launch_timeout_pad.
+(JUPYTER_GATEWAY_LAUNCH_TIMEOUT_PAD env var)""",
+ )
+
+ @default("launch_timeout_pad")
+ def _launch_timeout_pad_default(self):
+ return float(
+ os.environ.get(
+ self.launch_timeout_pad_env,
+ self.launch_timeout_pad_default_value,
+ )
)
+ accept_cookies_value = False
+ accept_cookies_env = "JUPYTER_GATEWAY_ACCEPT_COOKIES"
+ accept_cookies = Bool(
+ default_value=accept_cookies_value,
+ config=True,
+ help="""Accept and manage cookies sent by the service side. This is often useful
+ for load balancers to decide which backend node to use.
+ (JUPYTER_GATEWAY_ACCEPT_COOKIES env var)""",
+ )
+
+ @default("accept_cookies")
+ def _accept_cookies_default(self):
+ return bool(
+ os.environ.get(self.accept_cookies_env, str(self.accept_cookies_value).lower())
+ not in ["no", "false"]
+ )
+
+ _deprecated_traits = {
+ "env_whitelist": ("allowed_envs", "2.0"),
+ }
+
+ # Method copied from
+ # https://github.com/jupyterhub/jupyterhub/blob/d1a85e53dccfc7b1dd81b0c1985d158cc6b61820/jupyterhub/auth.py#L143-L161
+ @observe(*list(_deprecated_traits))
+ def _deprecated_trait(self, change):
+ """observer for deprecated traits"""
+ old_attr = change.name
+ new_attr, version = self._deprecated_traits[old_attr]
+ new_value = getattr(self, new_attr)
+ if new_value != change.new:
+ # only warn if different
+ # protects backward-compatible config from warnings
+ # if they set the same value under both names
+ self.log.warning(
+ f"{self.__class__.__name__}.{old_attr} is deprecated in jupyter_server "
+ f"{version}, use {self.__class__.__name__}.{new_attr} instead"
+ )
+ setattr(self, new_attr, change.new)
+
@property
def gateway_enabled(self):
return bool(self.url is not None and len(self.url) > 0)
@@ -348,84 +537,289 @@ def gateway_enabled(self):
# Ensure KERNEL_LAUNCH_TIMEOUT has a default value.
KERNEL_LAUNCH_TIMEOUT = int(os.environ.get("KERNEL_LAUNCH_TIMEOUT", 40))
- def init_static_args(self):
- """Initialize arguments used on every request. Since these are static values, we'll
- perform this operation once.
+ _connection_args: dict[str, ty.Any] # initialized on first use
+
+ gateway_token_renewer: GatewayTokenRenewerBase
+
+ def __init__(self, **kwargs):
+ """Initialize a gateway client."""
+ super().__init__(**kwargs)
+ self._connection_args = {} # initialized on first use
+ self.gateway_token_renewer = self.gateway_token_renewer_class(parent=self, log=self.log) # type:ignore[abstract]
+
+ # store of cookies with store time
+ self._cookies: dict[str, tuple[Morsel[ty.Any], datetime]] = {}
+ def init_connection_args(self):
+ """Initialize arguments used on every request. Since these are primarily static values,
+ we'll perform this operation once.
"""
- # Ensure that request timeout and KERNEL_LAUNCH_TIMEOUT are the same, taking the
- # greater value of the two.
- if self.request_timeout < float(GatewayClient.KERNEL_LAUNCH_TIMEOUT):
- self.request_timeout = float(GatewayClient.KERNEL_LAUNCH_TIMEOUT)
- elif self.request_timeout > float(GatewayClient.KERNEL_LAUNCH_TIMEOUT):
- GatewayClient.KERNEL_LAUNCH_TIMEOUT = int(self.request_timeout)
+ # Ensure that request timeout and KERNEL_LAUNCH_TIMEOUT are in sync, taking the
+ # greater value of the two and taking into account the following relation:
+ # request_timeout = KERNEL_LAUNCH_TIME + padding
+ minimum_request_timeout = (
+ float(GatewayClient.KERNEL_LAUNCH_TIMEOUT) + self.launch_timeout_pad
+ )
+ if self.request_timeout < minimum_request_timeout:
+ self.request_timeout = minimum_request_timeout
+ elif self.request_timeout > minimum_request_timeout:
+ GatewayClient.KERNEL_LAUNCH_TIMEOUT = int(
+ self.request_timeout - self.launch_timeout_pad
+ )
# Ensure any adjustments are reflected in env.
os.environ["KERNEL_LAUNCH_TIMEOUT"] = str(GatewayClient.KERNEL_LAUNCH_TIMEOUT)
- self._static_args["headers"] = json.loads(self.headers)
- if "Authorization" not in self._static_args["headers"].keys():
- self._static_args["headers"].update(
- {"Authorization": f"{self.auth_scheme} {self.auth_token}"}
- )
- self._static_args["connect_timeout"] = self.connect_timeout
- self._static_args["request_timeout"] = self.request_timeout
- self._static_args["validate_cert"] = self.validate_cert
+ if self.headers:
+ self._connection_args["headers"] = json.loads(self.headers)
+ if self.auth_header_key not in self._connection_args["headers"]:
+ self._connection_args["headers"].update(
+ {f"{self.auth_header_key}": f"{self.auth_scheme} {self.auth_token}"}
+ )
+ self._connection_args["connect_timeout"] = self.connect_timeout
+ self._connection_args["request_timeout"] = self.request_timeout
+ self._connection_args["validate_cert"] = self.validate_cert
if self.client_cert:
- self._static_args["client_cert"] = self.client_cert
- self._static_args["client_key"] = self.client_key
+ self._connection_args["client_cert"] = self.client_cert
+ self._connection_args["client_key"] = self.client_key
if self.ca_certs:
- self._static_args["ca_certs"] = self.ca_certs
+ self._connection_args["ca_certs"] = self.ca_certs
if self.http_user:
- self._static_args["auth_username"] = self.http_user
+ self._connection_args["auth_username"] = self.http_user
if self.http_pwd:
- self._static_args["auth_password"] = self.http_pwd
+ self._connection_args["auth_password"] = self.http_pwd
def load_connection_args(self, **kwargs):
- """Merges the static args relative to the connection, with the given keyword arguments. If statics
- have yet to be initialized, we'll do that here.
+ """Merges the static args relative to the connection, with the given keyword arguments. If static
+ args have yet to be initialized, we'll do that here.
"""
- if len(self._static_args) == 0:
- self.init_static_args()
+ if len(self._connection_args) == 0:
+ self.init_connection_args()
+
+ # Give token renewal a shot at renewing the token
+ prev_auth_token = self.auth_token
+ if self.auth_token is not None:
+ try:
+ self.auth_token = self.gateway_token_renewer.get_token(
+ self.auth_header_key, self.auth_scheme, self.auth_token
+ )
+ except Exception as ex:
+ self.log.error(
+ f"An exception occurred attempting to renew the "
+ f"Gateway authorization token using an instance of class "
+ f"'{self.gateway_token_renewer_class}'. The request will "
+ f"proceed using the current token value. Exception was: {ex}"
+ )
+ self.auth_token = prev_auth_token
+
+ for arg, value in self._connection_args.items():
+ if arg == "headers":
+ given_value = kwargs.setdefault(arg, {})
+ if isinstance(given_value, dict):
+ given_value.update(value)
+ # Ensure the auth header is current
+ given_value.update(
+ {f"{self.auth_header_key}": f"{self.auth_scheme} {self.auth_token}"}
+ )
+ else:
+ kwargs[arg] = value
+
+ if self.accept_cookies:
+ self._update_cookie_header(kwargs)
- kwargs.update(self._static_args)
return kwargs
+ def update_cookies(self, cookie: SimpleCookie) -> None:
+ """Update cookies from existing requests for load balancers"""
+ if not self.accept_cookies:
+ return
+
+ store_time = datetime.now(tz=timezone.utc)
+ for key, item in cookie.items():
+ # Convert "expires" arg into "max-age" to facilitate expiration management.
+ # As "max-age" has precedence, ignore "expires" when "max-age" exists.
+ if item.get("expires") and not item.get("max-age"):
+ expire_timedelta = parsedate_to_datetime(item["expires"]) - store_time
+ item["max-age"] = str(expire_timedelta.total_seconds())
+
+ self._cookies[key] = (item, store_time)
+
+ def _clear_expired_cookies(self) -> None:
+ """Clear expired cookies."""
+ check_time = datetime.now(tz=timezone.utc)
+ expired_keys = []
+
+ for key, (morsel, store_time) in self._cookies.items():
+ cookie_max_age = morsel.get("max-age")
+ if not cookie_max_age:
+ continue
+ expired_timedelta = check_time - store_time
+ if expired_timedelta.total_seconds() > float(cookie_max_age):
+ expired_keys.append(key)
+
+ for key in expired_keys:
+ self._cookies.pop(key)
+
+ def _update_cookie_header(self, connection_args: dict[str, ty.Any]) -> None:
+ """Update a cookie header."""
+ self._clear_expired_cookies()
+
+ gateway_cookie_values = "; ".join(
+ f"{name}={morsel.coded_value}" for name, (morsel, _time) in self._cookies.items()
+ )
+ if gateway_cookie_values:
+ headers = connection_args.get("headers", {})
+
+ # As headers are case-insensitive, we get existing name of cookie header,
+ # or use "Cookie" by default.
+ cookie_header_name = next(
+ (header_key for header_key in headers if header_key.lower() == "cookie"),
+ "Cookie",
+ )
+ existing_cookie = headers.get(cookie_header_name)
+
+ # merge gateway-managed cookies with cookies already in arguments
+ if existing_cookie:
+ gateway_cookie_values = existing_cookie + "; " + gateway_cookie_values
+ headers[cookie_header_name] = gateway_cookie_values
+
+ connection_args["headers"] = headers
+
+
+class RetryableHTTPClient:
+ """
+ Inspired by urllib.util.Retry (https://urllib3.readthedocs.io/en/stable/reference/urllib3.util.html),
+ this class is initialized with desired retry characteristics, uses a recursive method `fetch()` against an instance
+ of `AsyncHTTPClient` which tracks the current retry count across applicable request retries.
+ """
+
+ MAX_RETRIES_DEFAULT = 2
+ MAX_RETRIES_CAP = 10 # The upper limit to max_retries value.
+ max_retries: int = int(os.getenv("JUPYTER_GATEWAY_MAX_REQUEST_RETRIES", MAX_RETRIES_DEFAULT))
+ max_retries = max(0, min(max_retries, MAX_RETRIES_CAP)) # Enforce boundaries
+ retried_methods: set[str] = {"GET", "DELETE"}
+ retried_errors: set[int] = {502, 503, 504, 599}
+ retried_exceptions: set[type] = {ConnectionError}
+ backoff_factor: float = 0.1
+
+ def __init__(self):
+ """Initialize the retryable http client."""
+ self.retry_count: int = 0
+ self.client: AsyncHTTPClient = AsyncHTTPClient()
+
+ async def fetch(self, endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
+ """
+ Retryable AsyncHTTPClient.fetch() method. When the request fails, this method will
+ recurse up to max_retries times if the condition deserves a retry.
+ """
+ self.retry_count = 0
+ return await self._fetch(endpoint, **kwargs)
+
+ async def _fetch(self, endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
+ """
+ Performs the fetch against the contained AsyncHTTPClient instance and determines
+ if retry is necessary on any exceptions. If so, retry is performed recursively.
+ """
+ try:
+ response: HTTPResponse = await self.client.fetch(endpoint, **kwargs)
+ except Exception as e:
+ is_retryable: bool = await self._is_retryable(kwargs["method"], e)
+ if not is_retryable:
+ raise e
+ logging.getLogger("ServerApp").info(
+ f"Attempting retry ({self.retry_count}) against "
+ f"endpoint '{endpoint}'. Retried error: '{e!r}'"
+ )
+ response = await self._fetch(endpoint, **kwargs)
+ return response
+
+ async def _is_retryable(self, method: str, exception: Exception) -> bool:
+ """Determines if the given exception is retryable based on object's configuration."""
+
+ if method not in self.retried_methods:
+ return False
+ if self.retry_count == self.max_retries:
+ return False
+
+ # Determine if error is retryable...
+ if isinstance(exception, HTTPClientError):
+ hce: HTTPClientError = exception
+ if hce.code not in self.retried_errors:
+ return False
+ elif not any(isinstance(exception, error) for error in self.retried_exceptions):
+ return False
+
+ # Is retryable, wait for backoff, then increment count
+ await asyncio.sleep(self.backoff_factor * (2**self.retry_count))
+ self.retry_count += 1
+ return True
+
-async def gateway_request(endpoint, **kwargs):
+async def gateway_request(endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
"""Make an async request to kernel gateway endpoint, returns a response"""
- client = AsyncHTTPClient()
kwargs = GatewayClient.instance().load_connection_args(**kwargs)
+ rhc = RetryableHTTPClient()
try:
- response = await client.fetch(endpoint, **kwargs)
+ response = await rhc.fetch(endpoint, **kwargs)
+ GatewayClient.instance().emit(
+ data={STATUS_KEY: SUCCESS_STATUS, STATUS_CODE_KEY: 200, MESSAGE_KEY: "success"}
+ )
# Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect
# or the server is not running.
- # NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes
+ # NOTE: We do this here since this handler is called during the server's startup and subsequent refreshes
# of the tree view.
- except ConnectionRefusedError as e:
+ except HTTPClientError as e:
+ GatewayClient.instance().emit(
+ data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: e.code, MESSAGE_KEY: str(e.message)}
+ )
+ error_reason = f"Exception while attempting to connect to Gateway server url '{GatewayClient.instance().url}'"
+ error_message = e.message
+ if e.response:
+ try:
+ error_payload = json.loads(e.response.body)
+ error_reason = error_payload.get("reason") or error_reason
+ error_message = error_payload.get("message") or error_message
+ except json.decoder.JSONDecodeError:
+ error_reason = e.response.body.decode()
+
raise web.HTTPError(
- 503,
- "Connection refused from Gateway server url '{}'. "
- "Check to be sure the Gateway instance is running.".format(
- GatewayClient.instance().url
- ),
+ e.code,
+ f"Error from Gateway: [{error_message}] {error_reason}. "
+ "Ensure gateway url is valid and the Gateway instance is running.",
) from e
- except HTTPError as e:
- # This can occur if the host is valid (e.g., foo.com) but there's nothing there.
+ except ConnectionError as e:
+ GatewayClient.instance().emit(
+ data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 503, MESSAGE_KEY: str(e)}
+ )
raise web.HTTPError(
- e.code,
- "Error attempting to connect to Gateway server url '{}'. "
- "Ensure gateway url is valid and the Gateway instance is running.".format(
- GatewayClient.instance().url
- ),
+ 503,
+ f"ConnectionError was received from Gateway server url '{GatewayClient.instance().url}'. "
+ "Check to be sure the Gateway instance is running.",
) from e
except gaierror as e:
+ GatewayClient.instance().emit(
+ data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 404, MESSAGE_KEY: str(e)}
+ )
raise web.HTTPError(
404,
- "The Gateway server specified in the gateway_url '{}' doesn't appear to be valid. "
- "Ensure gateway url is valid and the Gateway instance is running.".format(
- GatewayClient.instance().url
- ),
+ f"The Gateway server specified in the gateway_url '{GatewayClient.instance().url}' doesn't "
+ f"appear to be valid. Ensure gateway url is valid and the Gateway instance is running.",
) from e
-
+ except Exception as e:
+ GatewayClient.instance().emit(
+ data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 505, MESSAGE_KEY: str(e)}
+ )
+ logging.getLogger("ServerApp").error(
+ f"Exception while trying to launch kernel via Gateway URL {GatewayClient.instance().url} , {e}",
+ e,
+ )
+ raise e
+
+ if GatewayClient.instance().accept_cookies:
+ # Update cookies on GatewayClient from server if configured.
+ cookie_values = response.headers.get("Set-Cookie")
+ if cookie_values:
+ cookie: SimpleCookie = SimpleCookie()
+ cookie.load(cookie_values)
+ GatewayClient.instance().update_cookies(cookie)
return response
diff --git a/jupyter_server/gateway/handlers.py b/jupyter_server/gateway/handlers.py
index a36f2d4faf..dcde4cd5ca 100644
--- a/jupyter_server/gateway/handlers.py
+++ b/jupyter_server/gateway/handlers.py
@@ -1,10 +1,15 @@
+"""Gateway API handlers."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
+from __future__ import annotations
+
import asyncio
import logging
import mimetypes
import os
import random
+import warnings
+from typing import Any, Optional, cast
from jupyter_client.session import Session
from tornado import web
@@ -17,13 +22,21 @@
from ..base.handlers import APIHandler, JupyterHandler
from ..utils import url_path_join
-from .managers import GatewayClient
+from .gateway_client import GatewayClient
+
+warnings.warn(
+ "The jupyter_server.gateway.handlers module is deprecated and will not be supported in Jupyter Server 3.0",
+ DeprecationWarning,
+ stacklevel=2,
+)
+
# Keepalive ping interval (default: 30 seconds)
-GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", 30))
+GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", "30"))
class WebSocketChannelsHandler(WebSocketHandler, JupyterHandler):
+ """Gateway web socket channels handler."""
session = None
gateway = None
@@ -31,13 +44,14 @@ class WebSocketChannelsHandler(WebSocketHandler, JupyterHandler):
ping_callback = None
def check_origin(self, origin=None):
+ """Check origin for the socket."""
return JupyterHandler.check_origin(self, origin)
def set_default_headers(self):
"""Undo the set_default_headers in JupyterHandler which doesn't make sense for websockets"""
- pass
def get_compression_options(self):
+ """Get the compression options for the socket."""
# use deflate compress websocket
return {}
@@ -48,28 +62,33 @@ def authenticate(self):
the websocket finishes completing.
"""
# authenticate the request before opening the websocket
- if self.get_current_user() is None:
+ if self.current_user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)
- if self.get_argument("session_id", False):
- self.session.session = self.get_argument("session_id")
+ if self.get_argument("session_id", None):
+ assert self.session is not None
+ self.session.session = self.get_argument("session_id") # type:ignore[unreachable]
else:
self.log.warning("No session ID specified")
def initialize(self):
+ """Initialize the socket."""
self.log.debug("Initializing websocket connection %s", self.request.path)
self.session = Session(config=self.config)
self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url)
async def get(self, kernel_id, *args, **kwargs):
+ """Get the socket."""
self.authenticate()
self.kernel_id = kernel_id
- await super().get(kernel_id=kernel_id, *args, **kwargs)
+ kwargs["kernel_id"] = kernel_id
+ await super().get(*args, **kwargs)
def send_ping(self):
+ """Send a ping to the socket."""
if self.ws_connection is None and self.ping_callback is not None:
- self.ping_callback.stop()
+ self.ping_callback.stop() # type:ignore[unreachable]
return
self.ping(b"")
@@ -79,6 +98,7 @@ def open(self, kernel_id, *args, **kwargs):
self.ping_callback = PeriodicCallback(self.send_ping, GATEWAY_WS_PING_INTERVAL_SECS * 1000)
self.ping_callback.start()
+ assert self.gateway is not None
self.gateway.on_open(
kernel_id=kernel_id,
message_callback=self.write_message,
@@ -87,6 +107,7 @@ def open(self, kernel_id, *args, **kwargs):
def on_message(self, message):
"""Forward message to gateway web socket handler."""
+ assert self.gateway is not None
self.gateway.on_message(message)
def write_message(self, message, binary=False):
@@ -98,18 +119,19 @@ def write_message(self, message, binary=False):
elif self.log.isEnabledFor(logging.DEBUG):
msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message)))
self.log.debug(
- "Notebook client closed websocket connection - message dropped: {}".format(
- msg_summary
- )
+ f"Notebook client closed websocket connection - message dropped: {msg_summary}"
)
def on_close(self):
+ """Handle a closing socket."""
self.log.debug("Closing websocket connection %s", self.request.path)
+ assert self.gateway is not None
self.gateway.on_close()
super().on_close()
@staticmethod
def _get_message_summary(message):
+ """Get a summary of a message."""
summary = []
message_type = message["msg_type"]
summary.append(f"type: {message_type}")
@@ -134,35 +156,41 @@ class GatewayWebSocketClient(LoggingConfigurable):
"""Proxy web socket connection to a kernel/enterprise gateway."""
def __init__(self, **kwargs):
- super().__init__(**kwargs)
+ """Initialize the gateway web socket client."""
+ super().__init__()
self.kernel_id = None
self.ws = None
- self.ws_future = Future()
+ self.ws_future: Future[Any] = Future()
self.disconnected = False
self.retry = 0
async def _connect(self, kernel_id, message_callback):
+ """Connect to the socket."""
# websocket is initialized before connection
self.ws = None
self.kernel_id = kernel_id
+ client = GatewayClient.instance()
+ assert client.ws_url is not None
+
ws_url = url_path_join(
- GatewayClient.instance().ws_url,
- GatewayClient.instance().kernels_endpoint,
+ client.ws_url,
+ client.kernels_endpoint,
url_escape(kernel_id),
"channels",
)
self.log.info(f"Connecting to {ws_url}")
- kwargs = {}
- kwargs = GatewayClient.instance().load_connection_args(**kwargs)
+ kwargs: dict[str, Any] = {}
+ kwargs = client.load_connection_args(**kwargs)
request = HTTPRequest(ws_url, **kwargs)
- self.ws_future = websocket_connect(request)
+ self.ws_future = cast("Future[Any]", websocket_connect(request))
self.ws_future.add_done_callback(self._connection_done)
loop = IOLoop.current()
loop.add_future(self.ws_future, lambda future: self._read_messages(message_callback))
def _connection_done(self, fut):
+ """Handle a finished connection."""
if (
not self.disconnected and fut.exception() is None
): # prevent concurrent.futures._base.CancelledError
@@ -178,6 +206,7 @@ def _connection_done(self, fut):
)
def _disconnect(self):
+ """Handle a disconnect."""
self.disconnected = True
if self.ws is not None:
# Close connection
@@ -261,16 +290,20 @@ class GatewayResourceHandler(APIHandler):
@web.authenticated
async def get(self, kernel_name, path, include_body=True):
+ """Get a gateway resource by name and path."""
+ mimetype: Optional[str] = None
ksm = self.kernel_spec_manager
- kernel_spec_res = await ksm.get_kernel_spec_resource(kernel_name, path)
+ kernel_spec_res = await ksm.get_kernel_spec_resource( # type:ignore[attr-defined]
+ kernel_name, path
+ )
if kernel_spec_res is None:
self.log.warning(
"Kernelspec resource '{}' for '{}' not found. Gateway may not support"
" resource serving.".format(path, kernel_name)
)
else:
- self.set_header("Content-Type", mimetypes.guess_type(path)[0])
- self.finish(kernel_spec_res)
+ mimetype = mimetypes.guess_type(path)[0] or "text/plain"
+ self.finish(kernel_spec_res, set_content_type=mimetype)
from ..services.kernels.handlers import _kernel_id_regex
diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py
index 4645429cf1..cd0b27b50d 100644
--- a/jupyter_server/gateway/managers.py
+++ b/jupyter_server/gateway/managers.py
@@ -1,27 +1,36 @@
+"""Kernel gateway managers."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
+from __future__ import annotations
+
+import asyncio
import datetime
import json
import os
from logging import Logger
-from queue import Queue
+from queue import Empty, Queue
from threading import Thread
-from typing import Dict
+from time import monotonic
+from typing import Any, Optional, cast
import websocket
from jupyter_client.asynchronous.client import AsyncKernelClient
from jupyter_client.clientabc import KernelClientABC
from jupyter_client.kernelspec import KernelSpecManager
-from jupyter_client.manager import AsyncKernelManager
from jupyter_client.managerabc import KernelManagerABC
+from jupyter_core.utils import ensure_async
from tornado import web
from tornado.escape import json_decode, json_encode, url_escape, utf8
from traitlets import DottedObjectName, Instance, Type, default
-from .._tz import UTC
-from ..services.kernels.kernelmanager import AsyncMappingKernelManager
+from .._tz import UTC, utcnow
+from ..services.kernels.kernelmanager import (
+ AsyncMappingKernelManager,
+ ServerKernelManager,
+ emit_kernel_action_event,
+)
from ..services.sessions.sessionmanager import SessionManager
-from ..utils import ensure_async, url_path_join
+from ..utils import url_path_join
from .gateway_client import GatewayClient, gateway_request
@@ -29,7 +38,7 @@ class GatewayMappingKernelManager(AsyncMappingKernelManager):
"""Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway."""
# We'll maintain our own set of kernel ids
- _kernels: Dict[str, "GatewayKernelManager"] = {}
+ _kernels: dict[str, GatewayKernelManager] = {} # type:ignore[assignment]
@default("kernel_manager_class")
def _default_kernel_manager_class(self):
@@ -40,9 +49,10 @@ def _default_shared_context(self):
return False # no need to share zmq contexts
def __init__(self, **kwargs):
+ """Initialize a gateway mapping kernel manager."""
super().__init__(**kwargs)
self.kernels_url = url_path_join(
- GatewayClient.instance().url, GatewayClient.instance().kernels_endpoint
+ GatewayClient.instance().url or "", GatewayClient.instance().kernels_endpoint or ""
)
def remove_kernel(self, kernel_id):
@@ -52,7 +62,7 @@ def remove_kernel(self, kernel_id):
except KeyError:
pass
- async def start_kernel(self, kernel_id=None, path=None, **kwargs):
+ async def start_kernel(self, *, kernel_id=None, path=None, **kwargs):
"""Start a kernel for a session and return its kernel_id.
Parameters
@@ -67,15 +77,13 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs):
"""
self.log.info(f"Request start kernel: kernel_id={kernel_id}, path='{path}'")
- if kernel_id is None:
- if path is not None:
- kwargs["cwd"] = self.cwd_for_path(path)
+ if kernel_id is None and path is not None:
+ kwargs["cwd"] = self.cwd_for_path(path)
km = self.kernel_manager_factory(parent=self, log=self.log)
await km.start_kernel(kernel_id=kernel_id, **kwargs)
kernel_id = km.kernel_id
self._kernels[kernel_id] = km
-
# Initialize culling if not already
if not self._initialized_culler:
self.initialize_culler()
@@ -92,9 +100,9 @@ async def kernel_model(self, kernel_id):
The uuid of the kernel.
"""
model = None
- km = self.get_kernel(kernel_id)
- if km:
- model = km.kernel
+ km = self.get_kernel(str(kernel_id))
+ if km: # type:ignore[truthy-bool]
+ model = km.kernel # type:ignore[attr-defined]
return model
async def list_kernels(self, **kwargs):
@@ -119,11 +127,34 @@ async def list_kernels(self, **kwargs):
culled_ids = []
for kid, _ in our_kernels.items():
if kid not in kernel_models:
+ # The upstream kernel was not reported in the list of kernels.
self.log.warning(
- f"Kernel {kid} no longer active - probably culled on Gateway server."
+ f"Kernel {kid} not present in the list of kernels - possibly culled on Gateway server."
)
- self._kernels.pop(kid, None)
- culled_ids.append(kid) # TODO: Figure out what do with these.
+ try:
+ # Try to directly refresh the model for this specific kernel in case
+ # the upstream list of kernels was erroneously incomplete.
+ #
+ # That might happen if the case of a proxy that manages multiple
+ # backends where there could be transient connectivity issues with
+ # a single backend.
+ #
+ # Alternatively, it could happen if there is simply a bug in the
+ # upstream gateway server.
+ #
+ # Either way, including this check improves our reliability in the
+ # face of such scenarios.
+ model = await self._kernels[kid].refresh_model()
+ except web.HTTPError:
+ model = None
+ if model:
+ kernel_models[kid] = model
+ else:
+ self.log.warning(
+ f"Kernel {kid} no longer active - probably culled on Gateway server."
+ )
+ self._kernels.pop(kid, None)
+ culled_ids.append(kid) # TODO: Figure out what do with these.
return list(kernel_models.values())
async def shutdown_kernel(self, kernel_id, now=False, restart=False):
@@ -139,7 +170,7 @@ async def shutdown_kernel(self, kernel_id, now=False, restart=False):
The purpose of this shutdown is to restart the kernel (True)
"""
km = self.get_kernel(kernel_id)
- await km.shutdown_kernel(now=now, restart=restart)
+ await ensure_async(km.shutdown_kernel(now=now, restart=restart))
self.remove_kernel(kernel_id)
async def restart_kernel(self, kernel_id, now=False, **kwargs):
@@ -151,7 +182,7 @@ async def restart_kernel(self, kernel_id, now=False, **kwargs):
The id of the kernel to restart.
"""
km = self.get_kernel(kernel_id)
- await km.restart_kernel(now=now, **kwargs)
+ await ensure_async(km.restart_kernel(now=now, **kwargs))
async def interrupt_kernel(self, kernel_id, **kwargs):
"""Interrupt a kernel by its kernel uuid.
@@ -162,44 +193,73 @@ async def interrupt_kernel(self, kernel_id, **kwargs):
The id of the kernel to interrupt.
"""
km = self.get_kernel(kernel_id)
- await km.interrupt_kernel()
+ await ensure_async(km.interrupt_kernel())
async def shutdown_all(self, now=False):
"""Shutdown all kernels."""
- for kernel_id in self._kernels:
+ kids = list(self._kernels)
+ for kernel_id in kids:
km = self.get_kernel(kernel_id)
- await km.shutdown_kernel(now=now)
+ await ensure_async(km.shutdown_kernel(now=now))
self.remove_kernel(kernel_id)
async def cull_kernels(self):
- """Override cull_kernels so we can be sure their state is current."""
+ """Override cull_kernels, so we can be sure their state is current."""
await self.list_kernels()
await super().cull_kernels()
class GatewayKernelSpecManager(KernelSpecManager):
+ """A gateway kernel spec manager."""
+
def __init__(self, **kwargs):
+ """Initialize a gateway kernel spec manager."""
super().__init__(**kwargs)
base_endpoint = url_path_join(
- GatewayClient.instance().url, GatewayClient.instance().kernelspecs_endpoint
+ GatewayClient.instance().url or "", GatewayClient.instance().kernelspecs_endpoint
)
self.base_endpoint = GatewayKernelSpecManager._get_endpoint_for_user_filter(base_endpoint)
self.base_resource_endpoint = url_path_join(
- GatewayClient.instance().url,
+ GatewayClient.instance().url or "",
GatewayClient.instance().kernelspecs_resource_endpoint,
)
@staticmethod
def _get_endpoint_for_user_filter(default_endpoint):
+ """Get the endpoint for a user filter."""
kernel_user = os.environ.get("KERNEL_USERNAME")
if kernel_user:
return "?user=".join([default_endpoint, kernel_user])
return default_endpoint
+ def _replace_path_kernelspec_resources(self, kernel_specs):
+ """Helper method that replaces any gateway base_url with the server's base_url
+ This enables clients to properly route through jupyter_server to a gateway
+ for kernel resources such as logo files
+ """
+ if not self.parent:
+ return {}
+ kernelspecs = kernel_specs["kernelspecs"]
+ for kernel_name in kernelspecs:
+ resources = kernelspecs[kernel_name]["resources"]
+ for resource_name in resources:
+ original_path = resources[resource_name]
+ split_eg_base_url = str.rsplit(original_path, sep="/kernelspecs/", maxsplit=1)
+ if len(split_eg_base_url) > 1:
+ new_path = url_path_join(
+ self.parent.base_url, "kernelspecs", split_eg_base_url[1]
+ )
+ kernel_specs["kernelspecs"][kernel_name]["resources"][resource_name] = new_path
+ if original_path != new_path:
+ self.log.debug(
+ f"Replaced original kernel resource path {original_path} with new "
+ f"path {kernel_specs['kernelspecs'][kernel_name]['resources'][resource_name]}"
+ )
+ return kernel_specs
+
def _get_kernelspecs_endpoint_url(self, kernel_name=None):
"""Builds a url for the kernels endpoint
-
Parameters
----------
kernel_name : kernel name (optional)
@@ -210,12 +270,15 @@ def _get_kernelspecs_endpoint_url(self, kernel_name=None):
return self.base_endpoint
async def get_all_specs(self):
+ """Get all of the kernel specs for the gateway."""
fetched_kspecs = await self.list_kernel_specs()
# get the default kernel name and compare to that of this server.
# If different log a warning and reset the default. However, the
# caller of this method will still return this server's value until
# the next fetch of kernelspecs - at which time they'll match.
+ if not self.parent:
+ return {}
km = self.parent.kernel_manager
remote_default_kernel_name = fetched_kspecs.get("default")
if remote_default_kernel_name != km.default_kernel_name:
@@ -234,6 +297,7 @@ async def list_kernel_specs(self):
self.log.debug(f"Request list kernel specs at: {kernel_spec_url}")
response = await gateway_request(kernel_spec_url, method="GET")
kernel_specs = json_decode(response.body)
+ kernel_specs = self._replace_path_kernelspec_resources(kernel_specs)
return kernel_specs
async def get_kernel_spec(self, kernel_name, **kwargs):
@@ -252,12 +316,8 @@ async def get_kernel_spec(self, kernel_name, **kwargs):
if error.status_code == 404:
# Convert not found to KeyError since that's what the Notebook handler expects
# message is not used, but might as well make it useful for troubleshooting
- raise KeyError(
- "kernelspec {kernel_name} not found on Gateway server at: {gateway_url}".format(
- kernel_name=kernel_name,
- gateway_url=GatewayClient.instance().url,
- )
- ) from error
+ msg = f"kernelspec {kernel_name} not found on Gateway server at: {GatewayClient.instance().url}"
+ raise KeyError(msg) from None
else:
raise
else:
@@ -292,26 +352,32 @@ async def get_kernel_spec_resource(self, kernel_name, path):
class GatewaySessionManager(SessionManager):
+ """A gateway session manager."""
+
kernel_manager = Instance("jupyter_server.gateway.managers.GatewayMappingKernelManager")
- async def kernel_culled(self, kernel_id):
- """Checks if the kernel is still considered alive and returns true if its not found."""
- kernel = None
+ async def kernel_culled(self, kernel_id: str) -> bool: # typing: ignore
+ """Checks if the kernel is still considered alive and returns true if it's not found."""
+ km: Optional[GatewayKernelManager] = None
try:
+ # Since we keep the models up-to-date via client polling, use that state to determine
+ # if this kernel no longer exists on the gateway server rather than perform a redundant
+ # fetch operation - especially since this is called at approximately the same interval.
+ # This has the effect of reducing GET /api/kernels requests against the gateway server
+ # by 50%!
+ # Note that should the redundant polling be consolidated, or replaced with an event-based
+ # notification model, this will need to be revisited.
km = self.kernel_manager.get_kernel(kernel_id)
- kernel = await km.refresh_model()
- except Exception: # Let exceptions here reflect culled kernel
+ except Exception:
+ # Let exceptions here reflect culled kernel
pass
- return kernel is None
+ return km is None
-"""KernelManager class to manage a kernel running on a Gateway Server via the REST API"""
-
-
-class GatewayKernelManager(AsyncKernelManager):
+class GatewayKernelManager(ServerKernelManager):
"""Manages a single kernel remotely via a Gateway Server."""
- kernel_id = None
+ kernel_id: Optional[str] = None # type:ignore[assignment]
kernel = None
@default("cache_ports")
@@ -319,13 +385,16 @@ def _default_cache_ports(self):
return False # no need to cache ports here
def __init__(self, **kwargs):
+ """Initialize the gateway kernel manager."""
super().__init__(**kwargs)
self.kernels_url = url_path_join(
- GatewayClient.instance().url, GatewayClient.instance().kernels_endpoint
+ GatewayClient.instance().url or "", GatewayClient.instance().kernels_endpoint
)
- self.kernel_url = self.kernel = self.kernel_id = None
+ self.kernel_url: str
+ self.kernel = self.kernel_id = None
# simulate busy/activity markers:
- self.execution_state = self.last_activity = None
+ self.execution_state = "starting"
+ self.last_activity = utcnow()
@property
def has_kernel(self):
@@ -341,13 +410,13 @@ def has_kernel(self):
def client(self, **kwargs):
"""Create a client configured to connect to our kernel"""
- kw = {}
+ kw: dict[str, Any] = {}
kw.update(self.get_connection_info(session=True))
kw.update(
- dict(
- connection_file=self.connection_file,
- parent=self,
- )
+ {
+ "connection_file": self.connection_file,
+ "parent": self,
+ }
)
kw["kernel_id"] = self.kernel_id
@@ -387,9 +456,9 @@ async def refresh_model(self, model=None):
if isinstance(self.parent, AsyncMappingKernelManager):
# Update connections only if there's a mapping kernel manager parent for
# this kernel manager. The current kernel manager instance may not have
- # an parent instance if, say, a server extension is using another application
+ # a parent instance if, say, a server extension is using another application
# (e.g., papermill) that uses a KernelManager instance directly.
- self.parent._kernel_connections[self.kernel_id] = int(model["connections"])
+ self.parent._kernel_connections[self.kernel_id] = int(model["connections"]) # type:ignore[index]
self.kernel = model
return model
@@ -398,6 +467,9 @@ async def refresh_model(self, model=None):
# Kernel management
# --------------------------------------------------------------------------
+ @emit_kernel_action_event(
+ success_msg="Kernel {kernel_id} was started.",
+ )
async def start_kernel(self, **kwargs):
"""Starts a kernel via HTTP in an asynchronous manner.
@@ -415,24 +487,30 @@ async def start_kernel(self, **kwargs):
# Let KERNEL_USERNAME take precedent over http_user config option.
if os.environ.get("KERNEL_USERNAME") is None and GatewayClient.instance().http_user:
- os.environ["KERNEL_USERNAME"] = GatewayClient.instance().http_user
+ os.environ["KERNEL_USERNAME"] = GatewayClient.instance().http_user or ""
+
+ payload_envs = os.environ.copy()
+ payload_envs.update(kwargs.get("env", {})) # Add any env entries in this request
+ # Build the actual env payload, filtering allowed_envs and those starting with 'KERNEL_'
kernel_env = {
k: v
- for (k, v) in dict(os.environ).items()
- if k.startswith("KERNEL_") or k in GatewayClient.instance().env_whitelist.split(",")
+ for (k, v) in payload_envs.items()
+ if k.startswith("KERNEL_") or k in GatewayClient.instance().allowed_envs.split(",")
}
- # Add any env entries in this request
- kernel_env.update(kwargs.get("env", {}))
-
# Convey the full path to where this notebook file is located.
if kwargs.get("cwd") is not None and kernel_env.get("KERNEL_WORKING_DIR") is None:
kernel_env["KERNEL_WORKING_DIR"] = kwargs["cwd"]
json_body = json_encode({"name": kernel_name, "env": kernel_env})
- response = await gateway_request(self.kernels_url, method="POST", body=json_body)
+ response = await gateway_request(
+ self.kernels_url,
+ method="POST",
+ headers={"Content-Type": "application/json"},
+ body=json_body,
+ )
self.kernel = json_decode(response.body)
self.kernel_id = self.kernel["id"]
self.kernel_url = url_path_join(self.kernels_url, url_escape(str(self.kernel_id)))
@@ -443,28 +521,55 @@ async def start_kernel(self, **kwargs):
self.kernel = await self.refresh_model()
self.log.info(f"GatewayKernelManager using existing kernel: {self.kernel_id}")
+ @emit_kernel_action_event(
+ success_msg="Kernel {kernel_id} was shutdown.",
+ )
async def shutdown_kernel(self, now=False, restart=False):
"""Attempts to stop the kernel process cleanly via HTTP."""
if self.has_kernel:
self.log.debug("Request shutdown kernel at: %s", self.kernel_url)
- response = await gateway_request(self.kernel_url, method="DELETE")
- self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason)
+ try:
+ response = await gateway_request(self.kernel_url, method="DELETE")
+ self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason)
+ except web.HTTPError as error:
+ if error.status_code == 404:
+ self.log.debug("Shutdown kernel response: kernel not found (ignored)")
+ else:
+ raise
+ @emit_kernel_action_event(
+ success_msg="Kernel {kernel_id} was restarted.",
+ )
async def restart_kernel(self, **kw):
"""Restarts a kernel via HTTP."""
if self.has_kernel:
+ assert self.kernel_url is not None
kernel_url = self.kernel_url + "/restart"
self.log.debug("Request restart kernel at: %s", kernel_url)
- response = await gateway_request(kernel_url, method="POST", body=json_encode({}))
+ response = await gateway_request(
+ kernel_url,
+ method="POST",
+ headers={"Content-Type": "application/json"},
+ body=json_encode({}),
+ )
self.log.debug("Restart kernel response: %d %s", response.code, response.reason)
+ @emit_kernel_action_event(
+ success_msg="Kernel {kernel_id} was interrupted.",
+ )
async def interrupt_kernel(self):
"""Interrupts the kernel via an HTTP request."""
if self.has_kernel:
+ assert self.kernel_url is not None
kernel_url = self.kernel_url + "/interrupt"
self.log.debug("Request interrupt kernel at: %s", kernel_url)
- response = await gateway_request(kernel_url, method="POST", body=json_encode({}))
+ response = await gateway_request(
+ kernel_url,
+ method="POST",
+ headers={"Content-Type": "application/json"},
+ body=json_encode({}),
+ )
self.log.debug("Interrupt kernel response: %d %s", response.code, response.reason)
async def is_alive(self):
@@ -472,40 +577,67 @@ async def is_alive(self):
if self.has_kernel:
# Go ahead and issue a request to get the kernel
self.kernel = await self.refresh_model()
+ self.log.debug(f"The kernel: {self.kernel} is alive.")
return True
else: # we don't have a kernel
+ self.log.debug(f"The kernel: {self.kernel} no longer exists.")
return False
def cleanup_resources(self, restart=False):
"""Clean up resources when the kernel is shut down"""
- pass
KernelManagerABC.register(GatewayKernelManager)
-class ChannelQueue(Queue):
+class ChannelQueue(Queue): # type:ignore[type-arg]
+ """A queue for a named channel."""
- channel_name: str = None
+ channel_name: Optional[str] = None
+ response_router_finished: bool
def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log: Logger):
+ """Initialize a channel queue."""
super().__init__()
self.channel_name = channel_name
self.channel_socket = channel_socket
self.log = log
+ self.response_router_finished = False
+
+ async def _async_get(self, timeout=None):
+ """Asynchronously get from the queue."""
+ if timeout is None:
+ timeout = float("inf")
+ elif timeout < 0:
+ msg = "'timeout' must be a non-negative number"
+ raise ValueError(msg)
+ end_time = monotonic() + timeout
+
+ while True:
+ try:
+ return self.get(block=False)
+ except Empty:
+ if self.response_router_finished:
+ msg = "Response router had finished"
+ raise RuntimeError(msg) from None
+ if monotonic() > end_time:
+ raise
+ await asyncio.sleep(0)
- async def get_msg(self, *args, **kwargs) -> dict:
+ async def get_msg(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
+ """Get a message from the queue."""
timeout = kwargs.get("timeout", 1)
- msg = self.get(timeout=timeout)
+ msg = await self._async_get(timeout=timeout)
self.log.debug(
"Received message on channel: {}, msg_id: {}, msg_type: {}".format(
self.channel_name, msg["msg_id"], msg["msg_type"] if msg else "null"
)
)
self.task_done()
- return msg
+ return cast("dict[str, Any]", msg)
- def send(self, msg: dict) -> None:
+ def send(self, msg: dict[str, Any]) -> None:
+ """Send a message to the queue."""
message = json.dumps(msg, default=ChannelQueue.serialize_datetime).replace("", "<\\/")
self.log.debug(
"Sending message on channel: {}, msg_id: {}, msg_type: {}".format(
@@ -516,14 +648,16 @@ def send(self, msg: dict) -> None:
@staticmethod
def serialize_datetime(dt):
- if isinstance(dt, (datetime.date, datetime.datetime)):
+ """Serialize a datetime object."""
+ if isinstance(dt, datetime.datetime):
return dt.timestamp()
return None
def start(self) -> None:
- pass
+ """Start the queue."""
def stop(self) -> None:
+ """Stop the queue."""
if not self.empty():
# If unprocessed messages are detected, drain the queue collecting non-status
# messages. If any remain that are not 'shutdown_reply' and this is not iopub
@@ -543,11 +677,15 @@ def stop(self) -> None:
)
def is_alive(self) -> bool:
+ """Whether the queue is alive."""
return self.channel_socket is not None
class HBChannelQueue(ChannelQueue):
+ """A queue for the heartbeat channel."""
+
def is_beating(self) -> bool:
+ """Whether the channel is beating."""
# Just use the is_alive status for now
return self.is_alive()
@@ -571,14 +709,22 @@ class GatewayKernelClient(AsyncKernelClient):
# flag for whether execute requests should be allowed to call raw_input:
allow_stdin = False
- _channels_stopped = False
- _channel_queues = {}
-
- def __init__(self, **kwargs):
+ _channels_stopped: bool
+ _channel_queues: Optional[dict[str, ChannelQueue]]
+ _control_channel: Optional[ChannelQueue] # type:ignore[assignment]
+ _hb_channel: Optional[ChannelQueue] # type:ignore[assignment]
+ _stdin_channel: Optional[ChannelQueue] # type:ignore[assignment]
+ _iopub_channel: Optional[ChannelQueue] # type:ignore[assignment]
+ _shell_channel: Optional[ChannelQueue] # type:ignore[assignment]
+
+ def __init__(self, kernel_id, **kwargs):
+ """Initialize a gateway kernel client."""
super().__init__(**kwargs)
- self.kernel_id = kwargs["kernel_id"]
- self.channel_socket = None
- self.response_router = None
+ self.kernel_id = kernel_id
+ self.channel_socket: Optional[websocket.WebSocket] = None
+ self.response_router: Optional[Thread] = None
+ self._channels_stopped = False
+ self._channel_queues = {}
# --------------------------------------------------------------------------
# Channel management methods
@@ -588,21 +734,22 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont
"""Starts the channels for this kernel.
For this class, we establish a websocket connection to the destination
- and setup the channel-based queues on which applicable messages will
+ and set up the channel-based queues on which applicable messages will
be posted.
"""
ws_url = url_path_join(
- GatewayClient.instance().ws_url,
+ GatewayClient.instance().ws_url or "",
GatewayClient.instance().kernels_endpoint,
url_escape(self.kernel_id),
"channels",
)
# Gather cert info in case where ssl is desired...
- ssl_options = {}
- ssl_options["ca_certs"] = GatewayClient.instance().ca_certs
- ssl_options["certfile"] = GatewayClient.instance().client_cert
- ssl_options["keyfile"] = GatewayClient.instance().client_key
+ ssl_options = {
+ "ca_certs": GatewayClient.instance().ca_certs,
+ "certfile": GatewayClient.instance().client_cert,
+ "keyfile": GatewayClient.instance().client_key,
+ }
self.channel_socket = websocket.create_connection(
ws_url,
@@ -610,13 +757,14 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont
enable_multithread=True,
sslopt=ssl_options,
)
- self.response_router = Thread(target=self._route_responses)
- self.response_router.start()
await ensure_async(
super().start_channels(shell=shell, iopub=iopub, stdin=stdin, hb=hb, control=control)
)
+ self.response_router = Thread(target=self._route_responses)
+ self.response_router.start()
+
def stop_channels(self):
"""Stops all the running channels for this kernel.
@@ -627,7 +775,9 @@ def stop_channels(self):
self._channels_stopped = True
self.log.debug("Closing websocket connection")
+ assert self.channel_socket is not None
self.channel_socket.close()
+ assert self.response_router is not None
self.response_router.join()
if self._channel_queues:
@@ -641,7 +791,9 @@ def shell_channel(self):
"""Get the shell channel object for this kernel."""
if self._shell_channel is None:
self.log.debug("creating shell channel queue")
+ assert self.channel_socket is not None
self._shell_channel = ChannelQueue("shell", self.channel_socket, self.log)
+ assert self._channel_queues is not None
self._channel_queues["shell"] = self._shell_channel
return self._shell_channel
@@ -650,7 +802,9 @@ def iopub_channel(self):
"""Get the iopub channel object for this kernel."""
if self._iopub_channel is None:
self.log.debug("creating iopub channel queue")
+ assert self.channel_socket is not None
self._iopub_channel = ChannelQueue("iopub", self.channel_socket, self.log)
+ assert self._channel_queues is not None
self._channel_queues["iopub"] = self._iopub_channel
return self._iopub_channel
@@ -659,7 +813,9 @@ def stdin_channel(self):
"""Get the stdin channel object for this kernel."""
if self._stdin_channel is None:
self.log.debug("creating stdin channel queue")
+ assert self.channel_socket is not None
self._stdin_channel = ChannelQueue("stdin", self.channel_socket, self.log)
+ assert self._channel_queues is not None
self._channel_queues["stdin"] = self._stdin_channel
return self._stdin_channel
@@ -668,7 +824,9 @@ def hb_channel(self):
"""Get the hb channel object for this kernel."""
if self._hb_channel is None:
self.log.debug("creating hb channel queue")
+ assert self.channel_socket is not None
self._hb_channel = HBChannelQueue("hb", self.channel_socket, self.log)
+ assert self._channel_queues is not None
self._channel_queues["hb"] = self._hb_channel
return self._hb_channel
@@ -677,7 +835,9 @@ def control_channel(self):
"""Get the control channel object for this kernel."""
if self._control_channel is None:
self.log.debug("creating control channel queue")
+ assert self.channel_socket is not None
self._control_channel = ChannelQueue("control", self.channel_socket, self.log)
+ assert self._channel_queues is not None
self._channel_queues["control"] = self._control_channel
return self._control_channel
@@ -691,20 +851,27 @@ def _route_responses(self):
"""
try:
while not self._channels_stopped:
+ assert self.channel_socket is not None
raw_message = self.channel_socket.recv()
if not raw_message:
break
response_message = json_decode(utf8(raw_message))
channel = response_message["channel"]
+ assert self._channel_queues is not None
self._channel_queues[channel].put_nowait(response_message)
except websocket.WebSocketConnectionClosedException:
- pass # websocket closure most likely due to shutdown
+ pass # websocket closure most likely due to shut down
except BaseException as be:
if not self._channels_stopped:
self.log.warning(f"Unexpected exception encountered ({be})")
+ # Notify channel queues that this thread had finished and no more messages are being received
+ assert self._channel_queues is not None
+ for channel_queue in self._channel_queues.values():
+ channel_queue.response_router_finished = True
+
self.log.debug("Response router thread exiting...")
diff --git a/jupyter_server/i18n/README.md b/jupyter_server/i18n/README.md
index 17a475ce47..28562c2748 100644
--- a/jupyter_server/i18n/README.md
+++ b/jupyter_server/i18n/README.md
@@ -29,7 +29,9 @@ if running Ubuntu 14, you should set environment variable `LANGUAGE="xx_XX"`.
**All i18n-related commands are done from the related directory :**
- cd notebook/i18n/
+```
+cd notebook/i18n/
+```
### Message extraction
@@ -69,7 +71,9 @@ pybabel compile -D nbui -f -l ${LANG} -i ${LANG}/LC_MESSAGES/nbui.po -o ${LANG}/
_nbjs.po_ needs to be converted to JSON for use within the JavaScript code, with _po2json_, as follows:
- po2json -p -F -f jed1.x -d nbjs ${LANG}/LC_MESSAGES/nbjs.po ${LANG}/LC_MESSAGES/nbjs.json
+```
+po2json -p -F -f jed1.x -d nbjs ${LANG}/LC_MESSAGES/nbjs.po ${LANG}/LC_MESSAGES/nbjs.json
+```
When new languages get added, their language codes should be added to _notebook/i18n/nbjs.json_
under the `supported_languages` element.
@@ -111,21 +115,25 @@ to handle these cases properly.
### Known issues and future evolutions
-1. Right now there are two different places where the desired language is set. At startup time, the Jupyter console's messages pay attention to the setting of the `${LANG}` environment variable
- as set in the shell at startup time. Unfortunately, this is also the time where the Jinja2
- environment is set up, which means that the template stuff will always come from this setting.
- We really want to be paying attention to the browser's settings for the stuff that happens in the
- browser, so we need to be able to retrieve this information after the browser is started and somehow
- communicate this back to Jinja2. So far, I haven't yet figured out how to do this, which means that if the ${LANG} at startup doesn't match the browser's settings, you could potentially get a mix
- of languages in the UI ( never a good thing ).
-
-2. We will need to decide if console messages should be translatable, and enable them if desired.
-3. The keyboard shortcut editor was implemented after the i18n work was completed, so that portion
- does not have translation support at this time.
-4. Babel's documentation has instructions on how to integrate messages extraction
- into your _setup.py_ so that eventually we can just do:
-
- ./setup.py extract_messages
+1. Right now there are two different places where the desired language is set. At startup time, the Jupyter console's messages pay attention to the setting of the `${LANG}` environment variable
+ as set in the shell at startup time. Unfortunately, this is also the time where the Jinja2
+ environment is set up, which means that the template stuff will always come from this setting.
+ We really want to be paying attention to the browser's settings for the stuff that happens in the
+ browser, so we need to be able to retrieve this information after the browser is started and somehow
+ communicate this back to Jinja2. So far, I haven't yet figured out how to do this, which means that if the ${LANG} at startup doesn't match the browser's settings, you could potentially get a mix
+ of languages in the UI ( never a good thing ).
+
+1. We will need to decide if console messages should be translatable, and enable them if desired.
+
+1. The keyboard shortcut editor was implemented after the i18n work was completed, so that portion
+ does not have translation support at this time.
+
+1. Babel's documentation has instructions on how to integrate messages extraction
+ into your _setup.py_ so that eventually we can just do:
+
+ ```
+ ./setup.py extract_messages
+ ```
I hope to get this working at some point in the near future. 5. The conversions from `.po` to `.mo` probably can and should be done using `setup.py install`.
diff --git a/jupyter_server/i18n/__init__.py b/jupyter_server/i18n/__init__.py
index e44aa11393..896f41c57c 100644
--- a/jupyter_server/i18n/__init__.py
+++ b/jupyter_server/i18n/__init__.py
@@ -1,11 +1,14 @@
"""Server functions for loading translations
"""
+from __future__ import annotations
+
import errno
import json
import re
from collections import defaultdict
from os.path import dirname
from os.path import join as pjoin
+from typing import Any
I18N_DIR = dirname(__file__)
# Cache structure:
@@ -15,7 +18,7 @@
# ...
# }
# }}
-TRANSLATIONS_CACHE = {"nbjs": {}}
+TRANSLATIONS_CACHE: dict[str, Any] = {"nbjs": {}}
_accept_lang_re = re.compile(
@@ -42,10 +45,7 @@ def parse_accept_lang_header(accept_lang):
lang, qvalue = m.group("lang", "qvalue")
# Browser header format is zh-CN, gettext uses zh_CN
lang = lang.replace("-", "_")
- if qvalue is None:
- qvalue = 1.0
- else:
- qvalue = float(qvalue)
+ qvalue = 1.0 if qvalue is None else float(qvalue)
if qvalue == 0:
continue # 0 means not accepted
by_q[qvalue].append(lang)
@@ -59,7 +59,7 @@ def parse_accept_lang_header(accept_lang):
def load(language, domain="nbjs"):
"""Load translations from an nbjs.json file"""
try:
- f = open(pjoin(I18N_DIR, language, "LC_MESSAGES", "nbjs.json"), encoding="utf-8")
+ f = open(pjoin(I18N_DIR, language, "LC_MESSAGES", "nbjs.json"), encoding="utf-8") # noqa: SIM115
except OSError as e:
if e.errno != errno.ENOENT:
raise
@@ -87,7 +87,7 @@ def combine_translations(accept_language, domain="nbjs"):
Returns data re-packaged in jed1.x format.
"""
lang_codes = parse_accept_lang_header(accept_language)
- combined = {}
+ combined: dict[str, Any] = {}
for language in lang_codes:
if language == "en":
# en is default, all translations are in frontend.
diff --git a/jupyter_server/i18n/notebook.pot b/jupyter_server/i18n/notebook.pot
index 333b40d76c..b8d588f964 100644
--- a/jupyter_server/i18n/notebook.pot
+++ b/jupyter_server/i18n/notebook.pot
@@ -280,7 +280,7 @@ msgid "server_extensions is deprecated, use jpserver_extensions"
msgstr ""
#: jupyter_server/serverapp.py:1040
-msgid "Dict of Python modules to load as notebook server extensions. Entry values can be used to enable and disable the loading ofthe extensions. The extensions will be loaded in alphabetical order."
+msgid "Dict of Python modules to load as notebook server extensions. Entry values can be used to enable and disable the loading of the extensions. The extensions will be loaded in alphabetical order."
msgstr ""
#: jupyter_server/serverapp.py:1049
diff --git a/jupyter_server/i18n/zh_CN/LC_MESSAGES/notebook.po b/jupyter_server/i18n/zh_CN/LC_MESSAGES/notebook.po
index ee74a2097c..8f65bd35bf 100644
--- a/jupyter_server/i18n/zh_CN/LC_MESSAGES/notebook.po
+++ b/jupyter_server/i18n/zh_CN/LC_MESSAGES/notebook.po
@@ -283,7 +283,7 @@ msgid "No such notebook dir: '%r'"
msgstr "没有找到路径: '%r' "
#: notebook/serverapp.py:1046
-msgid "Dict of Python modules to load as notebook server extensions.Entry values can be used to enable and disable the loading ofthe extensions. The extensions will be loaded in alphabetical order."
+msgid "Dict of Python modules to load as notebook server extensions.Entry values can be used to enable and disable the loading of the extensions. The extensions will be loaded in alphabetical order."
msgstr "将Python模块作为笔记本服务器扩展加载。可以使用条目值来启用和禁用扩展的加载。这些扩展将以字母顺序加载。"
#: notebook/serverapp.py:1055
diff --git a/jupyter_server/kernelspecs/handlers.py b/jupyter_server/kernelspecs/handlers.py
index 3ac8506a31..c7cb141459 100644
--- a/jupyter_server/kernelspecs/handlers.py
+++ b/jupyter_server/kernelspecs/handlers.py
@@ -1,6 +1,10 @@
+"""Kernelspecs API Handlers."""
+import mimetypes
+
+from jupyter_core.utils import ensure_async
from tornado import web
-from jupyter_server.auth import authorized
+from jupyter_server.auth.decorator import authorized
from ..base.handlers import JupyterHandler
from ..services.kernelspecs.handlers import kernel_name_regex
@@ -9,29 +13,55 @@
class KernelSpecResourceHandler(web.StaticFileHandler, JupyterHandler):
- SUPPORTED_METHODS = ("GET", "HEAD")
+ """A Kernelspec resource handler."""
+
+ SUPPORTED_METHODS = ("GET", "HEAD") # type:ignore[assignment]
auth_resource = AUTH_RESOURCE
def initialize(self):
+ """Initialize a kernelspec resource handler."""
web.StaticFileHandler.initialize(self, path="")
@web.authenticated
@authorized
- def get(self, kernel_name, path, include_body=True):
+ async def get(self, kernel_name, path, include_body=True):
+ """Get a kernelspec resource."""
ksm = self.kernel_spec_manager
if path.lower().endswith(".png"):
self.set_header("Cache-Control", f"max-age={60*60*24*30}")
+ ksm = self.kernel_spec_manager
+ if hasattr(ksm, "get_kernel_spec_resource"):
+ # If the kernel spec manager defines a method to get kernelspec resources,
+ # then use that instead of trying to read from disk.
+ kernel_spec_res = await ksm.get_kernel_spec_resource(kernel_name, path)
+ if kernel_spec_res is not None:
+ # We have to explicitly specify the `absolute_path` attribute so that
+ # the underlying StaticFileHandler methods can calculate an etag.
+ self.absolute_path = path
+ mimetype: str = mimetypes.guess_type(path)[0] or "text/plain"
+ self.set_header("Content-Type", mimetype)
+ self.finish(kernel_spec_res)
+ return None
+ else:
+ self.log.warning(
+ "Kernelspec resource '{}' for '{}' not found. Kernel spec manager may"
+ " not support resource serving. Falling back to reading from disk".format(
+ path, kernel_name
+ )
+ )
try:
- self.root = ksm.get_kernel_spec(kernel_name).resource_dir
+ kspec = await ensure_async(ksm.get_kernel_spec(kernel_name))
+ self.root = kspec.resource_dir
except KeyError as e:
raise web.HTTPError(404, "Kernel spec %s not found" % kernel_name) from e
self.log.debug("Serving kernel resource from: %s", self.root)
- return web.StaticFileHandler.get(self, path, include_body=include_body)
+ return await web.StaticFileHandler.get(self, path, include_body=include_body)
@web.authenticated
@authorized
- def head(self, kernel_name, path):
- return self.get(kernel_name, path, include_body=False)
+ async def head(self, kernel_name, path):
+ """Get the head info for a kernel resource."""
+ return await ensure_async(self.get(kernel_name, path, include_body=False))
default_handlers = [
diff --git a/jupyter_server/log.py b/jupyter_server/log.py
index d23799456d..705eaaf44c 100644
--- a/jupyter_server/log.py
+++ b/jupyter_server/log.py
@@ -1,15 +1,44 @@
+"""Log utilities."""
# -----------------------------------------------------------------------------
# Copyright (c) Jupyter Development Team
#
# Distributed under the terms of the BSD License. The full license is in
-# the file COPYING, distributed as part of this software.
+# the file LICENSE, distributed as part of this software.
# -----------------------------------------------------------------------------
import json
+from urllib.parse import urlparse, urlunparse
from tornado.log import access_log
+from .auth import User
from .prometheus.log_functions import prometheus_log_method
+# url params to be scrubbed if seen
+# any url param that *contains* one of these
+# will be scrubbed from logs
+_SCRUB_PARAM_KEYS = {"token", "auth", "key", "code", "state", "xsrf"}
+
+
+def _scrub_uri(uri: str) -> str:
+ """scrub auth info from uri"""
+ parsed = urlparse(uri)
+ if parsed.query:
+ # check for potentially sensitive url params
+ # use manual list + split rather than parsing
+ # to minimally perturb original
+ parts = parsed.query.split("&")
+ changed = False
+ for i, s in enumerate(parts):
+ key, sep, value = s.partition("=")
+ for substring in _SCRUB_PARAM_KEYS:
+ if substring in key:
+ parts[i] = f"{key}{sep}[secret]"
+ changed = True
+ if changed:
+ parsed = parsed._replace(query="&".join(parts))
+ return urlunparse(parsed)
+ return uri
+
def log_request(handler):
"""log a bit more information about each request than tornado's default
@@ -37,17 +66,27 @@ def log_request(handler):
log_method = logger.error
request_time = 1000.0 * handler.request.request_time()
- ns = dict(
- status=status,
- method=request.method,
- ip=request.remote_ip,
- uri=request.uri,
- request_time=request_time,
- )
- msg = "{status} {method} {uri} ({ip}) {request_time:.2f}ms"
+ ns = {
+ "status": status,
+ "method": request.method,
+ "ip": request.remote_ip,
+ "uri": _scrub_uri(request.uri),
+ "request_time": request_time,
+ }
+ # log username
+ # make sure we don't break anything
+ # in case mixins cause current_user to not be a User somehow
+ try:
+ user = handler.current_user
+ except Exception:
+ user = None
+ username = (user.username if isinstance(user, User) else "unknown") if user else ""
+ ns["username"] = username
+
+ msg = "{status} {method} {uri} ({username}@{ip}) {request_time:.2f}ms"
if status >= 400:
- # log bad referers
- ns["referer"] = request.headers.get("Referer", "None")
+ # log bad referrers
+ ns["referer"] = _scrub_uri(request.headers.get("Referer", "None"))
msg = msg + " referer={referer}"
if status >= 500 and status != 502:
# Log a subset of the headers if it caused an error.
diff --git a/jupyter_server/nbconvert/handlers.py b/jupyter_server/nbconvert/handlers.py
index c5e3840699..b7a39d0c8b 100644
--- a/jupyter_server/nbconvert/handlers.py
+++ b/jupyter_server/nbconvert/handlers.py
@@ -7,12 +7,12 @@
import zipfile
from anyio.to_thread import run_sync
+from jupyter_core.utils import ensure_async
from nbformat import from_dict
from tornado import web
from tornado.log import app_log
-from jupyter_server.auth import authorized
-from jupyter_server.utils import ensure_async
+from jupyter_server.auth.decorator import authorized
from ..base.handlers import FilesRedirectHandler, JupyterHandler, path_regex
@@ -27,6 +27,7 @@
def find_resource_files(output_files_dir):
+ """Find the resource files in a directory."""
files = []
for dirpath, _, filenames in os.walk(output_files_dir):
files.extend([os.path.join(dirpath, f) for f in filenames])
@@ -85,13 +86,15 @@ def get_exporter(format, **kwargs):
class NbconvertFileHandler(JupyterHandler):
+ """An nbconvert file handler."""
auth_resource = AUTH_RESOURCE
- SUPPORTED_METHODS = ("GET",)
+ SUPPORTED_METHODS = ("GET",) # type:ignore[assignment]
@web.authenticated
@authorized
async def get(self, format, path):
+ """Get a notebook file in a desired format."""
self.check_xsrf_cookie()
exporter = get_exporter(format, config=self.config, log=self.log)
@@ -132,11 +135,11 @@ async def get(self, format, path):
lambda: exporter.from_notebook_node(nb, resources=resource_dict)
)
except Exception as e:
- self.log.exception("nbconvert failed: %s", e)
+ self.log.exception("nbconvert failed: %r", e)
raise web.HTTPError(500, "nbconvert failed: %s" % e) from e
if respond_zip(self, name, output, resources):
- return
+ return None
# Force download if requested
if self.get_argument("download", "false").lower() == "true":
@@ -152,16 +155,19 @@ async def get(self, format, path):
class NbconvertPostHandler(JupyterHandler):
+ """An nbconvert post handler."""
- SUPPORTED_METHODS = ("POST",)
+ SUPPORTED_METHODS = ("POST",) # type:ignore[assignment]
auth_resource = AUTH_RESOURCE
@web.authenticated
@authorized
async def post(self, format):
+ """Convert a notebook file to a desired format."""
exporter = get_exporter(format, config=self.config)
model = self.get_json_body()
+ assert model is not None
name = model.get("name", "notebook.ipynb")
nbnode = from_dict(model["content"])
diff --git a/jupyter_server/prometheus/log_functions.py b/jupyter_server/prometheus/log_functions.py
index 4f0d497b6c..ac4bd620c1 100644
--- a/jupyter_server/prometheus/log_functions.py
+++ b/jupyter_server/prometheus/log_functions.py
@@ -1,4 +1,5 @@
-from .metrics import HTTP_REQUEST_DURATION_SECONDS
+"""Log functions for prometheus"""
+from .metrics import HTTP_REQUEST_DURATION_SECONDS # type:ignore[unused-ignore]
def prometheus_log_method(handler):
diff --git a/jupyter_server/prometheus/metrics.py b/jupyter_server/prometheus/metrics.py
index ae98043c3e..1a02f86209 100644
--- a/jupyter_server/prometheus/metrics.py
+++ b/jupyter_server/prometheus/metrics.py
@@ -16,7 +16,6 @@
)
except ImportError:
-
from prometheus_client import Gauge, Histogram
HTTP_REQUEST_DURATION_SECONDS = Histogram(
@@ -35,3 +34,10 @@
"counter for how many kernels are running labeled by type",
["type"],
)
+
+
+__all__ = [
+ "HTTP_REQUEST_DURATION_SECONDS",
+ "TERMINAL_CURRENTLY_RUNNING_TOTAL",
+ "KERNEL_CURRENTLY_RUNNING_TOTAL",
+]
diff --git a/.gitmodules b/jupyter_server/py.typed
similarity index 100%
rename from .gitmodules
rename to jupyter_server/py.typed
diff --git a/jupyter_server/pytest_plugin.py b/jupyter_server/pytest_plugin.py
index 7b35795c63..f77448f866 100644
--- a/jupyter_server/pytest_plugin.py
+++ b/jupyter_server/pytest_plugin.py
@@ -1,409 +1,15 @@
+"""Pytest Fixtures exported by Jupyter Server."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
-import io
import json
-import logging
-import os
-import shutil
-import sys
-import urllib.parse
-from binascii import hexlify
+from pathlib import Path
-import jupyter_core.paths
-import nbformat
import pytest
-import tornado
-from tornado.escape import url_escape
-from traitlets.config import Config
-from jupyter_server.extension import serverextension
-from jupyter_server.serverapp import ServerApp
-from jupyter_server.services.contents.filemanager import FileContentsManager
-from jupyter_server.services.contents.largefilemanager import LargeFileManager
-from jupyter_server.utils import url_path_join
-
-# List of dependencies needed for this plugin.
-pytest_plugins = [
- "pytest_tornasync",
- # Once the chunk below moves to Jupyter Core, we'll uncomment
- # This plugin and use the fixtures directly from Jupyter Core.
- # "jupyter_core.pytest_plugin"
-]
-
-
-import asyncio
-
-if os.name == "nt" and sys.version_info >= (3, 7):
- asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
-
-
-# ============ Move to Jupyter Core =============
-
-
-def mkdir(tmp_path, *parts):
- path = tmp_path.joinpath(*parts)
- if not path.exists():
- path.mkdir(parents=True)
- return path
-
-
-@pytest.fixture
-def jp_home_dir(tmp_path):
- """Provides a temporary HOME directory value."""
- return mkdir(tmp_path, "home")
-
-
-@pytest.fixture
-def jp_data_dir(tmp_path):
- """Provides a temporary Jupyter data dir directory value."""
- return mkdir(tmp_path, "data")
-
-
-@pytest.fixture
-def jp_config_dir(tmp_path):
- """Provides a temporary Jupyter config dir directory value."""
- return mkdir(tmp_path, "config")
-
-
-@pytest.fixture
-def jp_runtime_dir(tmp_path):
- """Provides a temporary Jupyter runtime dir directory value."""
- return mkdir(tmp_path, "runtime")
-
-
-@pytest.fixture
-def jp_system_jupyter_path(tmp_path):
- """Provides a temporary Jupyter system path value."""
- return mkdir(tmp_path, "share", "jupyter")
-
-
-@pytest.fixture
-def jp_env_jupyter_path(tmp_path):
- """Provides a temporary Jupyter env system path value."""
- return mkdir(tmp_path, "env", "share", "jupyter")
-
-
-@pytest.fixture
-def jp_system_config_path(tmp_path):
- """Provides a temporary Jupyter config path value."""
- return mkdir(tmp_path, "etc", "jupyter")
-
-
-@pytest.fixture
-def jp_env_config_path(tmp_path):
- """Provides a temporary Jupyter env config path value."""
- return mkdir(tmp_path, "env", "etc", "jupyter")
-
-
-@pytest.fixture
-def jp_environ(
- monkeypatch,
- tmp_path,
- jp_home_dir,
- jp_data_dir,
- jp_config_dir,
- jp_runtime_dir,
- jp_system_jupyter_path,
- jp_system_config_path,
- jp_env_jupyter_path,
- jp_env_config_path,
-):
- """Configures a temporary environment based on Jupyter-specific environment variables."""
- monkeypatch.setenv("HOME", str(jp_home_dir))
- monkeypatch.setenv("PYTHONPATH", os.pathsep.join(sys.path))
- # monkeypatch.setenv("JUPYTER_NO_CONFIG", "1")
- monkeypatch.setenv("JUPYTER_CONFIG_DIR", str(jp_config_dir))
- monkeypatch.setenv("JUPYTER_DATA_DIR", str(jp_data_dir))
- monkeypatch.setenv("JUPYTER_RUNTIME_DIR", str(jp_runtime_dir))
- monkeypatch.setattr(jupyter_core.paths, "SYSTEM_JUPYTER_PATH", [str(jp_system_jupyter_path)])
- monkeypatch.setattr(jupyter_core.paths, "ENV_JUPYTER_PATH", [str(jp_env_jupyter_path)])
- monkeypatch.setattr(jupyter_core.paths, "SYSTEM_CONFIG_PATH", [str(jp_system_config_path)])
- monkeypatch.setattr(jupyter_core.paths, "ENV_CONFIG_PATH", [str(jp_env_config_path)])
-
-
-# ================= End: Move to Jupyter core ================
-
-
-@pytest.fixture
-def jp_server_config():
- """Allows tests to setup their specific configuration values."""
- return {}
-
-
-@pytest.fixture
-def jp_root_dir(tmp_path):
- """Provides a temporary Jupyter root directory value."""
- return mkdir(tmp_path, "root_dir")
-
-
-@pytest.fixture
-def jp_template_dir(tmp_path):
- """Provides a temporary Jupyter templates directory value."""
- return mkdir(tmp_path, "templates")
-
-
-@pytest.fixture
-def jp_argv():
- """Allows tests to setup specific argv values."""
- return []
-
-
-@pytest.fixture
-def jp_extension_environ(jp_env_config_path, monkeypatch):
- """Monkeypatch a Jupyter Extension's config path into each test's environment variable"""
- monkeypatch.setattr(serverextension, "ENV_CONFIG_PATH", [str(jp_env_config_path)])
-
-
-@pytest.fixture
-def jp_http_port(http_server_port):
- """Returns the port value from the http_server_port fixture."""
- return http_server_port[-1]
-
-
-@pytest.fixture
-def jp_nbconvert_templates(jp_data_dir):
- """Setups up a temporary directory consisting of the nbconvert templates."""
-
- # Get path to nbconvert template directory *before*
- # monkeypatching the paths env variable via the jp_environ fixture.
- possible_paths = jupyter_core.paths.jupyter_path("nbconvert", "templates")
- nbconvert_path = None
- for path in possible_paths:
- if os.path.exists(path):
- nbconvert_path = path
- break
-
- nbconvert_target = jp_data_dir / "nbconvert" / "templates"
-
- # copy nbconvert templates to new tmp data_dir.
- if nbconvert_path:
- shutil.copytree(nbconvert_path, str(nbconvert_target))
-
-
-@pytest.fixture
-def jp_logging_stream():
- """StringIO stream intended to be used by the core
- Jupyter ServerApp logger's default StreamHandler. This
- helps avoid collision with stdout which is hijacked
- by Pytest.
- """
- logging_stream = io.StringIO()
- yield logging_stream
- output = logging_stream.getvalue()
- # If output exists, print it.
- if output:
- print(output)
- return output
-
-
-@pytest.fixture(scope="function")
-def jp_configurable_serverapp(
- jp_nbconvert_templates, # this fixture must preceed jp_environ
- jp_environ,
- jp_server_config,
- jp_argv,
- jp_http_port,
- jp_base_url,
- tmp_path,
- jp_root_dir,
- io_loop,
- jp_logging_stream,
-):
- """Starts a Jupyter Server instance based on
- the provided configuration values.
-
- The fixture is a factory; it can be called like
- a function inside a unit test. Here's a basic
- example of how use this fixture:
-
- .. code-block:: python
-
- def my_test(jp_configurable_serverapp):
-
- app = jp_configurable_serverapp(...)
- ...
- """
- ServerApp.clear_instance()
-
- def _configurable_serverapp(
- config=jp_server_config,
- base_url=jp_base_url,
- argv=jp_argv,
- environ=jp_environ,
- http_port=jp_http_port,
- tmp_path=tmp_path,
- root_dir=jp_root_dir,
- **kwargs,
- ):
- c = Config(config)
- c.NotebookNotary.db_file = ":memory:"
- token = hexlify(os.urandom(4)).decode("ascii")
- app = ServerApp.instance(
- # Set the log level to debug for testing purposes
- log_level="DEBUG",
- port=http_port,
- port_retries=0,
- open_browser=False,
- root_dir=str(root_dir),
- base_url=base_url,
- config=c,
- allow_root=True,
- token=token,
- **kwargs,
- )
-
- app.init_signal = lambda: None
- app.log.propagate = True
- app.log.handlers = []
- # Initialize app without httpserver
- app.initialize(argv=argv, new_httpserver=False)
- # Reroute all logging StreamHandlers away from stdin/stdout since pytest hijacks
- # these streams and closes them at unfortunate times.
- stream_handlers = [h for h in app.log.handlers if isinstance(h, logging.StreamHandler)]
- for handler in stream_handlers:
- handler.setStream(jp_logging_stream)
- app.log.propagate = True
- app.log.handlers = []
- # Start app without ioloop
- app.start_app()
- return app
-
- return _configurable_serverapp
-
-
-@pytest.fixture
-def jp_ensure_app_fixture(request):
- """Ensures that the 'app' fixture used by pytest-tornasync
- is set to `jp_web_app`, the Tornado Web Application returned
- by the ServerApp in Jupyter Server, provided by the jp_web_app
- fixture in this module.
-
- Note, this hardcodes the `app_fixture` option from
- pytest-tornasync to `jp_web_app`. If this value is configured
- to something other than the default, it will raise an exception.
- """
- app_option = request.config.getoption("app_fixture")
- if app_option not in ["app", "jp_web_app"]:
- raise Exception(
- "jp_serverapp requires the `app-fixture` option "
- "to be set to 'jp_web_app`. Try rerunning the "
- "current tests with the option `--app-fixture "
- "jp_web_app`."
- )
- elif app_option == "app":
- # Manually set the app_fixture to `jp_web_app` if it's
- # not set already.
- request.config.option.app_fixture = "jp_web_app"
-
-
-@pytest.fixture(scope="function")
-def jp_serverapp(jp_ensure_app_fixture, jp_server_config, jp_argv, jp_configurable_serverapp):
- """Starts a Jupyter Server instance based on the established configuration values."""
- app = jp_configurable_serverapp(config=jp_server_config, argv=jp_argv)
- yield app
- app.remove_server_info_file()
- app.remove_browser_open_files()
-
-
-@pytest.fixture
-def jp_web_app(jp_serverapp):
- """app fixture is needed by pytest_tornasync plugin"""
- return jp_serverapp.web_app
-
-
-@pytest.fixture
-def jp_auth_header(jp_serverapp):
- """Configures an authorization header using the token from the serverapp fixture."""
- return {"Authorization": f"token {jp_serverapp.token}"}
-
-
-@pytest.fixture
-def jp_base_url():
- """Returns the base url to use for the test."""
- return "/a%40b/"
-
-
-@pytest.fixture
-def jp_fetch(jp_serverapp, http_server_client, jp_auth_header, jp_base_url):
- """Sends an (asynchronous) HTTP request to a test server.
-
- The fixture is a factory; it can be called like
- a function inside a unit test. Here's a basic
- example of how use this fixture:
-
- .. code-block:: python
-
- async def my_test(jp_fetch):
-
- response = await jp_fetch("api", "spec.yaml")
- ...
- """
-
- def client_fetch(*parts, headers=None, params=None, **kwargs):
- if not headers:
- headers = {}
- if not params:
- params = {}
- # Handle URL strings
- path_url = url_escape(url_path_join(*parts), plus=False)
- base_path_url = url_path_join(jp_base_url, path_url)
- params_url = urllib.parse.urlencode(params)
- url = base_path_url + "?" + params_url
- # Add auth keys to header
- headers.update(jp_auth_header)
- # Make request.
- return http_server_client.fetch(url, headers=headers, request_timeout=20, **kwargs)
-
- return client_fetch
-
-
-@pytest.fixture
-def jp_ws_fetch(jp_serverapp, http_server_client, jp_auth_header, jp_http_port, jp_base_url):
- """Sends a websocket request to a test server.
-
- The fixture is a factory; it can be called like
- a function inside a unit test. Here's a basic
- example of how use this fixture:
-
- .. code-block:: python
-
- async def my_test(jp_fetch, jp_ws_fetch):
- # Start a kernel
- r = await jp_fetch(
- 'api', 'kernels',
- method='POST',
- body=json.dumps({
- 'name': "python3"
- })
- )
- kid = json.loads(r.body.decode())['id']
-
- # Open a websocket connection.
- ws = await jp_ws_fetch(
- 'api', 'kernels', kid, 'channels'
- )
- ...
- """
-
- def client_fetch(*parts, headers=None, params=None, **kwargs):
- if not headers:
- headers = {}
- if not params:
- params = {}
- # Handle URL strings
- path_url = url_escape(url_path_join(*parts), plus=False)
- base_path_url = url_path_join(jp_base_url, path_url)
- urlparts = urllib.parse.urlparse(f"ws://localhost:{jp_http_port}")
- urlparts = urlparts._replace(path=base_path_url, query=urllib.parse.urlencode(params))
- url = urlparts.geturl()
- # Add auth keys to header
- headers.update(jp_auth_header)
- # Make request.
- req = tornado.httpclient.HTTPRequest(url, headers=headers, connect_timeout=120)
- return tornado.websocket.websocket_connect(req)
-
- return client_fetch
+from jupyter_server.services.contents.filemanager import AsyncFileContentsManager
+from jupyter_server.services.contents.largefilemanager import AsyncLargeFileManager
+pytest_plugins = ["pytest_jupyter.jupyter_server"]
some_resource = "The very model of a modern major general"
sample_kernel_json = {
@@ -412,8 +18,8 @@ def client_fetch(*parts, headers=None, params=None, **kwargs):
}
-@pytest.fixture
-def jp_kernelspecs(jp_data_dir):
+@pytest.fixture() # type:ignore[misc]
+def jp_kernelspecs(jp_data_dir: Path) -> None: # noqa: PT004
"""Configures some sample kernelspecs in the Jupyter data directory."""
spec_names = ["sample", "sample2", "bad"]
for name in spec_names:
@@ -432,79 +38,11 @@ def jp_kernelspecs(jp_data_dir):
@pytest.fixture(params=[True, False])
def jp_contents_manager(request, tmp_path):
- """Returns a FileContentsManager instance based on the use_atomic_writing parameter value."""
- return FileContentsManager(root_dir=str(tmp_path), use_atomic_writing=request.param)
+ """Returns an AsyncFileContentsManager instance based on the use_atomic_writing parameter value."""
+ return AsyncFileContentsManager(root_dir=str(tmp_path), use_atomic_writing=request.param)
-@pytest.fixture
+@pytest.fixture()
def jp_large_contents_manager(tmp_path):
- """Returns a LargeFileManager instance."""
- return LargeFileManager(root_dir=str(tmp_path))
-
-
-@pytest.fixture
-def jp_create_notebook(jp_root_dir):
- """Creates a notebook in the test's home directory."""
-
- def inner(nbpath):
- nbpath = jp_root_dir.joinpath(nbpath)
- # Check that the notebook has the correct file extension.
- if nbpath.suffix != ".ipynb":
- raise Exception("File extension for notebook must be .ipynb")
- # If the notebook path has a parent directory, make sure it's created.
- parent = nbpath.parent
- parent.mkdir(parents=True, exist_ok=True)
- # Create a notebook string and write to file.
- nb = nbformat.v4.new_notebook()
- nbtext = nbformat.writes(nb, version=4)
- nbpath.write_text(nbtext)
-
- return inner
-
-
-@pytest.fixture(autouse=True)
-def jp_server_cleanup():
- yield
- ServerApp.clear_instance()
-
-
-@pytest.fixture
-def jp_cleanup_subprocesses(jp_serverapp):
- """Clean up subprocesses started by a Jupyter Server, i.e. kernels and terminal."""
-
- async def _():
- terminal_cleanup = jp_serverapp.web_app.settings["terminal_manager"].terminate_all
- kernel_cleanup = jp_serverapp.kernel_manager.shutdown_all
-
- async def kernel_cleanup_steps():
- # Try a graceful shutdown with a timeout
- try:
- await asyncio.wait_for(kernel_cleanup(), timeout=15.0)
- except asyncio.TimeoutError:
- # Now force a shutdown
- try:
- await asyncio.wait_for(kernel_cleanup(now=True), timeout=15.0)
- except asyncio.TimeoutError:
- print(Exception("Kernel never shutdown!"))
- except Exception as e:
- print(e)
-
- if asyncio.iscoroutinefunction(terminal_cleanup):
- try:
- await terminal_cleanup()
- except Exception as e:
- print(e)
- else:
- try:
- await terminal_cleanup()
- except Exception as e:
- print(e)
- if asyncio.iscoroutinefunction(kernel_cleanup):
- await kernel_cleanup_steps()
- else:
- try:
- kernel_cleanup()
- except Exception as e:
- print(e)
-
- return _
+ """Returns an AsyncLargeFileManager instance."""
+ return AsyncLargeFileManager(root_dir=str(tmp_path))
diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py
index f2c337d404..9e4a57375d 100644
--- a/jupyter_server/serverapp.py
+++ b/jupyter_server/serverapp.py
@@ -1,13 +1,13 @@
"""A tornado based Jupyter server."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
-import binascii
+from __future__ import annotations
+
import datetime
import errno
import gettext
import hashlib
import hmac
-import inspect
import ipaddress
import json
import logging
@@ -23,46 +23,28 @@
import sys
import threading
import time
+import typing as t
import urllib
import warnings
-import webbrowser
from base64 import encodebytes
+from pathlib import Path
-try:
- import resource
-except ImportError:
- # Windows
- resource = None
-
-from jinja2 import Environment, FileSystemLoader
-from jupyter_core.paths import secure_write
-
-from jupyter_server.transutils import _i18n, trans
-from jupyter_server.utils import pathname2url, run_sync_in_loop, urljoin
-
-# the minimum viable tornado version: needs to be kept in sync with setup.py
-MIN_TORNADO = (6, 1, 0)
-
-try:
- import tornado
-
- assert tornado.version_info >= MIN_TORNADO
-except (ImportError, AttributeError, AssertionError) as e: # pragma: no cover
- raise ImportError(_i18n("The Jupyter Server requires tornado >=%s.%s.%s") % MIN_TORNADO) from e
-
+import jupyter_client
+from jupyter_client.kernelspec import KernelSpecManager
+from jupyter_client.manager import KernelManager
+from jupyter_client.session import Session
+from jupyter_core.application import JupyterApp, base_aliases, base_flags
+from jupyter_core.paths import jupyter_runtime_dir
+from jupyter_events.logger import EventLogger
+from nbformat.sign import NotebookNotary
from tornado import httpserver, ioloop, web
from tornado.httputil import url_concat
from tornado.log import LogFormatter, access_log, app_log, gen_log
+from tornado.netutil import bind_sockets
if not sys.platform.startswith("win"):
from tornado.netutil import bind_unix_socket
-from jupyter_client import KernelManager
-from jupyter_client.kernelspec import KernelSpecManager
-from jupyter_client.session import Session
-from jupyter_core.application import JupyterApp, base_aliases, base_flags
-from jupyter_core.paths import jupyter_runtime_dir
-from nbformat.sign import NotebookNotary
from traitlets import (
Any,
Bool,
@@ -75,6 +57,7 @@
TraitError,
Type,
Unicode,
+ Union,
default,
observe,
validate,
@@ -83,14 +66,21 @@
from traitlets.config.application import boolean_flag, catch_config_error
from jupyter_server import (
+ DEFAULT_EVENTS_SCHEMA_PATH,
DEFAULT_JUPYTER_SERVER_PORT,
DEFAULT_STATIC_FILES_PATH,
DEFAULT_TEMPLATE_PATH_LIST,
+ JUPYTER_SERVER_EVENTS_URI,
__version__,
)
from jupyter_server._sysinfo import get_sys_info
from jupyter_server._tz import utcnow
from jupyter_server.auth.authorizer import AllowAllAuthorizer, Authorizer
+from jupyter_server.auth.identity import (
+ IdentityProvider,
+ LegacyIdentityProvider,
+ PasswordIdentityProvider,
+)
from jupyter_server.auth.login import LoginHandler
from jupyter_server.auth.logout import LogoutHandler
from jupyter_server.base.handlers import (
@@ -102,8 +92,9 @@
from jupyter_server.extension.config import ExtensionConfigManager
from jupyter_server.extension.manager import ExtensionManager
from jupyter_server.extension.serverextension import ServerExtensionApp
+from jupyter_server.gateway.connections import GatewayWebSocketConnection
+from jupyter_server.gateway.gateway_client import GatewayClient
from jupyter_server.gateway.managers import (
- GatewayClient,
GatewayKernelSpecManager,
GatewayMappingKernelManager,
GatewaySessionManager,
@@ -114,17 +105,15 @@
AsyncFileContentsManager,
FileContentsManager,
)
-from jupyter_server.services.contents.largefilemanager import LargeFileManager
-from jupyter_server.services.contents.manager import (
- AsyncContentsManager,
- ContentsManager,
-)
+from jupyter_server.services.contents.largefilemanager import AsyncLargeFileManager
+from jupyter_server.services.contents.manager import AsyncContentsManager, ContentsManager
+from jupyter_server.services.kernels.connection.base import BaseKernelWebsocketConnection
+from jupyter_server.services.kernels.connection.channels import ZMQChannelsWebsocketConnection
from jupyter_server.services.kernels.kernelmanager import (
AsyncMappingKernelManager,
MappingKernelManager,
)
from jupyter_server.services.sessions.sessionmanager import SessionManager
-from jupyter_server.traittypes import TypeFromClasses
from jupyter_server.utils import (
check_pid,
fetch,
@@ -134,13 +123,34 @@
urlencode_unix_socket_path,
)
-# Tolerate missing terminado package.
try:
- from jupyter_server.terminal import TerminalManager
+ import resource
+except ImportError:
+ # Windows
+ resource = None # type:ignore[assignment]
+
+from jinja2 import Environment, FileSystemLoader
+from jupyter_core.paths import secure_write
+from jupyter_core.utils import ensure_async
+
+from jupyter_server.transutils import _i18n, trans
+from jupyter_server.utils import pathname2url, urljoin
+
+# the minimum viable tornado version: needs to be kept in sync with setup.py
+MIN_TORNADO = (6, 1, 0)
+
+try:
+ import tornado
- terminado_available = True
+ assert tornado.version_info >= MIN_TORNADO
+except (ImportError, AttributeError, AssertionError) as e: # pragma: no cover
+ raise ImportError(_i18n("The Jupyter Server requires tornado >=%s.%s.%s") % MIN_TORNADO) from e
+
+try:
+ import resource
except ImportError:
- terminado_available = False
+ # Windows
+ resource = None # type:ignore[assignment]
# -----------------------------------------------------------------------------
# Module globals
@@ -152,26 +162,29 @@
jupyter server password # enter a password to protect the server
"""
-JUPYTER_SERVICE_HANDLERS = dict(
- auth=None,
- api=["jupyter_server.services.api.handlers"],
- config=["jupyter_server.services.config.handlers"],
- contents=["jupyter_server.services.contents.handlers"],
- files=["jupyter_server.files.handlers"],
- kernels=["jupyter_server.services.kernels.handlers"],
- kernelspecs=[
+JUPYTER_SERVICE_HANDLERS = {
+ "auth": None,
+ "api": ["jupyter_server.services.api.handlers"],
+ "config": ["jupyter_server.services.config.handlers"],
+ "contents": ["jupyter_server.services.contents.handlers"],
+ "files": ["jupyter_server.files.handlers"],
+ "kernels": [
+ "jupyter_server.services.kernels.handlers",
+ ],
+ "kernelspecs": [
"jupyter_server.kernelspecs.handlers",
"jupyter_server.services.kernelspecs.handlers",
],
- nbconvert=[
+ "nbconvert": [
"jupyter_server.nbconvert.handlers",
"jupyter_server.services.nbconvert.handlers",
],
- security=["jupyter_server.services.security.handlers"],
- sessions=["jupyter_server.services.sessions.handlers"],
- shutdown=["jupyter_server.services.shutdown"],
- view=["jupyter_server.view.handlers"],
-)
+ "security": ["jupyter_server.services.security.handlers"],
+ "sessions": ["jupyter_server.services.sessions.handlers"],
+ "shutdown": ["jupyter_server.services.shutdown"],
+ "view": ["jupyter_server.view.handlers"],
+ "events": ["jupyter_server.services.events.handlers"],
+}
# Added for backwards compatibility from classic notebook server.
DEFAULT_SERVER_PORT = DEFAULT_JUPYTER_SERVER_PORT
@@ -181,7 +194,7 @@
# -----------------------------------------------------------------------------
-def random_ports(port, n):
+def random_ports(port: int, n: int) -> t.Generator[int, None, None]:
"""Generate a list of n random ports near the given port.
The first 5 ports will be sequential, and the remaining n-5 will be
@@ -193,7 +206,7 @@ def random_ports(port, n):
yield max(1, port + random.randint(-2 * n, 2 * n))
-def load_handlers(name):
+def load_handlers(name: str) -> t.Any:
"""Load the (URL pattern, handler) tuples for each component."""
mod = __import__(name, fromlist=["default_handlers"])
return mod.default_handlers
@@ -205,6 +218,8 @@ def load_handlers(name):
class ServerWebApplication(web.Application):
+ """A server web application."""
+
def __init__(
self,
jupyter_app,
@@ -214,14 +229,28 @@ def __init__(
session_manager,
kernel_spec_manager,
config_manager,
+ event_logger,
extra_services,
log,
base_url,
default_url,
settings_overrides,
jinja_env_options,
+ *,
authorizer=None,
+ identity_provider=None,
+ kernel_websocket_connection_class=None,
):
+ """Initialize a server web application."""
+ if identity_provider is None:
+ warnings.warn(
+ "identity_provider unspecified. Using default IdentityProvider."
+ " Specify an identity_provider to avoid this message.",
+ RuntimeWarning,
+ stacklevel=2,
+ )
+ identity_provider = IdentityProvider(parent=jupyter_app)
+
if authorizer is None:
warnings.warn(
"authorizer unspecified. Using permissive AllowAllAuthorizer."
@@ -229,7 +258,7 @@ def __init__(
RuntimeWarning,
stacklevel=2,
)
- authorizer = AllowAllAuthorizer(jupyter_app)
+ authorizer = AllowAllAuthorizer(parent=jupyter_app, identity_provider=identity_provider)
settings = self.init_settings(
jupyter_app,
@@ -238,6 +267,7 @@ def __init__(
session_manager,
kernel_spec_manager,
config_manager,
+ event_logger,
extra_services,
log,
base_url,
@@ -245,6 +275,8 @@ def __init__(
settings_overrides,
jinja_env_options,
authorizer=authorizer,
+ identity_provider=identity_provider,
+ kernel_websocket_connection_class=kernel_websocket_connection_class,
)
handlers = self.init_handlers(default_services, settings)
@@ -258,15 +290,19 @@ def init_settings(
session_manager,
kernel_spec_manager,
config_manager,
+ event_logger,
extra_services,
log,
base_url,
default_url,
settings_overrides,
jinja_env_options=None,
+ *,
authorizer=None,
+ identity_provider=None,
+ kernel_websocket_connection_class=None,
):
-
+ """Initialize settings for the web application."""
_template_path = settings_overrides.get(
"template_path",
jupyter_app.template_file_path,
@@ -275,7 +311,7 @@ def init_settings(
_template_path = (_template_path,)
template_path = [os.path.expanduser(path) for path in _template_path]
- jenv_opt = {"autoescape": True}
+ jenv_opt: dict[str, t.Any] = {"autoescape": True}
jenv_opt.update(jinja_env_options if jinja_env_options else {})
env = Environment(
@@ -292,11 +328,12 @@ def init_settings(
env.install_gettext_translations(nbui, newstyle=False)
if sys_info["commit_source"] == "repository":
- # don't cache (rely on 304) when working from default branch
+ # don't cache (rely on 304) when working from master
version_hash = ""
else:
# reset the cache on server restart
- version_hash = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ utc = datetime.timezone.utc
+ version_hash = datetime.datetime.now(tz=utc).strftime("%Y%m%d%H%M%S")
now = utcnow()
@@ -306,64 +343,63 @@ def init_settings(
# collapse $HOME to ~
root_dir = "~" + root_dir[len(home) :]
- settings = dict(
+ settings = {
# basics
- log_function=log_request,
- base_url=base_url,
- default_url=default_url,
- template_path=template_path,
- static_path=jupyter_app.static_file_path,
- static_custom_path=jupyter_app.static_custom_path,
- static_handler_class=FileFindHandler,
- static_url_prefix=url_path_join(base_url, "/static/"),
- static_handler_args={
+ "log_function": log_request,
+ "base_url": base_url,
+ "default_url": default_url,
+ "template_path": template_path,
+ "static_path": jupyter_app.static_file_path,
+ "static_custom_path": jupyter_app.static_custom_path,
+ "static_handler_class": FileFindHandler,
+ "static_url_prefix": url_path_join(base_url, "/static/"),
+ "static_handler_args": {
# don't cache custom.js
"no_cache_paths": [url_path_join(base_url, "static", "custom")],
},
- version_hash=version_hash,
- # kernel message protocol over websoclet
- kernel_ws_protocol=jupyter_app.kernel_ws_protocol,
+ "version_hash": version_hash,
+ # kernel message protocol over websocket
+ "kernel_ws_protocol": jupyter_app.kernel_ws_protocol,
# rate limits
- limit_rate=jupyter_app.limit_rate,
- iopub_msg_rate_limit=jupyter_app.iopub_msg_rate_limit,
- iopub_data_rate_limit=jupyter_app.iopub_data_rate_limit,
- rate_limit_window=jupyter_app.rate_limit_window,
+ "limit_rate": jupyter_app.limit_rate,
+ "iopub_msg_rate_limit": jupyter_app.iopub_msg_rate_limit,
+ "iopub_data_rate_limit": jupyter_app.iopub_data_rate_limit,
+ "rate_limit_window": jupyter_app.rate_limit_window,
# authentication
- cookie_secret=jupyter_app.cookie_secret,
- login_url=url_path_join(base_url, "/login"),
- login_handler_class=jupyter_app.login_handler_class,
- logout_handler_class=jupyter_app.logout_handler_class,
- password=jupyter_app.password,
- xsrf_cookies=True,
- disable_check_xsrf=jupyter_app.disable_check_xsrf,
- allow_remote_access=jupyter_app.allow_remote_access,
- local_hostnames=jupyter_app.local_hostnames,
- authenticate_prometheus=jupyter_app.authenticate_prometheus,
+ "cookie_secret": jupyter_app.cookie_secret,
+ "login_url": url_path_join(base_url, "/login"),
+ "xsrf_cookies": True,
+ "disable_check_xsrf": jupyter_app.disable_check_xsrf,
+ "allow_remote_access": jupyter_app.allow_remote_access,
+ "local_hostnames": jupyter_app.local_hostnames,
+ "authenticate_prometheus": jupyter_app.authenticate_prometheus,
# managers
- kernel_manager=kernel_manager,
- contents_manager=contents_manager,
- session_manager=session_manager,
- kernel_spec_manager=kernel_spec_manager,
- config_manager=config_manager,
- authorizer=authorizer,
+ "kernel_manager": kernel_manager,
+ "contents_manager": contents_manager,
+ "session_manager": session_manager,
+ "kernel_spec_manager": kernel_spec_manager,
+ "config_manager": config_manager,
+ "authorizer": authorizer,
+ "identity_provider": identity_provider,
+ "event_logger": event_logger,
+ "kernel_websocket_connection_class": kernel_websocket_connection_class,
# handlers
- extra_services=extra_services,
+ "extra_services": extra_services,
# Jupyter stuff
- started=now,
+ "started": now,
# place for extensions to register activity
# so that they can prevent idle-shutdown
- last_activity_times={},
- jinja_template_vars=jupyter_app.jinja_template_vars,
- websocket_url=jupyter_app.websocket_url,
- shutdown_button=jupyter_app.quit_button,
- config=jupyter_app.config,
- config_dir=jupyter_app.config_dir,
- allow_password_change=jupyter_app.allow_password_change,
- server_root_dir=root_dir,
- jinja2_env=env,
- terminals_available=terminado_available and jupyter_app.terminals_enabled,
- serverapp=jupyter_app,
- )
+ "last_activity_times": {},
+ "jinja_template_vars": jupyter_app.jinja_template_vars,
+ "websocket_url": jupyter_app.websocket_url,
+ "shutdown_button": jupyter_app.quit_button,
+ "config": jupyter_app.config,
+ "config_dir": jupyter_app.config_dir,
+ "allow_password_change": jupyter_app.allow_password_change,
+ "server_root_dir": root_dir,
+ "jinja2_env": env,
+ "serverapp": jupyter_app,
+ }
# allow custom overrides for the tornado web app.
settings.update(settings_overrides)
@@ -381,11 +417,6 @@ def init_handlers(self, default_services, settings):
for service in settings["extra_services"]:
handlers.extend(load_handlers(service))
- # Add auth services.
- if "auth" in default_services:
- handlers.extend([(r"/login", settings["login_handler_class"])])
- handlers.extend([(r"/logout", settings["logout_handler_class"])])
-
# Load default services. Raise exception if service not
# found in JUPYTER_SERVICE_HANLDERS.
for service in default_services:
@@ -395,26 +426,18 @@ def init_handlers(self, default_services, settings):
for loc in locations:
handlers.extend(load_handlers(loc))
else:
- raise Exception(
- "{} is not recognized as a jupyter_server "
+ msg = (
+ f"{service} is not recognized as a jupyter_server "
"service. If this is a custom service, "
"try adding it to the "
- "`extra_services` list.".format(service)
+ "`extra_services` list."
)
+ raise Exception(msg)
# Add extra handlers from contents manager.
handlers.extend(settings["contents_manager"].get_extra_handlers())
-
- # If gateway mode is enabled, replace appropriate handlers to perform redirection
- if GatewayClient.instance().gateway_enabled:
- # for each handler required for gateway, locate its pattern
- # in the current list and replace that entry...
- gateway_handlers = load_handlers("jupyter_server.gateway.handlers")
- for _, gwh in enumerate(gateway_handlers):
- for j, h in enumerate(handlers):
- if gwh[0] == h[0]:
- handlers[j] = (gwh[0], gwh[1])
- break
+ # And from identity provider
+ handlers.extend(settings["identity_provider"].get_handlers())
# register base handlers last
handlers.extend(load_handlers("jupyter_server.base.handlers"))
@@ -438,7 +461,7 @@ def init_handlers(self, default_services, settings):
new_handlers = []
for handler in handlers:
pattern = url_path_join(settings["base_url"], handler[0])
- new_handler = tuple([pattern] + list(handler[1:]))
+ new_handler = (pattern, *list(handler[1:]))
new_handlers.append(new_handler)
# add 404 on the end, which will catch everything that falls through
new_handlers.append((r"(.*)", Template404))
@@ -454,14 +477,12 @@ def last_activity(self):
self.settings["started"],
self.settings["kernel_manager"].last_kernel_activity,
]
- try:
- sources.append(self.settings["api_last_activity"])
- except KeyError:
- pass
- try:
- sources.append(self.settings["terminal_last_activity"])
- except KeyError:
- pass
+ # Any setting that ends with a key that ends with `_last_activity` is
+ # counted here. This provides a hook for extensions to add a last activity
+ # setting to the server.
+ sources.extend(
+ [val for key, val in self.settings.items() if key.endswith("_last_activity")]
+ )
sources.extend(self.settings["last_activity_times"].values())
return max(sources)
@@ -473,12 +494,14 @@ class JupyterPasswordApp(JupyterApp):
and removes the need for token-based authentication.
"""
- description = __doc__
+ description: str = __doc__
def _config_file_default(self):
+ """the default config file."""
return os.path.join(self.config_dir, "jupyter_server_config.json")
def start(self):
+ """Start the password app."""
from jupyter_server.auth.security import set_password
set_password(config_file=self.config_file)
@@ -500,11 +523,29 @@ def shutdown_server(server_info, timeout=5, log=None):
url = server_info["url"]
pid = server_info["pid"]
+ try:
+ shutdown_url = urljoin(url, "api/shutdown")
+ if log:
+ log.debug("POST request to %s", shutdown_url)
+ fetch(
+ shutdown_url,
+ method="POST",
+ body=b"",
+ headers={"Authorization": "token " + server_info["token"]},
+ )
+ except Exception as ex:
+ if not str(ex) == "Unknown URL scheme.":
+ raise ex
+ if log:
+ log.debug("Was not a HTTP scheme. Treating as socket instead.")
+ log.debug("POST request to %s", url)
+ fetch(
+ url,
+ method="POST",
+ body=b"",
+ headers={"Authorization": "token " + server_info["token"]},
+ )
- if log:
- log.debug("POST request to %sapi/shutdown", url)
-
- fetch(url, method="POST", headers={"Authorization": "token " + server_info["token"]})
# Poll to see if it shut down.
for _ in range(timeout * 10):
if not check_pid(pid):
@@ -535,9 +576,10 @@ def shutdown_server(server_info, timeout=5, log=None):
class JupyterServerStopApp(JupyterApp):
+ """An application to stop a Jupyter server."""
- version = __version__
- description = "Stop currently running Jupyter server for a given port"
+ version: str = __version__
+ description: str = "Stop currently running Jupyter server for a given port"
port = Integer(
DEFAULT_JUPYTER_SERVER_PORT,
@@ -548,6 +590,7 @@ class JupyterServerStopApp(JupyterApp):
sock = Unicode("", config=True, help="UNIX socket of the server to be killed.")
def parse_command_line(self, argv=None):
+ """Parse command line options."""
super().parse_command_line(argv)
if self.extra_args:
try:
@@ -557,21 +600,26 @@ def parse_command_line(self, argv=None):
self.sock = self.extra_args[0]
def shutdown_server(self, server):
+ """Shut down a server."""
return shutdown_server(server, log=self.log)
def _shutdown_or_exit(self, target_endpoint, server):
- print("Shutting down server on %s..." % target_endpoint)
+ """Handle a shutdown."""
+ self.log.info("Shutting down server on %s..." % target_endpoint)
if not self.shutdown_server(server):
sys.exit("Could not stop server on %s" % target_endpoint)
@staticmethod
def _maybe_remove_unix_socket(socket_path):
+ """Try to remove a socket path."""
try:
os.unlink(socket_path)
except OSError:
pass
def start(self):
+ """Start the server stop app."""
+ info = self.log.info
servers = list(list_running_servers(self.runtime_dir, log=self.log))
if not servers:
self.exit("There are no running servers (per %s)" % self.runtime_dir)
@@ -589,30 +637,29 @@ def start(self):
self._shutdown_or_exit(port, server)
return
current_endpoint = self.sock or self.port
- print(
- f"There is currently no server running on {current_endpoint}",
- file=sys.stderr,
- )
- print("Ports/sockets currently in use:", file=sys.stderr)
+ info(f"There is currently no server running on {current_endpoint}")
+ info("Ports/sockets currently in use:")
for server in servers:
- print(" - {}".format(server.get("sock") or server["port"]), file=sys.stderr)
+ info(" - {}".format(server.get("sock") or server["port"]))
self.exit(1)
class JupyterServerListApp(JupyterApp):
- version = __version__
- description = _i18n("List currently running Jupyter servers.")
+ """An application to list running Jupyter servers."""
- flags = dict(
- jsonlist=(
+ version: str = __version__
+ description: str = _i18n("List currently running Jupyter servers.")
+
+ flags = {
+ "jsonlist": (
{"JupyterServerListApp": {"jsonlist": True}},
_i18n("Produce machine-readable JSON list output."),
),
- json=(
+ "json": (
{"JupyterServerListApp": {"json": True}},
_i18n("Produce machine-readable JSON object on each line of output."),
),
- )
+ }
jsonlist = Bool(
False,
@@ -634,6 +681,7 @@ class JupyterServerListApp(JupyterApp):
)
def start(self):
+ """Start the server list application."""
serverinfo_list = list(list_running_servers(self.runtime_dir, log=self.log))
if self.jsonlist:
print(json.dumps(serverinfo_list, indent=2))
@@ -714,18 +762,19 @@ def start(self):
class ServerApp(JupyterApp):
+ """The Jupyter Server application class."""
name = "jupyter-server"
- version = __version__
- description = _i18n(
+ version: str = __version__
+ description: str = _i18n(
"""The Jupyter Server.
This launches a Tornado-based Jupyter Server."""
)
examples = _examples
- flags = Dict(flags)
- aliases = Dict(aliases)
+ flags = Dict(flags) # type:ignore[assignment]
+ aliases = Dict(aliases) # type:ignore[assignment]
classes = [
KernelManager,
@@ -741,18 +790,31 @@ class ServerApp(JupyterApp):
GatewayMappingKernelManager,
GatewayKernelSpecManager,
GatewaySessionManager,
+ GatewayWebSocketConnection,
GatewayClient,
Authorizer,
+ EventLogger,
+ ZMQChannelsWebsocketConnection,
]
- if terminado_available: # Only necessary when terminado is available
- classes.append(TerminalManager)
- subcommands = dict(
- list=(JupyterServerListApp, JupyterServerListApp.description.splitlines()[0]),
- stop=(JupyterServerStopApp, JupyterServerStopApp.description.splitlines()[0]),
- password=(JupyterPasswordApp, JupyterPasswordApp.description.splitlines()[0]),
- extension=(ServerExtensionApp, ServerExtensionApp.description.splitlines()[0]),
- )
+ subcommands: dict[str, t.Any] = {
+ "list": (
+ JupyterServerListApp,
+ JupyterServerListApp.description.splitlines()[0],
+ ),
+ "stop": (
+ JupyterServerStopApp,
+ JupyterServerStopApp.description.splitlines()[0],
+ ),
+ "password": (
+ JupyterPasswordApp,
+ JupyterPasswordApp.description.splitlines()[0],
+ ),
+ "extension": (
+ ServerExtensionApp,
+ ServerExtensionApp.description.splitlines()[0],
+ ),
+ }
# A list of services whose handlers will be exposed.
# Subclasses can override this list to
@@ -770,16 +832,18 @@ class ServerApp(JupyterApp):
"sessions",
"shutdown",
"view",
+ "events",
)
- _log_formatter_cls = LogFormatter
+ _log_formatter_cls = LogFormatter # type:ignore[assignment]
+ _stopping = Bool(False, help="Signal that we've begun stopping.")
@default("log_level")
- def _default_log_level(self):
+ def _default_log_level(self) -> int:
return logging.INFO
@default("log_format")
- def _default_log_format(self):
+ def _default_log_format(self) -> str:
"""override default log format to include date & time"""
return (
"%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s"
@@ -848,7 +912,7 @@ def _default_log_format(self):
)
@default("ip")
- def _default_ip(self):
+ def _default_ip(self) -> str:
"""Return localhost if available, 127.0.0.1 otherwise.
On some (horribly broken) systems, localhost cannot be bound.
@@ -866,8 +930,8 @@ def _default_ip(self):
return "localhost"
@validate("ip")
- def _validate_ip(self, proposal):
- value = proposal["value"]
+ def _validate_ip(self, proposal: t.Any) -> str:
+ value = t.cast(str, proposal["value"])
if value == "*":
value = ""
return value
@@ -898,7 +962,7 @@ def _validate_ip(self, proposal):
)
@default("port")
- def port_default(self):
+ def _port_default(self) -> int:
return int(os.getenv(self.port_env, self.port_default_value))
port_retries_env = "JUPYTER_PORT_RETRIES"
@@ -913,7 +977,7 @@ def port_default(self):
)
@default("port_retries")
- def port_retries_default(self):
+ def _port_retries_default(self) -> int:
return int(os.getenv(self.port_retries_env, self.port_retries_default_value))
sock = Unicode("", config=True, help="The UNIX socket the Jupyter server will listen on.")
@@ -925,7 +989,7 @@ def port_retries_default(self):
)
@validate("sock_mode")
- def _validate_sock_mode(self, proposal):
+ def _validate_sock_mode(self, proposal: t.Any) -> t.Any:
value = proposal["value"]
try:
converted_value = int(value.encode(), 8)
@@ -938,12 +1002,14 @@ def _validate_sock_mode(self, proposal):
converted_value <= 2**12,
)
)
- except ValueError:
- raise TraitError('invalid --sock-mode value: %s, please specify as e.g. "0600"' % value)
- except AssertionError:
+ except ValueError as e:
+ raise TraitError(
+ 'invalid --sock-mode value: %s, please specify as e.g. "0600"' % value
+ ) from e
+ except AssertionError as e:
raise TraitError(
"invalid --sock-mode value: %s, must have u+rw (0600) at a minimum" % value
- )
+ ) from e
return value
certfile = Unicode(
@@ -971,7 +1037,7 @@ def _validate_sock_mode(self, proposal):
)
@default("cookie_secret_file")
- def _default_cookie_secret_file(self):
+ def _default_cookie_secret_file(self) -> str:
return os.path.join(self.runtime_dir, "jupyter_cookie_secret")
cookie_secret = Bytes(
@@ -987,7 +1053,7 @@ def _default_cookie_secret_file(self):
)
@default("cookie_secret")
- def _default_cookie_secret(self):
+ def _default_cookie_secret(self) -> bytes:
if os.path.exists(self.cookie_secret_file):
with open(self.cookie_secret_file, "rb") as f:
key = f.read()
@@ -998,7 +1064,7 @@ def _default_cookie_secret(self):
h.update(self.password.encode())
return h.digest()
- def _write_cookie_secret_file(self, secret):
+ def _write_cookie_secret_file(self, secret: bytes) -> None:
"""write my secret to my secret_file"""
self.log.info(_i18n("Writing Jupyter server cookie secret to %s"), self.cookie_secret_file)
try:
@@ -1011,40 +1077,24 @@ def _write_cookie_secret_file(self, secret):
e,
)
- token = Unicode(
- "",
- help=_i18n(
- """Token used for authenticating first-time connections to the server.
-
- The token can be read from the file referenced by JUPYTER_TOKEN_FILE or set directly
- with the JUPYTER_TOKEN environment variable.
-
- When no password is enabled,
- the default is to generate a new, random token.
+ _token_set = False
- Setting to an empty string disables authentication altogether, which is NOT RECOMMENDED.
- """
- ),
- ).tag(config=True)
+ token = Unicode("", help=_i18n("""DEPRECATED. Use IdentityProvider.token""")).tag(
+ config=True
+ )
- _token_generated = True
+ @observe("token")
+ def _deprecated_token(self, change: t.Any) -> None:
+ self._warn_deprecated_config(change, "IdentityProvider")
@default("token")
- def _token_default(self):
- if os.getenv("JUPYTER_TOKEN"):
- self._token_generated = False
- return os.getenv("JUPYTER_TOKEN")
- if os.getenv("JUPYTER_TOKEN_FILE"):
- self._token_generated = False
- with open(os.getenv("JUPYTER_TOKEN_FILE")) as token_file:
- return token_file.read()
- if self.password:
- # no token if password is enabled
- self._token_generated = False
- return ""
- else:
- self._token_generated = True
- return binascii.hexlify(os.urandom(24)).decode("ascii")
+ def _deprecated_token_access(self) -> str:
+ warnings.warn(
+ "ServerApp.token config is deprecated in jupyter-server 2.0. Use IdentityProvider.token",
+ DeprecationWarning,
+ stacklevel=3,
+ )
+ return self.identity_provider.token
min_open_files_limit = Integer(
config=True,
@@ -1058,10 +1108,10 @@ def _token_default(self):
)
@default("min_open_files_limit")
- def _default_min_open_files_limit(self):
+ def _default_min_open_files_limit(self) -> t.Optional[int]:
if resource is None:
# Ignoring min_open_files_limit because the limit cannot be adjusted (for example, on Windows)
- return None
+ return None # type:ignore[unreachable]
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
@@ -1099,55 +1149,59 @@ def _default_min_open_files_limit(self):
""",
)
- @observe("token")
- def _token_changed(self, change):
- self._token_generated = False
-
password = Unicode(
"",
config=True,
- help="""Hashed password to use for web authentication.
-
- To generate, type in a python/IPython shell:
-
- from jupyter_server.auth import passwd; passwd()
-
- The string should be of the form type:salt:hashed-password.
- """,
+ help="""DEPRECATED in 2.0. Use PasswordIdentityProvider.hashed_password""",
)
password_required = Bool(
False,
config=True,
- help="""Forces users to use a password for the Jupyter server.
- This is useful in a multi user environment, for instance when
- everybody in the LAN can access each other's machine through ssh.
-
- In such a case, serving on localhost is not secure since
- any user can connect to the Jupyter server via ssh.
-
- """,
+ help="""DEPRECATED in 2.0. Use PasswordIdentityProvider.password_required""",
)
allow_password_change = Bool(
True,
config=True,
- help="""Allow password to be changed at login for the Jupyter server.
+ help="""DEPRECATED in 2.0. Use PasswordIdentityProvider.allow_password_change""",
+ )
+
+ def _warn_deprecated_config(
+ self, change: t.Any, clsname: str, new_name: t.Optional[str] = None
+ ) -> None:
+ """Warn on deprecated config."""
+ if new_name is None:
+ new_name = change.name
+ if clsname not in self.config or new_name not in self.config[clsname]:
+ # Deprecated config used, new config not used.
+ # Use deprecated config, warn about new name.
+ self.log.warning(
+ f"ServerApp.{change.name} config is deprecated in 2.0. Use {clsname}.{new_name}."
+ )
+ self.config[clsname][new_name] = change.new
+ # Deprecated config used, new config also used.
+ # Warn only if the values differ.
+ # If the values are the same, assume intentional backward-compatible config.
+ elif self.config[clsname][new_name] != change.new:
+ self.log.warning(
+ f"Ignoring deprecated ServerApp.{change.name} config. Using {clsname}.{new_name}."
+ )
- While logging in with a token, the Jupyter server UI will give the opportunity to
- the user to enter a new password at the same time that will replace
- the token login mechanism.
+ @observe("password")
+ def _deprecated_password(self, change: t.Any) -> None:
+ self._warn_deprecated_config(change, "PasswordIdentityProvider", new_name="hashed_password")
- This can be set to false to prevent changing password from the UI/API.
- """,
- )
+ @observe("password_required", "allow_password_change")
+ def _deprecated_password_config(self, change: t.Any) -> None:
+ self._warn_deprecated_config(change, "PasswordIdentityProvider")
disable_check_xsrf = Bool(
False,
config=True,
help="""Disable cross-site-request-forgery protection
- Jupyter notebook 4.3.1 introduces protection from cross-site request forgeries,
+ Jupyter server includes protection from cross-site request forgeries,
requiring API requests to either:
- originate from pages served by this server (validated with XSRF cookie and token), or
@@ -1178,7 +1232,7 @@ def _token_changed(self, change):
)
@default("allow_remote_access")
- def _default_allow_remote(self):
+ def _default_allow_remote(self) -> bool:
"""Disallow remote access if we're listening only on loopback addresses"""
# if blank, self.ip was configured to "*" meaning bind to all interfaces,
@@ -1191,10 +1245,10 @@ def _default_allow_remote(self):
except ValueError:
# Address is a hostname
for info in socket.getaddrinfo(self.ip, self.port, 0, socket.SOCK_STREAM):
- addr = info[4][0]
+ addr = info[4][0] # type:ignore[assignment]
try:
- parsed = ipaddress.ip_address(addr.split("%")[0])
+ parsed = ipaddress.ip_address(addr.split("%")[0]) # type:ignore[union-attr]
except ValueError:
self.log.warning("Unrecognised IP address: %r", addr)
continue
@@ -1202,7 +1256,9 @@ def _default_allow_remote(self):
# Macs map localhost to 'fe80::1%lo0', a link local address
# scoped to the loopback interface. For now, we'll assume that
# any scoped link-local address is effectively local.
- if not (parsed.is_loopback or (("%" in addr) and parsed.is_link_local)):
+ if not (
+ parsed.is_loopback or (("%" in addr) and parsed.is_link_local) # type:ignore[operator]
+ ):
return True
return False
else:
@@ -1300,24 +1356,24 @@ def _default_allow_remote(self):
),
)
terminado_settings = Dict(
+ Union([List(), Unicode()]),
config=True,
help=_i18n('Supply overrides for terminado. Currently only supports "shell_command".'),
)
cookie_options = Dict(
config=True,
- help=_i18n(
- "Extra keyword arguments to pass to `set_secure_cookie`."
- " See tornado's set_secure_cookie docs for details."
- ),
+ help=_i18n("DEPRECATED. Use IdentityProvider.cookie_options"),
)
get_secure_cookie_kwargs = Dict(
config=True,
- help=_i18n(
- "Extra keyword arguments to pass to `get_secure_cookie`."
- " See tornado's get_secure_cookie docs for details."
- ),
+ help=_i18n("DEPRECATED. Use IdentityProvider.get_secure_cookie_kwargs"),
)
+
+ @observe("cookie_options", "get_secure_cookie_kwargs")
+ def _deprecated_cookie_config(self, change: t.Any) -> None:
+ self._warn_deprecated_config(change, "IdentityProvider")
+
ssl_options = Dict(
allow_none=True,
config=True,
@@ -1348,8 +1404,8 @@ def _default_allow_remote(self):
)
@validate("base_url")
- def _update_base_url(self, proposal):
- value = proposal["value"]
+ def _update_base_url(self, proposal: t.Any) -> str:
+ value = t.cast(str, proposal["value"])
if not value.startswith("/"):
value = "/" + value
if not value.endswith("/"):
@@ -1366,14 +1422,14 @@ def _update_base_url(self, proposal):
)
@property
- def static_file_path(self):
+ def static_file_path(self) -> list[str]:
"""return extra paths + the default location"""
- return self.extra_static_paths + [DEFAULT_STATIC_FILES_PATH]
+ return [*self.extra_static_paths, DEFAULT_STATIC_FILES_PATH]
static_custom_path = List(Unicode(), help=_i18n("""Path to search for custom.js, css"""))
@default("static_custom_path")
- def _default_static_custom_path(self):
+ def _default_static_custom_path(self) -> list[str]:
return [os.path.join(d, "custom") for d in (self.config_dir, DEFAULT_STATIC_FILES_PATH)]
extra_template_paths = List(
@@ -1387,7 +1443,7 @@ def _default_static_custom_path(self):
)
@property
- def template_file_path(self):
+ def template_file_path(self) -> list[str]:
"""return extra paths + the default locations"""
return self.extra_template_paths + DEFAULT_TEMPLATE_PATH_LIST
@@ -1415,51 +1471,50 @@ def template_file_path(self):
help="""If True, display controls to shut down the Jupyter server, such as menu items or buttons.""",
)
- # REMOVE in VERSION 2.0
- # Temporarily allow content managers to inherit from the 'notebook'
- # package. We will deprecate this in the next major release.
- contents_manager_class = TypeFromClasses(
- default_value=LargeFileManager,
- klasses=[
- "jupyter_server.services.contents.manager.ContentsManager",
- "notebook.services.contents.manager.ContentsManager",
- ],
+ contents_manager_class = Type(
+ default_value=AsyncLargeFileManager,
+ klass=ContentsManager,
config=True,
help=_i18n("The content manager class to use."),
)
- # Throws a deprecation warning to notebook based contents managers.
- @observe("contents_manager_class")
- def _observe_contents_manager_class(self, change):
- new = change["new"]
- # If 'new' is a class, get a string representing the import
- # module path.
- if inspect.isclass(new):
- new = new.__module__
-
- if new.startswith("notebook"):
- self.log.warning(
- "The specified 'contents_manager_class' class inherits a manager from the "
- "'notebook' package. This is not guaranteed to work in future "
- "releases of Jupyter Server. Instead, consider switching the "
- "manager to inherit from the 'jupyter_server' managers. "
- "Jupyter Server will temporarily allow 'notebook' managers "
- "until its next major release (2.x)."
- )
-
kernel_manager_class = Type(
- default_value=AsyncMappingKernelManager,
klass=MappingKernelManager,
config=True,
help=_i18n("The kernel manager class to use."),
)
+ @default("kernel_manager_class")
+ def _default_kernel_manager_class(self) -> t.Union[str, type[AsyncMappingKernelManager]]:
+ if self.gateway_config.gateway_enabled:
+ return "jupyter_server.gateway.managers.GatewayMappingKernelManager"
+ return AsyncMappingKernelManager
+
session_manager_class = Type(
- default_value=SessionManager,
config=True,
help=_i18n("The session manager class to use."),
)
+ @default("session_manager_class")
+ def _default_session_manager_class(self) -> t.Union[str, type[SessionManager]]:
+ if self.gateway_config.gateway_enabled:
+ return "jupyter_server.gateway.managers.GatewaySessionManager"
+ return SessionManager
+
+ kernel_websocket_connection_class = Type(
+ klass=BaseKernelWebsocketConnection,
+ config=True,
+ help=_i18n("The kernel websocket connection class to use."),
+ )
+
+ @default("kernel_websocket_connection_class")
+ def _default_kernel_websocket_connection_class(
+ self,
+ ) -> t.Union[str, type[ZMQChannelsWebsocketConnection]]:
+ if self.gateway_config.gateway_enabled:
+ return "jupyter_server.gateway.connections.GatewayWebSocketConnection"
+ return ZMQChannelsWebsocketConnection
+
config_manager_class = Type(
default_value=ConfigManager,
config=True,
@@ -1469,7 +1524,6 @@ def _observe_contents_manager_class(self, change):
kernel_spec_manager = Instance(KernelSpecManager, allow_none=True)
kernel_spec_manager_class = Type(
- default_value=KernelSpecManager,
config=True,
help="""
The kernel spec manager class to use. Should be a subclass
@@ -1480,9 +1534,16 @@ def _observe_contents_manager_class(self, change):
""",
)
+ @default("kernel_spec_manager_class")
+ def _default_kernel_spec_manager_class(self) -> t.Union[str, type[KernelSpecManager]]:
+ if self.gateway_config.gateway_enabled:
+ return "jupyter_server.gateway.managers.GatewayKernelSpecManager"
+ return KernelSpecManager
+
login_handler_class = Type(
default_value=LoginHandler,
klass=web.RequestHandler,
+ allow_none=True,
config=True,
help=_i18n("The login handler class to use."),
)
@@ -1490,9 +1551,11 @@ def _observe_contents_manager_class(self, change):
logout_handler_class = Type(
default_value=LogoutHandler,
klass=web.RequestHandler,
+ allow_none=True,
config=True,
help=_i18n("The logout handler class to use."),
)
+ # TODO: detect deprecated login handler config
authorizer_class = Type(
default_value=AllowAllAuthorizer,
@@ -1501,6 +1564,13 @@ def _observe_contents_manager_class(self, change):
help=_i18n("The authorizer class to use."),
)
+ identity_provider_class = Type(
+ default_value=PasswordIdentityProvider,
+ klass=IdentityProvider,
+ config=True,
+ help=_i18n("The identity provider class to use."),
+ )
+
trust_xheaders = Bool(
False,
config=True,
@@ -1512,24 +1582,34 @@ def _observe_contents_manager_class(self, change):
),
)
+ event_logger = Instance(
+ EventLogger,
+ allow_none=True,
+ help="An EventLogger for emitting structured event data from Jupyter Server and extensions.",
+ )
+
info_file = Unicode()
@default("info_file")
- def _default_info_file(self):
+ def _default_info_file(self) -> str:
info_file = "jpserver-%s.json" % os.getpid()
return os.path.join(self.runtime_dir, info_file)
+ no_browser_open_file = Bool(
+ False, help="If True, do not write redirect HTML file disk, or show in messages."
+ )
+
browser_open_file = Unicode()
@default("browser_open_file")
- def _default_browser_open_file(self):
+ def _default_browser_open_file(self) -> str:
basename = "jpserver-%s-open.html" % os.getpid()
return os.path.join(self.runtime_dir, basename)
browser_open_file_to_run = Unicode()
@default("browser_open_file_to_run")
- def _default_browser_open_file_to_run(self):
+ def _default_browser_open_file_to_run(self) -> str:
basename = "jpserver-file-to-run-%s-open.html" % os.getpid()
return os.path.join(self.runtime_dir, basename)
@@ -1544,12 +1624,9 @@ def _default_browser_open_file_to_run(self):
)
@observe("pylab")
- def _update_pylab(self, change):
+ def _update_pylab(self, change: t.Any) -> None:
"""when --pylab is specified, display a warning and exit"""
- if change["new"] != "warn":
- backend = " %s" % change["new"]
- else:
- backend = ""
+ backend = " %s" % change["new"] if change["new"] != "warn" else ""
self.log.error(
_i18n("Support for specifying --pylab on the command line has been removed.")
)
@@ -1563,25 +1640,46 @@ def _update_pylab(self, change):
notebook_dir = Unicode(config=True, help=_i18n("DEPRECATED, use root_dir."))
@observe("notebook_dir")
- def _update_notebook_dir(self, change):
+ def _update_notebook_dir(self, change: t.Any) -> None:
if self._root_dir_set:
# only use deprecated config if new config is not set
return
self.log.warning(_i18n("notebook_dir is deprecated, use root_dir"))
self.root_dir = change["new"]
+ external_connection_dir = Unicode(
+ None,
+ allow_none=True,
+ config=True,
+ help=_i18n(
+ "The directory to look at for external kernel connection files, if allow_external_kernels is True. "
+ "Defaults to Jupyter runtime_dir/external_kernels. "
+ "Make sure that this directory is not filled with left-over connection files, "
+ "that could result in unnecessary kernel manager creations."
+ ),
+ )
+
+ allow_external_kernels = Bool(
+ False,
+ config=True,
+ help=_i18n(
+ "Whether or not to allow external kernels, whose connection files are placed in external_connection_dir."
+ ),
+ )
+
root_dir = Unicode(config=True, help=_i18n("The directory to use for notebooks and kernels."))
_root_dir_set = False
@default("root_dir")
- def _default_root_dir(self):
+ def _default_root_dir(self) -> str:
if self.file_to_run:
self._root_dir_set = True
return os.path.dirname(os.path.abspath(self.file_to_run))
else:
return os.getcwd()
- def _normalize_dir(self, value):
+ def _normalize_dir(self, value: str) -> str:
+ """Normalize a directory."""
# Strip any trailing slashes
# *except* if it's root
_, path = os.path.splitdrive(value)
@@ -1594,46 +1692,36 @@ def _normalize_dir(self, value):
return value
@validate("root_dir")
- def _root_dir_validate(self, proposal):
+ def _root_dir_validate(self, proposal: t.Any) -> str:
value = self._normalize_dir(proposal["value"])
if not os.path.isdir(value):
raise TraitError(trans.gettext("No such directory: '%r'") % value)
return value
+ @observe("root_dir")
+ def _root_dir_changed(self, change: t.Any) -> None:
+ # record that root_dir is set,
+ # which affects loading of deprecated notebook_dir
+ self._root_dir_set = True
+
preferred_dir = Unicode(
config=True,
help=trans.gettext("Preferred starting directory to use for notebooks and kernels."),
)
@default("preferred_dir")
- def _default_prefered_dir(self):
+ def _default_prefered_dir(self) -> str:
return self.root_dir
@validate("preferred_dir")
- def _preferred_dir_validate(self, proposal):
+ def _preferred_dir_validate(self, proposal: t.Any) -> str:
value = self._normalize_dir(proposal["value"])
if not os.path.isdir(value):
raise TraitError(trans.gettext("No such preferred dir: '%r'") % value)
-
- # preferred_dir must be equal or a subdir of root_dir
- if not value.startswith(self.root_dir):
- raise TraitError(
- trans.gettext("preferred_dir must be equal or a subdir of root_dir: '%r'") % value
- )
-
return value
- @observe("root_dir")
- def _root_dir_changed(self, change):
- self._root_dir_set = True
- if not self.preferred_dir.startswith(change["new"]):
- self.log.warning(
- trans.gettext("Value of preferred_dir updated to use value of root_dir")
- )
- self.preferred_dir = change["new"]
-
@observe("server_extensions")
- def _update_server_extensions(self, change):
+ def _update_server_extensions(self, change: t.Any) -> None:
self.log.warning(_i18n("server_extensions is deprecated, use jpserver_extensions"))
self.server_extensions = change["new"]
@@ -1658,63 +1746,61 @@ def _update_server_extensions(self, change):
)
kernel_ws_protocol = Unicode(
- None,
allow_none=True,
config=True,
- help=_i18n(
- "Preferred kernel message protocol over websocket to use (default: None). "
- "If an empty string is passed, select the legacy protocol. If None, "
- "the selected protocol will depend on what the front-end supports "
- "(usually the most recent protocol supported by the back-end and the "
- "front-end)."
- ),
+ help=_i18n("DEPRECATED. Use ZMQChannelsWebsocketConnection.kernel_ws_protocol"),
)
+ @observe("kernel_ws_protocol")
+ def _deprecated_kernel_ws_protocol(self, change: t.Any) -> None:
+ self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection")
+
limit_rate = Bool(
- True,
+ allow_none=True,
config=True,
- help=_i18n(
- "Whether to limit the rate of IOPub messages (default: True). "
- "If True, use iopub_msg_rate_limit, iopub_data_rate_limit and/or rate_limit_window "
- "to tune the rate."
- ),
+ help=_i18n("DEPRECATED. Use ZMQChannelsWebsocketConnection.limit_rate"),
)
+ @observe("limit_rate")
+ def _deprecated_limit_rate(self, change: t.Any) -> None:
+ self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection")
+
iopub_msg_rate_limit = Float(
- 1000,
+ allow_none=True,
config=True,
- help=_i18n(
- """(msgs/sec)
- Maximum rate at which messages can be sent on iopub before they are
- limited."""
- ),
+ help=_i18n("DEPRECATED. Use ZMQChannelsWebsocketConnection.iopub_msg_rate_limit"),
)
+ @observe("iopub_msg_rate_limit")
+ def _deprecated_iopub_msg_rate_limit(self, change: t.Any) -> None:
+ self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection")
+
iopub_data_rate_limit = Float(
- 1000000,
+ allow_none=True,
config=True,
- help=_i18n(
- """(bytes/sec)
- Maximum rate at which stream output can be sent on iopub before they are
- limited."""
- ),
+ help=_i18n("DEPRECATED. Use ZMQChannelsWebsocketConnection.iopub_data_rate_limit"),
)
+ @observe("iopub_data_rate_limit")
+ def _deprecated_iopub_data_rate_limit(self, change: t.Any) -> None:
+ self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection")
+
rate_limit_window = Float(
- 3,
+ allow_none=True,
config=True,
- help=_i18n(
- """(sec) Time window used to
- check the message and data rate limits."""
- ),
+ help=_i18n("DEPRECATED. Use ZMQChannelsWebsocketConnection.rate_limit_window"),
)
+ @observe("rate_limit_window")
+ def _deprecated_rate_limit_window(self, change: t.Any) -> None:
+ self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection")
+
shutdown_no_activity_timeout = Integer(
0,
config=True,
help=(
- "Shut down the server after N seconds with no kernels or "
- "terminals running and no activity. "
+ "Shut down the server after N seconds with no kernels"
+ "running and no activity. "
"This can be used together with culling idle kernels "
"(MappingKernelManager.cull_idle_timeout) to "
"shutdown the Jupyter server when it's not in use. This is not "
@@ -1724,7 +1810,6 @@ def _update_server_extensions(self, change):
)
terminals_enabled = Bool(
- True,
config=True,
help=_i18n(
"""Set to False to disable terminals.
@@ -1738,14 +1823,9 @@ def _update_server_extensions(self, change):
),
)
- # Since use of terminals is also a function of whether the terminado package is
- # available, this variable holds the "final indication" of whether terminal functionality
- # should be considered (particularly during shutdown/cleanup). It is enabled only
- # once both the terminals "service" can be initialized and terminals_enabled is True.
- # Note: this variable is slightly different from 'terminals_available' in the web settings
- # in that this variable *could* remain false if terminado is available, yet the terminal
- # service's initialization still fails. As a result, this variable holds the truth.
- terminals_available = False
+ @default("terminals_enabled")
+ def _default_terminals_enabled(self) -> bool:
+ return True
authenticate_prometheus = Bool(
True,
@@ -1755,6 +1835,17 @@ def _update_server_extensions(self, change):
config=True,
)
+ static_immutable_cache = List(
+ Unicode(),
+ help="""
+ Paths to set up static files as immutable.
+
+ This allow setting up the cache control of static files as immutable.
+ It should be used for static file named with a hash for instance.
+ """,
+ config=True,
+ )
+
_starter_app = Instance(
default_value=None,
allow_none=True,
@@ -1762,12 +1853,12 @@ def _update_server_extensions(self, change):
)
@property
- def starter_app(self):
+ def starter_app(self) -> t.Any:
"""Get the Extension that started this server."""
return self._starter_app
- def parse_command_line(self, argv=None):
-
+ def parse_command_line(self, argv: t.Optional[list[str]] = None) -> None:
+ """Parse the command line options."""
super().parse_command_line(argv)
if self.extra_args:
@@ -1787,34 +1878,62 @@ def parse_command_line(self, argv=None):
c.ServerApp.file_to_run = f
self.update_config(c)
- def init_configurables(self):
-
+ def init_configurables(self) -> None:
+ """Initialize configurables."""
# If gateway server is configured, replace appropriate managers to perform redirection. To make
# this determination, instantiate the GatewayClient config singleton.
self.gateway_config = GatewayClient.instance(parent=self)
- if self.gateway_config.gateway_enabled:
- self.kernel_manager_class = (
- "jupyter_server.gateway.managers.GatewayMappingKernelManager"
+ if not issubclass(
+ self.kernel_manager_class,
+ AsyncMappingKernelManager,
+ ):
+ warnings.warn(
+ "The synchronous MappingKernelManager class is deprecated and will not be supported in Jupyter Server 3.0",
+ DeprecationWarning,
+ stacklevel=2,
)
- self.session_manager_class = "jupyter_server.gateway.managers.GatewaySessionManager"
- self.kernel_spec_manager_class = (
- "jupyter_server.gateway.managers.GatewayKernelSpecManager"
+
+ if not issubclass(
+ self.contents_manager_class,
+ AsyncContentsManager,
+ ):
+ warnings.warn(
+ "The synchronous ContentsManager classes are deprecated and will not be supported in Jupyter Server 3.0",
+ DeprecationWarning,
+ stacklevel=2,
)
self.kernel_spec_manager = self.kernel_spec_manager_class(
parent=self,
)
- self.kernel_manager = self.kernel_manager_class(
- parent=self,
- log=self.log,
- connection_dir=self.runtime_dir,
- kernel_spec_manager=self.kernel_spec_manager,
- )
+
+ kwargs = {
+ "parent": self,
+ "log": self.log,
+ "connection_dir": self.runtime_dir,
+ "kernel_spec_manager": self.kernel_spec_manager,
+ }
+ if jupyter_client.version_info > (8, 3, 0): # type:ignore[attr-defined]
+ if self.allow_external_kernels:
+ external_connection_dir = self.external_connection_dir
+ if external_connection_dir is None:
+ external_connection_dir = str(Path(self.runtime_dir) / "external_kernels")
+ kwargs["external_connection_dir"] = external_connection_dir
+ elif self.allow_external_kernels:
+ self.log.warning(
+ "Although allow_external_kernels=True, external kernels are not supported "
+ "because jupyter-client's version does not allow them (should be >8.3.0)."
+ )
+
+ self.kernel_manager = self.kernel_manager_class(**kwargs)
self.contents_manager = self.contents_manager_class(
parent=self,
log=self.log,
)
+ # Trigger a default/validation here explicitly while we still support the
+ # deprecated trait on ServerApp (FIXME remove when deprecation finalized)
+ self.contents_manager.preferred_dir # noqa: B018
self.session_manager = self.session_manager_class(
parent=self,
log=self.log,
@@ -1825,9 +1944,58 @@ def init_configurables(self):
parent=self,
log=self.log,
)
- self.authorizer = self.authorizer_class(parent=self, log=self.log)
+ identity_provider_kwargs = {"parent": self, "log": self.log}
+
+ if (
+ self.login_handler_class is not LoginHandler
+ and self.identity_provider_class is PasswordIdentityProvider
+ ):
+ # default identity provider, non-default LoginHandler
+ # this indicates legacy custom LoginHandler config.
+ # enable LegacyIdentityProvider, which defers to the LoginHandler for pre-2.0 behavior.
+ self.identity_provider_class = LegacyIdentityProvider
+ self.log.warning(
+ f"Customizing authentication via ServerApp.login_handler_class={self.login_handler_class}"
+ " is deprecated in Jupyter Server 2.0."
+ " Use ServerApp.identity_provider_class."
+ " Falling back on legacy authentication.",
+ )
+ identity_provider_kwargs["login_handler_class"] = self.login_handler_class
+ if self.logout_handler_class:
+ identity_provider_kwargs["logout_handler_class"] = self.logout_handler_class
+ elif self.login_handler_class is not LoginHandler:
+ # non-default login handler ignored because also explicitly set identity provider
+ self.log.warning(
+ f"Ignoring deprecated config ServerApp.login_handler_class={self.login_handler_class}."
+ " Superseded by ServerApp.identity_provider_class={self.identity_provider_class}."
+ )
+ self.identity_provider = self.identity_provider_class(**identity_provider_kwargs)
+
+ if self.identity_provider_class is LegacyIdentityProvider:
+ # legacy config stored the password in tornado_settings
+ self.tornado_settings["password"] = self.identity_provider.hashed_password # type:ignore[attr-defined]
+ self.tornado_settings["token"] = self.identity_provider.token
- def init_logging(self):
+ if self._token_set:
+ self.log.warning(
+ "ServerApp.token config is deprecated in jupyter-server 2.0. Use IdentityProvider.token"
+ )
+ if self.identity_provider.token_generated:
+ # default behavior: generated default token
+ # preserve deprecated ServerApp.token config
+ self.identity_provider.token_generated = False
+ self.identity_provider.token = self.token
+ else:
+ # identity_provider didn't generate a default token,
+ # that means it has some config that should take higher priority than deprecated ServerApp.token
+ self.log.warning("Ignoring deprecated ServerApp.token config")
+
+ self.authorizer = self.authorizer_class(
+ parent=self, log=self.log, identity_provider=self.identity_provider
+ )
+
+ def init_logging(self) -> None:
+ """Initialize logging."""
# This prevents double log messages because tornado use a root logger that
# self.log is a child of. The logging module dipatches log messages to a log
# and all of its ancenstors until propagate is set to False.
@@ -1842,7 +2010,25 @@ def init_logging(self):
logger.parent = self.log
logger.setLevel(self.log.level)
- def init_webapp(self):
+ def init_event_logger(self) -> None:
+ """Initialize the Event Bus."""
+ self.event_logger = EventLogger(parent=self)
+ # Load the core Jupyter Server event schemas
+ # All event schemas must start with Jupyter Server's
+ # events URI, `JUPYTER_SERVER_EVENTS_URI`.
+ schema_ids = [
+ "https://events.jupyter.org/jupyter_server/contents_service/v1",
+ "https://events.jupyter.org/jupyter_server/gateway_client/v1",
+ "https://events.jupyter.org/jupyter_server/kernel_actions/v1",
+ ]
+ for schema_id in schema_ids:
+ # Get the schema path from the schema ID.
+ rel_schema_path = schema_id.replace(JUPYTER_SERVER_EVENTS_URI + "/", "") + ".yaml"
+ schema_path = DEFAULT_EVENTS_SCHEMA_PATH / rel_schema_path
+ # Use this pathlib object to register the schema
+ self.event_logger.register_event_schema(schema_path)
+
+ def init_webapp(self) -> None:
"""initialize tornado webapp"""
self.tornado_settings["allow_origin"] = self.allow_origin
self.tornado_settings["websocket_compression_options"] = self.websocket_compression_options
@@ -1850,22 +2036,21 @@ def init_webapp(self):
self.tornado_settings["allow_origin_pat"] = re.compile(self.allow_origin_pat)
self.tornado_settings["allow_credentials"] = self.allow_credentials
self.tornado_settings["autoreload"] = self.autoreload
- self.tornado_settings["cookie_options"] = self.cookie_options
- self.tornado_settings["get_secure_cookie_kwargs"] = self.get_secure_cookie_kwargs
- self.tornado_settings["token"] = self.token
+
+ # deprecate accessing these directly, in favor of identity_provider?
+ self.tornado_settings["cookie_options"] = self.identity_provider.cookie_options
+ self.tornado_settings[
+ "get_secure_cookie_kwargs"
+ ] = self.identity_provider.get_secure_cookie_kwargs
+ self.tornado_settings["token"] = self.identity_provider.token
+
+ if self.static_immutable_cache:
+ self.tornado_settings["static_immutable_cache"] = self.static_immutable_cache
# ensure default_url starts with base_url
if not self.default_url.startswith(self.base_url):
self.default_url = url_path_join(self.base_url, self.default_url)
- if self.password_required and (not self.password):
- self.log.critical(
- _i18n("Jupyter servers are configured to only be run with a password.")
- )
- self.log.critical(_i18n("Hint: run the following command to set a password"))
- self.log.critical(_i18n("\t$ python -m jupyter_server.auth password"))
- sys.exit(1)
-
# Socket options validation.
if self.sock:
if self.port != DEFAULT_JUPYTER_SERVER_PORT:
@@ -1906,6 +2091,7 @@ def init_webapp(self):
self.session_manager,
self.kernel_spec_manager,
self.config_manager,
+ self.event_logger,
self.extra_services,
self.log,
self.base_url,
@@ -1913,6 +2099,8 @@ def init_webapp(self):
self.tornado_settings,
self.jinja_environment_options,
authorizer=self.authorizer,
+ identity_provider=self.identity_provider,
+ kernel_websocket_connection_class=self.kernel_websocket_connection_class,
)
if self.certfile:
self.ssl_options["certfile"] = self.certfile
@@ -1923,7 +2111,7 @@ def init_webapp(self):
if not self.ssl_options:
# could be an empty dict or None
# None indicates no SSL config
- self.ssl_options = None
+ self.ssl_options = None # type:ignore[assignment]
else:
# SSL may be missing, so only import it if it's to be used
import ssl
@@ -1936,12 +2124,16 @@ def init_webapp(self):
if self.ssl_options.get("ca_certs", False):
self.ssl_options.setdefault("cert_reqs", ssl.CERT_REQUIRED)
- self.login_handler_class.validate_security(self, ssl_options=self.ssl_options)
+ self.identity_provider.validate_security(self, ssl_options=self.ssl_options)
- def init_resources(self):
+ if isinstance(self.identity_provider, LegacyIdentityProvider):
+ # LegacyIdentityProvider needs access to the tornado settings dict
+ self.identity_provider.settings = self.web_app.settings
+
+ def init_resources(self) -> None:
"""initialize system resources"""
if resource is None:
- self.log.debug(
+ self.log.debug( # type:ignore[unreachable]
"Ignoring min_open_files_limit because the limit cannot be adjusted (for example, on Windows)"
)
return
@@ -1949,17 +2141,17 @@ def init_resources(self):
old_soft, old_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
soft = self.min_open_files_limit
hard = old_hard
- if old_soft < soft:
+ if soft is not None and old_soft < soft:
if hard < soft:
hard = soft
self.log.debug(
- "Raising open file limit: soft {}->{}; hard {}->{}".format(
- old_soft, soft, old_hard, hard
- )
+ f"Raising open file limit: soft {old_soft}->{soft}; hard {old_hard}->{hard}"
)
resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard))
- def _get_urlparts(self, path=None, include_token=False):
+ def _get_urlparts(
+ self, path: t.Optional[str] = None, include_token: bool = False
+ ) -> urllib.parse.ParseResult:
"""Constructs a urllib named tuple, ParseResult,
with default values set by server config.
The returned tuple can be manipulated using the `_replace` method.
@@ -1976,30 +2168,24 @@ def _get_urlparts(self, path=None, include_token=False):
else:
ip = f"[{self.ip}]" if ":" in self.ip else self.ip
netloc = f"{ip}:{self.port}"
- if self.certfile:
- scheme = "https"
- else:
- scheme = "http"
+ scheme = "https" if self.certfile else "http"
if not path:
path = self.default_url
query = None
- if include_token:
- if self.token: # Don't log full token if it came from config
- token = self.token if self._token_generated else "..."
- query = urllib.parse.urlencode({"token": token})
+ # Don't log full token if it came from config
+ if include_token and self.identity_provider.token:
+ token = (
+ self.identity_provider.token if self.identity_provider.token_generated else "..."
+ )
+ query = urllib.parse.urlencode({"token": token})
# Build the URL Parts to dump.
urlparts = urllib.parse.ParseResult(
- scheme=scheme,
- netloc=netloc,
- path=path,
- params=None,
- query=query,
- fragment=None,
+ scheme=scheme, netloc=netloc, path=path, query=query or "", params="", fragment=""
)
return urlparts
@property
- def public_url(self):
+ def public_url(self) -> str:
parts = self._get_urlparts(include_token=True)
# Update with custom pieces.
if self.custom_display_url:
@@ -2012,7 +2198,7 @@ def public_url(self):
return parts.geturl()
@property
- def local_url(self):
+ def local_url(self) -> str:
parts = self._get_urlparts(include_token=True)
# Update with custom pieces.
if not self.sock:
@@ -2020,37 +2206,25 @@ def local_url(self):
return parts.geturl()
@property
- def display_url(self):
+ def display_url(self) -> str:
"""Human readable string with URLs for interacting
with the running Jupyter Server
"""
- url = self.public_url + "\n or " + self.local_url
+ url = self.public_url + "\n " + self.local_url
return url
@property
- def connection_url(self):
+ def connection_url(self) -> str:
urlparts = self._get_urlparts(path=self.base_url)
return urlparts.geturl()
- def init_terminals(self):
- if not self.terminals_enabled:
- return
-
- try:
- from jupyter_server.terminal import initialize
-
- initialize(
- self.web_app,
- self.root_dir,
- self.connection_url,
- self.terminado_settings,
- )
- self.terminals_available = True
- except ImportError as e:
- self.log.warning(_i18n("Terminals not available (error was %s)"), e)
-
- def init_signal(self):
- if not sys.platform.startswith("win") and sys.stdin and sys.stdin.isatty():
+ def init_signal(self) -> None:
+ """Initialize signal handlers."""
+ if (
+ not sys.platform.startswith("win")
+ and sys.stdin # type:ignore[truthy-bool]
+ and sys.stdin.isatty()
+ ):
signal.signal(signal.SIGINT, self._handle_sigint)
signal.signal(signal.SIGTERM, self._signal_stop)
if hasattr(signal, "SIGUSR1"):
@@ -2060,7 +2234,7 @@ def init_signal(self):
# only on BSD-based systems
signal.signal(signal.SIGINFO, self._signal_info)
- def _handle_sigint(self, sig, frame):
+ def _handle_sigint(self, sig: t.Any, frame: t.Any) -> None:
"""SIGINT handler spawns confirmation dialog"""
# register more forceful signal handler for ^C^C case
signal.signal(signal.SIGINT, self._signal_stop)
@@ -2070,11 +2244,11 @@ def _handle_sigint(self, sig, frame):
thread.daemon = True
thread.start()
- def _restore_sigint_handler(self):
+ def _restore_sigint_handler(self) -> None:
"""callback for restoring original SIGINT handler"""
signal.signal(signal.SIGINT, self._handle_sigint)
- def _confirm_exit(self):
+ def _confirm_exit(self) -> None:
"""confirm shutdown on ^C
A second ^C, or answering 'y' within 5s will cause shutdown,
@@ -2091,7 +2265,7 @@ def _confirm_exit(self):
# since this might be called from a signal handler
self.stop(from_signal=True)
return
- print(self.running_server_info())
+ info(self.running_server_info())
yes = _i18n("y")
no = _i18n("n")
sys.stdout.write(_i18n("Shutdown this Jupyter server (%s/[%s])? ") % (yes, no))
@@ -2106,27 +2280,32 @@ def _confirm_exit(self):
self.stop(from_signal=True)
return
else:
- print(_i18n("No answer for 5s:"), end=" ")
- print(_i18n("resuming operation..."))
+ if self._stopping:
+ # don't show 'no answer' if we're actually stopping,
+ # e.g. ctrl-C ctrl-C
+ return
+ info(_i18n("No answer for 5s:"))
+ info(_i18n("resuming operation..."))
# no answer, or answer is no:
# set it back to original SIGINT handler
# use IOLoop.add_callback because signal.signal must be called
# from main thread
self.io_loop.add_callback_from_signal(self._restore_sigint_handler)
- def _signal_stop(self, sig, frame):
+ def _signal_stop(self, sig: t.Any, frame: t.Any) -> None:
+ """Handle a stop signal."""
self.log.critical(_i18n("received signal %s, stopping"), sig)
self.stop(from_signal=True)
- def _signal_info(self, sig, frame):
- print(self.running_server_info())
+ def _signal_info(self, sig: t.Any, frame: t.Any) -> None:
+ """Handle an info signal."""
+ self.log.info(self.running_server_info())
- def init_components(self):
+ def init_components(self) -> None:
"""Check the components submodule, and warn if it's unclean"""
# TODO: this should still check, but now we use bower, not git submodule
- pass
- def find_server_extensions(self):
+ def find_server_extensions(self) -> None:
"""
Searches Jupyter paths for jpserver_extensions.
"""
@@ -2149,7 +2328,7 @@ def find_server_extensions(self):
self.config.ServerApp.jpserver_extensions.update({modulename: enabled})
self.jpserver_extensions.update({modulename: enabled})
- def init_server_extensions(self):
+ def init_server_extensions(self) -> None:
"""
If an extension's metadata includes an 'app' key,
the value must be a subclass of ExtensionApp. An instance
@@ -2162,7 +2341,7 @@ def init_server_extensions(self):
self.extension_manager.from_jpserver_extensions(self.jpserver_extensions)
self.extension_manager.link_all_extensions()
- def load_server_extensions(self):
+ def load_server_extensions(self) -> None:
"""Load any extensions specified by config.
Import the module, then call the load_jupyter_server_extension function,
@@ -2172,7 +2351,7 @@ def load_server_extensions(self):
"""
self.extension_manager.load_all_extensions()
- def init_mime_overrides(self):
+ def init_mime_overrides(self) -> None:
# On some Windows machines, an application has registered incorrect
# mimetypes in the registry.
# Tornado uses this when serving .css and .js files, causing browsers to
@@ -2188,57 +2367,58 @@ def init_mime_overrides(self):
# for python <3.8
mimetypes.add_type("application/wasm", ".wasm")
- def shutdown_no_activity(self):
+ def shutdown_no_activity(self) -> None:
"""Shutdown server on timeout when there are no kernels or terminals."""
km = self.kernel_manager
if len(km) != 0:
return # Kernels still running
- if self.terminals_available:
- term_mgr = self.web_app.settings["terminal_manager"]
- if term_mgr.terminals:
- return # Terminals still running
+ if self.extension_manager.any_activity():
+ return
seconds_since_active = (utcnow() - self.web_app.last_activity()).total_seconds()
self.log.debug("No activity for %d seconds.", seconds_since_active)
if seconds_since_active > self.shutdown_no_activity_timeout:
self.log.info(
- "No kernels or terminals for %d seconds; shutting down.",
+ "No kernels for %d seconds; shutting down.",
seconds_since_active,
)
self.stop()
- def init_shutdown_no_activity(self):
+ def init_shutdown_no_activity(self) -> None:
+ """Initialize a shutdown on no activity."""
if self.shutdown_no_activity_timeout > 0:
self.log.info(
- "Will shut down after %d seconds with no kernels or terminals.",
+ "Will shut down after %d seconds with no kernels.",
self.shutdown_no_activity_timeout,
)
pc = ioloop.PeriodicCallback(self.shutdown_no_activity, 60000)
pc.start()
@property
- def http_server(self):
+ def http_server(self) -> httpserver.HTTPServer:
"""An instance of Tornado's HTTPServer class for the Server Web Application."""
try:
return self._http_server
- except AttributeError as e:
- raise AttributeError(
+ except AttributeError:
+ msg = (
"An HTTPServer instance has not been created for the "
"Server Web Application. To create an HTTPServer for this "
"application, call `.init_httpserver()`."
- ) from e
+ )
+ raise AttributeError(msg) from None
- def init_httpserver(self):
+ def init_httpserver(self) -> None:
"""Creates an instance of a Tornado HTTPServer for the Server Web Application
and sets the http_server attribute.
"""
# Check that a web_app has been initialized before starting a server.
if not hasattr(self, "web_app"):
- raise AttributeError(
+ msg = (
"A tornado web application has not be initialized. "
"Try calling `.init_webapp()` first."
)
+ raise AttributeError(msg)
# Create an instance of the server.
self._http_server = httpserver.HTTPServer(
@@ -2249,7 +2429,14 @@ def init_httpserver(self):
max_buffer_size=self.max_buffer_size,
)
- success = self._bind_http_server()
+ # binding sockets must be called from inside an event loop
+ if not self.sock:
+ self._find_http_port()
+ self.io_loop.add_callback(self._bind_http_server)
+
+ def _bind_http_server(self) -> None:
+ """Bind our http server."""
+ success = self._bind_http_server_unix() if self.sock else self._bind_http_server_tcp()
if not success:
self.log.critical(
_i18n(
@@ -2259,10 +2446,8 @@ def init_httpserver(self):
)
self.exit(1)
- def _bind_http_server(self):
- return self._bind_http_server_unix() if self.sock else self._bind_http_server_tcp()
-
- def _bind_http_server_unix(self):
+ def _bind_http_server_unix(self) -> bool:
+ """Bind an http server on unix."""
if unix_socket_in_use(self.sock):
self.log.warning(_i18n("The socket %s is already in use.") % self.sock)
return False
@@ -2282,11 +2467,19 @@ def _bind_http_server_unix(self):
else:
return True
- def _bind_http_server_tcp(self):
- success = None
+ def _bind_http_server_tcp(self) -> bool:
+ """Bind a tcp server."""
+ self.http_server.listen(self.port, self.ip)
+ return True
+
+ def _find_http_port(self) -> None:
+ """Find an available http port."""
+ success = False
+ port = self.port
for port in random_ports(self.port, self.port_retries + 1):
try:
- self.http_server.listen(port, self.ip)
+ sockets = bind_sockets(port, self.ip)
+ sockets[0].close()
except OSError as e:
if e.errno == errno.EADDRINUSE:
if self.port_retries:
@@ -2296,17 +2489,16 @@ def _bind_http_server_tcp(self):
else:
self.log.info(_i18n("The port %i is already in use.") % port)
continue
- elif e.errno in (
+ if e.errno in (
errno.EACCES,
getattr(errno, "WSAEACCES", errno.EACCES),
):
self.log.warning(_i18n("Permission to listen on port %i denied.") % port)
continue
- else:
- raise
+ raise
else:
- self.port = port
success = True
+ self.port = port
break
if not success:
if self.port_retries:
@@ -2325,10 +2517,9 @@ def _bind_http_server_tcp(self):
% port
)
self.exit(1)
- return success
@staticmethod
- def _init_asyncio_patch():
+ def _init_asyncio_patch() -> None:
"""set default asyncio policy to be compatible with tornado
Tornado 6.0 is not compatible with default asyncio
@@ -2343,10 +2534,7 @@ def _init_asyncio_patch():
import asyncio
try:
- from asyncio import (
- WindowsProactorEventLoopPolicy,
- WindowsSelectorEventLoopPolicy,
- )
+ from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy
except ImportError:
pass
# not affected
@@ -2358,11 +2546,11 @@ def _init_asyncio_patch():
@catch_config_error
def initialize(
self,
- argv=None,
- find_extensions=True,
- new_httpserver=True,
- starter_extension=None,
- ):
+ argv: t.Optional[list[str]] = None,
+ find_extensions: bool = True,
+ new_httpserver: bool = True,
+ starter_extension: t.Any = None,
+ ) -> None:
"""Initialize the Server application class, configurables, web application, and http server.
Parameters
@@ -2371,7 +2559,7 @@ def initialize(
CLI arguments to parse.
find_extensions : bool
If True, find and load extensions listed in Jupyter config paths. If False,
- only load extensions that are passed to ServerApp directy through
+ only load extensions that are passed to ServerApp directly through
the `argv`, `config`, or `jpserver_extensions` arguments.
new_httpserver : bool
If True, a tornado HTTPServer instance will be created and configured for the Server Web
@@ -2386,11 +2574,16 @@ def initialize(
super().initialize(argv=argv)
if self._dispatching:
return
+ # initialize io loop as early as possible,
+ # so configurables, extensions may reference the event loop
+ self.init_ioloop()
+
# Then, use extensions' config loading mechanism to
# update config. ServerApp config takes precedence.
if find_extensions:
self.find_server_extensions()
self.init_logging()
+ self.init_event_logger()
self.init_server_extensions()
# Special case the starter extension and load
@@ -2409,68 +2602,49 @@ def initialize(
self.init_configurables()
self.init_components()
self.init_webapp()
- self.init_terminals()
self.init_signal()
- self.init_ioloop()
self.load_server_extensions()
self.init_mime_overrides()
self.init_shutdown_no_activity()
if new_httpserver:
self.init_httpserver()
- async def cleanup_kernels(self):
+ async def cleanup_kernels(self) -> None:
"""Shutdown all kernels.
The kernels will shutdown themselves when this process no longer exists,
but explicit shutdown allows the KernelManagers to cleanup the connection files.
"""
+ if not getattr(self, "kernel_manager", None):
+ return
n_kernels = len(self.kernel_manager.list_kernel_ids())
kernel_msg = trans.ngettext(
"Shutting down %d kernel", "Shutting down %d kernels", n_kernels
)
self.log.info(kernel_msg % n_kernels)
- await run_sync_in_loop(self.kernel_manager.shutdown_all())
+ await ensure_async(self.kernel_manager.shutdown_all())
- async def cleanup_terminals(self):
- """Shutdown all terminals.
-
- The terminals will shutdown themselves when this process no longer exists,
- but explicit shutdown allows the TerminalManager to cleanup.
- """
- if not self.terminals_available:
- return
-
- terminal_manager = self.web_app.settings["terminal_manager"]
- n_terminals = len(terminal_manager.list())
- terminal_msg = trans.ngettext(
- "Shutting down %d terminal", "Shutting down %d terminals", n_terminals
- )
- self.log.info(terminal_msg % n_terminals)
- await run_sync_in_loop(terminal_manager.terminate_all())
-
- async def cleanup_extensions(self):
+ async def cleanup_extensions(self) -> None:
"""Call shutdown hooks in all extensions."""
+ if not getattr(self, "extension_manager", None):
+ return
n_extensions = len(self.extension_manager.extension_apps)
extension_msg = trans.ngettext(
"Shutting down %d extension", "Shutting down %d extensions", n_extensions
)
self.log.info(extension_msg % n_extensions)
- await run_sync_in_loop(self.extension_manager.stop_all_extensions())
+ await ensure_async(self.extension_manager.stop_all_extensions())
- def running_server_info(self, kernel_count=True):
- "Return the current working directory and the server url information"
- info = self.contents_manager.info_string() + "\n"
+ def running_server_info(self, kernel_count: bool = True) -> str:
+ """Return the current working directory and the server url information"""
+ info = t.cast(str, self.contents_manager.info_string()) + "\n"
if kernel_count:
n_kernels = len(self.kernel_manager.list_kernel_ids())
kernel_msg = trans.ngettext("%d active kernel", "%d active kernels", n_kernels)
info += kernel_msg % n_kernels
info += "\n"
# Format the info so that the URL fits on a single line in 80 char display
- info += _i18n(
- "Jupyter Server {version} is running at:\n{url}".format(
- version=ServerApp.version, url=self.display_url
- )
- )
+ info += _i18n(f"Jupyter Server {ServerApp.version} is running at:\n{self.display_url}")
if self.gateway_config.gateway_enabled:
info += (
_i18n("\nKernels will be managed by the Gateway server running at:\n%s")
@@ -2478,7 +2652,7 @@ def running_server_info(self, kernel_count=True):
)
return info
- def server_info(self):
+ def server_info(self) -> dict[str, t.Any]:
"""Return a JSONable dict of information about this server."""
return {
"url": self.connection_url,
@@ -2487,22 +2661,22 @@ def server_info(self):
"sock": self.sock,
"secure": bool(self.certfile),
"base_url": self.base_url,
- "token": self.token,
+ "token": self.identity_provider.token,
"root_dir": os.path.abspath(self.root_dir),
"password": bool(self.password),
"pid": os.getpid(),
"version": ServerApp.version,
}
- def write_server_info_file(self):
+ def write_server_info_file(self) -> None:
"""Write the result of server_info() to the JSON file info_file."""
try:
with secure_write(self.info_file) as f:
json.dump(self.server_info(), f, indent=2, sort_keys=True)
except OSError as e:
- self.log.error(_i18n("Failed to write server-info to %s: %s"), self.info_file, e)
+ self.log.error(_i18n("Failed to write server-info to %s: %r"), self.info_file, e)
- def remove_server_info_file(self):
+ def remove_server_info_file(self) -> None:
"""Remove the jpserver-.json file created for this server.
Ignores the error raised when the file has already been removed.
@@ -2513,7 +2687,7 @@ def remove_server_info_file(self):
if e.errno != errno.ENOENT:
raise
- def _resolve_file_to_run_and_root_dir(self):
+ def _resolve_file_to_run_and_root_dir(self) -> str:
"""Returns a relative path from file_to_run
to root_dir. If root_dir and file_to_run
are incompatible, i.e. on different subtrees,
@@ -2543,17 +2717,19 @@ def _resolve_file_to_run_and_root_dir(self):
"is on the same path as `root_dir`."
)
self.exit(1)
+ return ""
- def _write_browser_open_file(self, url, fh):
- if self.token:
- url = url_concat(url, {"token": self.token})
+ def _write_browser_open_file(self, url: str, fh: t.Any) -> None:
+ """Write the browser open file."""
+ if self.identity_provider.token:
+ url = url_concat(url, {"token": self.identity_provider.token})
url = url_path_join(self.connection_url, url)
jinja2_env = self.web_app.settings["jinja2_env"]
template = jinja2_env.get_template("browser-open.html")
fh.write(template.render(open_url=url, base_url=self.base_url))
- def write_browser_open_files(self):
+ def write_browser_open_files(self) -> None:
"""Write an `browser_open_file` and `browser_open_file_to_run` files
This can be used to open a file directly in a browser.
@@ -2574,7 +2750,7 @@ def write_browser_open_files(self):
with open(self.browser_open_file_to_run, "w", encoding="utf-8") as f:
self._write_browser_open_file(file_open_url, f)
- def write_browser_open_file(self):
+ def write_browser_open_file(self) -> None:
"""Write an jpserver--open.html file
This can be used to open the notebook in a browser
@@ -2585,7 +2761,7 @@ def write_browser_open_file(self):
with open(self.browser_open_file, "w", encoding="utf-8") as f:
self._write_browser_open_file(open_url, f)
- def remove_browser_open_files(self):
+ def remove_browser_open_files(self) -> None:
"""Remove the `browser_open_file` and `browser_open_file_to_run` files
created for this server.
@@ -2598,7 +2774,7 @@ def remove_browser_open_files(self):
if e.errno != errno.ENOENT:
raise
- def remove_browser_open_file(self):
+ def remove_browser_open_file(self) -> None:
"""Remove the jpserver--open.html file created for this server.
Ignores the error raised when the file has already been removed.
@@ -2609,14 +2785,15 @@ def remove_browser_open_file(self):
if e.errno != errno.ENOENT:
raise
- def _prepare_browser_open(self):
+ def _prepare_browser_open(self) -> tuple[str, t.Optional[str]]:
+ """Prepare to open the browser."""
if not self.use_redirect_file:
uri = self.default_url[len(self.base_url) :]
- if self.token:
- uri = url_concat(uri, {"token": self.token})
+ if self.identity_provider.token:
+ uri = url_concat(uri, {"token": self.identity_provider.token})
- if self.file_to_run:
+ if self.file_to_run: # noqa: SIM108
# Create a separate, temporary open-browser-file
# pointing at a specific file.
open_file = self.browser_open_file_to_run
@@ -2631,11 +2808,16 @@ def _prepare_browser_open(self):
return assembled_url, open_file
- def launch_browser(self):
+ def launch_browser(self) -> None:
+ """Launch the browser."""
+ # Deferred import for environments that do not have
+ # the webbrowser module.
+ import webbrowser
+
try:
browser = webbrowser.get(self.browser or None)
except webbrowser.Error as e:
- self.log.warning(_i18n("No web browser found: %s.") % e)
+ self.log.warning(_i18n("No web browser found: %r.") % e)
browser = None
if not browser:
@@ -2644,11 +2826,13 @@ def launch_browser(self):
assembled_url, _ = self._prepare_browser_open()
def target():
+ assert browser is not None
browser.open(assembled_url, new=self.webbrowser_open_new)
threading.Thread(target=target).start()
- def start_app(self):
+ def start_app(self) -> None:
+ """Start the Jupyter Server application."""
super().start()
if not self.allow_root:
@@ -2682,13 +2866,15 @@ def start_app(self):
)
self.write_server_info_file()
- self.write_browser_open_files()
+
+ if not self.no_browser_open_file:
+ self.write_browser_open_files()
# Handle the browser opening.
if self.open_browser and not self.sock:
self.launch_browser()
- if self.token and self._token_generated:
+ if self.identity_provider.token and self.identity_provider.token_generated:
# log full URL with generated token, so there's a copy/pasteable link
# with auth info.
if self.sock:
@@ -2700,27 +2886,35 @@ def start_app(self):
"",
(
"UNIX sockets are not browser-connectable, but you can tunnel to "
- "the instance via e.g.`ssh -L 8888:%s -N user@this_host` and then "
- "open e.g. %s in a browser."
- )
- % (self.sock, self.connection_url),
+ "the instance via e.g.`ssh -L 8888:{} -N user@this_host` and then "
+ "open e.g. {} in a browser."
+ ).format(self.sock, self.connection_url),
]
)
)
else:
- self.log.critical(
- "\n".join(
- [
- "\n",
+ if self.no_browser_open_file:
+ message = [
+ "\n",
+ _i18n("To access the server, copy and paste one of these URLs:"),
+ " %s" % self.display_url,
+ ]
+ else:
+ message = [
+ "\n",
+ _i18n(
"To access the server, open this file in a browser:",
- " %s" % urljoin("file:", pathname2url(self.browser_open_file)),
+ ),
+ " %s" % urljoin("file:", pathname2url(self.browser_open_file)),
+ _i18n(
"Or copy and paste one of these URLs:",
- " %s" % self.display_url,
- ]
- )
- )
+ ),
+ " %s" % self.display_url,
+ ]
- async def _cleanup(self):
+ self.log.critical("\n".join(message))
+
+ async def _cleanup(self) -> None:
"""General cleanup of files, extensions and kernels created
by this instance ServerApp.
"""
@@ -2728,9 +2922,28 @@ async def _cleanup(self):
self.remove_browser_open_files()
await self.cleanup_extensions()
await self.cleanup_kernels()
- await self.cleanup_terminals()
+ try:
+ await self.kernel_websocket_connection_class.close_all() # type:ignore[attr-defined]
+ except AttributeError:
+ # This can happen in two different scenarios:
+ #
+ # 1. During tests, where the _cleanup method is invoked without
+ # the corresponding initialize method having been invoked.
+ # 2. If the provided `kernel_websocket_connection_class` does not
+ # implement the `close_all` class method.
+ #
+ # In either case, we don't need to do anything and just want to treat
+ # the raised error as a no-op.
+ pass
+ if getattr(self, "kernel_manager", None):
+ self.kernel_manager.__del__()
+ if getattr(self, "session_manager", None):
+ self.session_manager.close()
+ if hasattr(self, "http_server"):
+ # Stop a server if its set.
+ self.http_server.stop()
- def start_ioloop(self):
+ def start_ioloop(self) -> None:
"""Start the IO Loop."""
if sys.platform.startswith("win"):
# add no-op to wake every 5s
@@ -2742,11 +2955,11 @@ def start_ioloop(self):
except KeyboardInterrupt:
self.log.info(_i18n("Interrupted..."))
- def init_ioloop(self):
+ def init_ioloop(self) -> None:
"""init self.io_loop so that an extension can use it by io_loop.call_later() to create background tasks"""
self.io_loop = ioloop.IOLoop.current()
- def start(self):
+ def start(self) -> None:
"""Start the Jupyter server app, after initialization
This method takes no arguments so all configuration and initialization
@@ -2754,14 +2967,17 @@ def start(self):
self.start_app()
self.start_ioloop()
- async def _stop(self):
+ async def _stop(self) -> None:
"""Cleanup resources and stop the IO Loop."""
await self._cleanup()
- self.io_loop.stop()
+ if getattr(self, "io_loop", None):
+ self.io_loop.stop()
- def stop(self, from_signal=False):
+ def stop(self, from_signal: bool = False) -> None:
"""Cleanup resources and stop the server."""
- if hasattr(self, "_http_server"):
+ # signal that stopping has begun
+ self._stopping = True
+ if hasattr(self, "http_server"):
# Stop a server if its set.
self.http_server.stop()
if getattr(self, "io_loop", None):
@@ -2773,7 +2989,9 @@ def stop(self, from_signal=False):
self.io_loop.add_callback(self._stop)
-def list_running_servers(runtime_dir=None, log=None):
+def list_running_servers(
+ runtime_dir: t.Optional[str] = None, log: t.Optional[logging.Logger] = None
+) -> t.Generator[t.Any, None, None]:
"""Iterate over the server info files of running Jupyter servers.
Given a runtime directory, find jpserver-* files in the security directory,
@@ -2790,7 +3008,11 @@ def list_running_servers(runtime_dir=None, log=None):
for file_name in os.listdir(runtime_dir):
if re.match("jpserver-(.+).json", file_name):
with open(os.path.join(runtime_dir, file_name), encoding="utf-8") as f:
- info = json.load(f)
+ # Handle race condition where file is being written.
+ try:
+ info = json.load(f)
+ except json.JSONDecodeError:
+ continue
# Simple check whether that process is really still running
# Also remove leftover files from IPython 2.x without a pid field
diff --git a/jupyter_server/services/api/api.yaml b/jupyter_server/services/api/api.yaml
index 844831e045..5ee5c416bd 100644
--- a/jupyter_server/services/api/api.yaml
+++ b/jupyter_server/services/api/api.yaml
@@ -33,6 +33,16 @@ parameters:
in: path
description: file path
type: string
+ permissions:
+ name: permissions
+ type: string
+ required: false
+ in: query
+ description: |
+ JSON-serialized dictionary of `{"resource": ["action",]}`
+ (dict of lists of strings) to check.
+ The same dictionary structure will be returned,
+ containing only the actions for which the user is authorized.
checkpoint_id:
name: checkpoint_id
required: true
@@ -53,6 +63,22 @@ parameters:
type: string
paths:
+ /api/:
+ get:
+ summary: Get the Jupyter Server version
+ description: |
+ This endpoint returns only the Jupyter Server version.
+ It does not require any authentication.
+ responses:
+ 200:
+ description: Jupyter Server version information
+ schema:
+ type: object
+ properties:
+ version:
+ type: string
+ description: The Jupyter Server version number as a string.
+
/api/contents/{path}:
parameters:
- $ref: "#/parameters/path"
@@ -80,6 +106,10 @@ paths:
in: query
description: "Return content (0 for no content, 1 for return content)"
type: integer
+ - name: hash
+ in: query
+ description: "May return hash hexdigest string of content and the hash algorithm (0 for no hash - default, 1 for return hash). It may be ignored by the content manager."
+ type: integer
responses:
404:
description: No item found
@@ -578,7 +608,7 @@ paths:
- terminals
responses:
200:
- description: Succesfully created a new terminal
+ description: Successfully created a new terminal
schema:
$ref: "#/definitions/Terminal"
403:
@@ -611,12 +641,47 @@ paths:
- $ref: "#/parameters/terminal_id"
responses:
204:
- description: Succesfully deleted terminal session
+ description: Successfully deleted terminal session
403:
description: Forbidden to access
404:
description: Not found
-
+ /api/me:
+ get:
+ summary: |
+ Get the identity of the currently authenticated user.
+ If present, a `permissions` argument may be specified
+ to check what actions the user currently is authorized to take.
+ tags:
+ - identity
+ parameters:
+ - $ref: "#/parameters/permissions"
+ responses:
+ 200:
+ description: The user's identity and permissions
+ schema:
+ type: object
+ properties:
+ identity:
+ $ref: "#/definitions/Identity"
+ permissions:
+ $ref: "#/definitions/Permissions"
+ example:
+ identity:
+ username: minrk
+ name: Min Ragan-Kelley
+ display_name: Min RK
+ initials: MRK
+ avatar_url: null
+ color: null
+ permissions:
+ contents:
+ - read
+ - write
+ kernels:
+ - read
+ - write
+ - execute
/api/status:
get:
summary: Get the current status/activity of the server.
@@ -663,6 +728,53 @@ definitions:
type: number
description: |
The total number of running kernels.
+ Identity:
+ description: The identity of the currently authenticated user
+ properties:
+ username:
+ type: string
+ description: |
+ Unique string identifying the user
+ name:
+ type: string
+ description: |
+ For-humans name of the user.
+ May be the same as `username` in systems where
+ only usernames are available.
+ display_name:
+ type: string
+ description: |
+ Alternate rendering of name for display.
+ Often the same as `name`.
+ initials:
+ type: string
+ description: |
+ Short string of initials.
+ Initials should not be derived automatically due to localization issues.
+ May be `null` if unavailable.
+ avatar_url:
+ type: string
+ description: |
+ URL of an avatar to be used for the user.
+ May be `null` if unavailable.
+ color:
+ type: string
+ description: |
+ A CSS color string to use as a preferred color,
+ such as for collaboration cursors.
+ May be `null` if unavailable.
+ Permissions:
+ type: object
+ description: |
+ A dict of the form: `{"resource": ["action",]}`
+ containing only the AUTHORIZED subset of resource+actions
+ from the permissions specified in the request.
+ If no permission checks were made in the request,
+ this will be empty.
+ additionalProperties:
+ type: array
+ items:
+ type: string
KernelSpec:
description: Kernel spec (contents of kernel.json)
properties:
@@ -777,7 +889,7 @@ definitions:
kernel:
$ref: "#/definitions/Kernel"
Contents:
- description: "A contents object. The content and format keys may be null if content is not contained. If type is 'file', then the mimetype will be null."
+ description: "A contents object. The content and format keys may be null if content is not contained. The hash maybe null if hash is not required. If type is 'file', then the mimetype will be null."
type: object
required:
- type
@@ -826,6 +938,12 @@ definitions:
format:
type: string
description: Format of content (one of null, 'text', 'base64', 'json')
+ hash:
+ type: string
+ description: "[optional] The hexdigest hash string of content, if requested (otherwise null). It cannot be null if hash_algorithm is defined."
+ hash_algorithm:
+ type: string
+ description: "[optional] The algorithm used to produce the hash, if requested (otherwise null). It cannot be null if hash is defined."
Checkpoints:
description: A checkpoint object.
type: object
diff --git a/jupyter_server/services/api/handlers.py b/jupyter_server/services/api/handlers.py
index 1c0cca5e19..efb361186c 100644
--- a/jupyter_server/services/api/handlers.py
+++ b/jupyter_server/services/api/handlers.py
@@ -3,12 +3,13 @@
# Distributed under the terms of the Modified BSD License.
import json
import os
+from typing import Any, Dict, List
+from jupyter_core.utils import ensure_async
from tornado import web
from jupyter_server._tz import isoformat, utcfromtimestamp
-from jupyter_server.auth import authorized
-from jupyter_server.utils import ensure_async
+from jupyter_server.auth.decorator import authorized
from ...base.handlers import APIHandler, JupyterHandler
@@ -16,22 +17,28 @@
class APISpecHandler(web.StaticFileHandler, JupyterHandler):
+ """A spec handler for the REST API."""
+
auth_resource = AUTH_RESOURCE
def initialize(self):
+ """Initialize the API spec handler."""
web.StaticFileHandler.initialize(self, path=os.path.dirname(__file__))
@web.authenticated
@authorized
def get(self):
+ """Get the API spec."""
self.log.warning("Serving api spec (experimental, incomplete)")
return web.StaticFileHandler.get(self, "api.yaml")
def get_content_type(self):
+ """Get the content type."""
return "text/x-yaml"
class APIStatusHandler(APIHandler):
+ """An API status handler."""
auth_resource = AUTH_RESOURCE
_track_activity = False
@@ -39,13 +46,14 @@ class APIStatusHandler(APIHandler):
@web.authenticated
@authorized
async def get(self):
+ """Get the API status."""
# if started was missing, use unix epoch
started = self.settings.get("started", utcfromtimestamp(0))
started = isoformat(started)
kernels = await ensure_async(self.kernel_manager.list_kernels())
total_connections = sum(k["connections"] for k in kernels)
- last_activity = isoformat(self.application.last_activity())
+ last_activity = isoformat(self.application.last_activity()) # type:ignore[attr-defined]
model = {
"started": started,
"last_activity": last_activity,
@@ -55,7 +63,50 @@ async def get(self):
self.finish(json.dumps(model, sort_keys=True))
+class IdentityHandler(APIHandler):
+ """Get the current user's identity model"""
+
+ @web.authenticated
+ def get(self):
+ """Get the identity model."""
+ permissions_json: str = self.get_argument("permissions", "")
+ bad_permissions_msg = f'permissions should be a JSON dict of {{"resource": ["action",]}}, got {permissions_json!r}'
+ if permissions_json:
+ try:
+ permissions_to_check = json.loads(permissions_json)
+ except ValueError as e:
+ raise web.HTTPError(400, bad_permissions_msg) from e
+ if not isinstance(permissions_to_check, dict):
+ raise web.HTTPError(400, bad_permissions_msg)
+ else:
+ permissions_to_check = {}
+
+ permissions: Dict[str, List[str]] = {}
+ user = self.current_user
+
+ for resource, actions in permissions_to_check.items():
+ if (
+ not isinstance(resource, str)
+ or not isinstance(actions, list)
+ or not all(isinstance(action, str) for action in actions)
+ ):
+ raise web.HTTPError(400, bad_permissions_msg)
+
+ allowed = permissions[resource] = []
+ for action in actions:
+ if self.authorizer.is_authorized(self, user=user, resource=resource, action=action):
+ allowed.append(action)
+
+ identity: Dict[str, Any] = self.identity_provider.identity_model(user)
+ model = {
+ "identity": identity,
+ "permissions": permissions,
+ }
+ self.write(json.dumps(model))
+
+
default_handlers = [
(r"/api/spec.yaml", APISpecHandler),
(r"/api/status", APIStatusHandler),
+ (r"/api/me", IdentityHandler),
]
diff --git a/jupyter_server/services/config/__init__.py b/jupyter_server/services/config/__init__.py
index 9a2aee241d..a28f60a2b3 100644
--- a/jupyter_server/services/config/__init__.py
+++ b/jupyter_server/services/config/__init__.py
@@ -1 +1,3 @@
-from .manager import ConfigManager # noqa
+from .manager import ConfigManager
+
+__all__ = ["ConfigManager"]
diff --git a/jupyter_server/services/config/handlers.py b/jupyter_server/services/config/handlers.py
index 385672b2b3..743c98ef0b 100644
--- a/jupyter_server/services/config/handlers.py
+++ b/jupyter_server/services/config/handlers.py
@@ -5,7 +5,7 @@
from tornado import web
-from jupyter_server.auth import authorized
+from jupyter_server.auth.decorator import authorized
from ...base.handlers import APIHandler
@@ -13,17 +13,21 @@
class ConfigHandler(APIHandler):
+ """A config API handler."""
+
auth_resource = AUTH_RESOURCE
@web.authenticated
@authorized
def get(self, section_name):
+ """Get config by section name."""
self.set_header("Content-Type", "application/json")
self.finish(json.dumps(self.config_manager.get(section_name)))
@web.authenticated
@authorized
def put(self, section_name):
+ """Set a config section by name."""
data = self.get_json_body() # Will raise 400 if content is not valid JSON
self.config_manager.set(section_name, data)
self.set_status(204)
@@ -31,6 +35,7 @@ def put(self, section_name):
@web.authenticated
@authorized
def patch(self, section_name):
+ """Update a config section by name."""
new_data = self.get_json_body()
section = self.config_manager.update(section_name, new_data)
self.finish(json.dumps(section))
diff --git a/jupyter_server/services/config/manager.py b/jupyter_server/services/config/manager.py
index 5f04925fe7..720c8e7bd7 100644
--- a/jupyter_server/services/config/manager.py
+++ b/jupyter_server/services/config/manager.py
@@ -3,6 +3,7 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os.path
+import typing as t
from jupyter_core.paths import jupyter_config_dir, jupyter_config_path
from traitlets import Instance, List, Unicode, default, observe
@@ -22,7 +23,7 @@ class ConfigManager(LoggingConfigurable):
def get(self, section_name):
"""Get the config from all config sections."""
- config = {}
+ config: t.Dict[str, t.Any] = {}
# step through back to front, to ensure front of the list is top priority
for p in self.read_config_path[::-1]:
cm = BaseJSONConfigManager(config_dir=p)
diff --git a/jupyter_server/services/contents/checkpoints.py b/jupyter_server/services/contents/checkpoints.py
index 09ef4a8e81..e251f7b232 100644
--- a/jupyter_server/services/contents/checkpoints.py
+++ b/jupyter_server/services/contents/checkpoints.py
@@ -22,23 +22,23 @@ class Checkpoints(LoggingConfigurable):
def create_checkpoint(self, contents_mgr, path):
"""Create a checkpoint."""
- raise NotImplementedError("must be implemented in a subclass")
+ raise NotImplementedError
def restore_checkpoint(self, contents_mgr, checkpoint_id, path):
"""Restore a checkpoint"""
- raise NotImplementedError("must be implemented in a subclass")
+ raise NotImplementedError
def rename_checkpoint(self, checkpoint_id, old_path, new_path):
"""Rename a single checkpoint from old_path to new_path."""
- raise NotImplementedError("must be implemented in a subclass")
+ raise NotImplementedError
def delete_checkpoint(self, checkpoint_id, path):
"""delete a checkpoint for a file"""
- raise NotImplementedError("must be implemented in a subclass")
+ raise NotImplementedError
def list_checkpoints(self, path):
"""Return a list of checkpoints for a given file"""
- raise NotImplementedError("must be implemented in a subclass")
+ raise NotImplementedError
def rename_all_checkpoints(self, old_path, new_path):
"""Rename all checkpoints for old_path to new_path."""
@@ -75,13 +75,13 @@ class GenericCheckpointsMixin:
def create_checkpoint(self, contents_mgr, path):
model = contents_mgr.get(path, content=True)
- type = model["type"]
- if type == "notebook":
+ type_ = model["type"]
+ if type_ == "notebook":
return self.create_notebook_checkpoint(
model["content"],
path,
)
- elif type == "file":
+ elif type_ == "file":
return self.create_file_checkpoint(
model["content"],
model["format"],
@@ -92,13 +92,13 @@ def create_checkpoint(self, contents_mgr, path):
def restore_checkpoint(self, contents_mgr, checkpoint_id, path):
"""Restore a checkpoint."""
- type = contents_mgr.get(path, content=False)["type"]
- if type == "notebook":
+ type_ = contents_mgr.get(path, content=False)["type"]
+ if type_ == "notebook":
model = self.get_notebook_checkpoint(checkpoint_id, path)
- elif type == "file":
+ elif type_ == "file":
model = self.get_file_checkpoint(checkpoint_id, path)
else:
- raise HTTPError(500, "Unexpected type %s" % type)
+ raise HTTPError(500, "Unexpected type %s" % type_)
contents_mgr.save(model, path)
# Required Methods
@@ -107,37 +107,39 @@ def create_file_checkpoint(self, content, format, path):
Returns a checkpoint model for the new checkpoint.
"""
- raise NotImplementedError("must be implemented in a subclass")
+ raise NotImplementedError
def create_notebook_checkpoint(self, nb, path):
"""Create a checkpoint of the current state of a file
Returns a checkpoint model for the new checkpoint.
"""
- raise NotImplementedError("must be implemented in a subclass")
+ raise NotImplementedError
def get_file_checkpoint(self, checkpoint_id, path):
"""Get the content of a checkpoint for a non-notebook file.
- Returns a dict of the form:
- {
- 'type': 'file',
- 'content': ,
- 'format': {'text','base64'},
- }
+ Returns a dict of the form::
+
+ {
+ 'type': 'file',
+ 'content': ,
+ 'format': {'text','base64'},
+ }
"""
- raise NotImplementedError("must be implemented in a subclass")
+ raise NotImplementedError
def get_notebook_checkpoint(self, checkpoint_id, path):
"""Get the content of a checkpoint for a notebook.
- Returns a dict of the form:
- {
- 'type': 'notebook',
- 'content':