89 lines
2.7 KiB
Python
89 lines
2.7 KiB
Python
"""
|
|
Embedding provider for the memory system.
|
|
Uses fastembed (ONNX) for fully local, zero-API-call embeddings.
|
|
|
|
Inspired by OpenClaw's src/memory/embeddings.ts, simplified to:
|
|
• Single provider: fastembed with BAAI/bge-small-en-v1.5 (384-dim)
|
|
• Local only — no OpenAI/Voyage/Gemini API calls
|
|
• Thread-safe lazy initialization
|
|
"""
|
|
|
|
import logging
|
|
import threading
|
|
|
|
from memory.internal import normalize_embedding
|
|
|
|
logger = logging.getLogger("aetheel.memory.embeddings")
|
|
|
|
# The fastembed model is loaded lazily on first use
|
|
_model_lock = threading.Lock()
|
|
_model = None
|
|
_model_name: str | None = None
|
|
|
|
|
|
def _ensure_model(model_name: str = "BAAI/bge-small-en-v1.5"):
|
|
"""Lazy-load the fastembed model (thread-safe)."""
|
|
global _model, _model_name
|
|
|
|
if _model is not None and _model_name == model_name:
|
|
return _model
|
|
|
|
with _model_lock:
|
|
# Double-check after acquiring lock
|
|
if _model is not None and _model_name == model_name:
|
|
return _model
|
|
|
|
try:
|
|
from fastembed import TextEmbedding
|
|
except ImportError:
|
|
raise ImportError(
|
|
"fastembed is required for local embeddings.\n"
|
|
"Install with: uv add fastembed\n"
|
|
"Or: pip install fastembed"
|
|
)
|
|
|
|
logger.info(f"Loading embedding model: {model_name}...")
|
|
_model = TextEmbedding(model_name=model_name)
|
|
_model_name = model_name
|
|
logger.info(f"Embedding model loaded: {model_name}")
|
|
return _model
|
|
|
|
|
|
def embed_query(text: str, model_name: str = "BAAI/bge-small-en-v1.5") -> list[float]:
|
|
"""
|
|
Generate an embedding vector for a single query string.
|
|
Returns a normalized 384-dimensional vector.
|
|
"""
|
|
model = _ensure_model(model_name)
|
|
embeddings = list(model.query_embed([text]))
|
|
if not embeddings:
|
|
return []
|
|
vec = embeddings[0].tolist()
|
|
return normalize_embedding(vec)
|
|
|
|
|
|
def embed_batch(
|
|
texts: list[str],
|
|
model_name: str = "BAAI/bge-small-en-v1.5",
|
|
) -> list[list[float]]:
|
|
"""
|
|
Generate embedding vectors for a batch of text strings.
|
|
Returns a list of normalized 384-dimensional vectors.
|
|
"""
|
|
if not texts:
|
|
return []
|
|
model = _ensure_model(model_name)
|
|
embeddings = list(model.passage_embed(texts))
|
|
return [normalize_embedding(e.tolist()) for e in embeddings]
|
|
|
|
|
|
def get_embedding_dims(model_name: str = "BAAI/bge-small-en-v1.5") -> int:
|
|
"""Get the dimensionality of the embedding model."""
|
|
# Known dimensions for common models
|
|
known_dims = {
|
|
"BAAI/bge-small-en-v1.5": 384,
|
|
"BAAI/bge-base-en-v1.5": 768,
|
|
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
|
}
|
|
return known_dims.get(model_name, 384)
|