59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
"""Compute sentence embeddings for each generation in each condition."""
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import yaml
|
|
from sentence_transformers import SentenceTransformer
|
|
from tqdm import tqdm
|
|
|
|
|
|
def load_config(path: str = "config.yaml") -> dict:
|
|
with open(path, "r") as f:
|
|
return yaml.safe_load(f)
|
|
|
|
|
|
def load_outputs(outputs_dir: Path) -> list[str]:
|
|
"""Load all .txt outputs from a condition directory, sorted by filename."""
|
|
files = sorted(outputs_dir.glob("*.txt"))
|
|
return [f.read_text(encoding="utf-8") for f in files]
|
|
|
|
|
|
def embed_condition(
|
|
model: SentenceTransformer,
|
|
texts: list[str],
|
|
) -> np.ndarray:
|
|
"""Return (N, D) embedding matrix. L2-normalized for cosine similarity."""
|
|
embeddings = model.encode(
|
|
texts,
|
|
batch_size=8,
|
|
show_progress_bar=True,
|
|
convert_to_numpy=True,
|
|
normalize_embeddings=True,
|
|
)
|
|
return embeddings
|
|
|
|
|
|
def main() -> None:
|
|
cfg = load_config()
|
|
model = SentenceTransformer(cfg["embedding"]["model"])
|
|
|
|
outputs_root = Path(cfg["paths"]["outputs_dir"])
|
|
emb_root = Path(cfg["paths"]["embeddings_dir"])
|
|
emb_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
for condition in ("sparse", "dense"):
|
|
texts = load_outputs(outputs_root / condition)
|
|
if not texts:
|
|
print(f"No outputs found for {condition}; skipping.")
|
|
continue
|
|
print(f"Embedding {len(texts)} {condition} outputs...")
|
|
embeddings = embed_condition(model, texts)
|
|
np.save(emb_root / f"{condition}.npy", embeddings)
|
|
print(f"Saved {condition}.npy with shape {embeddings.shape}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|