Skip to content

Commit c1f4484

Browse files
committed
Move websocket routes into a separate app
This is mostly so we don't have to do any wacky patching but it also makes it so we don't have to keep checking if the request is a web socket request every time we add middleware.
1 parent 62a22ae commit c1f4484

File tree

8 files changed

+134
-145
lines changed

8 files changed

+134
-145
lines changed

src/node/app.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ import { promises as fs } from "fs"
44
import http from "http"
55
import * as httpolyglot from "httpolyglot"
66
import { DefaultedArgs } from "./cli"
7-
import { handleUpgrade } from "./http"
7+
import { handleUpgrade } from "./wsRouter"
88

99
/**
1010
* Create an Express app and an HTTP/S server to serve it.
1111
*/
12-
export const createApp = async (args: DefaultedArgs): Promise<[Express, http.Server]> => {
12+
export const createApp = async (args: DefaultedArgs): Promise<[Express, Express, http.Server]> => {
1313
const app = express()
1414

1515
const server = args.cert
@@ -39,9 +39,10 @@ export const createApp = async (args: DefaultedArgs): Promise<[Express, http.Ser
3939
}
4040
})
4141

42-
handleUpgrade(app, server)
42+
const wsApp = express()
43+
handleUpgrade(wsApp, server)
4344

44-
return [app, server]
45+
return [app, wsApp, server]
4546
}
4647

4748
/**

src/node/entry.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ const main = async (args: DefaultedArgs): Promise<void> => {
102102
throw new Error("Please pass in a password via the config file or $PASSWORD")
103103
}
104104

105-
const [app, server] = await createApp(args)
105+
const [app, wsApp, server] = await createApp(args)
106106
const serverAddress = ensureAddress(server)
107-
await register(app, server, args)
107+
await register(app, wsApp, server, args)
108108

109109
logger.info(`Using config file ${humanPath(args.config)}`)
110110
logger.info(`HTTP server listening on ${serverAddress} ${args.link ? "(randomized by --link)" : ""}`)

src/node/http.ts

Lines changed: 0 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import { field, logger } from "@coder/logger"
22
import * as express from "express"
33
import * as expressCore from "express-serve-static-core"
4-
import * as http from "http"
5-
import * as net from "net"
64
import qs from "qs"
75
import safeCompare from "safe-compare"
86
import { HttpCode, HttpError } from "../common/http"
@@ -135,111 +133,3 @@ export const getCookieDomain = (host: string, proxyDomains: string[]): string |
135133
logger.debug("got cookie doman", field("host", host))
136134
return host || undefined
137135
}
138-
139-
declare module "express" {
140-
function Router(options?: express.RouterOptions): express.Router & WithWebsocketMethod
141-
142-
type WebSocketRequestHandler = (
143-
req: express.Request & WithWebSocket,
144-
res: express.Response,
145-
next: express.NextFunction,
146-
) => void | Promise<void>
147-
148-
type WebSocketMethod<T> = (route: expressCore.PathParams, ...handlers: WebSocketRequestHandler[]) => T
149-
150-
interface WithWebSocket {
151-
ws: net.Socket
152-
head: Buffer
153-
}
154-
155-
interface WithWebsocketMethod {
156-
ws: WebSocketMethod<this>
157-
}
158-
}
159-
160-
interface WebsocketRequest extends express.Request, express.WithWebSocket {
161-
_ws_handled: boolean
162-
}
163-
164-
function isWebSocketRequest(req: express.Request): req is WebsocketRequest {
165-
return !!(req as WebsocketRequest).ws
166-
}
167-
168-
export const handleUpgrade = (app: express.Express, server: http.Server): void => {
169-
server.on("upgrade", (req, socket, head) => {
170-
socket.on("error", () => socket.destroy())
171-
172-
req.ws = socket
173-
req.head = head
174-
req._ws_handled = false
175-
176-
const res = new http.ServerResponse(req)
177-
res.writeHead = function writeHead(statusCode: number) {
178-
if (statusCode > 200) {
179-
socket.destroy(new Error(`${statusCode}`))
180-
}
181-
return res
182-
}
183-
184-
// Send the request off to be handled by Express.
185-
;(app as any).handle(req, res, () => {
186-
if (!req._ws_handled) {
187-
socket.destroy(new Error("Not found"))
188-
}
189-
})
190-
})
191-
}
192-
193-
/**
194-
* Patch Express routers to handle web sockets.
195-
*
196-
* Not using express-ws since the ws-wrapped sockets don't work with the proxy.
197-
*/
198-
function patchRouter(): void {
199-
// This works because Router is also the prototype assigned to the routers it
200-
// returns.
201-
202-
// Store this since the original method will be overridden.
203-
const originalGet = (express.Router as any).prototype.get
204-
205-
// Inject the `ws` method.
206-
;(express.Router as any).prototype.ws = function ws(
207-
route: expressCore.PathParams,
208-
...handlers: express.WebSocketRequestHandler[]
209-
) {
210-
originalGet.apply(this, [
211-
route,
212-
...handlers.map((handler) => {
213-
const wrapped: express.Handler = (req, res, next) => {
214-
if (isWebSocketRequest(req)) {
215-
req._ws_handled = true
216-
return handler(req, res, next)
217-
}
218-
next()
219-
}
220-
return wrapped
221-
}),
222-
])
223-
return this
224-
}
225-
// Overwrite `get` so we can distinguish between websocket and non-websocket
226-
// routes.
227-
;(express.Router as any).prototype.get = function get(route: expressCore.PathParams, ...handlers: express.Handler[]) {
228-
originalGet.apply(this, [
229-
route,
230-
...handlers.map((handler) => {
231-
const wrapped: express.Handler = (req, res, next) => {
232-
if (!isWebSocketRequest(req)) {
233-
return handler(req, res, next)
234-
}
235-
next()
236-
}
237-
return wrapped
238-
}),
239-
])
240-
return this
241-
}
242-
}
243-
244-
// This needs to happen before anything creates a router.
245-
patchRouter()

src/node/proxy.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { Request, Router } from "express"
22
import proxyServer from "http-proxy"
33
import { HttpCode, HttpError } from "../common/http"
44
import { authenticated, ensureAuthenticated } from "./http"
5+
import { Router as WsRouter } from "./wsRouter"
56

67
export const proxy = proxyServer.createProxyServer({})
78
proxy.on("error", (error, _, res) => {
@@ -82,7 +83,9 @@ router.all("*", (req, res, next) => {
8283
})
8384
})
8485

85-
router.ws("*", (req, _, next) => {
86+
export const wsRouter = WsRouter()
87+
88+
wsRouter.ws("*", (req, _, next) => {
8689
const port = maybeProxy(req)
8790
if (!port) {
8891
return next()

src/node/routes/index.ts

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { logger } from "@coder/logger"
22
import bodyParser from "body-parser"
33
import cookieParser from "cookie-parser"
4-
import { ErrorRequestHandler, Express } from "express"
4+
import * as express from "express"
55
import { promises as fs } from "fs"
66
import http from "http"
77
import * as path from "path"
@@ -15,6 +15,7 @@ import { replaceTemplates } from "../http"
1515
import { loadPlugins } from "../plugin"
1616
import * as domainProxy from "../proxy"
1717
import { getMediaMime, paths } from "../util"
18+
import { WebsocketRequest } from "../wsRouter"
1819
import * as health from "./health"
1920
import * as login from "./login"
2021
import * as proxy from "./proxy"
@@ -36,7 +37,12 @@ declare global {
3637
/**
3738
* Register all routes and middleware.
3839
*/
39-
export const register = async (app: Express, server: http.Server, args: DefaultedArgs): Promise<void> => {
40+
export const register = async (
41+
app: express.Express,
42+
wsApp: express.Express,
43+
server: http.Server,
44+
args: DefaultedArgs,
45+
): Promise<void> => {
4046
const heart = new Heart(path.join(paths.data, "heartbeat"), async () => {
4147
return new Promise((resolve, reject) => {
4248
server.getConnections((error, count) => {
@@ -50,14 +56,28 @@ export const register = async (app: Express, server: http.Server, args: Defaulte
5056
})
5157

5258
app.disable("x-powered-by")
59+
wsApp.disable("x-powered-by")
5360

5461
app.use(cookieParser())
62+
wsApp.use(cookieParser())
63+
5564
app.use(bodyParser.json())
5665
app.use(bodyParser.urlencoded({ extended: true }))
5766

58-
app.use(async (req, res, next) => {
67+
const common: express.RequestHandler = (req, _, next) => {
5968
heart.beat()
6069

70+
// Add common variables routes can use.
71+
req.args = args
72+
req.heart = heart
73+
74+
next()
75+
}
76+
77+
app.use(common)
78+
wsApp.use(common)
79+
80+
app.use(async (req, res, next) => {
6181
// If we're handling TLS ensure all requests are redirected to HTTPS.
6282
// TODO: This does *NOT* work if you have a base path since to specify the
6383
// protocol we need to specify the whole path.
@@ -72,31 +92,36 @@ export const register = async (app: Express, server: http.Server, args: Defaulte
7292
return res.send(await fs.readFile(resourcePath))
7393
}
7494

75-
// Add common variables routes can use.
76-
req.args = args
77-
req.heart = heart
78-
79-
return next()
95+
next()
8096
})
8197

8298
app.use("/", domainProxy.router)
99+
wsApp.use("/", domainProxy.wsRouter.router)
100+
83101
app.use("/", vscode.router)
102+
wsApp.use("/", vscode.wsRouter.router)
103+
app.use("/vscode", vscode.router)
104+
wsApp.use("/vscode", vscode.wsRouter.router)
105+
84106
app.use("/healthz", health.router)
107+
85108
if (args.auth === AuthType.Password) {
86109
app.use("/login", login.router)
87110
}
111+
88112
app.use("/proxy", proxy.router)
113+
wsApp.use("/proxy", proxy.wsRouter.router)
114+
89115
app.use("/static", _static.router)
90116
app.use("/update", update.router)
91-
app.use("/vscode", vscode.router)
92117

93118
await loadPlugins(app, args)
94119

95120
app.use(() => {
96121
throw new HttpError("Not Found", HttpCode.NotFound)
97122
})
98123

99-
const errorHandler: ErrorRequestHandler = async (err, req, res, next) => {
124+
const errorHandler: express.ErrorRequestHandler = async (err, req, res, next) => {
100125
const resourcePath = path.resolve(rootPath, "src/browser/pages/error.html")
101126
res.set("Content-Type", getMediaMime(resourcePath))
102127
try {
@@ -117,4 +142,11 @@ export const register = async (app: Express, server: http.Server, args: Defaulte
117142
}
118143

119144
app.use(errorHandler)
145+
146+
const wsErrorHandler: express.ErrorRequestHandler = async (err, req) => {
147+
logger.error(`${err.message} ${err.stack}`)
148+
;(req as WebsocketRequest).ws.destroy(err)
149+
}
150+
151+
wsApp.use(wsErrorHandler)
120152
}

src/node/routes/proxy.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import qs from "qs"
33
import { HttpCode, HttpError } from "../../common/http"
44
import { authenticated, redirect } from "../http"
55
import { proxy } from "../proxy"
6+
import { Router as WsRouter } from "../wsRouter"
67

78
export const router = Router()
89

@@ -35,7 +36,9 @@ router.all("/(:port)(/*)?", (req, res) => {
3536
})
3637
})
3738

38-
router.ws("/(:port)(/*)?", (req) => {
39+
export const wsRouter = WsRouter()
40+
41+
wsRouter.ws("/(:port)(/*)?", (req) => {
3942
proxy.ws(req, req.ws, req.head, {
4043
ignorePath: true,
4144
target: getProxyTarget(req, true),

src/node/routes/vscode.ts

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { commit, rootPath, version } from "../constants"
66
import { authenticated, ensureAuthenticated, redirect, replaceTemplates } from "../http"
77
import { getMediaMime, pathToFsPath } from "../util"
88
import { VscodeProvider } from "../vscode"
9+
import { Router as WsRouter } from "../wsRouter"
910

1011
export const router = Router()
1112

@@ -53,23 +54,6 @@ router.get("/", async (req, res) => {
5354
)
5455
})
5556

56-
router.ws("/", ensureAuthenticated, async (req) => {
57-
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
58-
const reply = crypto
59-
.createHash("sha1")
60-
.update(req.headers["sec-websocket-key"] + magic)
61-
.digest("base64")
62-
req.ws.write(
63-
[
64-
"HTTP/1.1 101 Switching Protocols",
65-
"Upgrade: websocket",
66-
"Connection: Upgrade",
67-
`Sec-WebSocket-Accept: ${reply}`,
68-
].join("\r\n") + "\r\n\r\n",
69-
)
70-
await vscode.sendWebsocket(req.ws, req.query)
71-
})
72-
7357
/**
7458
* TODO: Might currently be unused.
7559
*/
@@ -103,3 +87,22 @@ router.get("/webview/*", ensureAuthenticated, async (req, res) => {
10387
await fs.readFile(path.join(vscode.vsRootPath, "out/vs/workbench/contrib/webview/browser/pre", req.params[0])),
10488
)
10589
})
90+
91+
export const wsRouter = WsRouter()
92+
93+
wsRouter.ws("/", ensureAuthenticated, async (req) => {
94+
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
95+
const reply = crypto
96+
.createHash("sha1")
97+
.update(req.headers["sec-websocket-key"] + magic)
98+
.digest("base64")
99+
req.ws.write(
100+
[
101+
"HTTP/1.1 101 Switching Protocols",
102+
"Upgrade: websocket",
103+
"Connection: Upgrade",
104+
`Sec-WebSocket-Accept: ${reply}`,
105+
].join("\r\n") + "\r\n\r\n",
106+
)
107+
await vscode.sendWebsocket(req.ws, req.query)
108+
})

0 commit comments

Comments
 (0)