""" MockDragonfly - Simulates DragonflyDB/Redis for testing. Provides deterministic state management, locks, and pub/sub without requiring a real Redis instance. """ from typing import Dict, List, Any, Optional, Callable, Set from dataclasses import dataclass, field from datetime import datetime, timedelta import json import threading import time @dataclass class MockLock: """Represents a distributed lock""" key: str owner: str acquired_at: datetime ttl: int # seconds released: bool = False def is_valid(self) -> bool: if self.released: return False if datetime.utcnow() > self.acquired_at + timedelta(seconds=self.ttl): return False return True class MockDragonfly: """ Mock DragonflyDB/Redis implementation for testing. Simulates: - Key-value storage (strings, hashes, lists, sets) - Key expiration (TTL) - Distributed locks - Pub/Sub messaging - Atomic operations """ def __init__(self): self._strings: Dict[str, Any] = {} self._hashes: Dict[str, Dict[str, Any]] = {} self._lists: Dict[str, List[Any]] = {} self._sets: Dict[str, Set[str]] = {} self._expiry: Dict[str, datetime] = {} self._locks: Dict[str, MockLock] = {} self._subscribers: Dict[str, List[Callable]] = {} self._lock = threading.Lock() def _check_expiry(self, key: str) -> bool: """Check if key has expired. Returns True if key is valid.""" if key in self._expiry: if datetime.utcnow() > self._expiry[key]: # Key expired, clean up self._delete_key(key) return False return True def _delete_key(self, key: str): """Delete key from all stores""" self._strings.pop(key, None) self._hashes.pop(key, None) self._lists.pop(key, None) self._sets.pop(key, None) self._expiry.pop(key, None) # === String Operations === def set(self, key: str, value: Any, ex: int = None, nx: bool = False) -> bool: """ Set a string value. Args: key: Key name value: Value to store ex: Expiry in seconds nx: Only set if key does not exist """ with self._lock: if nx and key in self._strings and self._check_expiry(key): return False self._strings[key] = value if ex: self._expiry[key] = datetime.utcnow() + timedelta(seconds=ex) return True def get(self, key: str) -> Optional[Any]: """Get a string value""" with self._lock: if not self._check_expiry(key): return None return self._strings.get(key) def incr(self, key: str, amount: int = 1) -> int: """Increment a numeric value""" with self._lock: current = int(self._strings.get(key, 0)) new_value = current + amount self._strings[key] = new_value return new_value def decr(self, key: str, amount: int = 1) -> int: """Decrement a numeric value""" return self.incr(key, -amount) # === Hash Operations === def hset(self, key: str, field: str = None, value: Any = None, mapping: Dict[str, Any] = None) -> int: """ Set hash field(s). Can be called as: - hset(key, field, value) - hset(key, mapping={...}) """ with self._lock: if key not in self._hashes: self._hashes[key] = {} count = 0 if mapping: for f, v in mapping.items(): if f not in self._hashes[key]: count += 1 self._hashes[key][f] = v elif field is not None: if field not in self._hashes[key]: count = 1 self._hashes[key][field] = value return count def hget(self, key: str, field: str) -> Optional[Any]: """Get a hash field""" with self._lock: if not self._check_expiry(key): return None return self._hashes.get(key, {}).get(field) def hgetall(self, key: str) -> Dict[str, Any]: """Get all hash fields""" with self._lock: if not self._check_expiry(key): return {} return self._hashes.get(key, {}).copy() def hincrby(self, key: str, field: str, amount: int = 1) -> int: """Increment a hash field""" with self._lock: if key not in self._hashes: self._hashes[key] = {} current = int(self._hashes[key].get(field, 0)) new_value = current + amount self._hashes[key][field] = new_value return new_value def hdel(self, key: str, *fields: str) -> int: """Delete hash fields""" with self._lock: if key not in self._hashes: return 0 count = 0 for field in fields: if field in self._hashes[key]: del self._hashes[key][field] count += 1 return count # === List Operations === def lpush(self, key: str, *values: Any) -> int: """Push values to left of list""" with self._lock: if key not in self._lists: self._lists[key] = [] for value in reversed(values): self._lists[key].insert(0, value) return len(self._lists[key]) def rpush(self, key: str, *values: Any) -> int: """Push values to right of list""" with self._lock: if key not in self._lists: self._lists[key] = [] self._lists[key].extend(values) return len(self._lists[key]) def lpop(self, key: str) -> Optional[Any]: """Pop from left of list""" with self._lock: if key not in self._lists or not self._lists[key]: return None return self._lists[key].pop(0) def rpop(self, key: str) -> Optional[Any]: """Pop from right of list""" with self._lock: if key not in self._lists or not self._lists[key]: return None return self._lists[key].pop() def lrange(self, key: str, start: int, stop: int) -> List[Any]: """Get range of list elements""" with self._lock: if key not in self._lists: return [] # Redis uses inclusive stop, Python uses exclusive if stop == -1: return self._lists[key][start:] return self._lists[key][start:stop + 1] def llen(self, key: str) -> int: """Get list length""" with self._lock: return len(self._lists.get(key, [])) def ltrim(self, key: str, start: int, stop: int) -> bool: """Trim list to specified range""" with self._lock: if key not in self._lists: return True if stop == -1: self._lists[key] = self._lists[key][start:] else: self._lists[key] = self._lists[key][start:stop + 1] return True # === Set Operations === def sadd(self, key: str, *members: str) -> int: """Add members to set""" with self._lock: if key not in self._sets: self._sets[key] = set() before = len(self._sets[key]) self._sets[key].update(members) return len(self._sets[key]) - before def srem(self, key: str, *members: str) -> int: """Remove members from set""" with self._lock: if key not in self._sets: return 0 before = len(self._sets[key]) self._sets[key] -= set(members) return before - len(self._sets[key]) def smembers(self, key: str) -> Set[str]: """Get all set members""" with self._lock: return self._sets.get(key, set()).copy() def sismember(self, key: str, member: str) -> bool: """Check if member is in set""" with self._lock: return member in self._sets.get(key, set()) # === Key Operations === def exists(self, *keys: str) -> int: """Check if keys exist""" with self._lock: count = 0 for key in keys: if self._check_expiry(key): if key in self._strings or key in self._hashes or \ key in self._lists or key in self._sets: count += 1 return count def delete(self, *keys: str) -> int: """Delete keys""" with self._lock: count = 0 for key in keys: if key in self._strings or key in self._hashes or \ key in self._lists or key in self._sets: self._delete_key(key) count += 1 return count def expire(self, key: str, seconds: int) -> bool: """Set key expiration""" with self._lock: if key in self._strings or key in self._hashes or \ key in self._lists or key in self._sets: self._expiry[key] = datetime.utcnow() + timedelta(seconds=seconds) return True return False def ttl(self, key: str) -> int: """Get TTL in seconds (-1 if no expiry, -2 if key doesn't exist)""" with self._lock: exists = key in self._strings or key in self._hashes or \ key in self._lists or key in self._sets if not exists: return -2 if key not in self._expiry: return -1 remaining = (self._expiry[key] - datetime.utcnow()).total_seconds() return max(0, int(remaining)) def keys(self, pattern: str = "*") -> List[str]: """Get keys matching pattern""" with self._lock: all_keys = set() all_keys.update(self._strings.keys()) all_keys.update(self._hashes.keys()) all_keys.update(self._lists.keys()) all_keys.update(self._sets.keys()) if pattern == "*": return list(all_keys) # Simple pattern matching (only supports * and prefix*) import fnmatch return [k for k in all_keys if fnmatch.fnmatch(k, pattern)] # === Distributed Locks === def acquire_lock(self, key: str, owner: str, ttl: int = 30) -> bool: """ Acquire a distributed lock. Args: key: Lock key owner: Lock owner identifier ttl: Lock TTL in seconds """ with self._lock: if key in self._locks and self._locks[key].is_valid(): if self._locks[key].owner != owner: return False self._locks[key] = MockLock( key=key, owner=owner, acquired_at=datetime.utcnow(), ttl=ttl ) return True def release_lock(self, key: str, owner: str) -> bool: """Release a distributed lock""" with self._lock: if key not in self._locks: return False if self._locks[key].owner != owner: return False self._locks[key].released = True return True def refresh_lock(self, key: str, owner: str, ttl: int = 30) -> bool: """Refresh lock TTL""" with self._lock: if key not in self._locks: return False lock = self._locks[key] if lock.owner != owner or not lock.is_valid(): return False lock.acquired_at = datetime.utcnow() lock.ttl = ttl return True # === Pub/Sub === def subscribe(self, channel: str, callback: Callable[[str, Any], None]): """Subscribe to a channel""" with self._lock: if channel not in self._subscribers: self._subscribers[channel] = [] self._subscribers[channel].append(callback) def unsubscribe(self, channel: str, callback: Callable = None): """Unsubscribe from a channel""" with self._lock: if channel not in self._subscribers: return if callback: self._subscribers[channel] = [ cb for cb in self._subscribers[channel] if cb != callback ] else: del self._subscribers[channel] def publish(self, channel: str, message: Any) -> int: """Publish message to channel""" with self._lock: callbacks = self._subscribers.get(channel, []).copy() count = 0 for callback in callbacks: try: callback(channel, message) count += 1 except Exception: pass return count # === Test Helpers === def reset(self): """Reset all state for testing""" with self._lock: self._strings.clear() self._hashes.clear() self._lists.clear() self._sets.clear() self._expiry.clear() self._locks.clear() self._subscribers.clear() def get_all_state(self) -> Dict[str, Any]: """Get complete state for test assertions""" with self._lock: return { "strings": self._strings.copy(), "hashes": {k: v.copy() for k, v in self._hashes.items()}, "lists": {k: v.copy() for k, v in self._lists.items()}, "sets": {k: v.copy() for k, v in self._sets.items()}, "locks": {k: {"owner": v.owner, "valid": v.is_valid()} for k, v in self._locks.items()} } def inject_state(self, state: Dict[str, Any]): """Inject state for testing""" with self._lock: if "strings" in state: self._strings.update(state["strings"]) if "hashes" in state: for k, v in state["hashes"].items(): self._hashes[k] = v.copy() if "lists" in state: for k, v in state["lists"].items(): self._lists[k] = v.copy() if "sets" in state: for k, v in state["sets"].items(): self._sets[k] = set(v)