#!/usr/bin/env python3 """ Tenant-Aware Redis Client ========================= Provides tenant-prefixed Redis operations for multi-tenant data isolation. Usage: from lib.tenant_redis import TenantRedis redis = TenantRedis(tenant_id="acme", project_id="web-app") redis.set("checkpoint:latest", data) # Stored as: tenant:acme:project:web-app:checkpoint:latest """ import json import os from typing import Optional, Any, List import redis class TenantRedis: """Redis client with automatic tenant/project key prefixing""" def __init__( self, tenant_id: str = "default", project_id: str = "default", host: str = "127.0.0.1", port: int = 6379, password: Optional[str] = None, db: int = 0 ): self.tenant_id = tenant_id self.project_id = project_id # Get password from environment or Vault if password is None: password = os.environ.get("DRAGONFLY_PASSWORD", "governance2026") self._client = redis.Redis( host=host, port=port, password=password, db=db, decode_responses=True ) def _prefix(self, key: str) -> str: """Add tenant/project prefix to key""" return f"tenant:{self.tenant_id}:project:{self.project_id}:{key}" def _global_prefix(self, key: str) -> str: """Add tenant-only prefix for cross-project data""" return f"tenant:{self.tenant_id}:{key}" # ========================================================================= # Basic Operations (Project-Scoped) # ========================================================================= def get(self, key: str) -> Optional[str]: """Get value by key (project-scoped)""" return self._client.get(self._prefix(key)) def set(self, key: str, value: Any, ex: Optional[int] = None, nx: bool = False) -> bool: """Set value by key (project-scoped)""" if isinstance(value, (dict, list)): value = json.dumps(value) return self._client.set(self._prefix(key), value, ex=ex, nx=nx) def delete(self, key: str) -> int: """Delete key (project-scoped)""" return self._client.delete(self._prefix(key)) def exists(self, key: str) -> bool: """Check if key exists (project-scoped)""" return self._client.exists(self._prefix(key)) > 0 def keys(self, pattern: str) -> List[str]: """Find keys matching pattern (project-scoped)""" full_pattern = self._prefix(pattern) keys = self._client.keys(full_pattern) # Strip prefix from results prefix_len = len(self._prefix("")) return [k[prefix_len:] for k in keys] def expire(self, key: str, seconds: int) -> bool: """Set key expiration (project-scoped)""" return self._client.expire(self._prefix(key), seconds) def ttl(self, key: str) -> int: """Get key TTL (project-scoped)""" return self._client.ttl(self._prefix(key)) # ========================================================================= # Hash Operations (Project-Scoped) # ========================================================================= def hget(self, key: str, field: str) -> Optional[str]: """Get hash field""" return self._client.hget(self._prefix(key), field) def hset(self, key: str, field: str, value: Any) -> int: """Set hash field""" if isinstance(value, (dict, list)): value = json.dumps(value) return self._client.hset(self._prefix(key), field, value) def hgetall(self, key: str) -> dict: """Get all hash fields""" return self._client.hgetall(self._prefix(key)) def hincrby(self, key: str, field: str, amount: int = 1) -> int: """Increment hash field""" return self._client.hincrby(self._prefix(key), field, amount) # ========================================================================= # List Operations (Project-Scoped) # ========================================================================= def lpush(self, key: str, *values) -> int: """Push to list head""" serialized = [json.dumps(v) if isinstance(v, (dict, list)) else v for v in values] return self._client.lpush(self._prefix(key), *serialized) def rpush(self, key: str, *values) -> int: """Push to list tail""" serialized = [json.dumps(v) if isinstance(v, (dict, list)) else v for v in values] return self._client.rpush(self._prefix(key), *serialized) def lpop(self, key: str) -> Optional[str]: """Pop from list head""" return self._client.lpop(self._prefix(key)) def rpop(self, key: str) -> Optional[str]: """Pop from list tail""" return self._client.rpop(self._prefix(key)) def lrange(self, key: str, start: int, end: int) -> List[str]: """Get list range""" return self._client.lrange(self._prefix(key), start, end) def llen(self, key: str) -> int: """Get list length""" return self._client.llen(self._prefix(key)) def ltrim(self, key: str, start: int, end: int) -> bool: """Trim list to range""" return self._client.ltrim(self._prefix(key), start, end) # ========================================================================= # Set Operations (Project-Scoped) # ========================================================================= def sadd(self, key: str, *members) -> int: """Add to set""" return self._client.sadd(self._prefix(key), *members) def srem(self, key: str, *members) -> int: """Remove from set""" return self._client.srem(self._prefix(key), *members) def smembers(self, key: str) -> set: """Get set members""" return self._client.smembers(self._prefix(key)) def sismember(self, key: str, member: str) -> bool: """Check set membership""" return self._client.sismember(self._prefix(key), member) # ========================================================================= # Tenant-Level Operations (Cross-Project) # ========================================================================= def tenant_get(self, key: str) -> Optional[str]: """Get value by key (tenant-level, cross-project)""" return self._client.get(self._global_prefix(key)) def tenant_set(self, key: str, value: Any, ex: Optional[int] = None) -> bool: """Set value by key (tenant-level, cross-project)""" if isinstance(value, (dict, list)): value = json.dumps(value) return self._client.set(self._global_prefix(key), value, ex=ex) def tenant_lpush(self, key: str, *values) -> int: """Push to list (tenant-level)""" serialized = [json.dumps(v) if isinstance(v, (dict, list)) else v for v in values] return self._client.lpush(self._global_prefix(key), *serialized) def tenant_lrange(self, key: str, start: int, end: int) -> List[str]: """Get list range (tenant-level)""" return self._client.lrange(self._global_prefix(key), start, end) # ========================================================================= # Global Operations (No Prefix - Admin Only) # ========================================================================= def global_get(self, key: str) -> Optional[str]: """Get value without prefix (admin use)""" return self._client.get(key) def global_set(self, key: str, value: Any, ex: Optional[int] = None) -> bool: """Set value without prefix (admin use)""" if isinstance(value, (dict, list)): value = json.dumps(value) return self._client.set(key, value, ex=ex) def global_keys(self, pattern: str) -> List[str]: """Find keys without prefix (admin use)""" return self._client.keys(pattern) # ========================================================================= # Utility # ========================================================================= def ping(self) -> bool: """Check connection""" try: return self._client.ping() except: return False def info(self) -> dict: """Get Redis info""" return self._client.info() def switch_context(self, tenant_id: str = None, project_id: str = None) -> 'TenantRedis': """Create new client with different tenant/project context""" return TenantRedis( tenant_id=tenant_id or self.tenant_id, project_id=project_id or self.project_id, host=self._client.connection_pool.connection_kwargs.get('host', '127.0.0.1'), port=self._client.connection_pool.connection_kwargs.get('port', 6379), password=self._client.connection_pool.connection_kwargs.get('password'), db=self._client.connection_pool.connection_kwargs.get('db', 0) ) @property def prefix(self) -> str: """Get current key prefix""" return self._prefix("") # ============================================================================= # Factory function for easy creation # ============================================================================= def get_tenant_redis( tenant_id: str = "default", project_id: str = "default" ) -> TenantRedis: """ Create a TenantRedis client with given context. Environment variables: DRAGONFLY_HOST: Redis host (default: 127.0.0.1) DRAGONFLY_PORT: Redis port (default: 6379) DRAGONFLY_PASSWORD: Redis password """ return TenantRedis( tenant_id=tenant_id, project_id=project_id, host=os.environ.get("DRAGONFLY_HOST", "127.0.0.1"), port=int(os.environ.get("DRAGONFLY_PORT", "6379")), password=os.environ.get("DRAGONFLY_PASSWORD", "governance2026") ) # ============================================================================= # Migration helper - for transitioning existing keys # ============================================================================= def migrate_keys_to_tenant( source_redis: redis.Redis, tenant_id: str = "default", project_id: str = "default", patterns: List[str] = None, dry_run: bool = True ) -> dict: """ Migrate existing non-prefixed keys to tenant-prefixed format. Args: source_redis: Existing Redis connection tenant_id: Target tenant ID project_id: Target project ID patterns: Key patterns to migrate (default: common governance keys) dry_run: If True, only report what would be migrated Returns: dict with migration results """ if patterns is None: patterns = [ "checkpoint:*", "agent:*", "task:*", "orchestration:*", "revocations:*", "alerts:*" ] prefix = f"tenant:{tenant_id}:project:{project_id}:" results = { "migrated": [], "skipped": [], "errors": [] } for pattern in patterns: keys = source_redis.keys(pattern) for key in keys: # Skip if already prefixed if key.startswith("tenant:"): results["skipped"].append(key) continue new_key = f"{prefix}{key}" try: if dry_run: results["migrated"].append({"old": key, "new": new_key, "dry_run": True}) else: # Get key type and copy appropriately key_type = source_redis.type(key) if key_type == "string": value = source_redis.get(key) source_redis.set(new_key, value) elif key_type == "list": values = source_redis.lrange(key, 0, -1) if values: source_redis.rpush(new_key, *values) elif key_type == "hash": values = source_redis.hgetall(key) if values: source_redis.hset(new_key, mapping=values) elif key_type == "set": values = source_redis.smembers(key) if values: source_redis.sadd(new_key, *values) results["migrated"].append({"old": key, "new": new_key}) except Exception as e: results["errors"].append({"key": key, "error": str(e)}) return results