where-simd-helps/analysis/cmd/analyze-simd/main.go

243 lines
6.6 KiB
Go

// analyze-simd computes speedup ratios from aggregated pqc-bench results.
//
// Usage:
//
// analyze-simd [--baseline ref] [--in results.json] [--out speedups.json]
//
// It reads the JSON produced by 'aggregate', computes per-operation speedups
// relative to the baseline variant, and emits both a human-readable table
// and a structured JSON file suitable for downstream plotting.
package main
import (
"cmp"
"encoding/json"
"flag"
"fmt"
"math"
"os"
"slices"
"strings"
"text/tabwriter"
)
// Record mirrors the aggregate output schema (fields we need).
type Record struct {
Algorithm string `json:"algorithm"`
Variant string `json:"variant"`
Operation string `json:"operation"`
Median float64 `json:"median"`
CI95 [2]float64 `json:"ci95"`
NRuns int `json:"n_runs"`
}
// Speedup is one variant-vs-baseline comparison for a single (algorithm, operation).
type Speedup struct {
Variant string `json:"variant"`
Median float64 `json:"median"`
Speedup float64 `json:"speedup"`
SpeedupCI [2]float64 `json:"speedup_ci95"`
}
// Result is one output row: all comparisons for one (algorithm, operation) pair.
type Result struct {
Algorithm string `json:"algorithm"`
Operation string `json:"operation"`
BaselineVariant string `json:"baseline_variant"`
BaselineMedian float64 `json:"baseline_median"`
BaselineCI95 [2]float64 `json:"baseline_ci95"`
Comparisons []Speedup `json:"comparisons"`
}
func main() {
baseline := flag.String("baseline", "ref", "variant to use as the speedup denominator")
inFile := flag.String("in", "results/kyber.json", "input JSON from aggregate")
outFile := flag.String("out", "", "write speedup JSON to this file (default: stdout)")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: analyze-simd [--baseline VARIANT] [--in FILE] [--out FILE]\n")
flag.PrintDefaults()
}
flag.Parse()
raw, err := os.ReadFile(*inFile)
if err != nil {
fmt.Fprintf(os.Stderr, "error reading %s: %v\n", *inFile, err)
os.Exit(1)
}
var records []Record
if err := json.Unmarshal(raw, &records); err != nil {
fmt.Fprintf(os.Stderr, "error parsing JSON: %v\n", err)
os.Exit(1)
}
// Index by (algorithm, variant, operation).
type key struct{ algorithm, variant, operation string }
idx := make(map[key]Record, len(records))
for _, r := range records {
idx[key{r.Algorithm, r.Variant, r.Operation}] = r
}
// Collect sorted unique values for stable output.
algorithms := unique(records, func(r Record) string { return r.Algorithm })
operations := unique(records, func(r Record) string { return r.Operation })
variants := unique(records, func(r Record) string { return r.Variant })
// Remove baseline from comparison variants.
variants = slices.DeleteFunc(variants, func(v string) bool { return v == *baseline })
// Build results.
var results []Result
for _, alg := range algorithms {
for _, op := range operations {
baseRec, ok := idx[key{alg, *baseline, op}]
if !ok || baseRec.Median == 0 {
continue
}
res := Result{
Algorithm: alg,
Operation: op,
BaselineVariant: *baseline,
BaselineMedian: baseRec.Median,
BaselineCI95: baseRec.CI95,
}
for _, v := range variants {
cmpRec, ok := idx[key{alg, v, op}]
if !ok || cmpRec.Median == 0 {
continue
}
sp := baseRec.Median / cmpRec.Median
// Conservative CI: ratio of interval bounds.
// speedup_lo = baseline_lo / cmp_hi
// speedup_hi = baseline_hi / cmp_lo
var spCI [2]float64
if cmpRec.CI95[1] > 0 {
spCI[0] = safeDiv(baseRec.CI95[0], cmpRec.CI95[1])
}
if cmpRec.CI95[0] > 0 {
spCI[1] = safeDiv(baseRec.CI95[1], cmpRec.CI95[0])
}
res.Comparisons = append(res.Comparisons, Speedup{
Variant: v,
Median: cmpRec.Median,
Speedup: sp,
SpeedupCI: spCI,
})
}
if len(res.Comparisons) > 0 {
results = append(results, res)
}
}
}
// Print human-readable table to stderr.
printTable(os.Stderr, results, variants, *baseline)
// Emit JSON.
out, err := json.MarshalIndent(results, "", " ")
if err != nil {
fmt.Fprintf(os.Stderr, "error marshalling JSON: %v\n", err)
os.Exit(1)
}
if *outFile != "" {
if err := os.WriteFile(*outFile, out, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "error writing %s: %v\n", *outFile, err)
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "wrote %d results to %s\n", len(results), *outFile)
} else {
fmt.Println(string(out))
}
}
func printTable(w *os.File, results []Result, variants []string, baseline string) {
tw := tabwriter.NewWriter(w, 0, 0, 2, ' ', 0)
// Group by algorithm.
byAlg := make(map[string][]Result)
for _, r := range results {
byAlg[r.Algorithm] = append(byAlg[r.Algorithm], r)
}
algs := make([]string, 0, len(byAlg))
for a := range byAlg {
algs = append(algs, a)
}
slices.Sort(algs)
for _, alg := range algs {
fmt.Fprintf(tw, "\n── %s (baseline: %s) ──\n", strings.ToUpper(alg), baseline)
// Header.
var hdr strings.Builder
fmt.Fprintf(&hdr, "%-38s\t%12s", "operation", baseline+"(cycles)")
for _, v := range variants {
fmt.Fprintf(&hdr, "\t%10s", v)
}
fmt.Fprintln(tw, hdr.String())
fmt.Fprintln(tw, strings.Repeat("-", 38+13+11*len(variants)))
rows := byAlg[alg]
slices.SortFunc(rows, func(a, b Result) int {
// Sort by descending avx2 speedup if available, else alphabetically.
sa := speedupFor(a, "avx2")
sb := speedupFor(b, "avx2")
if sa != sb {
return cmp.Compare(sb, sa) // descending
}
return strings.Compare(a.Operation, b.Operation)
})
for _, r := range rows {
var line strings.Builder
fmt.Fprintf(&line, "%-38s\t%12s", r.Operation, formatCycles(r.BaselineMedian))
for _, v := range variants {
sp := speedupFor(r, v)
if math.IsNaN(sp) {
fmt.Fprintf(&line, "\t%10s", "---")
} else {
fmt.Fprintf(&line, "\t%9.2fx", sp)
}
}
fmt.Fprintln(tw, line.String())
}
}
tw.Flush()
}
func speedupFor(r Result, variant string) float64 {
for _, c := range r.Comparisons {
if c.Variant == variant {
return c.Speedup
}
}
return math.NaN()
}
func formatCycles(c float64) string {
if c >= 1_000_000 {
return fmt.Sprintf("%.2fM", c/1_000_000)
}
if c >= 1_000 {
return fmt.Sprintf("%.1fK", c/1_000)
}
return fmt.Sprintf("%.0f", c)
}
func safeDiv(a, b float64) float64 {
if b == 0 {
return 0
}
return a / b
}
func unique(records []Record, fn func(Record) string) []string {
seen := make(map[string]struct{})
for _, r := range records {
seen[fn(r)] = struct{}{}
}
out := make([]string, 0, len(seen))
for k := range seen {
out = append(out, k)
}
slices.Sort(out)
return out
}