first commit
This commit is contained in:
88
memory/embeddings.py
Normal file
88
memory/embeddings.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Embedding provider for the memory system.
|
||||
Uses fastembed (ONNX) for fully local, zero-API-call embeddings.
|
||||
|
||||
Inspired by OpenClaw's src/memory/embeddings.ts, simplified to:
|
||||
• Single provider: fastembed with BAAI/bge-small-en-v1.5 (384-dim)
|
||||
• Local only — no OpenAI/Voyage/Gemini API calls
|
||||
• Thread-safe lazy initialization
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from memory.internal import normalize_embedding
|
||||
|
||||
logger = logging.getLogger("aetheel.memory.embeddings")
|
||||
|
||||
# The fastembed model is loaded lazily on first use
|
||||
_model_lock = threading.Lock()
|
||||
_model = None
|
||||
_model_name: str | None = None
|
||||
|
||||
|
||||
def _ensure_model(model_name: str = "BAAI/bge-small-en-v1.5"):
|
||||
"""Lazy-load the fastembed model (thread-safe)."""
|
||||
global _model, _model_name
|
||||
|
||||
if _model is not None and _model_name == model_name:
|
||||
return _model
|
||||
|
||||
with _model_lock:
|
||||
# Double-check after acquiring lock
|
||||
if _model is not None and _model_name == model_name:
|
||||
return _model
|
||||
|
||||
try:
|
||||
from fastembed import TextEmbedding
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"fastembed is required for local embeddings.\n"
|
||||
"Install with: uv add fastembed\n"
|
||||
"Or: pip install fastembed"
|
||||
)
|
||||
|
||||
logger.info(f"Loading embedding model: {model_name}...")
|
||||
_model = TextEmbedding(model_name=model_name)
|
||||
_model_name = model_name
|
||||
logger.info(f"Embedding model loaded: {model_name}")
|
||||
return _model
|
||||
|
||||
|
||||
def embed_query(text: str, model_name: str = "BAAI/bge-small-en-v1.5") -> list[float]:
|
||||
"""
|
||||
Generate an embedding vector for a single query string.
|
||||
Returns a normalized 384-dimensional vector.
|
||||
"""
|
||||
model = _ensure_model(model_name)
|
||||
embeddings = list(model.query_embed([text]))
|
||||
if not embeddings:
|
||||
return []
|
||||
vec = embeddings[0].tolist()
|
||||
return normalize_embedding(vec)
|
||||
|
||||
|
||||
def embed_batch(
|
||||
texts: list[str],
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate embedding vectors for a batch of text strings.
|
||||
Returns a list of normalized 384-dimensional vectors.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
model = _ensure_model(model_name)
|
||||
embeddings = list(model.passage_embed(texts))
|
||||
return [normalize_embedding(e.tolist()) for e in embeddings]
|
||||
|
||||
|
||||
def get_embedding_dims(model_name: str = "BAAI/bge-small-en-v1.5") -> int:
|
||||
"""Get the dimensionality of the embedding model."""
|
||||
# Known dimensions for common models
|
||||
known_dims = {
|
||||
"BAAI/bge-small-en-v1.5": 384,
|
||||
"BAAI/bge-base-en-v1.5": 768,
|
||||
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
||||
}
|
||||
return known_dims.get(model_name, 384)
|
||||
Reference in New Issue
Block a user