test(manager,market,trader): cover previously untested core paths

- manager: 15 tests incl. concurrent map access under -race (was 0 tests)
- market: timeframe normalization regression tests
- trader: FIFO position rebuild tests (partial closes, hedge/one-way mode,
  PnL-fallback entry price, invalid input)
This commit is contained in:
tinkle-community
2026-06-11 00:37:45 +08:00
parent c0d8a9a375
commit 41c2625bb2
3 changed files with 1089 additions and 0 deletions

View File

@@ -0,0 +1,480 @@
package manager
import (
"errors"
"fmt"
"sort"
"strings"
"sync"
"testing"
"nofx/store"
"nofx/trader"
)
// newIdleTrader returns a zero-value AutoTrader. It is safe to store in the
// manager map for map-semantics tests: GetStatus works on a zero value and
// Stop returns early because the trader is not running. It must NOT be used
// for anything that touches an exchange (Run, GetAccountInfo, ...).
func newIdleTrader() *trader.AutoTrader {
return &trader.AutoTrader{}
}
// insertTrader places a trader directly into the manager's internal map,
// bypassing store loading (same-package access).
func insertTrader(tm *TraderManager, id string, t *trader.AutoTrader) {
tm.mu.Lock()
defer tm.mu.Unlock()
tm.traders[id] = t
}
func TestNewTraderManager(t *testing.T) {
tm := NewTraderManager()
if tm == nil {
t.Fatal("NewTraderManager() returned nil")
}
if tm.traders == nil {
t.Error("traders map should be initialized, got nil")
}
if len(tm.traders) != 0 {
t.Errorf("traders map should be empty, got %d entries", len(tm.traders))
}
if tm.loadErrors == nil {
t.Error("loadErrors map should be initialized, got nil")
}
if len(tm.loadErrors) != 0 {
t.Errorf("loadErrors map should be empty, got %d entries", len(tm.loadErrors))
}
if tm.competitionCache == nil {
t.Fatal("competitionCache should be initialized, got nil")
}
if tm.competitionCache.data == nil {
t.Error("competitionCache.data should be initialized, got nil")
}
if !tm.competitionCache.timestamp.IsZero() {
t.Errorf("competitionCache.timestamp should be zero, got %v", tm.competitionCache.timestamp)
}
}
func TestGetTrader(t *testing.T) {
tm := NewTraderManager()
t.Run("missing ID returns error", func(t *testing.T) {
got, err := tm.GetTrader("does-not-exist")
if err == nil {
t.Fatal("GetTrader on missing ID expected error, got nil")
}
if got != nil {
t.Errorf("GetTrader on missing ID should return nil trader, got %v", got)
}
if !strings.Contains(err.Error(), "does-not-exist") {
t.Errorf("error %q should mention the trader ID", err.Error())
}
})
t.Run("existing ID returns same instance", func(t *testing.T) {
at := newIdleTrader()
insertTrader(tm, "trader-1", at)
got, err := tm.GetTrader("trader-1")
if err != nil {
t.Fatalf("GetTrader unexpected error: %v", err)
}
if got != at {
t.Errorf("GetTrader returned %p, want the stored instance %p", got, at)
}
})
}
func TestGetLoadError(t *testing.T) {
tm := NewTraderManager()
t.Run("unknown trader returns nil", func(t *testing.T) {
if err := tm.GetLoadError("unknown"); err != nil {
t.Errorf("GetLoadError for unknown trader = %v, want nil", err)
}
})
t.Run("stored error is returned", func(t *testing.T) {
wantErr := errors.New("failed to create trader: boom")
tm.mu.Lock()
tm.loadErrors["trader-x"] = wantErr
tm.mu.Unlock()
if got := tm.GetLoadError("trader-x"); !errors.Is(got, wantErr) {
t.Errorf("GetLoadError = %v, want %v", got, wantErr)
}
})
}
func TestGetAllTradersReturnsCopy(t *testing.T) {
tm := NewTraderManager()
at1 := newIdleTrader()
at2 := newIdleTrader()
insertTrader(tm, "t1", at1)
insertTrader(tm, "t2", at2)
all := tm.GetAllTraders()
if len(all) != 2 {
t.Fatalf("GetAllTraders returned %d entries, want 2", len(all))
}
if all["t1"] != at1 || all["t2"] != at2 {
t.Error("GetAllTraders should return the same trader instances")
}
// Mutating the returned map must not affect internal state.
delete(all, "t1")
all["t3"] = newIdleTrader()
if _, err := tm.GetTrader("t1"); err != nil {
t.Errorf("deleting from returned map leaked into internal state: %v", err)
}
if _, err := tm.GetTrader("t3"); err == nil {
t.Error("adding to returned map leaked into internal state")
}
if got := len(tm.GetTraderIDs()); got != 2 {
t.Errorf("internal trader count = %d after mutating returned map, want 2", got)
}
}
func TestGetTraderIDs(t *testing.T) {
tm := NewTraderManager()
t.Run("empty manager returns empty non-nil slice", func(t *testing.T) {
ids := tm.GetTraderIDs()
if ids == nil {
t.Fatal("GetTraderIDs should return an empty slice, got nil")
}
if len(ids) != 0 {
t.Errorf("GetTraderIDs = %v, want empty", ids)
}
})
t.Run("returns all IDs", func(t *testing.T) {
want := []string{"a", "b", "c"}
for _, id := range want {
insertTrader(tm, id, newIdleTrader())
}
got := tm.GetTraderIDs()
sort.Strings(got)
if len(got) != len(want) {
t.Fatalf("GetTraderIDs returned %d IDs, want %d", len(got), len(want))
}
for i := range want {
if got[i] != want[i] {
t.Errorf("GetTraderIDs[%d] = %q, want %q", i, got[i], want[i])
}
}
})
}
func TestRemoveTrader(t *testing.T) {
t.Run("removes existing non-running trader", func(t *testing.T) {
tm := NewTraderManager()
insertTrader(tm, "t1", newIdleTrader())
tm.RemoveTrader("t1")
if _, err := tm.GetTrader("t1"); err == nil {
t.Error("trader t1 should be removed")
}
if got := len(tm.GetTraderIDs()); got != 0 {
t.Errorf("trader count after removal = %d, want 0", got)
}
})
t.Run("missing ID is a no-op", func(t *testing.T) {
tm := NewTraderManager()
insertTrader(tm, "t1", newIdleTrader())
tm.RemoveTrader("missing") // must not panic
if _, err := tm.GetTrader("t1"); err != nil {
t.Errorf("unrelated trader was removed: %v", err)
}
})
}
func TestStartAllEmpty(t *testing.T) {
tm := NewTraderManager()
tm.StartAll() // must not panic with no traders
}
func TestStopAllWithIdleTraders(t *testing.T) {
tm := NewTraderManager()
tm.StopAll() // empty: must not panic
insertTrader(tm, "t1", newIdleTrader())
insertTrader(tm, "t2", newIdleTrader())
tm.StopAll() // not-running traders: Stop is an early-return no-op
}
func TestTraderLogTag(t *testing.T) {
tests := []struct {
name string
traderID string
traderName string
want string
}{
{
name: "with name",
traderID: "abc-123",
traderName: "MyBot",
want: "[trader_id=abc-123 trader_name=MyBot]",
},
{
name: "without name",
traderID: "abc-123",
want: "[trader_id=abc-123]",
},
{
name: "both empty",
want: "[trader_id=]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := traderLogTag(tt.traderID, tt.traderName); got != tt.want {
t.Errorf("traderLogTag(%q, %q) = %q, want %q", tt.traderID, tt.traderName, got, tt.want)
}
})
}
}
func TestEnsureHyperliquidNativeStrategy(t *testing.T) {
t.Run("nil config does not panic", func(t *testing.T) {
ensureHyperliquidNativeStrategy("bot", "hyperliquid", nil)
})
t.Run("non-hyperliquid exchange is untouched", func(t *testing.T) {
cfg := &store.StrategyConfig{
CoinSource: store.CoinSourceConfig{
SourceType: "ai500",
UseAI500: true,
},
}
ensureHyperliquidNativeStrategy("bot", "binance", cfg)
if cfg.CoinSource.SourceType != "ai500" || !cfg.CoinSource.UseAI500 {
t.Errorf("non-hyperliquid config was modified: %+v", cfg.CoinSource)
}
})
t.Run("native sources are kept as-is", func(t *testing.T) {
nativeSources := []string{"hyper_rank", "static", "hyper_all", "hyper_main", " Hyper_Rank "}
for _, src := range nativeSources {
cfg := &store.StrategyConfig{
CoinSource: store.CoinSourceConfig{SourceType: src},
}
ensureHyperliquidNativeStrategy("bot", "hyperliquid", cfg)
if cfg.CoinSource.SourceType != src {
t.Errorf("native source %q was rewritten to %q", src, cfg.CoinSource.SourceType)
}
}
})
t.Run("legacy source on hyperliquid is forced to hyper_rank with defaults", func(t *testing.T) {
cfg := &store.StrategyConfig{
CoinSource: store.CoinSourceConfig{
SourceType: "ai500",
UseAI500: true,
UseOITop: true,
UseOILow: true,
UseHyperAll: true,
UseHyperMain: true,
},
}
ensureHyperliquidNativeStrategy("bot", "hyperliquid", cfg)
cs := cfg.CoinSource
if cs.SourceType != "hyper_rank" {
t.Errorf("SourceType = %q, want hyper_rank", cs.SourceType)
}
if cs.UseAI500 || cs.UseOITop || cs.UseOILow || cs.UseHyperAll || cs.UseHyperMain {
t.Errorf("legacy source flags should all be cleared: %+v", cs)
}
if cs.HyperRankCategory != "stock" {
t.Errorf("HyperRankCategory = %q, want stock", cs.HyperRankCategory)
}
if cs.HyperRankDirection != "gainers" {
t.Errorf("HyperRankDirection = %q, want gainers", cs.HyperRankDirection)
}
if cs.HyperRankLimit != 5 {
t.Errorf("HyperRankLimit = %d, want 5", cs.HyperRankLimit)
}
})
t.Run("existing hyper_rank settings are preserved when forcing", func(t *testing.T) {
cfg := &store.StrategyConfig{
CoinSource: store.CoinSourceConfig{
SourceType: "oi_top",
HyperRankCategory: "crypto",
HyperRankDirection: "losers",
HyperRankLimit: 8,
},
}
ensureHyperliquidNativeStrategy("bot", "hyperliquid", cfg)
cs := cfg.CoinSource
if cs.SourceType != "hyper_rank" {
t.Errorf("SourceType = %q, want hyper_rank", cs.SourceType)
}
if cs.HyperRankCategory != "crypto" {
t.Errorf("HyperRankCategory = %q, want crypto (preserved)", cs.HyperRankCategory)
}
if cs.HyperRankDirection != "losers" {
t.Errorf("HyperRankDirection = %q, want losers (preserved)", cs.HyperRankDirection)
}
if cs.HyperRankLimit != 8 {
t.Errorf("HyperRankLimit = %d, want 8 (preserved)", cs.HyperRankLimit)
}
})
t.Run("exchange type is matched case-insensitively with whitespace", func(t *testing.T) {
cfg := &store.StrategyConfig{
CoinSource: store.CoinSourceConfig{SourceType: "ai500"},
}
ensureHyperliquidNativeStrategy("bot", " HyperLiquid ", cfg)
if cfg.CoinSource.SourceType != "hyper_rank" {
t.Errorf("SourceType = %q, want hyper_rank for case-insensitive exchange match", cfg.CoinSource.SourceType)
}
})
}
func TestGetCompetitionDataEmptyAndCache(t *testing.T) {
tm := NewTraderManager()
first, err := tm.GetCompetitionData()
if err != nil {
t.Fatalf("GetCompetitionData unexpected error: %v", err)
}
if got := first["count"]; got != 0 {
t.Errorf("count = %v, want 0", got)
}
if got := first["total_count"]; got != 0 {
t.Errorf("total_count = %v, want 0", got)
}
tm.competitionCache.mu.RLock()
cachedTimestamp := tm.competitionCache.timestamp
tm.competitionCache.mu.RUnlock()
if cachedTimestamp.IsZero() {
t.Error("competition cache timestamp should be set after first call")
}
// Second call within 30s must be served from the cache.
second, err := tm.GetCompetitionData()
if err != nil {
t.Fatalf("GetCompetitionData (cached) unexpected error: %v", err)
}
if got := second["count"]; got != 0 {
t.Errorf("cached count = %v, want 0", got)
}
tm.competitionCache.mu.RLock()
timestampAfterSecond := tm.competitionCache.timestamp
tm.competitionCache.mu.RUnlock()
if !timestampAfterSecond.Equal(cachedTimestamp) {
t.Error("cached call should not refresh the cache timestamp")
}
}
func TestGetTopTradersDataEmpty(t *testing.T) {
tm := NewTraderManager()
result, err := tm.GetTopTradersData()
if err != nil {
t.Fatalf("GetTopTradersData unexpected error: %v", err)
}
if got := result["count"]; got != 0 {
t.Errorf("count = %v, want 0", got)
}
traders, ok := result["traders"].([]map[string]interface{})
if !ok {
t.Fatalf("traders has type %T, want []map[string]interface{}", result["traders"])
}
if len(traders) != 0 {
t.Errorf("traders length = %d, want 0", len(traders))
}
}
func TestGetComparisonDataEmpty(t *testing.T) {
tm := NewTraderManager()
result, err := tm.GetComparisonData()
if err != nil {
t.Fatalf("GetComparisonData unexpected error: %v", err)
}
if got := result["count"]; got != 0 {
t.Errorf("count = %v, want 0", got)
}
}
// TestConcurrentAccess exercises the RWMutex by hammering the read paths
// while traders are removed concurrently. Run with -race.
func TestConcurrentAccess(t *testing.T) {
tm := NewTraderManager()
const traderCount = 16
ids := make([]string, 0, traderCount)
for i := 0; i < traderCount; i++ {
id := fmt.Sprintf("trader-%d", i)
ids = append(ids, id)
insertTrader(tm, id, newIdleTrader())
}
const (
goroutinesPerKind = 8
iterations = 200
)
var wg sync.WaitGroup
// Readers: GetTrader / GetLoadError
for g := 0; g < goroutinesPerKind; g++ {
wg.Add(1)
go func(seed int) {
defer wg.Done()
for i := 0; i < iterations; i++ {
id := ids[(seed+i)%traderCount]
_, _ = tm.GetTrader(id)
_ = tm.GetLoadError(id)
}
}(g)
}
// Readers: GetAllTraders / GetTraderIDs
for g := 0; g < goroutinesPerKind; g++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < iterations; i++ {
_ = tm.GetAllTraders()
_ = tm.GetTraderIDs()
}
}()
}
// Writers: RemoveTrader (including repeated removal of the same ID)
for g := 0; g < goroutinesPerKind; g++ {
wg.Add(1)
go func(seed int) {
defer wg.Done()
for i := 0; i < iterations; i++ {
tm.RemoveTrader(ids[(seed+i)%traderCount])
}
}(g)
}
wg.Wait()
if got := len(tm.GetTraderIDs()); got != 0 {
t.Errorf("all traders should be removed after concurrent removal, %d left", got)
}
}

123
market/timeframe_test.go Normal file
View File

@@ -0,0 +1,123 @@
package market
import (
"slices"
"testing"
"time"
)
func TestNormalizeTimeframe(t *testing.T) {
tests := []struct {
name string
input string
want string
wantErr bool
}{
{name: "valid lowercase minute", input: "1m", want: "1m"},
{name: "valid lowercase hour", input: "4h", want: "4h"},
{name: "valid lowercase day", input: "1d", want: "1d"},
{name: "uppercase normalized", input: "1H", want: "1h"},
{name: "mixed case normalized", input: "15M", want: "15m"},
{name: "uppercase day", input: "1D", want: "1d"},
{name: "leading and trailing whitespace", input: " 30m ", want: "30m"},
{name: "whitespace and uppercase", input: " 12H ", want: "12h"},
{name: "empty string", input: "", wantErr: true},
{name: "whitespace only", input: " ", wantErr: true},
{name: "unsupported value", input: "7m", wantErr: true},
{name: "unsupported week", input: "1w", wantErr: true},
{name: "garbage input", input: "abc", wantErr: true},
{name: "internal whitespace not trimmed", input: "1 m", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NormalizeTimeframe(tt.input)
if tt.wantErr {
if err == nil {
t.Fatalf("NormalizeTimeframe(%q) = %q, want error", tt.input, got)
}
return
}
if err != nil {
t.Fatalf("NormalizeTimeframe(%q) unexpected error: %v", tt.input, err)
}
if got != tt.want {
t.Errorf("NormalizeTimeframe(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestTFDuration(t *testing.T) {
tests := []struct {
name string
input string
want time.Duration
wantErr bool
}{
{name: "one minute", input: "1m", want: time.Minute},
{name: "three minutes", input: "3m", want: 3 * time.Minute},
{name: "five minutes", input: "5m", want: 5 * time.Minute},
{name: "fifteen minutes", input: "15m", want: 15 * time.Minute},
{name: "thirty minutes", input: "30m", want: 30 * time.Minute},
{name: "one hour", input: "1h", want: time.Hour},
{name: "two hours", input: "2h", want: 2 * time.Hour},
{name: "four hours", input: "4h", want: 4 * time.Hour},
{name: "six hours", input: "6h", want: 6 * time.Hour},
{name: "twelve hours", input: "12h", want: 12 * time.Hour},
{name: "one day", input: "1d", want: 24 * time.Hour},
{name: "uppercase with whitespace", input: " 1D ", want: 24 * time.Hour},
{name: "empty string", input: "", wantErr: true},
{name: "unsupported value", input: "2d", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := TFDuration(tt.input)
if tt.wantErr {
if err == nil {
t.Fatalf("TFDuration(%q) = %v, want error", tt.input, got)
}
return
}
if err != nil {
t.Fatalf("TFDuration(%q) unexpected error: %v", tt.input, err)
}
if got != tt.want {
t.Errorf("TFDuration(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestSupportedTimeframes(t *testing.T) {
got := SupportedTimeframes()
if len(got) == 0 {
t.Fatal("SupportedTimeframes() returned empty slice")
}
if !slices.IsSorted(got) {
t.Errorf("SupportedTimeframes() not sorted: %v", got)
}
for _, required := range []string{"1m", "1d"} {
if !slices.Contains(got, required) {
t.Errorf("SupportedTimeframes() missing %q: %v", required, got)
}
}
// Every advertised timeframe must round-trip through NormalizeTimeframe and TFDuration.
for _, tf := range got {
norm, err := NormalizeTimeframe(tf)
if err != nil {
t.Errorf("NormalizeTimeframe(%q) unexpected error: %v", tf, err)
}
if norm != tf {
t.Errorf("NormalizeTimeframe(%q) = %q, want identity", tf, norm)
}
if d, err := TFDuration(tf); err != nil || d <= 0 {
t.Errorf("TFDuration(%q) = %v, %v; want positive duration and nil error", tf, d, err)
}
}
}

View File

@@ -0,0 +1,486 @@
package trader
import (
"math"
"testing"
"time"
)
// testTime returns a deterministic timestamp offset by n minutes.
func testTime(n int) time.Time {
return time.Date(2026, 1, 2, 10, n, 0, 0, time.UTC)
}
func floatsClose(a, b float64) bool {
return math.Abs(a-b) < 1e-9
}
func TestRebuildPositionsFromTrades_EmptyInput(t *testing.T) {
if got := RebuildPositionsFromTrades(nil); got != nil {
t.Errorf("RebuildPositionsFromTrades(nil) = %v, want nil", got)
}
if got := RebuildPositionsFromTrades([]TradeRecord{}); got != nil {
t.Errorf("RebuildPositionsFromTrades([]) = %v, want nil", got)
}
}
func TestRebuildPositionsFromTrades_SimpleLongOpenClose(t *testing.T) {
trades := []TradeRecord{
{
TradeID: "t1",
Symbol: "BTCUSDT",
Side: "BUY",
Price: 100.0,
Quantity: 1.0,
Fee: 0.1,
Time: testTime(0),
},
{
TradeID: "t2",
Symbol: "BTCUSDT",
Side: "SELL",
Price: 110.0,
Quantity: 1.0,
RealizedPnL: 10.0,
Fee: 0.2,
Time: testTime(1),
},
}
records := RebuildPositionsFromTrades(trades)
if len(records) != 1 {
t.Fatalf("got %d records, want 1", len(records))
}
r := records[0]
if r.Symbol != "BTCUSDT" {
t.Errorf("Symbol = %q, want BTCUSDT", r.Symbol)
}
if r.Side != "long" {
t.Errorf("Side = %q, want long", r.Side)
}
if !floatsClose(r.EntryPrice, 100.0) {
t.Errorf("EntryPrice = %v, want 100", r.EntryPrice)
}
if !floatsClose(r.ExitPrice, 110.0) {
t.Errorf("ExitPrice = %v, want 110", r.ExitPrice)
}
if !floatsClose(r.Quantity, 1.0) {
t.Errorf("Quantity = %v, want 1", r.Quantity)
}
if !floatsClose(r.RealizedPnL, 10.0) {
t.Errorf("RealizedPnL = %v, want 10", r.RealizedPnL)
}
// Fee should be entry fee + exit fee.
if !floatsClose(r.Fee, 0.3) {
t.Errorf("Fee = %v, want 0.3 (entry 0.1 + exit 0.2)", r.Fee)
}
if !r.EntryTime.Equal(testTime(0)) {
t.Errorf("EntryTime = %v, want %v", r.EntryTime, testTime(0))
}
if !r.ExitTime.Equal(testTime(1)) {
t.Errorf("ExitTime = %v, want %v", r.ExitTime, testTime(1))
}
if r.OrderID != "t2" || r.ExchangeID != "t2" {
t.Errorf("OrderID/ExchangeID = %q/%q, want t2/t2", r.OrderID, r.ExchangeID)
}
if r.CloseType != "unknown" {
t.Errorf("CloseType = %q, want unknown", r.CloseType)
}
}
func TestRebuildPositionsFromTrades_PartialClose(t *testing.T) {
trades := []TradeRecord{
{
TradeID: "open1",
Symbol: "ETHUSDT",
Side: "BUY",
Price: 100.0,
Quantity: 2.0,
Fee: 0.4,
Time: testTime(0),
},
{
TradeID: "close1",
Symbol: "ETHUSDT",
Side: "SELL",
Price: 110.0,
Quantity: 1.0,
RealizedPnL: 10.0,
Fee: 0.1,
Time: testTime(1),
},
{
TradeID: "close2",
Symbol: "ETHUSDT",
Side: "SELL",
Price: 120.0,
Quantity: 1.0,
RealizedPnL: 20.0,
Fee: 0.1,
Time: testTime(2),
},
}
records := RebuildPositionsFromTrades(trades)
if len(records) != 2 {
t.Fatalf("got %d records, want 2", len(records))
}
for i, r := range records {
if r.Side != "long" {
t.Errorf("records[%d].Side = %q, want long", i, r.Side)
}
// FIFO: both partial closes consume the single open at 100.
if !floatsClose(r.EntryPrice, 100.0) {
t.Errorf("records[%d].EntryPrice = %v, want 100", i, r.EntryPrice)
}
if !floatsClose(r.Quantity, 1.0) {
t.Errorf("records[%d].Quantity = %v, want 1", i, r.Quantity)
}
if !r.EntryTime.Equal(testTime(0)) {
t.Errorf("records[%d].EntryTime = %v, want %v", i, r.EntryTime, testTime(0))
}
}
if !floatsClose(records[0].ExitPrice, 110.0) {
t.Errorf("records[0].ExitPrice = %v, want 110", records[0].ExitPrice)
}
if !floatsClose(records[1].ExitPrice, 120.0) {
t.Errorf("records[1].ExitPrice = %v, want 120", records[1].ExitPrice)
}
// First partial close: exit fee 0.1 + proportional entry fee 0.4*(1/2) = 0.3.
if !floatsClose(records[0].Fee, 0.3) {
t.Errorf("records[0].Fee = %v, want 0.3", records[0].Fee)
}
// NOTE: documents current behavior. The open trade's Fee field is not
// reduced when partially consumed, so the second close re-attributes the
// full remaining ratio of the original fee: 0.1 + 0.4*(1/1) = 0.5.
// Total attributed entry fee across both closes is 0.6 > 0.4 actually paid.
if !floatsClose(records[1].Fee, 0.5) {
t.Errorf("records[1].Fee = %v, want 0.5 (current over-attribution behavior)", records[1].Fee)
}
}
func TestRebuildPositionsFromTrades_MultipleOpensWeightedEntry(t *testing.T) {
trades := []TradeRecord{
{
TradeID: "open1",
Symbol: "BTCUSDT",
Side: "BUY",
Price: 100.0,
Quantity: 1.0,
Fee: 0.1,
Time: testTime(0),
},
{
TradeID: "open2",
Symbol: "BTCUSDT",
Side: "BUY",
Price: 110.0,
Quantity: 1.0,
Fee: 0.1,
Time: testTime(1),
},
{
TradeID: "close1",
Symbol: "BTCUSDT",
Side: "SELL",
Price: 120.0,
Quantity: 2.0,
RealizedPnL: 30.0,
Fee: 0.2,
Time: testTime(2),
},
}
records := RebuildPositionsFromTrades(trades)
if len(records) != 1 {
t.Fatalf("got %d records, want 1", len(records))
}
r := records[0]
// Weighted average: (100*1 + 110*1) / 2 = 105.
if !floatsClose(r.EntryPrice, 105.0) {
t.Errorf("EntryPrice = %v, want 105", r.EntryPrice)
}
if !floatsClose(r.Quantity, 2.0) {
t.Errorf("Quantity = %v, want 2", r.Quantity)
}
// Exit fee 0.2 + both entry fees 0.1 + 0.1.
if !floatsClose(r.Fee, 0.4) {
t.Errorf("Fee = %v, want 0.4", r.Fee)
}
// EntryTime is the first matched open trade's time.
if !r.EntryTime.Equal(testTime(0)) {
t.Errorf("EntryTime = %v, want %v", r.EntryTime, testTime(0))
}
}
func TestRebuildPositionsFromTrades_HedgeMode(t *testing.T) {
trades := []TradeRecord{
{
TradeID: "lo",
Symbol: "BTCUSDT",
Side: "BUY",
PositionSide: "LONG",
Price: 100.0,
Quantity: 1.0,
Time: testTime(0),
},
{
TradeID: "so",
Symbol: "BTCUSDT",
Side: "SELL",
PositionSide: "SHORT",
Price: 100.0,
Quantity: 1.0,
Time: testTime(1),
},
{
TradeID: "lc",
Symbol: "BTCUSDT",
Side: "SELL",
PositionSide: "LONG",
Price: 110.0,
Quantity: 1.0,
RealizedPnL: 10.0,
Time: testTime(2),
},
{
TradeID: "sc",
Symbol: "BTCUSDT",
Side: "BUY",
PositionSide: "SHORT",
Price: 90.0,
Quantity: 1.0,
RealizedPnL: 10.0,
Time: testTime(3),
},
}
records := RebuildPositionsFromTrades(trades)
if len(records) != 2 {
t.Fatalf("got %d records, want 2", len(records))
}
long := records[0]
if long.Side != "long" {
t.Fatalf("records[0].Side = %q, want long", long.Side)
}
if !floatsClose(long.EntryPrice, 100.0) || !floatsClose(long.ExitPrice, 110.0) {
t.Errorf("long entry/exit = %v/%v, want 100/110", long.EntryPrice, long.ExitPrice)
}
short := records[1]
if short.Side != "short" {
t.Fatalf("records[1].Side = %q, want short", short.Side)
}
if !floatsClose(short.EntryPrice, 100.0) || !floatsClose(short.ExitPrice, 90.0) {
t.Errorf("short entry/exit = %v/%v, want 100/90", short.EntryPrice, short.ExitPrice)
}
}
func TestRebuildPositionsFromTrades_OneWayModeShort(t *testing.T) {
trades := []TradeRecord{
{
TradeID: "open1",
Symbol: "SOLUSDT",
Side: "SELL", // sell with zero PnL opens a short in one-way mode
Price: 100.0,
Quantity: 1.0,
Fee: 0.05,
Time: testTime(0),
},
{
TradeID: "close1",
Symbol: "SOLUSDT",
Side: "BUY", // buy with non-zero PnL closes the short
Price: 90.0,
Quantity: 1.0,
RealizedPnL: 10.0,
Fee: 0.05,
Time: testTime(1),
},
}
records := RebuildPositionsFromTrades(trades)
if len(records) != 1 {
t.Fatalf("got %d records, want 1", len(records))
}
r := records[0]
if r.Side != "short" {
t.Errorf("Side = %q, want short", r.Side)
}
if !floatsClose(r.EntryPrice, 100.0) {
t.Errorf("EntryPrice = %v, want 100", r.EntryPrice)
}
if !floatsClose(r.ExitPrice, 90.0) {
t.Errorf("ExitPrice = %v, want 90", r.ExitPrice)
}
if !floatsClose(r.Fee, 0.1) {
t.Errorf("Fee = %v, want 0.1", r.Fee)
}
}
func TestRebuildPositionsFromTrades_PnLFallbackEntryPrice(t *testing.T) {
tests := []struct {
name string
trade TradeRecord
wantSide string
wantEntry float64
}{
{
name: "long fallback: entry = exit - pnl/qty",
trade: TradeRecord{
TradeID: "lone-long",
Symbol: "BTCUSDT",
Side: "SELL",
Price: 110.0,
Quantity: 2.0,
RealizedPnL: 20.0,
Time: testTime(0),
},
wantSide: "long",
wantEntry: 100.0, // 110 - 20/2
},
{
name: "short fallback: entry = exit + pnl/qty",
trade: TradeRecord{
TradeID: "lone-short",
Symbol: "BTCUSDT",
Side: "BUY",
Price: 95.0,
Quantity: 1.0,
RealizedPnL: 5.0,
Time: testTime(0),
},
wantSide: "short",
wantEntry: 100.0, // 95 + 5/1
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
records := RebuildPositionsFromTrades([]TradeRecord{tt.trade})
if len(records) != 1 {
t.Fatalf("got %d records, want 1", len(records))
}
r := records[0]
if r.Side != tt.wantSide {
t.Errorf("Side = %q, want %q", r.Side, tt.wantSide)
}
if !floatsClose(r.EntryPrice, tt.wantEntry) {
t.Errorf("EntryPrice = %v, want %v", r.EntryPrice, tt.wantEntry)
}
// Without a matching open trade, entry time falls back to exit time.
if !r.EntryTime.Equal(r.ExitTime) {
t.Errorf("EntryTime = %v, want exit time %v", r.EntryTime, r.ExitTime)
}
})
}
}
func TestRebuildPositionsFromTrades_InvalidTrades(t *testing.T) {
tests := []struct {
name string
trades []TradeRecord
}{
{
name: "closing trade with zero quantity",
trades: []TradeRecord{
{
TradeID: "zq",
Symbol: "BTCUSDT",
Side: "SELL",
Price: 110.0,
Quantity: 0,
RealizedPnL: 10.0,
Time: testTime(0),
},
},
},
{
name: "closing trade with zero price",
trades: []TradeRecord{
{
TradeID: "zp",
Symbol: "BTCUSDT",
Side: "SELL",
Price: 0,
Quantity: 1.0,
RealizedPnL: 10.0,
Time: testTime(0),
},
},
},
{
name: "trade with unrecognized side is skipped",
trades: []TradeRecord{
{
TradeID: "bad-side",
Symbol: "BTCUSDT",
Side: "HOLD",
Price: 100.0,
Quantity: 1.0,
Time: testTime(0),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
records := RebuildPositionsFromTrades(tt.trades)
if len(records) != 0 {
t.Errorf("got %d records, want 0: %+v", len(records), records)
}
})
}
}
func TestRebuildPositionsFromTrades_UnsortedInputUsesChronologicalFIFO(t *testing.T) {
// Deliberately out of chronological order: close first, opens reversed.
trades := []TradeRecord{
{
TradeID: "close1",
Symbol: "BTCUSDT",
Side: "SELL",
Price: 120.0,
Quantity: 1.0,
RealizedPnL: 20.0,
Time: testTime(2),
},
{
TradeID: "open2",
Symbol: "BTCUSDT",
Side: "BUY",
Price: 110.0,
Quantity: 1.0,
Time: testTime(1),
},
{
TradeID: "open1",
Symbol: "BTCUSDT",
Side: "BUY",
Price: 100.0,
Quantity: 1.0,
Time: testTime(0),
},
}
records := RebuildPositionsFromTrades(trades)
if len(records) != 1 {
t.Fatalf("got %d records, want 1", len(records))
}
// FIFO after time sort: the earliest open (price 100) is matched first.
if !floatsClose(records[0].EntryPrice, 100.0) {
t.Errorf("EntryPrice = %v, want 100 (earliest open via FIFO)", records[0].EntryPrice)
}
if !records[0].EntryTime.Equal(testTime(0)) {
t.Errorf("EntryTime = %v, want %v", records[0].EntryTime, testTime(0))
}
}