diff --git a/manager/trader_manager_test.go b/manager/trader_manager_test.go new file mode 100644 index 00000000..15ea6a01 --- /dev/null +++ b/manager/trader_manager_test.go @@ -0,0 +1,480 @@ +package manager + +import ( + "errors" + "fmt" + "sort" + "strings" + "sync" + "testing" + + "nofx/store" + "nofx/trader" +) + +// newIdleTrader returns a zero-value AutoTrader. It is safe to store in the +// manager map for map-semantics tests: GetStatus works on a zero value and +// Stop returns early because the trader is not running. It must NOT be used +// for anything that touches an exchange (Run, GetAccountInfo, ...). +func newIdleTrader() *trader.AutoTrader { + return &trader.AutoTrader{} +} + +// insertTrader places a trader directly into the manager's internal map, +// bypassing store loading (same-package access). +func insertTrader(tm *TraderManager, id string, t *trader.AutoTrader) { + tm.mu.Lock() + defer tm.mu.Unlock() + tm.traders[id] = t +} + +func TestNewTraderManager(t *testing.T) { + tm := NewTraderManager() + + if tm == nil { + t.Fatal("NewTraderManager() returned nil") + } + if tm.traders == nil { + t.Error("traders map should be initialized, got nil") + } + if len(tm.traders) != 0 { + t.Errorf("traders map should be empty, got %d entries", len(tm.traders)) + } + if tm.loadErrors == nil { + t.Error("loadErrors map should be initialized, got nil") + } + if len(tm.loadErrors) != 0 { + t.Errorf("loadErrors map should be empty, got %d entries", len(tm.loadErrors)) + } + if tm.competitionCache == nil { + t.Fatal("competitionCache should be initialized, got nil") + } + if tm.competitionCache.data == nil { + t.Error("competitionCache.data should be initialized, got nil") + } + if !tm.competitionCache.timestamp.IsZero() { + t.Errorf("competitionCache.timestamp should be zero, got %v", tm.competitionCache.timestamp) + } +} + +func TestGetTrader(t *testing.T) { + tm := NewTraderManager() + + t.Run("missing ID returns error", func(t *testing.T) { + got, err := tm.GetTrader("does-not-exist") + if err == nil { + t.Fatal("GetTrader on missing ID expected error, got nil") + } + if got != nil { + t.Errorf("GetTrader on missing ID should return nil trader, got %v", got) + } + if !strings.Contains(err.Error(), "does-not-exist") { + t.Errorf("error %q should mention the trader ID", err.Error()) + } + }) + + t.Run("existing ID returns same instance", func(t *testing.T) { + at := newIdleTrader() + insertTrader(tm, "trader-1", at) + + got, err := tm.GetTrader("trader-1") + if err != nil { + t.Fatalf("GetTrader unexpected error: %v", err) + } + if got != at { + t.Errorf("GetTrader returned %p, want the stored instance %p", got, at) + } + }) +} + +func TestGetLoadError(t *testing.T) { + tm := NewTraderManager() + + t.Run("unknown trader returns nil", func(t *testing.T) { + if err := tm.GetLoadError("unknown"); err != nil { + t.Errorf("GetLoadError for unknown trader = %v, want nil", err) + } + }) + + t.Run("stored error is returned", func(t *testing.T) { + wantErr := errors.New("failed to create trader: boom") + tm.mu.Lock() + tm.loadErrors["trader-x"] = wantErr + tm.mu.Unlock() + + if got := tm.GetLoadError("trader-x"); !errors.Is(got, wantErr) { + t.Errorf("GetLoadError = %v, want %v", got, wantErr) + } + }) +} + +func TestGetAllTradersReturnsCopy(t *testing.T) { + tm := NewTraderManager() + at1 := newIdleTrader() + at2 := newIdleTrader() + insertTrader(tm, "t1", at1) + insertTrader(tm, "t2", at2) + + all := tm.GetAllTraders() + + if len(all) != 2 { + t.Fatalf("GetAllTraders returned %d entries, want 2", len(all)) + } + if all["t1"] != at1 || all["t2"] != at2 { + t.Error("GetAllTraders should return the same trader instances") + } + + // Mutating the returned map must not affect internal state. + delete(all, "t1") + all["t3"] = newIdleTrader() + + if _, err := tm.GetTrader("t1"); err != nil { + t.Errorf("deleting from returned map leaked into internal state: %v", err) + } + if _, err := tm.GetTrader("t3"); err == nil { + t.Error("adding to returned map leaked into internal state") + } + if got := len(tm.GetTraderIDs()); got != 2 { + t.Errorf("internal trader count = %d after mutating returned map, want 2", got) + } +} + +func TestGetTraderIDs(t *testing.T) { + tm := NewTraderManager() + + t.Run("empty manager returns empty non-nil slice", func(t *testing.T) { + ids := tm.GetTraderIDs() + if ids == nil { + t.Fatal("GetTraderIDs should return an empty slice, got nil") + } + if len(ids) != 0 { + t.Errorf("GetTraderIDs = %v, want empty", ids) + } + }) + + t.Run("returns all IDs", func(t *testing.T) { + want := []string{"a", "b", "c"} + for _, id := range want { + insertTrader(tm, id, newIdleTrader()) + } + + got := tm.GetTraderIDs() + sort.Strings(got) + if len(got) != len(want) { + t.Fatalf("GetTraderIDs returned %d IDs, want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("GetTraderIDs[%d] = %q, want %q", i, got[i], want[i]) + } + } + }) +} + +func TestRemoveTrader(t *testing.T) { + t.Run("removes existing non-running trader", func(t *testing.T) { + tm := NewTraderManager() + insertTrader(tm, "t1", newIdleTrader()) + + tm.RemoveTrader("t1") + + if _, err := tm.GetTrader("t1"); err == nil { + t.Error("trader t1 should be removed") + } + if got := len(tm.GetTraderIDs()); got != 0 { + t.Errorf("trader count after removal = %d, want 0", got) + } + }) + + t.Run("missing ID is a no-op", func(t *testing.T) { + tm := NewTraderManager() + insertTrader(tm, "t1", newIdleTrader()) + + tm.RemoveTrader("missing") // must not panic + + if _, err := tm.GetTrader("t1"); err != nil { + t.Errorf("unrelated trader was removed: %v", err) + } + }) +} + +func TestStartAllEmpty(t *testing.T) { + tm := NewTraderManager() + tm.StartAll() // must not panic with no traders +} + +func TestStopAllWithIdleTraders(t *testing.T) { + tm := NewTraderManager() + tm.StopAll() // empty: must not panic + + insertTrader(tm, "t1", newIdleTrader()) + insertTrader(tm, "t2", newIdleTrader()) + tm.StopAll() // not-running traders: Stop is an early-return no-op +} + +func TestTraderLogTag(t *testing.T) { + tests := []struct { + name string + traderID string + traderName string + want string + }{ + { + name: "with name", + traderID: "abc-123", + traderName: "MyBot", + want: "[trader_id=abc-123 trader_name=MyBot]", + }, + { + name: "without name", + traderID: "abc-123", + want: "[trader_id=abc-123]", + }, + { + name: "both empty", + want: "[trader_id=]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := traderLogTag(tt.traderID, tt.traderName); got != tt.want { + t.Errorf("traderLogTag(%q, %q) = %q, want %q", tt.traderID, tt.traderName, got, tt.want) + } + }) + } +} + +func TestEnsureHyperliquidNativeStrategy(t *testing.T) { + t.Run("nil config does not panic", func(t *testing.T) { + ensureHyperliquidNativeStrategy("bot", "hyperliquid", nil) + }) + + t.Run("non-hyperliquid exchange is untouched", func(t *testing.T) { + cfg := &store.StrategyConfig{ + CoinSource: store.CoinSourceConfig{ + SourceType: "ai500", + UseAI500: true, + }, + } + ensureHyperliquidNativeStrategy("bot", "binance", cfg) + + if cfg.CoinSource.SourceType != "ai500" || !cfg.CoinSource.UseAI500 { + t.Errorf("non-hyperliquid config was modified: %+v", cfg.CoinSource) + } + }) + + t.Run("native sources are kept as-is", func(t *testing.T) { + nativeSources := []string{"hyper_rank", "static", "hyper_all", "hyper_main", " Hyper_Rank "} + for _, src := range nativeSources { + cfg := &store.StrategyConfig{ + CoinSource: store.CoinSourceConfig{SourceType: src}, + } + ensureHyperliquidNativeStrategy("bot", "hyperliquid", cfg) + + if cfg.CoinSource.SourceType != src { + t.Errorf("native source %q was rewritten to %q", src, cfg.CoinSource.SourceType) + } + } + }) + + t.Run("legacy source on hyperliquid is forced to hyper_rank with defaults", func(t *testing.T) { + cfg := &store.StrategyConfig{ + CoinSource: store.CoinSourceConfig{ + SourceType: "ai500", + UseAI500: true, + UseOITop: true, + UseOILow: true, + UseHyperAll: true, + UseHyperMain: true, + }, + } + ensureHyperliquidNativeStrategy("bot", "hyperliquid", cfg) + + cs := cfg.CoinSource + if cs.SourceType != "hyper_rank" { + t.Errorf("SourceType = %q, want hyper_rank", cs.SourceType) + } + if cs.UseAI500 || cs.UseOITop || cs.UseOILow || cs.UseHyperAll || cs.UseHyperMain { + t.Errorf("legacy source flags should all be cleared: %+v", cs) + } + if cs.HyperRankCategory != "stock" { + t.Errorf("HyperRankCategory = %q, want stock", cs.HyperRankCategory) + } + if cs.HyperRankDirection != "gainers" { + t.Errorf("HyperRankDirection = %q, want gainers", cs.HyperRankDirection) + } + if cs.HyperRankLimit != 5 { + t.Errorf("HyperRankLimit = %d, want 5", cs.HyperRankLimit) + } + }) + + t.Run("existing hyper_rank settings are preserved when forcing", func(t *testing.T) { + cfg := &store.StrategyConfig{ + CoinSource: store.CoinSourceConfig{ + SourceType: "oi_top", + HyperRankCategory: "crypto", + HyperRankDirection: "losers", + HyperRankLimit: 8, + }, + } + ensureHyperliquidNativeStrategy("bot", "hyperliquid", cfg) + + cs := cfg.CoinSource + if cs.SourceType != "hyper_rank" { + t.Errorf("SourceType = %q, want hyper_rank", cs.SourceType) + } + if cs.HyperRankCategory != "crypto" { + t.Errorf("HyperRankCategory = %q, want crypto (preserved)", cs.HyperRankCategory) + } + if cs.HyperRankDirection != "losers" { + t.Errorf("HyperRankDirection = %q, want losers (preserved)", cs.HyperRankDirection) + } + if cs.HyperRankLimit != 8 { + t.Errorf("HyperRankLimit = %d, want 8 (preserved)", cs.HyperRankLimit) + } + }) + + t.Run("exchange type is matched case-insensitively with whitespace", func(t *testing.T) { + cfg := &store.StrategyConfig{ + CoinSource: store.CoinSourceConfig{SourceType: "ai500"}, + } + ensureHyperliquidNativeStrategy("bot", " HyperLiquid ", cfg) + + if cfg.CoinSource.SourceType != "hyper_rank" { + t.Errorf("SourceType = %q, want hyper_rank for case-insensitive exchange match", cfg.CoinSource.SourceType) + } + }) +} + +func TestGetCompetitionDataEmptyAndCache(t *testing.T) { + tm := NewTraderManager() + + first, err := tm.GetCompetitionData() + if err != nil { + t.Fatalf("GetCompetitionData unexpected error: %v", err) + } + if got := first["count"]; got != 0 { + t.Errorf("count = %v, want 0", got) + } + if got := first["total_count"]; got != 0 { + t.Errorf("total_count = %v, want 0", got) + } + + tm.competitionCache.mu.RLock() + cachedTimestamp := tm.competitionCache.timestamp + tm.competitionCache.mu.RUnlock() + if cachedTimestamp.IsZero() { + t.Error("competition cache timestamp should be set after first call") + } + + // Second call within 30s must be served from the cache. + second, err := tm.GetCompetitionData() + if err != nil { + t.Fatalf("GetCompetitionData (cached) unexpected error: %v", err) + } + if got := second["count"]; got != 0 { + t.Errorf("cached count = %v, want 0", got) + } + + tm.competitionCache.mu.RLock() + timestampAfterSecond := tm.competitionCache.timestamp + tm.competitionCache.mu.RUnlock() + if !timestampAfterSecond.Equal(cachedTimestamp) { + t.Error("cached call should not refresh the cache timestamp") + } +} + +func TestGetTopTradersDataEmpty(t *testing.T) { + tm := NewTraderManager() + + result, err := tm.GetTopTradersData() + if err != nil { + t.Fatalf("GetTopTradersData unexpected error: %v", err) + } + if got := result["count"]; got != 0 { + t.Errorf("count = %v, want 0", got) + } + traders, ok := result["traders"].([]map[string]interface{}) + if !ok { + t.Fatalf("traders has type %T, want []map[string]interface{}", result["traders"]) + } + if len(traders) != 0 { + t.Errorf("traders length = %d, want 0", len(traders)) + } +} + +func TestGetComparisonDataEmpty(t *testing.T) { + tm := NewTraderManager() + + result, err := tm.GetComparisonData() + if err != nil { + t.Fatalf("GetComparisonData unexpected error: %v", err) + } + if got := result["count"]; got != 0 { + t.Errorf("count = %v, want 0", got) + } +} + +// TestConcurrentAccess exercises the RWMutex by hammering the read paths +// while traders are removed concurrently. Run with -race. +func TestConcurrentAccess(t *testing.T) { + tm := NewTraderManager() + + const traderCount = 16 + ids := make([]string, 0, traderCount) + for i := 0; i < traderCount; i++ { + id := fmt.Sprintf("trader-%d", i) + ids = append(ids, id) + insertTrader(tm, id, newIdleTrader()) + } + + const ( + goroutinesPerKind = 8 + iterations = 200 + ) + + var wg sync.WaitGroup + + // Readers: GetTrader / GetLoadError + for g := 0; g < goroutinesPerKind; g++ { + wg.Add(1) + go func(seed int) { + defer wg.Done() + for i := 0; i < iterations; i++ { + id := ids[(seed+i)%traderCount] + _, _ = tm.GetTrader(id) + _ = tm.GetLoadError(id) + } + }(g) + } + + // Readers: GetAllTraders / GetTraderIDs + for g := 0; g < goroutinesPerKind; g++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + _ = tm.GetAllTraders() + _ = tm.GetTraderIDs() + } + }() + } + + // Writers: RemoveTrader (including repeated removal of the same ID) + for g := 0; g < goroutinesPerKind; g++ { + wg.Add(1) + go func(seed int) { + defer wg.Done() + for i := 0; i < iterations; i++ { + tm.RemoveTrader(ids[(seed+i)%traderCount]) + } + }(g) + } + + wg.Wait() + + if got := len(tm.GetTraderIDs()); got != 0 { + t.Errorf("all traders should be removed after concurrent removal, %d left", got) + } +} diff --git a/market/timeframe_test.go b/market/timeframe_test.go new file mode 100644 index 00000000..3d881832 --- /dev/null +++ b/market/timeframe_test.go @@ -0,0 +1,123 @@ +package market + +import ( + "slices" + "testing" + "time" +) + +func TestNormalizeTimeframe(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + {name: "valid lowercase minute", input: "1m", want: "1m"}, + {name: "valid lowercase hour", input: "4h", want: "4h"}, + {name: "valid lowercase day", input: "1d", want: "1d"}, + {name: "uppercase normalized", input: "1H", want: "1h"}, + {name: "mixed case normalized", input: "15M", want: "15m"}, + {name: "uppercase day", input: "1D", want: "1d"}, + {name: "leading and trailing whitespace", input: " 30m ", want: "30m"}, + {name: "whitespace and uppercase", input: " 12H ", want: "12h"}, + {name: "empty string", input: "", wantErr: true}, + {name: "whitespace only", input: " ", wantErr: true}, + {name: "unsupported value", input: "7m", wantErr: true}, + {name: "unsupported week", input: "1w", wantErr: true}, + {name: "garbage input", input: "abc", wantErr: true}, + {name: "internal whitespace not trimmed", input: "1 m", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NormalizeTimeframe(tt.input) + if tt.wantErr { + if err == nil { + t.Fatalf("NormalizeTimeframe(%q) = %q, want error", tt.input, got) + } + return + } + if err != nil { + t.Fatalf("NormalizeTimeframe(%q) unexpected error: %v", tt.input, err) + } + if got != tt.want { + t.Errorf("NormalizeTimeframe(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestTFDuration(t *testing.T) { + tests := []struct { + name string + input string + want time.Duration + wantErr bool + }{ + {name: "one minute", input: "1m", want: time.Minute}, + {name: "three minutes", input: "3m", want: 3 * time.Minute}, + {name: "five minutes", input: "5m", want: 5 * time.Minute}, + {name: "fifteen minutes", input: "15m", want: 15 * time.Minute}, + {name: "thirty minutes", input: "30m", want: 30 * time.Minute}, + {name: "one hour", input: "1h", want: time.Hour}, + {name: "two hours", input: "2h", want: 2 * time.Hour}, + {name: "four hours", input: "4h", want: 4 * time.Hour}, + {name: "six hours", input: "6h", want: 6 * time.Hour}, + {name: "twelve hours", input: "12h", want: 12 * time.Hour}, + {name: "one day", input: "1d", want: 24 * time.Hour}, + {name: "uppercase with whitespace", input: " 1D ", want: 24 * time.Hour}, + {name: "empty string", input: "", wantErr: true}, + {name: "unsupported value", input: "2d", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := TFDuration(tt.input) + if tt.wantErr { + if err == nil { + t.Fatalf("TFDuration(%q) = %v, want error", tt.input, got) + } + return + } + if err != nil { + t.Fatalf("TFDuration(%q) unexpected error: %v", tt.input, err) + } + if got != tt.want { + t.Errorf("TFDuration(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestSupportedTimeframes(t *testing.T) { + got := SupportedTimeframes() + + if len(got) == 0 { + t.Fatal("SupportedTimeframes() returned empty slice") + } + + if !slices.IsSorted(got) { + t.Errorf("SupportedTimeframes() not sorted: %v", got) + } + + for _, required := range []string{"1m", "1d"} { + if !slices.Contains(got, required) { + t.Errorf("SupportedTimeframes() missing %q: %v", required, got) + } + } + + // Every advertised timeframe must round-trip through NormalizeTimeframe and TFDuration. + for _, tf := range got { + norm, err := NormalizeTimeframe(tf) + if err != nil { + t.Errorf("NormalizeTimeframe(%q) unexpected error: %v", tf, err) + } + if norm != tf { + t.Errorf("NormalizeTimeframe(%q) = %q, want identity", tf, norm) + } + if d, err := TFDuration(tf); err != nil || d <= 0 { + t.Errorf("TFDuration(%q) = %v, %v; want positive duration and nil error", tf, d, err) + } + } +} diff --git a/trader/position_rebuild_test.go b/trader/position_rebuild_test.go new file mode 100644 index 00000000..2df811a6 --- /dev/null +++ b/trader/position_rebuild_test.go @@ -0,0 +1,486 @@ +package trader + +import ( + "math" + "testing" + "time" +) + +// testTime returns a deterministic timestamp offset by n minutes. +func testTime(n int) time.Time { + return time.Date(2026, 1, 2, 10, n, 0, 0, time.UTC) +} + +func floatsClose(a, b float64) bool { + return math.Abs(a-b) < 1e-9 +} + +func TestRebuildPositionsFromTrades_EmptyInput(t *testing.T) { + if got := RebuildPositionsFromTrades(nil); got != nil { + t.Errorf("RebuildPositionsFromTrades(nil) = %v, want nil", got) + } + if got := RebuildPositionsFromTrades([]TradeRecord{}); got != nil { + t.Errorf("RebuildPositionsFromTrades([]) = %v, want nil", got) + } +} + +func TestRebuildPositionsFromTrades_SimpleLongOpenClose(t *testing.T) { + trades := []TradeRecord{ + { + TradeID: "t1", + Symbol: "BTCUSDT", + Side: "BUY", + Price: 100.0, + Quantity: 1.0, + Fee: 0.1, + Time: testTime(0), + }, + { + TradeID: "t2", + Symbol: "BTCUSDT", + Side: "SELL", + Price: 110.0, + Quantity: 1.0, + RealizedPnL: 10.0, + Fee: 0.2, + Time: testTime(1), + }, + } + + records := RebuildPositionsFromTrades(trades) + if len(records) != 1 { + t.Fatalf("got %d records, want 1", len(records)) + } + + r := records[0] + if r.Symbol != "BTCUSDT" { + t.Errorf("Symbol = %q, want BTCUSDT", r.Symbol) + } + if r.Side != "long" { + t.Errorf("Side = %q, want long", r.Side) + } + if !floatsClose(r.EntryPrice, 100.0) { + t.Errorf("EntryPrice = %v, want 100", r.EntryPrice) + } + if !floatsClose(r.ExitPrice, 110.0) { + t.Errorf("ExitPrice = %v, want 110", r.ExitPrice) + } + if !floatsClose(r.Quantity, 1.0) { + t.Errorf("Quantity = %v, want 1", r.Quantity) + } + if !floatsClose(r.RealizedPnL, 10.0) { + t.Errorf("RealizedPnL = %v, want 10", r.RealizedPnL) + } + // Fee should be entry fee + exit fee. + if !floatsClose(r.Fee, 0.3) { + t.Errorf("Fee = %v, want 0.3 (entry 0.1 + exit 0.2)", r.Fee) + } + if !r.EntryTime.Equal(testTime(0)) { + t.Errorf("EntryTime = %v, want %v", r.EntryTime, testTime(0)) + } + if !r.ExitTime.Equal(testTime(1)) { + t.Errorf("ExitTime = %v, want %v", r.ExitTime, testTime(1)) + } + if r.OrderID != "t2" || r.ExchangeID != "t2" { + t.Errorf("OrderID/ExchangeID = %q/%q, want t2/t2", r.OrderID, r.ExchangeID) + } + if r.CloseType != "unknown" { + t.Errorf("CloseType = %q, want unknown", r.CloseType) + } +} + +func TestRebuildPositionsFromTrades_PartialClose(t *testing.T) { + trades := []TradeRecord{ + { + TradeID: "open1", + Symbol: "ETHUSDT", + Side: "BUY", + Price: 100.0, + Quantity: 2.0, + Fee: 0.4, + Time: testTime(0), + }, + { + TradeID: "close1", + Symbol: "ETHUSDT", + Side: "SELL", + Price: 110.0, + Quantity: 1.0, + RealizedPnL: 10.0, + Fee: 0.1, + Time: testTime(1), + }, + { + TradeID: "close2", + Symbol: "ETHUSDT", + Side: "SELL", + Price: 120.0, + Quantity: 1.0, + RealizedPnL: 20.0, + Fee: 0.1, + Time: testTime(2), + }, + } + + records := RebuildPositionsFromTrades(trades) + if len(records) != 2 { + t.Fatalf("got %d records, want 2", len(records)) + } + + for i, r := range records { + if r.Side != "long" { + t.Errorf("records[%d].Side = %q, want long", i, r.Side) + } + // FIFO: both partial closes consume the single open at 100. + if !floatsClose(r.EntryPrice, 100.0) { + t.Errorf("records[%d].EntryPrice = %v, want 100", i, r.EntryPrice) + } + if !floatsClose(r.Quantity, 1.0) { + t.Errorf("records[%d].Quantity = %v, want 1", i, r.Quantity) + } + if !r.EntryTime.Equal(testTime(0)) { + t.Errorf("records[%d].EntryTime = %v, want %v", i, r.EntryTime, testTime(0)) + } + } + + if !floatsClose(records[0].ExitPrice, 110.0) { + t.Errorf("records[0].ExitPrice = %v, want 110", records[0].ExitPrice) + } + if !floatsClose(records[1].ExitPrice, 120.0) { + t.Errorf("records[1].ExitPrice = %v, want 120", records[1].ExitPrice) + } + + // First partial close: exit fee 0.1 + proportional entry fee 0.4*(1/2) = 0.3. + if !floatsClose(records[0].Fee, 0.3) { + t.Errorf("records[0].Fee = %v, want 0.3", records[0].Fee) + } + // NOTE: documents current behavior. The open trade's Fee field is not + // reduced when partially consumed, so the second close re-attributes the + // full remaining ratio of the original fee: 0.1 + 0.4*(1/1) = 0.5. + // Total attributed entry fee across both closes is 0.6 > 0.4 actually paid. + if !floatsClose(records[1].Fee, 0.5) { + t.Errorf("records[1].Fee = %v, want 0.5 (current over-attribution behavior)", records[1].Fee) + } +} + +func TestRebuildPositionsFromTrades_MultipleOpensWeightedEntry(t *testing.T) { + trades := []TradeRecord{ + { + TradeID: "open1", + Symbol: "BTCUSDT", + Side: "BUY", + Price: 100.0, + Quantity: 1.0, + Fee: 0.1, + Time: testTime(0), + }, + { + TradeID: "open2", + Symbol: "BTCUSDT", + Side: "BUY", + Price: 110.0, + Quantity: 1.0, + Fee: 0.1, + Time: testTime(1), + }, + { + TradeID: "close1", + Symbol: "BTCUSDT", + Side: "SELL", + Price: 120.0, + Quantity: 2.0, + RealizedPnL: 30.0, + Fee: 0.2, + Time: testTime(2), + }, + } + + records := RebuildPositionsFromTrades(trades) + if len(records) != 1 { + t.Fatalf("got %d records, want 1", len(records)) + } + + r := records[0] + // Weighted average: (100*1 + 110*1) / 2 = 105. + if !floatsClose(r.EntryPrice, 105.0) { + t.Errorf("EntryPrice = %v, want 105", r.EntryPrice) + } + if !floatsClose(r.Quantity, 2.0) { + t.Errorf("Quantity = %v, want 2", r.Quantity) + } + // Exit fee 0.2 + both entry fees 0.1 + 0.1. + if !floatsClose(r.Fee, 0.4) { + t.Errorf("Fee = %v, want 0.4", r.Fee) + } + // EntryTime is the first matched open trade's time. + if !r.EntryTime.Equal(testTime(0)) { + t.Errorf("EntryTime = %v, want %v", r.EntryTime, testTime(0)) + } +} + +func TestRebuildPositionsFromTrades_HedgeMode(t *testing.T) { + trades := []TradeRecord{ + { + TradeID: "lo", + Symbol: "BTCUSDT", + Side: "BUY", + PositionSide: "LONG", + Price: 100.0, + Quantity: 1.0, + Time: testTime(0), + }, + { + TradeID: "so", + Symbol: "BTCUSDT", + Side: "SELL", + PositionSide: "SHORT", + Price: 100.0, + Quantity: 1.0, + Time: testTime(1), + }, + { + TradeID: "lc", + Symbol: "BTCUSDT", + Side: "SELL", + PositionSide: "LONG", + Price: 110.0, + Quantity: 1.0, + RealizedPnL: 10.0, + Time: testTime(2), + }, + { + TradeID: "sc", + Symbol: "BTCUSDT", + Side: "BUY", + PositionSide: "SHORT", + Price: 90.0, + Quantity: 1.0, + RealizedPnL: 10.0, + Time: testTime(3), + }, + } + + records := RebuildPositionsFromTrades(trades) + if len(records) != 2 { + t.Fatalf("got %d records, want 2", len(records)) + } + + long := records[0] + if long.Side != "long" { + t.Fatalf("records[0].Side = %q, want long", long.Side) + } + if !floatsClose(long.EntryPrice, 100.0) || !floatsClose(long.ExitPrice, 110.0) { + t.Errorf("long entry/exit = %v/%v, want 100/110", long.EntryPrice, long.ExitPrice) + } + + short := records[1] + if short.Side != "short" { + t.Fatalf("records[1].Side = %q, want short", short.Side) + } + if !floatsClose(short.EntryPrice, 100.0) || !floatsClose(short.ExitPrice, 90.0) { + t.Errorf("short entry/exit = %v/%v, want 100/90", short.EntryPrice, short.ExitPrice) + } +} + +func TestRebuildPositionsFromTrades_OneWayModeShort(t *testing.T) { + trades := []TradeRecord{ + { + TradeID: "open1", + Symbol: "SOLUSDT", + Side: "SELL", // sell with zero PnL opens a short in one-way mode + Price: 100.0, + Quantity: 1.0, + Fee: 0.05, + Time: testTime(0), + }, + { + TradeID: "close1", + Symbol: "SOLUSDT", + Side: "BUY", // buy with non-zero PnL closes the short + Price: 90.0, + Quantity: 1.0, + RealizedPnL: 10.0, + Fee: 0.05, + Time: testTime(1), + }, + } + + records := RebuildPositionsFromTrades(trades) + if len(records) != 1 { + t.Fatalf("got %d records, want 1", len(records)) + } + + r := records[0] + if r.Side != "short" { + t.Errorf("Side = %q, want short", r.Side) + } + if !floatsClose(r.EntryPrice, 100.0) { + t.Errorf("EntryPrice = %v, want 100", r.EntryPrice) + } + if !floatsClose(r.ExitPrice, 90.0) { + t.Errorf("ExitPrice = %v, want 90", r.ExitPrice) + } + if !floatsClose(r.Fee, 0.1) { + t.Errorf("Fee = %v, want 0.1", r.Fee) + } +} + +func TestRebuildPositionsFromTrades_PnLFallbackEntryPrice(t *testing.T) { + tests := []struct { + name string + trade TradeRecord + wantSide string + wantEntry float64 + }{ + { + name: "long fallback: entry = exit - pnl/qty", + trade: TradeRecord{ + TradeID: "lone-long", + Symbol: "BTCUSDT", + Side: "SELL", + Price: 110.0, + Quantity: 2.0, + RealizedPnL: 20.0, + Time: testTime(0), + }, + wantSide: "long", + wantEntry: 100.0, // 110 - 20/2 + }, + { + name: "short fallback: entry = exit + pnl/qty", + trade: TradeRecord{ + TradeID: "lone-short", + Symbol: "BTCUSDT", + Side: "BUY", + Price: 95.0, + Quantity: 1.0, + RealizedPnL: 5.0, + Time: testTime(0), + }, + wantSide: "short", + wantEntry: 100.0, // 95 + 5/1 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + records := RebuildPositionsFromTrades([]TradeRecord{tt.trade}) + if len(records) != 1 { + t.Fatalf("got %d records, want 1", len(records)) + } + r := records[0] + if r.Side != tt.wantSide { + t.Errorf("Side = %q, want %q", r.Side, tt.wantSide) + } + if !floatsClose(r.EntryPrice, tt.wantEntry) { + t.Errorf("EntryPrice = %v, want %v", r.EntryPrice, tt.wantEntry) + } + // Without a matching open trade, entry time falls back to exit time. + if !r.EntryTime.Equal(r.ExitTime) { + t.Errorf("EntryTime = %v, want exit time %v", r.EntryTime, r.ExitTime) + } + }) + } +} + +func TestRebuildPositionsFromTrades_InvalidTrades(t *testing.T) { + tests := []struct { + name string + trades []TradeRecord + }{ + { + name: "closing trade with zero quantity", + trades: []TradeRecord{ + { + TradeID: "zq", + Symbol: "BTCUSDT", + Side: "SELL", + Price: 110.0, + Quantity: 0, + RealizedPnL: 10.0, + Time: testTime(0), + }, + }, + }, + { + name: "closing trade with zero price", + trades: []TradeRecord{ + { + TradeID: "zp", + Symbol: "BTCUSDT", + Side: "SELL", + Price: 0, + Quantity: 1.0, + RealizedPnL: 10.0, + Time: testTime(0), + }, + }, + }, + { + name: "trade with unrecognized side is skipped", + trades: []TradeRecord{ + { + TradeID: "bad-side", + Symbol: "BTCUSDT", + Side: "HOLD", + Price: 100.0, + Quantity: 1.0, + Time: testTime(0), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + records := RebuildPositionsFromTrades(tt.trades) + if len(records) != 0 { + t.Errorf("got %d records, want 0: %+v", len(records), records) + } + }) + } +} + +func TestRebuildPositionsFromTrades_UnsortedInputUsesChronologicalFIFO(t *testing.T) { + // Deliberately out of chronological order: close first, opens reversed. + trades := []TradeRecord{ + { + TradeID: "close1", + Symbol: "BTCUSDT", + Side: "SELL", + Price: 120.0, + Quantity: 1.0, + RealizedPnL: 20.0, + Time: testTime(2), + }, + { + TradeID: "open2", + Symbol: "BTCUSDT", + Side: "BUY", + Price: 110.0, + Quantity: 1.0, + Time: testTime(1), + }, + { + TradeID: "open1", + Symbol: "BTCUSDT", + Side: "BUY", + Price: 100.0, + Quantity: 1.0, + Time: testTime(0), + }, + } + + records := RebuildPositionsFromTrades(trades) + if len(records) != 1 { + t.Fatalf("got %d records, want 1", len(records)) + } + + // FIFO after time sort: the earliest open (price 100) is matched first. + if !floatsClose(records[0].EntryPrice, 100.0) { + t.Errorf("EntryPrice = %v, want 100 (earliest open via FIFO)", records[0].EntryPrice) + } + if !records[0].EntryTime.Equal(testTime(0)) { + t.Errorf("EntryTime = %v, want %v", records[0].EntryTime, testTime(0)) + } +}