Source code for pyoephys.interface._zmq_client

from __future__ import annotations

import time
import json
import uuid
import warnings
import numpy as np
import zmq
from typing import Iterable, Optional, List, Tuple
from pyoephys.logging import get_logger
from collections import deque
import threading
from ._gui_client import GUIClient
from ._gui_events import Event, Spike

log = get_logger("ZMQClient")


def _norm_name(s: str) -> str:
    return "".join(str(s).split()).upper()


def _parse_int(v) -> Optional[int]:
    if v is None:
        return None
    if isinstance(v, (int, np.integer)):
        return int(v)
    if isinstance(v, str):
        return int(v) if v.isdigit() else None
    try:
        return int(v)
    except Exception:
        return None


[docs] class NotReadyError(RuntimeError): """Raised when getters are called before required channels are ready.""" pass
def _addr(host: str, endpoint: str | int) -> str: """Return a valid ZMQ endpoint. If endpoint already has '://', return as-is. Otherwise, treat it as a port and build 'tcp://{host}:{port}'.""" ep = str(endpoint) if "://" in ep: return ep # if host already has '://', don't prepend another 'tcp://' if "://" in host: return f"{host}:{int(ep)}" return f"tcp://{host}:{int(ep)}"
[docs] class ZMQClient: """ Open Ephys–compatible ZMQ client with per-channel ring buffers (deque-based). """ def __init__(self, host_ip: str = "127.0.0.1", data_port: str = "5556", heartbeat_port: Optional[str] = None, buffer_seconds: float = 30.0, expected_channel_count: Optional[int] = None, expected_channel_names: Optional[Iterable[str]] = None, required_fraction: float = 1.0, max_channels: int = 256, auto_start: bool = False, set_index_looping: bool = True, align_to_header_index: bool = False, fill_value: float = np.nan, verbose: bool = False): # config self.host_ip = str(host_ip) self.data_port = str(data_port) self.hb_endpoint = str(int(data_port) + 1) if heartbeat_port is None else str(heartbeat_port) self.buffer_seconds = float(buffer_seconds) self.expected_count = int(expected_channel_count) if expected_channel_count else None self.expected_names = list(expected_channel_names) if expected_channel_names else None self.required_fraction = float(required_fraction) self.verbose = bool(verbose) self.max_channels = int(max_channels) self.name = "ZMQClient" self.type = "ZMQClient" self.expected_names = list(expected_channel_names) if expected_channel_names else None self.expected_names_norm = [_norm_name(n) for n in self.expected_names] if self.expected_names else None self.seen_names_norm = set() # normalized names seen so far # zmq self._ctx = zmq.Context.instance() self._poller = zmq.Poller() self._data_sock: Optional[zmq.Socket] = None self._hb_sock: Optional[zmq.Socket] = None # state/threading self._stop = threading.Event() self._thread: Optional[threading.Thread] = None self.ready_event = threading.Event() self.channels_ready_event = threading.Event() self._lock = threading.Lock() # stream / buffers self.fs: float = 2000.0 # default, updated from headers self._deque_len = self._target_deque_len(self.fs) self.buffers: List[deque] = [deque(maxlen=self._deque_len) for _ in range(self.max_channels)] self._name_by_index: dict[int, str] = {} self.seen_nums: set[int] = set() self.seen_names: set[str] = set() self._ref_clock_ch: Optional[int] = None self.total_samples_written: int = 0 self.channel_index: Optional[List[int]] = None # selection self.N_channels: int = 0 self.N_samples = self._deque_len # Support older code self._drain_last_total = 0 # for drain_new() # Explicit sample index tracking self.sample_index = 0 self.global_sample_index = 0 # Total across all channels self._last_index_log = 0.0 self.index_log_interval_s = 0.25 self._last_header_index = {} self.loop_global_index = False self.loop_cycle = 0 self.loop_sample_index = 0 self._index_offset = 0 self._last_ref_s0 = None self._last_ref_end = None self.set_index_looping(enabled=set_index_looping) self.align_to_header_index = bool(align_to_header_index) self.fill_value = np.float32(fill_value) if self.align_to_header_index: self._ring = np.full((self.max_channels, self._deque_len), self.fill_value, dtype=np.float32) self._valid = np.zeros((self.max_channels, self._deque_len), dtype=np.bool_) # heartbeat tracking self._last_hb_send = 0.0 self._waiting_hb_reply = False # https connection (if enabled) self.gui = GUIClient(host=self.host_ip) if auto_start: self.start() # Optional: let first frame arrive (non-blocking feel but helpful for UI) self.ready_event.wait(timeout=5.0) # --- internal ---- def _setup(self) -> None: self._teardown() # data SUB socket: subscribe-all like the working client data_addr = _addr(self.host_ip, self.data_port) self._data_sock = self._ctx.socket(zmq.SUB) self._data_sock.connect(data_addr) self._data_sock.setsockopt(zmq.SUBSCRIBE, b"") self._data_sock.setsockopt(zmq.RCVTIMEO, 1000) self._poller.register(self._data_sock, zmq.POLLIN) # optional heartbeat REQ socket if self.hb_endpoint: hb_addr = _addr(self.host_ip, self.hb_endpoint) self._hb_sock = self._ctx.socket(zmq.REQ) self._hb_sock.connect(hb_addr) self._hb_sock.setsockopt(zmq.RCVTIMEO, 2000) self._poller.register(self._hb_sock, zmq.POLLIN) self._last_hb_send = 0.0 self._waiting_hb_reply = False def _teardown(self) -> None: try: if self._data_sock: self._poller.unregister(self._data_sock) except Exception: pass try: if self._hb_sock: self._poller.unregister(self._hb_sock) except Exception: pass try: if self._data_sock: self._data_sock.close(0) finally: self._data_sock = None try: if self._hb_sock: self._hb_sock.close(0) finally: self._hb_sock = None def _target_deque_len(self, fs: float) -> int: return max(1, int(round(fs * self.buffer_seconds))) def _rebuild_deques_if_needed(self, new_fs: float) -> None: new_len = self._target_deque_len(new_fs) if new_len == self._deque_len: return # rebuild each deque to the new maxlen, keeping most recent samples for ch in range(self.max_channels): old = self.buffers[ch] if len(old) == 0: self.buffers[ch] = deque(maxlen=new_len) continue take = min(len(old), new_len) recent = list(old)[-take:] self.buffers[ch] = deque(recent, maxlen=new_len) self._deque_len = new_len self.N_samples = new_len if getattr(self, "align_to_header_index", False): self._ring = np.full((self.max_channels, self._deque_len), self.fill_value, dtype=np.float32) self._valid = np.zeros((self.max_channels, self._deque_len), dtype=np.bool_) def _send_heartbeat_if_due(self) -> None: if not self._hb_sock: return now = time.time() if self._waiting_hb_reply: return if now - self._last_hb_send >= 2.0: try: msg = json.dumps({"application": "NewZMQClient", "type": "heartbeat"}) self._hb_sock.send(msg.encode("utf-8")) self._last_hb_send = now self._waiting_hb_reply = True #if self.verbose: # print("[HB] sent") except Exception as e: if self.verbose: print(f"[HB] send error: {e}") def _run(self) -> None: while not self._stop.is_set(): try: self._send_heartbeat_if_due() socks = dict(self._poller.poll(10)) # heartbeat reply if self._hb_sock and self._hb_sock in socks and self._waiting_hb_reply: try: _ = self._hb_sock.recv(flags=zmq.NOBLOCK) self._waiting_hb_reply = False #if self.verbose: # print("[HB] reply") except zmq.Again: pass except Exception as e: if self.verbose: print(f"[HB] recv error: {e}") # data frames if self._data_sock and self._data_sock in socks: try: frames = self._data_sock.recv_multipart(flags=zmq.NOBLOCK) if len(frames) < 2: continue # frames[1]: JSON header; frames[2]: payload (if present) try: header = json.loads(frames[1].decode("utf-8", errors="ignore")) except Exception: continue typ = header.get("type", "") if typ == "data": content = header.get("content", {}) ch = int(content.get("channel_num", -1)) ch_name = content.get("channel_name", f"CH{ch + 1}") rate = float(content.get("sample_rate", self.fs)) s0 = _parse_int(content.get("sample_num")) # first sample index in packet ns = _parse_int(content.get("num_samples")) # number of samples in packet payload = frames[2] if len(frames) >= 3 else b"" samples = np.frombuffer(payload, dtype=np.float32) # Update fs & buffer sizes if the stream rate changes if rate > 0.0 and rate != self.fs: with self._lock: self.fs = rate self._rebuild_deques_if_needed(self.fs) if 0 <= ch < self.max_channels and samples.size: with self._lock: # --- buffer/meta updates --- self.buffers[ch].extend(samples.tolist()) self._name_by_index[ch] = ch_name self.seen_nums.add(ch) self.seen_names.add(ch_name) self.seen_names_norm.add(_norm_name(ch_name)) if self._ref_clock_ch is None: self._ref_clock_ch = ch if ch == self._ref_clock_ch: self.total_samples_written += samples.size # --- header-based clocks with loop detection --- end_idx = None if (s0 is not None) and (ns is not None): end_idx = s0 + ns # end-of-packet (exclusive) # Detect playback loop on the reference channel if (ch == self._ref_clock_ch) and (end_idx is not None): if (self._last_ref_s0 is not None) and (s0 < self._last_ref_s0): # loop just started self.loop_cycle += 1 if self._last_ref_end is not None: self._index_offset += int(self._last_ref_end) # Optional: epoch reset on loop # for dq in self.buffers: dq.clear() # self._drain_last_total = 0 self._last_ref_s0 = s0 self._last_ref_end = end_idx if s0 is not None: self._last_header_index[ch] = s0 # Maintain both clocks: loop-relative and monotonic if end_idx is not None: self.loop_sample_index = int(end_idx) if self.loop_global_index: # exported clock loops with playback if ch == self._ref_clock_ch: self.global_sample_index = self.loop_sample_index else: # exported clock is monotonic across loops mono = self._index_offset + int(end_idx) if mono > self.global_sample_index: self.global_sample_index = mono else: # No header indices; fall back if self.loop_global_index: if ch == self._ref_clock_ch: self.loop_sample_index += int(samples.size) self.global_sample_index = self.loop_sample_index else: self.global_sample_index += int(samples.size) # --- index-aligned ring write (optional) --- if getattr(self, "align_to_header_index", False): # Absolute start index for this packet (monotonic) if s0 is not None: abs_start = int(self._index_offset) + int(s0) else: # fallback: estimate from current end abs_end = (int(self._index_offset) + int(self.loop_sample_index)) \ if self.loop_global_index else int(self.global_sample_index) abs_start = abs_end - int(samples.size) n = int(samples.size) if n > 0: L = self._deque_len pos0 = abs_start % L if pos0 + n <= L: self._ring[ch, pos0:pos0 + n] = samples self._valid[ch, pos0:pos0 + n] = True else: first = L - pos0 self._ring[ch, pos0:L] = samples[:first] self._valid[ch, pos0:L] = True rest = n - first self._ring[ch, 0:rest] = samples[first:] self._valid[ch, 0:rest] = True # Mark ready and channels-ready if not self.ready_event.is_set(): self.ready_event.set() if self._channels_complete_enough_unlocked(): self.channels_ready_event.set() elif typ == "event": evt = Event(header.get("content", {}), frames[2] if len(frames) >= 3 else None) if self.verbose: print(evt) elif typ == "spike": spk = Spike(header.get("spike", {}), frames[2] if len(frames) >= 3 else None) if self.verbose: print(spk) except zmq.Again: pass except Exception as e: if self.verbose: print(f"[Data] error: {e}") time.sleep(0.01) # periodic global index logger now = time.time() if self.verbose and (now - self._last_index_log) >= self.index_log_interval_s and self.ready_event.is_set(): with self._lock: gidx = int(self.global_sample_index) lidx = int(self.loop_sample_index) cycle = int(self.loop_cycle) fs_local = float(self.fs) gt = (gidx / fs_local) if fs_local > 0 else float('nan') lt = (lidx / fs_local) if fs_local > 0 else float('nan') if self.loop_global_index: print(f"[IDX] (LOOP) global={gidx} t={gt:.3f}s | cycle={cycle} | fs={fs_local:.2f}Hz") else: print( f"[IDX] (MONO) global={gidx} t={gt:.3f}s | loop_idx={lidx} loop_t={lt:.3f}s cycle={cycle} | fs={fs_local:.2f}Hz") self._last_index_log = now except Exception as e: if self.verbose: print(f"[Loop] error: {e}") time.sleep(0.1)
[docs] def wait_for_expected_channels(self, timeout: float = 15.0) -> bool: """Block until expected channels (by name) have been seen (per required_fraction).""" if not self.expected_names_norm: return True return self.channels_ready_event.wait(timeout=timeout)
[docs] def set_index_looping(self, enabled: bool) -> None: """If True, global_sample_index restarts at each playback loop; else stays monotonic.""" self.loop_global_index = bool(enabled)
def _channels_complete_enough_unlocked(self) -> bool: if self.expected_names_norm: need = set(self.expected_names_norm) have = self.seen_names_norm frac = len(need & have) / max(1, len(need)) return frac >= self.required_fraction if self.expected_count: frac = len(self.seen_nums) / float(self.expected_count) return frac >= self.required_fraction return len(self.seen_nums) > 0
[docs] def start(self) -> None: if self._thread and self._thread.is_alive(): return self._setup() self._stop.clear() self._thread = threading.Thread(target=self._run, name="ZMQClient", daemon=True) self._thread.start() if self.verbose: print( f"[ZMQClient] started; data=tcp://{self.host_ip}:{self.data_port} hb=tcp://{self.host_ip}:{self.hb_endpoint or 'None'}")
[docs] def stop(self, timeout: Optional[float] = 2.0) -> None: self._stop.set() if self._thread and self._thread.is_alive(): self._thread.join(timeout=timeout) self._teardown() if self.verbose: print("[ZMQClient] stopped")
[docs] def close(self) -> None: self.stop()
# do not terminate the shared Context (instance()) globally # ---------- selection & info ----------
[docs] def set_channel_index(self, indices: Iterable[int]) -> None: with self._lock: idx = [int(i) for i in indices] for i in idx: if i < 0 or i >= self.max_channels: raise ValueError(f"Channel index {i} out of range [0,{self.max_channels - 1}]") self.channel_index = idx self.N_channels = len(idx)
@property def channel_names(self) -> List[str]: with self._lock: n_tot = max(self.seen_nums) + 1 if self.seen_nums else 0 return [self._name_by_index.get(i, f"CH{i + 1}") for i in range(n_tot)]
[docs] def fs_estimate(self, n_last: int = 2000) -> float: return float(self.fs)
[docs] def get_latest_window(self, window_ms: int = 500) -> np.ndarray: nsamples = int(round(self.fs * window_ms / 1000.0)) if nsamples < 1: raise ValueError("Window size must be at least 1 ms.") if not self.ready_event.is_set(): raise NotReadyError("NewZMQClient not ready; no data received yet.") return self.get_latest(nsamples)[0]
[docs] def get_latest(self, n: int) -> Tuple[np.ndarray, np.ndarray]: """ Return the latest n samples and their absolute timestamps. Y: (C_selected, n), t: (n,) """ if not self.ready_event.is_set(): raise NotReadyError("NewZMQClient not ready; no data received yet.") with self._lock: if not self.channel_index: # default to all seen channels, numeric order self.channel_index = sorted(self.seen_nums) self.N_channels = len(self.channel_index) n = max(1, int(n)) if getattr(self, "align_to_header_index", False): # Index-aligned read from the circular ring using absolute indices total_abs = (int(self._index_offset) + int(self.loop_sample_index)) \ if self.loop_global_index else int(self.global_sample_index) start_abs = total_abs - n L = self._deque_len idxs = (np.arange(start_abs, total_abs, dtype=np.int64) % L) Y = np.empty((self.N_channels, n), dtype=np.float32) for i, ch in enumerate(self.channel_index): Y[i, :] = self._ring[ch, idxs] else: # Legacy: read from per-channel deques (tail) Y = np.zeros((self.N_channels, n), dtype=np.float32) have = 0 for i, ch in enumerate(self.channel_index): buf = self.buffers[ch] if not buf: continue m = min(len(buf), n) Y[i, -m:] = list(buf)[-m:] have = max(have, m) if have < 1: raise NotReadyError("No samples yet for selected channels.") # timestamps from absolute index total = (int(self._index_offset) + int(self.loop_sample_index)) \ if self.loop_global_index else int(self.global_sample_index) idx = np.arange(total - n, total, dtype=np.float64) t = idx / self.fs return Y, t
[docs] def latest(self) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]: """ Return (t_rel, Y) like the old client: - t_rel: (M,) seconds ending at 0 ([-window, 0]) - Y: (N_channels, M) Uses self.N_samples as the window size. """ if self.channel_index is None or len(self.channel_index) == 0: return None, None M = int(self.N_samples) with self._lock: Y = np.zeros((len(self.channel_index), M), dtype=np.float32) have = 0 for i, ch in enumerate(self.channel_index): buf = list(self.buffers[ch]) n = min(len(buf), M) if n > 0: Y[i, -n:] = buf[-n:] have = max(have, n) if have < 2: return None, None t_rel = (np.arange(-M, 0, dtype=np.float64) / self.fs)[-have:] return t_rel, Y[:, -have:]
[docs] def drain_new(self) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]: """ Return only NEW samples since last call: - t_abs: (K,) seconds since stream start (based on total_samples_written / fs) - Y_new: (N_channels, K) """ if self.channel_index is None or len(self.channel_index) == 0: return None, None with self._lock: # total = int(self.total_samples_written) total = int(self.global_sample_index) n_new = total - self._drain_last_total if n_new <= 0: return None, None K = n_new Y_new = np.zeros((len(self.channel_index), K), dtype=np.float32) for i, ch in enumerate(self.channel_index): lst = list(self.buffers[ch]) if len(lst) >= K: Y_new[i, :] = lst[-K:] elif len(lst) > 0: pad = K - len(lst) Y_new[i, pad:] = lst # else: leave zeros for channels that don't have data yet t_new = np.arange(self._drain_last_total, total, dtype=np.float64) / self.fs self._drain_last_total = total return t_new, Y_new
class WorkingZMQClient: """ Real-time Open Ephys ZMQ client with ring buffers and LSL-like API. Key features: - Robust JSON header handling (skips non-JSON frames). - Tracks seen channels and (optionally) blocks until a required set is present. - 'latest()' returns a rolling window; 'drain_new()' returns only new samples since last call. - Channel selection via 'channel_index' (e.g., [7, 8, 9, 10]). - Time base based on sample index and fs (Open Ephys does not send per-sample timestamps). Parameters ---------- zqm_ip : str ZMQ endpoint prefix (e.g. 'tcp://localhost'). http_ip : str Host/IP for the HTTP side of GUIClient (optional helper). data_port : int Data stream port. heartbeat_port : int Heartbeat REQ port. window_secs : float Default time window length for plotting helpers. channels : Iterable[int] | None Selected channel indices to expose via latest() / drain_new(). If None, we’ll fill once we know 'n_channels_total' or use all seen channels. auto_start : bool Start streaming worker on construction. verbose : bool Chatty logs. expected_channel_names : Iterable[str] | None If provided, we’ll consider the stream “ready” when these names have appeared. expected_channel_count : int | None If names are not provided, you can require a specific count instead (e.g., 128). require_complete : bool If True, latest()/drain_new() will yield nothing until readiness criteria are met. required_fraction : float Fraction of expected channels required to mark ready (1.0 = all, 0.95 = 95%). max_channels : int Hard ceiling for buffer allocation (avoid reallocations when late channels appear). """ def __init__( self, zqm_ip: str = "tcp://localhost", http_ip: str = "127.0.0.1", data_port: str = "5556", heartbeat_port: Optional[str] = None, window_secs: float = 5.0, channels: Optional[Iterable[int]] = None, auto_start: bool = True, verbose: bool = False, expected_channel_names: Optional[Iterable[str]] = None, expected_channel_count: Optional[int] = None, require_complete: bool = False, required_fraction: float = 1.0, max_channels: int = 256, ): # --- config / state warnings.warn( "WorkingZMQClient is deprecated and will be removed in a future version. " "Use ZMQClient instead.", DeprecationWarning, stacklevel=2, ) self.ip = zqm_ip self.data_port = int(data_port) if heartbeat_port is None: heartbeat_port = int(data_port) + 1 self.heartbeat_port = int(heartbeat_port) self.window_secs = float(window_secs) self.verbose = bool(verbose) self.name = "ZMQClient" self.type = "ZMQ Data Stream" # readiness/connection self.connection_lost = True self.reconnect_attempts = 0 self.max_reconnect_attempts = 5 self.context = zmq.Context() self.poller = zmq.Poller() self.heartbeat_socket = None self.data_socket = None self.uuid = str(uuid.uuid4()) self.last_reply_time = time.time() self.last_heartbeat_time = time.time() self.socket_waits_reply = False # threading/locks self.lock = threading.Lock() self.ready_event = threading.Event() # first data packet seen self.channels_ready_event = threading.Event() # enough channels have been seen self.streaming = False self.streaming_thread = None # sampling/time base self.sampling_rate = 2000.0 # default; updated from headers when available self.fs = float(self.sampling_rate) self.N_samples = int(max(1, round(self.fs * self.window_secs))) self.total_samples_written = 0 self._drain_last_total = 0 # last drained sample index self._ref_clock_ch = None # channel we use as a "clock" (first-seen, or 0) # channels / buffers self.max_channels = int(max_channels) self._name_by_index = {} # {ch_idx: "CH#"} self.seen_nums = set() # {0,1,2,...} self.seen_names = set() # {"CH1", ...} # expected/required channel set self.expected_names = list(expected_channel_names) if expected_channel_names else None self.expected_count = int(expected_channel_count) if expected_channel_count else None self.require_complete = bool(require_complete) self.required_fraction = float(required_fraction) # Known channel count so far; will grow as data arrives self.n_channels_total = 0 # Channel selection (indices into "physical" channels) self.channel_index = list(channels) if channels is not None else None self.N_channels = len(self.channel_index) if self.channel_index is not None else 0 # ring buffers per physical channel index [0..max_channels-1] maxlen = int(self.fs * self.window_secs) self.buffers: List[deque] = [deque(maxlen=maxlen) for _ in range(self.max_channels)] # init ZMQ self._initialize_sockets() # GUI control self.gui = GUIClient(host=http_ip) if auto_start: self.start() # Optional: wait for first data frame so downstream code can proceed self.ready_event.wait(timeout=5.0) # Optionally wait for required channel set if self.require_complete: self.wait_for_channels(timeout_sec=15.0) if self.verbose: self._print_metadata() def _print_metadata(self): print(f"[ZMQClient] Connected to {self.ip}:{self.data_port}") print(f" UUID: {self.uuid}") print(f" Channels: {self.n_channels_total}") print(f" Sampling Rate: {self.sampling_rate} Hz") print(f" Channel Names: {self.channel_names}") print(f" Type: ZMQ Data Stream") # ------------------- sockets / heartbeat ------------------- def _initialize_sockets(self): try: if not self.data_socket: addr = f"{self.ip}:{self.data_port}" self.data_socket = self.context.socket(zmq.SUB) self.data_socket.connect(addr) self.data_socket.setsockopt(zmq.SUBSCRIBE, b"") self.data_socket.setsockopt(zmq.RCVTIMEO, 1000) self.poller.register(self.data_socket, zmq.POLLIN) print(f"[ZMQClient] Initialized data socket on {addr}") if not self.heartbeat_socket: addr = f"{self.ip}:{self.heartbeat_port}" self.heartbeat_socket = self.context.socket(zmq.REQ) self.heartbeat_socket.connect(addr) self.heartbeat_socket.setsockopt(zmq.RCVTIMEO, 2000) self.poller.register(self.heartbeat_socket, zmq.POLLIN) print(f"[ZMQClient] Initialized heartbeat socket on {addr}") self.connection_lost = False self.reconnect_attempts = 0 except Exception as e: print(f"[Socket Error] Failed to initialize sockets: {e}") self.connection_lost = True def _reconnect_sockets(self) -> bool: if self.reconnect_attempts >= self.max_reconnect_attempts: print(f"[Connection] Max reconnection attempts reached ({self.max_reconnect_attempts})") return False print(f"[Connection] Attempting reconnection ({self.reconnect_attempts + 1}/{self.max_reconnect_attempts})") try: if self.data_socket: self.data_socket.close() self.data_socket = None if self.heartbeat_socket: self.heartbeat_socket.close() self.heartbeat_socket = None time.sleep(0.5) self._initialize_sockets() self.reconnect_attempts += 1 return not self.connection_lost except Exception as e: print(f"[Reconnection Error] {e}") self.reconnect_attempts += 1 return False def _send_heartbeat(self): if self.connection_lost: return try: msg = json.dumps({"application": self.name, "uuid": self.uuid, "type": "heartbeat"}) self.heartbeat_socket.send(msg.encode("utf-8")) self.last_heartbeat_time = time.time() self.socket_waits_reply = True if self.verbose: print("[Heartbeat] Sent") except Exception as e: print(f"[Heartbeat Error] {e}") self.connection_lost = True # ------------------- readiness helpers ------------------- def _channels_complete_enough_unlocked(self) -> bool: """Check without acquiring lock.""" if self.expected_names: need = set(self.expected_names) have = self.seen_names frac = len(need & have) / max(1, len(need)) return frac >= self.required_fraction if self.expected_count: frac = len(self.seen_nums) / float(self.expected_count) return frac >= self.required_fraction # No expectations set: any channel seen marks readiness. return len(self.seen_nums) > 0 def wait_for_channels(self, timeout_sec: float = 10.0) -> bool: # Fast path with self.lock: if self._channels_complete_enough_unlocked(): self.channels_ready_event.set() return True # Block until enough channels have been seen end = time.time() + timeout_sec while time.time() < end: if self.channels_ready_event.wait(timeout=0.1): return True return False # ------------------- public controls ------------------- def start(self): if self.streaming: if self.verbose: print("Already streaming") return self.streaming = True self.streaming_thread = threading.Thread(target=self._streaming_worker, daemon=True) self.streaming_thread.start() def stop(self): self.streaming = False if self.streaming_thread: self.streaming_thread.join() def close(self): self.stop() if self.data_socket: self.data_socket.close() if self.heartbeat_socket: self.heartbeat_socket.close() self.context.term() # ------------------- worker loop ------------------- def _streaming_worker(self): while self.streaming: if self.connection_lost and not self._reconnect_sockets(): continue if (time.time() - self.last_heartbeat_time) > 2.0: self._send_heartbeat() try: socks = dict(self.poller.poll(10)) # heartbeat replies if self.heartbeat_socket in socks and self.socket_waits_reply: try: _ = self.heartbeat_socket.recv() self.socket_waits_reply = False self.last_reply_time = time.time() except zmq.Again: pass except Exception as e: print(f"[Heartbeat Error] {e}") self.connection_lost = True # data stream if self.data_socket in socks: try: msg = self.data_socket.recv_multipart(zmq.NOBLOCK) if len(msg) < 2: continue # Guard decode: some frames may not be JSON headers try: header = json.loads(msg[1].decode("utf-8")) except Exception: continue typ = header.get("type", "") if typ == "data": if not self.ready_event.is_set(): self.ready_event.set() content = header.get("content", {}) ch = int(content.get("channel_num", -1)) ch_name = content.get("channel_name", f"CH{ch + 1}") rate = float(content.get("sample_rate", self.fs)) # Update fs if needed if rate > 0 and rate != self.fs: self.fs = float(rate) # do not resize deques; just update sizes for latest() self.N_samples = int(max(1, round(self.fs * self.window_secs))) samples = np.frombuffer(msg[2], dtype=np.float32, count=-1) if 0 <= ch < self.max_channels and samples.size: with self.lock: # extend buffer for this physical channel self.buffers[ch].extend(samples.tolist()) # track names/count self._name_by_index[ch] = ch_name self.seen_nums.add(ch) self.seen_names.add(ch_name) self.n_channels_total = max(self.n_channels_total, ch + 1) # choose a reference clock channel if we haven't yet if self._ref_clock_ch is None: self._ref_clock_ch = ch # advance "global" samples when reference channel arrives if ch == self._ref_clock_ch: self.total_samples_written += samples.size # selection defaulting if self.channel_index is None and self.n_channels_total > 0: # default to all seen channels in numeric order self.channel_index = sorted(list(self.seen_nums)) self.N_channels = len(self.channel_index) # readiness criteria if self._channels_complete_enough_unlocked(): self.channels_ready_event.set() elif typ == "event": evt = Event(header.get("content", {}), msg[2] if header.get("data_size", 0) > 0 else None) if self.verbose: print(evt) elif typ == "spike": spk = Spike(header.get("spike", {}), msg[2]) if self.verbose: print(spk) except zmq.Again: pass except Exception as e: print(f"[Data Error] {e}") except Exception as e: print(f"Streaming worker error: {e}") self.connection_lost = True time.sleep(0.1) # ------------------- LSL-like API ------------------- def fs_estimate(self, n_last: int = 2000) -> float: return float(self.fs) @property def channel_names(self) -> List[str]: # Derived from name map; fill unknowns as CH# return [self._name_by_index.get(i, f"CH{i + 1}") for i in range(self.n_channels_total)] def set_channel_index(self, indices: Iterable[int]): with self.lock: idx = [int(i) for i in indices] for i in idx: if i < 0 or i >= self.max_channels: raise ValueError(f"Channel index {i} out of range [0,{self.max_channels - 1}]") self.channel_index = idx self.N_channels = len(idx) def latest(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: """ Return (t_rel, Y) where: - t_rel: (M,) seconds, ending at 0 ([-window_secs, 0]) - Y: (N_channels, M) """ if self.require_complete and not self.channels_ready_event.is_set(): return None, None M = self.N_samples with self.lock: if not self.channel_index: return None, None Y = np.zeros((self.N_channels, M), dtype=np.float32) have = 0 for i, ch in enumerate(self.channel_index): # if channel hasn't arrived yet, its buffer will be empty => zeros buf = list(self.buffers[ch]) n = min(len(buf), M) if n > 0: Y[i, -n:] = buf[-n:] have = max(have, n) if have < 2: return None, None t_rel = (np.arange(-M, 0, dtype=np.float64) / self.fs)[-have:] return t_rel, Y[:, -have:] def drain_new(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: """ Return only NEW samples since last call: - t_abs: (K,) seconds since stream start (based on reference channel sample index / fs) - Y_new: (N_channels, K) """ if self.require_complete and not self.channels_ready_event.is_set(): return None, None with self.lock: if not self.channel_index: return None, None total = int(self.total_samples_written) n_new = total - self._drain_last_total if n_new <= 0: return None, None K = n_new Y_new = np.zeros((self.N_channels, K), dtype=np.float32) for i, ch in enumerate(self.channel_index): lst = list(self.buffers[ch]) if len(lst) >= K: Y_new[i, :] = lst[-K:] elif len(lst) > 0: pad = K - len(lst) Y_new[i, pad:] = lst # else: keep zeros for channels with no data yet t_new = np.arange(self._drain_last_total, total, dtype=np.float64) / self.fs self._drain_last_total = total return t_new, Y_new # ------------------- convenience ------------------- def get_latest_window(self, window_ms: int) -> np.ndarray: """Return most-recent window for SELECTED channels (C, N).""" n_samples = int(self.fs * window_ms / 1000.0) with self.lock: if not self.channel_index: return np.zeros((0, n_samples), dtype=np.float32) out = np.zeros((self.N_channels, n_samples), dtype=np.float32) for i, ch in enumerate(self.channel_index): buf = list(self.buffers[ch]) if len(buf) >= n_samples: out[i, :] = buf[-n_samples:] elif len(buf) > 0: pad = n_samples - len(buf) out[i, :pad] = 0.0 out[i, pad:] = buf else: out[i, :] = 0.0 return out def get_connection_status(self): with self.lock: return { "connected": not self.connection_lost, "streaming": self.streaming, "reconnect_attempts": self.reconnect_attempts, "total_samples": self.total_samples_written, "seen_channels": sorted(list(self.seen_nums)), "n_channels_total": self.n_channels_total, }