feat: auto-refresh oauth tokens for upstream servers

Co-authored-by: samanhappy <2755122+samanhappy@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-12-13 14:31:39 +00:00
parent af44eac40c
commit 2ab60bf7a9
5 changed files with 195 additions and 15 deletions

View File

@@ -26,6 +26,7 @@ import {
getRegisteredClient,
removeRegisteredClient,
fetchScopesFromServer,
refreshAccessToken,
} from './oauthClientRegistration.js';
import {
clearOAuthData,
@@ -292,21 +293,8 @@ export class MCPHubOAuthProvider implements OAuthClientProvider {
/**
* Get stored OAuth tokens
*/
tokens(): OAuthTokens | undefined {
// Use cached config only (tokens are updated via saveTokens which updates cache)
const serverConfig = this.serverConfig;
if (!serverConfig?.oauth?.accessToken) {
return undefined;
}
return {
access_token: serverConfig.oauth.accessToken,
token_type: 'Bearer',
refresh_token: serverConfig.oauth.refreshToken,
// Note: expires_in is not typically stored, only the token itself
// The SDK will handle token refresh when needed
};
async tokens(): Promise<OAuthTokens | undefined> {
return this.getValidTokens();
}
/**
@@ -330,6 +318,7 @@ export class MCPHubOAuthProvider implements OAuthClientProvider {
const updatedConfig = await persistTokens(this.serverName, {
accessToken: tokens.access_token,
refreshToken: refreshTokenProvided ? (tokens.refresh_token ?? null) : undefined,
expiresIn: tokens.expires_in,
clearPendingAuthorization: hadPending,
});
@@ -348,6 +337,92 @@ export class MCPHubOAuthProvider implements OAuthClientProvider {
console.log(`Saved OAuth tokens for server: ${this.serverName}`);
}
/**
* Returns tokens refreshed when expired or close to expiring.
* Falls back to stored tokens if refresh cannot be performed.
*/
private async getValidTokens(): Promise<OAuthTokens | undefined> {
const oauth = this.serverConfig.oauth;
if (!oauth) {
return undefined;
}
if (!oauth.accessToken) {
return this.refreshAccessTokenIfNeeded(oauth.refreshToken);
}
// Refresh if token is expired or about to expire
const expiresAt = this.getAccessTokenExpiryMs(oauth);
const now = Date.now();
if ((expiresAt && expiresAt - now <= 60_000) || !oauth.accessToken) {
const refreshed = await this.refreshAccessTokenIfNeeded(oauth.refreshToken);
if (refreshed) {
return refreshed;
}
}
return {
access_token: oauth.accessToken,
token_type: 'Bearer',
refresh_token: oauth.refreshToken,
};
}
private getAccessTokenExpiryMs(oauth: NonNullable<ServerConfig['oauth']>): number | undefined {
const { accessTokenExpiresAt } = oauth;
if (!accessTokenExpiresAt) return undefined;
if (typeof accessTokenExpiresAt === 'number') return accessTokenExpiresAt;
if (typeof accessTokenExpiresAt === 'string') {
const parsed = Date.parse(accessTokenExpiresAt);
return Number.isNaN(parsed) ? undefined : parsed;
}
if (accessTokenExpiresAt instanceof Date) {
return accessTokenExpiresAt.getTime();
}
return undefined;
}
private async refreshAccessTokenIfNeeded(
refreshToken?: string | null,
): Promise<OAuthTokens | undefined> {
if (!refreshToken) {
return undefined;
}
try {
const clientInfo = await initializeOAuthForServer(this.serverName, this.serverConfig);
if (!clientInfo) {
return undefined;
}
const tokens = await refreshAccessToken(
this.serverName,
this.serverConfig,
clientInfo,
refreshToken,
);
// Reload latest config to sync updated tokens/expiry
const updatedConfig = await loadServerConfig(this.serverName);
if (updatedConfig) {
this.serverConfig = updatedConfig;
}
return {
access_token: tokens.accessToken,
refresh_token: tokens.refreshToken ?? refreshToken,
token_type: 'Bearer',
expires_in: tokens.expiresIn,
};
} catch (error) {
console.warn(
`Failed to auto-refresh OAuth token for server ${this.serverName}:`,
error instanceof Error ? error.message : error,
);
return undefined;
}
}
/**
* Redirect to authorization URL
* In a server environment, we can't directly redirect the user

View File

@@ -397,6 +397,7 @@ export const exchangeCodeForToken = async (
await persistTokens(serverName, {
accessToken: tokens.access_token,
refreshToken: tokens.refresh_token ?? undefined,
expiresIn: tokens.expires_in,
});
return {
@@ -437,6 +438,7 @@ export const refreshAccessToken = async (
await persistTokens(serverName, {
accessToken: tokens.access_token,
refreshToken: tokens.refresh_token ?? undefined,
expiresIn: tokens.expires_in,
});
return {

View File

@@ -100,12 +100,17 @@ export const persistTokens = async (
tokens: {
accessToken: string;
refreshToken?: string | null;
expiresIn?: number;
clearPendingAuthorization?: boolean;
},
): Promise<ServerConfigWithOAuth | undefined> => {
return mutateOAuthSettings(serverName, ({ oauth }) => {
oauth.accessToken = tokens.accessToken;
if (tokens.expiresIn !== undefined) {
oauth.accessTokenExpiresAt = Date.now() + tokens.expiresIn * 1000;
}
if (tokens.refreshToken !== undefined) {
if (tokens.refreshToken) {
oauth.refreshToken = tokens.refreshToken;
@@ -147,6 +152,7 @@ export const clearOAuthData = async (
if (scope === 'tokens' || scope === 'all') {
delete oauth.accessToken;
delete oauth.refreshToken;
delete oauth.accessTokenExpiresAt;
}
if (scope === 'client' || scope === 'all') {

View File

@@ -293,6 +293,7 @@ export interface ServerConfig {
scopes?: string[]; // Required OAuth scopes
accessToken?: string; // Pre-obtained access token (if available)
refreshToken?: string; // Refresh token for renewing access
accessTokenExpiresAt?: number; // Access token expiration timestamp (ms since epoch)
// Dynamic client registration (RFC7591)
// If not explicitly configured, will auto-detect via WWW-Authenticate header on 401 responses

View File

@@ -0,0 +1,96 @@
jest.mock('../../src/services/oauthClientRegistration.js', () => ({
initializeOAuthForServer: jest.fn(),
getRegisteredClient: jest.fn(),
removeRegisteredClient: jest.fn(),
fetchScopesFromServer: jest.fn(),
refreshAccessToken: jest.fn(),
}));
jest.mock('../../src/services/oauthSettingsStore.js', () => ({
loadServerConfig: jest.fn(),
mutateOAuthSettings: jest.fn(),
persistTokens: jest.fn(),
updatePendingAuthorization: jest.fn(),
}));
jest.mock('../../src/services/mcpService.js', () => ({
getServerByName: jest.fn(),
}));
jest.mock('../../src/dao/index.js', () => ({
getSystemConfigDao: jest.fn(() => ({ get: jest.fn() })),
}));
import { MCPHubOAuthProvider } from '../../src/services/mcpOAuthProvider.js';
import * as oauthRegistration from '../../src/services/oauthClientRegistration.js';
import * as oauthSettingsStore from '../../src/services/oauthSettingsStore.js';
describe('MCPHubOAuthProvider token refresh', () => {
beforeEach(() => {
jest.clearAllMocks();
});
const baseConfig = {
url: 'https://example.com/v1/sse',
oauth: {
clientId: 'client-id',
accessToken: 'old-access',
refreshToken: 'refresh-token',
},
};
it('refreshes access token when expired', async () => {
const expiredConfig = {
...baseConfig,
oauth: {
...baseConfig.oauth,
accessTokenExpiresAt: Date.now() - 1_000,
},
};
const refreshedConfig = {
...expiredConfig,
oauth: {
...expiredConfig.oauth,
accessToken: 'new-access',
refreshToken: 'new-refresh',
accessTokenExpiresAt: Date.now() + 3_600_000,
},
};
(oauthRegistration.initializeOAuthForServer as jest.Mock).mockResolvedValue({
config: {},
});
(oauthRegistration.refreshAccessToken as jest.Mock).mockResolvedValue({
accessToken: 'new-access',
refreshToken: 'new-refresh',
expiresIn: 3600,
});
(oauthSettingsStore.loadServerConfig as jest.Mock).mockResolvedValue(refreshedConfig);
const provider = new MCPHubOAuthProvider('atlassian-work', expiredConfig as any);
const tokens = await provider.tokens();
expect(oauthRegistration.refreshAccessToken).toHaveBeenCalledTimes(1);
expect(oauthSettingsStore.loadServerConfig).toHaveBeenCalledTimes(1);
expect(tokens?.access_token).toBe('new-access');
expect(tokens?.refresh_token).toBe('new-refresh');
});
it('returns cached token when not expired', async () => {
const freshConfig = {
...baseConfig,
oauth: {
...baseConfig.oauth,
accessTokenExpiresAt: Date.now() + 10 * 60 * 1_000,
},
};
const provider = new MCPHubOAuthProvider('atlassian-work', freshConfig as any);
const tokens = await provider.tokens();
expect(tokens?.access_token).toBe('old-access');
expect(oauthRegistration.refreshAccessToken).not.toHaveBeenCalled();
});
});