91 lines
2.3 KiB
Python
91 lines
2.3 KiB
Python
"""Generate completions for sparse and dense prompts via LMStudio.
|
|
|
|
LMStudio exposes an OpenAI-compatible server (default: localhost:1234).
|
|
Start the server from LMStudio's "Local Server" tab before running.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from openai import OpenAI
|
|
from tqdm import tqdm
|
|
|
|
|
|
def load_config(path: str = "config.yaml") -> dict:
|
|
with open(path, "r") as f:
|
|
return yaml.safe_load(f)
|
|
|
|
|
|
def make_client(cfg: dict) -> OpenAI:
|
|
return OpenAI(
|
|
base_url=cfg["lmstudio"]["base_url"],
|
|
api_key=cfg["lmstudio"]["api_key"],
|
|
)
|
|
|
|
|
|
def generate_one(
|
|
client: OpenAI,
|
|
model: str,
|
|
prompt: str,
|
|
temperature: float,
|
|
top_p: float,
|
|
max_tokens: int,
|
|
seed: int,
|
|
) -> str:
|
|
"""Single completion. Returns the assistant message content."""
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_tokens=max_tokens,
|
|
seed=seed,
|
|
)
|
|
return response.choices[0].message.content or ""
|
|
|
|
|
|
def run_condition(
|
|
client: OpenAI,
|
|
cfg: dict,
|
|
condition: str,
|
|
) -> None:
|
|
prompts_path = Path(cfg["paths"]["prompts_dir"]) / f"{condition}.json"
|
|
outputs_dir = Path(cfg["paths"]["outputs_dir"]) / condition
|
|
outputs_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(prompts_path, "r") as f:
|
|
prompts = json.load(f)
|
|
|
|
gen_cfg = cfg["generation"]
|
|
model = cfg["lmstudio"]["model"]
|
|
|
|
for i, prompt in enumerate(tqdm(prompts, desc=f"{condition}")):
|
|
out_file = outputs_dir / f"{i:02d}.txt"
|
|
if out_file.exists():
|
|
continue # resume support
|
|
text = generate_one(
|
|
client=client,
|
|
model=model,
|
|
prompt=prompt,
|
|
temperature=gen_cfg["temperature"],
|
|
top_p=gen_cfg["top_p"],
|
|
max_tokens=gen_cfg["max_tokens"],
|
|
seed=i,
|
|
)
|
|
out_file.write_text(text, encoding="utf-8")
|
|
|
|
|
|
def main() -> None:
|
|
cfg = load_config()
|
|
client = make_client(cfg)
|
|
for condition in ("sparse", "dense"):
|
|
run_condition(client, cfg, condition)
|
|
print("Generation complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|