diff --git a/src/node/app.ts b/src/node/app.ts index 7c868c2bc..c1b1006df 100644 --- a/src/node/app.ts +++ b/src/node/app.ts @@ -22,40 +22,35 @@ export interface App extends Disposable { server: http.Server } -const listen = (server: http.Server, { host, port, socket, "socket-mode": mode }: ListenOptions) => { - return new Promise(async (resolve, reject) => { +export const listen = async (server: http.Server, { host, port, socket, "socket-mode": mode }: ListenOptions) => { + if (socket) { + try { + await fs.unlink(socket) + } catch (error: any) { + handleArgsSocketCatchError(error) + } + } + await new Promise(async (resolve, reject) => { server.on("error", reject) - const onListen = () => { // Promise resolved earlier so this is an unrelated error. server.off("error", reject) server.on("error", (err) => util.logError(logger, "http server error", err)) - - if (socket && mode) { - fs.chmod(socket, mode) - .then(resolve) - .catch((err) => { - util.logError(logger, "socket chmod", err) - reject(err) - }) - } else { - resolve() - } + resolve() } - if (socket) { - try { - await fs.unlink(socket) - } catch (error: any) { - handleArgsSocketCatchError(error) - } - server.listen(socket, onListen) } else { // [] is the correct format when using :: but Node errors with them. server.listen(port, host.replace(/^\[|\]$/g, ""), onListen) } }) + + // NOTE@jsjoeio: we need to chmod after the server is finished + // listening. Otherwise, the socket may not have been created yet. + if (socket && mode) { + await fs.chmod(socket, mode) + } } /** @@ -138,6 +133,6 @@ export const handleServerError = (resolved: boolean, err: Error, reject: (err: E */ export const handleArgsSocketCatchError = (error: any) => { if (!isNodeJSErrnoException(error) || error.code !== "ENOENT") { - logger.error(error.message ? error.message : error) + throw Error(error.message ? error.message : error) } } diff --git a/test/unit/node/app.test.ts b/test/unit/node/app.test.ts index 29811d4f9..62b2887f6 100644 --- a/test/unit/node/app.test.ts +++ b/test/unit/node/app.test.ts @@ -3,7 +3,7 @@ import { promises } from "fs" import * as http from "http" import * as https from "https" import * as path from "path" -import { createApp, ensureAddress, handleArgsSocketCatchError, handleServerError } from "../../../src/node/app" +import { createApp, ensureAddress, handleArgsSocketCatchError, handleServerError, listen } from "../../../src/node/app" import { OptionalString, setDefaults } from "../../../src/node/cli" import { generateCertificate } from "../../../src/node/util" import { clean, mockLogger, getAvailablePort, tmpdir } from "../../utils/helpers" @@ -201,31 +201,33 @@ describe("handleArgsSocketCatchError", () => { }) it("should log an error if its not an NodeJS.ErrnoException", () => { - const error = new Error() + const message = "other message" + const error = new Error(message) - handleArgsSocketCatchError(error) - - expect(logger.error).toHaveBeenCalledTimes(1) - expect(logger.error).toHaveBeenCalledWith(error) + expect(() => { + handleArgsSocketCatchError(error) + }).toThrowError(error) }) it("should log an error if its not an NodeJS.ErrnoException (and the error has a message)", () => { const errorMessage = "handleArgsSocketCatchError Error" const error = new Error(errorMessage) - handleArgsSocketCatchError(error) - - expect(logger.error).toHaveBeenCalledTimes(1) - expect(logger.error).toHaveBeenCalledWith(errorMessage) + expect(() => { + handleArgsSocketCatchError(error) + }).toThrowError(error) }) - it("should not log an error if its a iNodeJS.ErrnoException", () => { - const error: NodeJS.ErrnoException = new Error() - error.code = "ENOENT" + it("should not log an error if its a NodeJS.ErrnoException", () => { + const code = "ENOENT" + const error: NodeJS.ErrnoException = new Error(code) + error.code = code handleArgsSocketCatchError(error) - expect(logger.error).toHaveBeenCalledTimes(0) + expect(() => { + handleArgsSocketCatchError(error) + }).not.toThrowError() }) it("should log an error if the code is not ENOENT (and the error has a message)", () => { @@ -234,19 +236,50 @@ describe("handleArgsSocketCatchError", () => { error.code = "EACCESS" error.message = errorMessage - handleArgsSocketCatchError(error) - - expect(logger.error).toHaveBeenCalledTimes(1) - expect(logger.error).toHaveBeenCalledWith(errorMessage) + expect(() => { + handleArgsSocketCatchError(error) + }).toThrowError(error) }) it("should log an error if the code is not ENOENT", () => { - const error: NodeJS.ErrnoException = new Error() - error.code = "EACCESS" + const code = "EACCESS" + const error: NodeJS.ErrnoException = new Error(code) + error.code = code - handleArgsSocketCatchError(error) - - expect(logger.error).toHaveBeenCalledTimes(1) - expect(logger.error).toHaveBeenCalledWith(error) + expect(() => { + handleArgsSocketCatchError(error) + }).toThrowError(error) + }) +}) + +describe("listen", () => { + let tmpDirPath: string + let mockServer: http.Server + + const testName = "listen" + + beforeEach(async () => { + await clean(testName) + mockLogger() + tmpDirPath = await tmpdir(testName) + mockServer = http.createServer() + }) + + afterEach(() => { + mockServer.close() + jest.clearAllMocks() + }) + + it("should throw an error if a directory is passed in instead of a file", async () => { + const errorMessage = "EISDIR: illegal operation on a directory, unlink" + const port = await getAvailablePort() + const mockArgs = { port, host: "0.0.0.0", socket: tmpDirPath } + + try { + await listen(mockServer, mockArgs) + } catch (error) { + expect(error).toBeInstanceOf(Error) + expect((error as any).message).toMatch(errorMessage) + } }) }) diff --git a/test/unit/node/constants.test.ts b/test/unit/node/constants.test.ts index 24501cbd2..d2aa68ab5 100644 --- a/test/unit/node/constants.test.ts +++ b/test/unit/node/constants.test.ts @@ -1,6 +1,7 @@ import { logger } from "@coder/logger" import { mockLogger } from "../../utils/helpers" import * as semver from "semver" +import path from "path" describe("constants", () => { let constants: typeof import("../../../src/node/constants") @@ -20,14 +21,18 @@ describe("constants", () => { } beforeAll(() => { + jest.clearAllMocks() mockLogger() - jest.mock("../../../package.json", () => mockPackageJson, { virtual: true }) - jest.mock("../../../vendor/modules/code-oss-dev/package.json", () => mockCodePackageJson, { virtual: true }) + jest.mock(path.resolve(__dirname, "../../../package.json"), () => mockPackageJson, { virtual: true }) + jest.mock( + path.resolve(__dirname, "../../../vendor/modules/code-oss-dev/package.json"), + () => mockCodePackageJson, + { virtual: true }, + ) constants = require("../../../src/node/constants") }) afterAll(() => { - jest.clearAllMocks() jest.resetModules() }) @@ -106,13 +111,17 @@ describe("constants", () => { } beforeAll(() => { - jest.mock("../../../package.json", () => mockPackageJson, { virtual: true }) - jest.mock("../../../vendor/modules/code-oss-dev/package.json", () => mockCodePackageJson, { virtual: true }) + jest.clearAllMocks() + jest.mock(path.resolve(__dirname, "../../../package.json"), () => mockPackageJson, { virtual: true }) + jest.mock( + path.resolve(__dirname, "../../../vendor/modules/code-oss-dev/package.json"), + () => mockCodePackageJson, + { virtual: true }, + ) constants = require("../../../src/node/constants") }) afterAll(() => { - jest.clearAllMocks() jest.resetModules() })