Files
Aetheel/memory/embeddings.py
Tanmay Karande ec8bd80a3d first commit
2026-02-13 23:56:09 -05:00

89 lines
2.7 KiB
Python

"""
Embedding provider for the memory system.
Uses fastembed (ONNX) for fully local, zero-API-call embeddings.
Inspired by OpenClaw's src/memory/embeddings.ts, simplified to:
• Single provider: fastembed with BAAI/bge-small-en-v1.5 (384-dim)
• Local only — no OpenAI/Voyage/Gemini API calls
• Thread-safe lazy initialization
"""
import logging
import threading
from memory.internal import normalize_embedding
logger = logging.getLogger("aetheel.memory.embeddings")
# The fastembed model is loaded lazily on first use
_model_lock = threading.Lock()
_model = None
_model_name: str | None = None
def _ensure_model(model_name: str = "BAAI/bge-small-en-v1.5"):
"""Lazy-load the fastembed model (thread-safe)."""
global _model, _model_name
if _model is not None and _model_name == model_name:
return _model
with _model_lock:
# Double-check after acquiring lock
if _model is not None and _model_name == model_name:
return _model
try:
from fastembed import TextEmbedding
except ImportError:
raise ImportError(
"fastembed is required for local embeddings.\n"
"Install with: uv add fastembed\n"
"Or: pip install fastembed"
)
logger.info(f"Loading embedding model: {model_name}...")
_model = TextEmbedding(model_name=model_name)
_model_name = model_name
logger.info(f"Embedding model loaded: {model_name}")
return _model
def embed_query(text: str, model_name: str = "BAAI/bge-small-en-v1.5") -> list[float]:
"""
Generate an embedding vector for a single query string.
Returns a normalized 384-dimensional vector.
"""
model = _ensure_model(model_name)
embeddings = list(model.query_embed([text]))
if not embeddings:
return []
vec = embeddings[0].tolist()
return normalize_embedding(vec)
def embed_batch(
texts: list[str],
model_name: str = "BAAI/bge-small-en-v1.5",
) -> list[list[float]]:
"""
Generate embedding vectors for a batch of text strings.
Returns a list of normalized 384-dimensional vectors.
"""
if not texts:
return []
model = _ensure_model(model_name)
embeddings = list(model.passage_embed(texts))
return [normalize_embedding(e.tolist()) for e in embeddings]
def get_embedding_dims(model_name: str = "BAAI/bge-small-en-v1.5") -> int:
"""Get the dimensionality of the embedding model."""
# Known dimensions for common models
known_dims = {
"BAAI/bge-small-en-v1.5": 384,
"BAAI/bge-base-en-v1.5": 768,
"sentence-transformers/all-MiniLM-L6-v2": 384,
}
return known_dims.get(model_name, 384)