You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

253 lines
9.9 KiB

import time
import os
from typing import Dict, Optional, Any
import requests
import jwt
from jwt.algorithms import RSAAlgorithm
from nicegui import app
from services.log import access_logger as logger, basic_logger
import urllib.parse
class OIDCConfig:
def __init__(self):
# Configure these for your OIDC provider
self.client_id = "PTU7L29N5ws6GAfjYVbd0rtVMa6oliaPxuqEUK4Q95jUD1CIL3uX0zUBvIOVm5Ht"
self.client_secret = "E0HRxJRDGJtdDueEulvA6Y46oNco0gkaw75a2cRUPFTdVRjQ7RhLPXg3PRfIJ3N2"
self.discovery_url = "https://cloud.enne2.net/index.php/.well-known/openid-configuration"
self.redirect_uri = "http://127.0.0.1:8080/auth/callback"
self.scope = "openid profile email"
# Cache for OIDC configuration
self._config_cache = None
self._jwks_cache = None
basic_logger.info(f"OIDC Config initialized with discovery URL: {self.discovery_url}")
def get_oidc_config(self) -> Dict:
"""Fetch OIDC configuration from discovery endpoint"""
if not self._config_cache:
basic_logger.info(f"Fetching OIDC configuration from {self.discovery_url}")
try:
response = requests.get(self.discovery_url)
response.raise_for_status()
self._config_cache = response.json()
basic_logger.info("OIDC configuration fetched successfully")
except Exception as e:
basic_logger.error(f"Failed to fetch OIDC configuration: {e}")
raise
return self._config_cache
def get_jwks(self) -> Dict:
"""Fetch JSON Web Key Set for token validation"""
if not self._jwks_cache:
config = self.get_oidc_config()
jwks_uri = config['jwks_uri']
basic_logger.info(f"Fetching JWKS from {jwks_uri}")
try:
response = requests.get(jwks_uri)
response.raise_for_status()
self._jwks_cache = response.json()
basic_logger.info("JWKS fetched successfully")
except Exception as e:
basic_logger.error(f"Failed to fetch JWKS: {e}")
raise
return self._jwks_cache
def get_authorization_url(self, state: str) -> str:
"""Generate authorization URL for OIDC login"""
config = self.get_oidc_config()
params = {
'response_type': 'code',
'client_id': self.client_id,
'redirect_uri': self.redirect_uri,
'scope': self.scope,
'state': state,
}
auth_url = f"{config['authorization_endpoint']}?{urllib.parse.urlencode(params)}"
logger.info(f"Generated authorization URL with state: {state}")
return auth_url
async def exchange_code_for_tokens(self, code: str, redirect_uri: str) -> Optional[Dict[str, Any]]:
"""Exchange authorization code for access and ID tokens"""
try:
data = {
'grant_type': 'authorization_code',
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': code,
'redirect_uri': redirect_uri
}
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
# Use HTTPS for the token endpoint
config = self.get_oidc_config()
token_url = config.get('token_endpoint', 'http://example.com/token').replace('http://', 'https://')
logger.debug(f"Token exchange request data: {data}")
logger.debug(f"Token exchange request URL: {token_url}")
logger.debug(f"Token exchange request headers: {headers}")
response = requests.post(
token_url,
data=data,
headers=headers,
timeout=30
)
logger.debug(f"Token exchange response status: {response.status_code}")
logger.debug(f"Token exchange response content: {response.text}")
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"Failed to exchange code for tokens: {e}")
return None
def refresh_access_token(self) -> Dict:
"""Refresh access token using refresh token"""
last_refreshed = app.storage.user.get('last_refreshed', 0)
if time.time() - last_refreshed < os.getenv('TOKEN_REFRESH_MINIMUM_INTERVAL', 300):
#logger.info("Access token recently refreshed, skipping refresh")
return app.storage.user.get('access_token', {})
config = self.get_oidc_config()
refresh_token = app.storage.user.get('refresh_token')
logger.info("Refreshing access token")
data = {
'grant_type': 'refresh_token',
'client_id': self.client_id,
'client_secret': self.client_secret,
'refresh_token': refresh_token,
}
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
token_url = config.get('token_endpoint', 'http://example.com/token').replace('http://', 'https://')
try:
response = requests.post(
token_url,
data=data,
headers=headers,
timeout=30
)
response.raise_for_status()
tokens = response.json()
data = {
'access_token': tokens.get('access_token'),
'refresh_token': tokens.get('refresh_token'),
'expires_in': tokens.get('expires_in'),
'id_token': tokens.get('id_token'),
'last_refreshed': time.time(),
}
# Update user session with new tokens
app.storage.user.update(data)
logger.info("Successfully refreshed access token")
return tokens
except Exception as e:
logger.error(f"Failed to refresh access token: {e}")
raise
def validate_token(self, token: str) -> Optional[Dict]:
"""Validate and decode JWT token with proper signature verification"""
try:
# Get the token header to find the key ID
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get('kid')
if not kid:
logger.warning("No key ID found in token header")
return None
# Get JWKS and find the matching key
jwks = self.get_jwks()
public_key = None
for key in jwks.get('keys', []):
if key.get('kid') == kid:
# Convert JWK to PEM format
public_key = RSAAlgorithm.from_jwk(key)
break
if not public_key:
logger.warning(f"No matching public key found for kid: {kid}")
return None
# Verify and decode the token
payload = jwt.decode(
token,
public_key,
algorithms=['RS256'],
audience=self.client_id,
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False,
"verify_iss": True
}
)
logger.debug("Token validated successfully with signature verification")
return payload
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
return None
except jwt.InvalidAudienceError:
logger.warning("Invalid token audience")
return None
except jwt.InvalidIssuerError:
logger.warning("Invalid token issuer")
return None
except jwt.InvalidTokenError as e:
logger.warning(f"Invalid token: {e}")
return None
except Exception as e:
logger.error(f"Token validation error: {e}")
return None
def logout_user(self) -> bool:
"""Clear user session and call api to log out"""
config = self.get_oidc_config()
logout_url = config.get('end_session_endpoint')
if not logout_url:
logger.error("No end session endpoint found in OIDC configuration")
# Still clear local session even if remote logout fails
app.storage.user.clear()
return True
# Prepare logout request
user_info = app.storage.user.get('user_info', {})
user_id = user_info.get('sub') or user_info.get('preferred_username', 'unknown')
params = {
'id_token_hint': app.storage.user.get('id_token'),
'client_id': self.client_id,
'post_logout_redirect_uri': self.redirect_uri.replace('/auth/callback', '/login'),
}
# Remove None values
params = {k: v for k, v in params.items() if v is not None}
logout_url = f"{logout_url}?{urllib.parse.urlencode(params)}"
try:
response = requests.get(logout_url, timeout=10)
response.raise_for_status()
logger.info(f"User {user_id} logged out successfully from OIDC provider")
except requests.exceptions.RequestException as e:
logger.error(f"Failed to log out user {user_id} from OIDC provider: {e}")
# Continue with local logout even if remote logout fails
# Clear user session
app.storage.user.clear()
logger.info(f"Local session cleared for user {user_id}")
return True
# Global OIDC config instance
oidc_config = OIDCConfig()