diff --git a/src/node/http.ts b/src/node/http.ts index 071ccd97e..42896f26b 100644 --- a/src/node/http.ts +++ b/src/node/http.ts @@ -34,12 +34,15 @@ export const replaceTemplates = ( } /** - * Throw an error if not authorized. + * Throw an error if not authorized. Call `next` if provided. */ -export const ensureAuthenticated = (req: express.Request): void => { +export const ensureAuthenticated = (req: express.Request, _?: express.Response, next?: express.NextFunction): void => { if (!authenticated(req)) { throw new HttpError("Unauthorized", HttpCode.Unauthorized) } + if (next) { + next() + } } /** @@ -136,20 +139,32 @@ export const getCookieDomain = (host: string, proxyDomains: string[]): string | declare module "express" { function Router(options?: express.RouterOptions): express.Router & WithWebsocketMethod - type WebsocketRequestHandler = ( - socket: net.Socket, - head: Buffer, - req: express.Request, + type WebSocketRequestHandler = ( + req: express.Request & WithWebSocket, + res: express.Response, next: express.NextFunction, ) => void | Promise - type WebsocketMethod = (route: expressCore.PathParams, ...handlers: WebsocketRequestHandler[]) => T + type WebSocketMethod = (route: expressCore.PathParams, ...handlers: WebSocketRequestHandler[]) => T + + interface WithWebSocket { + ws: net.Socket + head: Buffer + } interface WithWebsocketMethod { - ws: WebsocketMethod + ws: WebSocketMethod } } +interface WebsocketRequest extends express.Request, express.WithWebSocket { + _ws_handled: boolean +} + +function isWebSocketRequest(req: express.Request): req is WebsocketRequest { + return !!(req as WebsocketRequest).ws +} + export const handleUpgrade = (app: express.Express, server: http.Server): void => { server.on("upgrade", (req, socket, head) => { socket.on("error", () => socket.destroy()) @@ -193,15 +208,15 @@ function patchRouter(): void { // Inject the `ws` method. ;(express.Router as any).ws = function ws( route: expressCore.PathParams, - ...handlers: express.WebsocketRequestHandler[] + ...handlers: express.WebSocketRequestHandler[] ) { originalGet.apply(this, [ route, ...handlers.map((handler) => { - const wrapped: express.Handler = (req, _, next) => { - if ((req as any).ws) { - ;(req as any)._ws_handled = true - Promise.resolve(handler((req as any).ws, (req as any).head, req, next)).catch(next) + const wrapped: express.Handler = (req, res, next) => { + if (isWebSocketRequest(req)) { + req._ws_handled = true + Promise.resolve(handler(req, res, next)).catch(next) } else { next() } @@ -218,7 +233,7 @@ function patchRouter(): void { route, ...handlers.map((handler) => { const wrapped: express.Handler = (req, res, next) => { - if (!(req as any).ws) { + if (!isWebSocketRequest(req)) { Promise.resolve(handler(req, res, next)).catch(next) } else { next() diff --git a/src/node/proxy.ts b/src/node/proxy.ts index d53de6927..bfc6af5b3 100644 --- a/src/node/proxy.ts +++ b/src/node/proxy.ts @@ -82,7 +82,7 @@ router.all("*", (req, res, next) => { }) }) -router.ws("*", (socket, head, req, next) => { +router.ws("*", (req, _, next) => { const port = maybeProxy(req) if (!port) { return next() @@ -91,7 +91,7 @@ router.ws("*", (socket, head, req, next) => { // Must be authenticated to use the proxy. ensureAuthenticated(req) - proxy.ws(req, socket, head, { + proxy.ws(req, req.ws, req.head, { ignorePath: true, target: `http://0.0.0.0:${port}${req.originalUrl}`, }) diff --git a/src/node/routes/proxy.ts b/src/node/routes/proxy.ts index 29aa999ae..59db92d97 100644 --- a/src/node/routes/proxy.ts +++ b/src/node/routes/proxy.ts @@ -35,8 +35,8 @@ router.all("/(:port)(/*)?", (req, res) => { }) }) -router.ws("/(:port)(/*)?", (socket, head, req) => { - proxy.ws(req, socket, head, { +router.ws("/(:port)(/*)?", (req) => { + proxy.ws(req, req.ws, req.head, { ignorePath: true, target: getProxyTarget(req, true), }) diff --git a/src/node/routes/update.ts b/src/node/routes/update.ts index b4fbc197e..ac1ddc413 100644 --- a/src/node/routes/update.ts +++ b/src/node/routes/update.ts @@ -7,12 +7,7 @@ export const router = Router() const provider = new UpdateProvider() -router.use((req, _, next) => { - ensureAuthenticated(req) - next() -}) - -router.get("/", async (req, res) => { +router.get("/", ensureAuthenticated, async (req, res) => { const update = await provider.getUpdate(req.query.force === "true") res.json({ checked: update.checked, diff --git a/src/node/routes/vscode.ts b/src/node/routes/vscode.ts index e7842a297..c936571c5 100644 --- a/src/node/routes/vscode.ts +++ b/src/node/routes/vscode.ts @@ -53,14 +53,13 @@ router.get("/", async (req, res) => { ) }) -router.ws("/", async (socket, _, req) => { - ensureAuthenticated(req) +router.ws("/", ensureAuthenticated, async (req) => { const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" const reply = crypto .createHash("sha1") .update(req.headers["sec-websocket-key"] + magic) .digest("base64") - socket.write( + req.ws.write( [ "HTTP/1.1 101 Switching Protocols", "Upgrade: websocket", @@ -68,14 +67,13 @@ router.ws("/", async (socket, _, req) => { `Sec-WebSocket-Accept: ${reply}`, ].join("\r\n") + "\r\n\r\n", ) - await vscode.sendWebsocket(socket, req.query) + await vscode.sendWebsocket(req.ws, req.query) }) /** * TODO: Might currently be unused. */ -router.get("/resource(/*)?", async (req, res) => { - ensureAuthenticated(req) +router.get("/resource(/*)?", ensureAuthenticated, async (req, res) => { if (typeof req.query.path === "string") { res.set("Content-Type", getMediaMime(req.query.path)) res.send(await fs.readFile(pathToFsPath(req.query.path))) @@ -85,8 +83,7 @@ router.get("/resource(/*)?", async (req, res) => { /** * Used by VS Code to load files. */ -router.get("/vscode-remote-resource(/*)?", async (req, res) => { - ensureAuthenticated(req) +router.get("/vscode-remote-resource(/*)?", ensureAuthenticated, async (req, res) => { if (typeof req.query.path === "string") { res.set("Content-Type", getMediaMime(req.query.path)) res.send(await fs.readFile(pathToFsPath(req.query.path))) @@ -97,8 +94,7 @@ router.get("/vscode-remote-resource(/*)?", async (req, res) => { * VS Code webviews use these paths to load files and to load webview assets * like HTML and JavaScript. */ -router.get("/webview/*", async (req, res) => { - ensureAuthenticated(req) +router.get("/webview/*", ensureAuthenticated, async (req, res) => { res.set("Content-Type", getMediaMime(req.path)) if (/^vscode-resource/.test(req.params[0])) { return res.send(await fs.readFile(req.params[0].replace(/^vscode-resource(\/file)?/, "")))