mirror of
https://github.com/NoFxAiOS/nofx.git
synced 2026-07-03 02:50:59 +08:00
test(manager,market,trader): cover previously untested core paths
- manager: 15 tests incl. concurrent map access under -race (was 0 tests) - market: timeframe normalization regression tests - trader: FIFO position rebuild tests (partial closes, hedge/one-way mode, PnL-fallback entry price, invalid input)
This commit is contained in:
480
manager/trader_manager_test.go
Normal file
480
manager/trader_manager_test.go
Normal file
@@ -0,0 +1,480 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"nofx/store"
|
||||
"nofx/trader"
|
||||
)
|
||||
|
||||
// newIdleTrader returns a zero-value AutoTrader. It is safe to store in the
|
||||
// manager map for map-semantics tests: GetStatus works on a zero value and
|
||||
// Stop returns early because the trader is not running. It must NOT be used
|
||||
// for anything that touches an exchange (Run, GetAccountInfo, ...).
|
||||
func newIdleTrader() *trader.AutoTrader {
|
||||
return &trader.AutoTrader{}
|
||||
}
|
||||
|
||||
// insertTrader places a trader directly into the manager's internal map,
|
||||
// bypassing store loading (same-package access).
|
||||
func insertTrader(tm *TraderManager, id string, t *trader.AutoTrader) {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
tm.traders[id] = t
|
||||
}
|
||||
|
||||
func TestNewTraderManager(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
|
||||
if tm == nil {
|
||||
t.Fatal("NewTraderManager() returned nil")
|
||||
}
|
||||
if tm.traders == nil {
|
||||
t.Error("traders map should be initialized, got nil")
|
||||
}
|
||||
if len(tm.traders) != 0 {
|
||||
t.Errorf("traders map should be empty, got %d entries", len(tm.traders))
|
||||
}
|
||||
if tm.loadErrors == nil {
|
||||
t.Error("loadErrors map should be initialized, got nil")
|
||||
}
|
||||
if len(tm.loadErrors) != 0 {
|
||||
t.Errorf("loadErrors map should be empty, got %d entries", len(tm.loadErrors))
|
||||
}
|
||||
if tm.competitionCache == nil {
|
||||
t.Fatal("competitionCache should be initialized, got nil")
|
||||
}
|
||||
if tm.competitionCache.data == nil {
|
||||
t.Error("competitionCache.data should be initialized, got nil")
|
||||
}
|
||||
if !tm.competitionCache.timestamp.IsZero() {
|
||||
t.Errorf("competitionCache.timestamp should be zero, got %v", tm.competitionCache.timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTrader(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
|
||||
t.Run("missing ID returns error", func(t *testing.T) {
|
||||
got, err := tm.GetTrader("does-not-exist")
|
||||
if err == nil {
|
||||
t.Fatal("GetTrader on missing ID expected error, got nil")
|
||||
}
|
||||
if got != nil {
|
||||
t.Errorf("GetTrader on missing ID should return nil trader, got %v", got)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "does-not-exist") {
|
||||
t.Errorf("error %q should mention the trader ID", err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("existing ID returns same instance", func(t *testing.T) {
|
||||
at := newIdleTrader()
|
||||
insertTrader(tm, "trader-1", at)
|
||||
|
||||
got, err := tm.GetTrader("trader-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetTrader unexpected error: %v", err)
|
||||
}
|
||||
if got != at {
|
||||
t.Errorf("GetTrader returned %p, want the stored instance %p", got, at)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetLoadError(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
|
||||
t.Run("unknown trader returns nil", func(t *testing.T) {
|
||||
if err := tm.GetLoadError("unknown"); err != nil {
|
||||
t.Errorf("GetLoadError for unknown trader = %v, want nil", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stored error is returned", func(t *testing.T) {
|
||||
wantErr := errors.New("failed to create trader: boom")
|
||||
tm.mu.Lock()
|
||||
tm.loadErrors["trader-x"] = wantErr
|
||||
tm.mu.Unlock()
|
||||
|
||||
if got := tm.GetLoadError("trader-x"); !errors.Is(got, wantErr) {
|
||||
t.Errorf("GetLoadError = %v, want %v", got, wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAllTradersReturnsCopy(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
at1 := newIdleTrader()
|
||||
at2 := newIdleTrader()
|
||||
insertTrader(tm, "t1", at1)
|
||||
insertTrader(tm, "t2", at2)
|
||||
|
||||
all := tm.GetAllTraders()
|
||||
|
||||
if len(all) != 2 {
|
||||
t.Fatalf("GetAllTraders returned %d entries, want 2", len(all))
|
||||
}
|
||||
if all["t1"] != at1 || all["t2"] != at2 {
|
||||
t.Error("GetAllTraders should return the same trader instances")
|
||||
}
|
||||
|
||||
// Mutating the returned map must not affect internal state.
|
||||
delete(all, "t1")
|
||||
all["t3"] = newIdleTrader()
|
||||
|
||||
if _, err := tm.GetTrader("t1"); err != nil {
|
||||
t.Errorf("deleting from returned map leaked into internal state: %v", err)
|
||||
}
|
||||
if _, err := tm.GetTrader("t3"); err == nil {
|
||||
t.Error("adding to returned map leaked into internal state")
|
||||
}
|
||||
if got := len(tm.GetTraderIDs()); got != 2 {
|
||||
t.Errorf("internal trader count = %d after mutating returned map, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTraderIDs(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
|
||||
t.Run("empty manager returns empty non-nil slice", func(t *testing.T) {
|
||||
ids := tm.GetTraderIDs()
|
||||
if ids == nil {
|
||||
t.Fatal("GetTraderIDs should return an empty slice, got nil")
|
||||
}
|
||||
if len(ids) != 0 {
|
||||
t.Errorf("GetTraderIDs = %v, want empty", ids)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns all IDs", func(t *testing.T) {
|
||||
want := []string{"a", "b", "c"}
|
||||
for _, id := range want {
|
||||
insertTrader(tm, id, newIdleTrader())
|
||||
}
|
||||
|
||||
got := tm.GetTraderIDs()
|
||||
sort.Strings(got)
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("GetTraderIDs returned %d IDs, want %d", len(got), len(want))
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("GetTraderIDs[%d] = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRemoveTrader(t *testing.T) {
|
||||
t.Run("removes existing non-running trader", func(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
insertTrader(tm, "t1", newIdleTrader())
|
||||
|
||||
tm.RemoveTrader("t1")
|
||||
|
||||
if _, err := tm.GetTrader("t1"); err == nil {
|
||||
t.Error("trader t1 should be removed")
|
||||
}
|
||||
if got := len(tm.GetTraderIDs()); got != 0 {
|
||||
t.Errorf("trader count after removal = %d, want 0", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing ID is a no-op", func(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
insertTrader(tm, "t1", newIdleTrader())
|
||||
|
||||
tm.RemoveTrader("missing") // must not panic
|
||||
|
||||
if _, err := tm.GetTrader("t1"); err != nil {
|
||||
t.Errorf("unrelated trader was removed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStartAllEmpty(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
tm.StartAll() // must not panic with no traders
|
||||
}
|
||||
|
||||
func TestStopAllWithIdleTraders(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
tm.StopAll() // empty: must not panic
|
||||
|
||||
insertTrader(tm, "t1", newIdleTrader())
|
||||
insertTrader(tm, "t2", newIdleTrader())
|
||||
tm.StopAll() // not-running traders: Stop is an early-return no-op
|
||||
}
|
||||
|
||||
func TestTraderLogTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
traderID string
|
||||
traderName string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "with name",
|
||||
traderID: "abc-123",
|
||||
traderName: "MyBot",
|
||||
want: "[trader_id=abc-123 trader_name=MyBot]",
|
||||
},
|
||||
{
|
||||
name: "without name",
|
||||
traderID: "abc-123",
|
||||
want: "[trader_id=abc-123]",
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
want: "[trader_id=]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := traderLogTag(tt.traderID, tt.traderName); got != tt.want {
|
||||
t.Errorf("traderLogTag(%q, %q) = %q, want %q", tt.traderID, tt.traderName, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureHyperliquidNativeStrategy(t *testing.T) {
|
||||
t.Run("nil config does not panic", func(t *testing.T) {
|
||||
ensureHyperliquidNativeStrategy("bot", "hyperliquid", nil)
|
||||
})
|
||||
|
||||
t.Run("non-hyperliquid exchange is untouched", func(t *testing.T) {
|
||||
cfg := &store.StrategyConfig{
|
||||
CoinSource: store.CoinSourceConfig{
|
||||
SourceType: "ai500",
|
||||
UseAI500: true,
|
||||
},
|
||||
}
|
||||
ensureHyperliquidNativeStrategy("bot", "binance", cfg)
|
||||
|
||||
if cfg.CoinSource.SourceType != "ai500" || !cfg.CoinSource.UseAI500 {
|
||||
t.Errorf("non-hyperliquid config was modified: %+v", cfg.CoinSource)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("native sources are kept as-is", func(t *testing.T) {
|
||||
nativeSources := []string{"hyper_rank", "static", "hyper_all", "hyper_main", " Hyper_Rank "}
|
||||
for _, src := range nativeSources {
|
||||
cfg := &store.StrategyConfig{
|
||||
CoinSource: store.CoinSourceConfig{SourceType: src},
|
||||
}
|
||||
ensureHyperliquidNativeStrategy("bot", "hyperliquid", cfg)
|
||||
|
||||
if cfg.CoinSource.SourceType != src {
|
||||
t.Errorf("native source %q was rewritten to %q", src, cfg.CoinSource.SourceType)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy source on hyperliquid is forced to hyper_rank with defaults", func(t *testing.T) {
|
||||
cfg := &store.StrategyConfig{
|
||||
CoinSource: store.CoinSourceConfig{
|
||||
SourceType: "ai500",
|
||||
UseAI500: true,
|
||||
UseOITop: true,
|
||||
UseOILow: true,
|
||||
UseHyperAll: true,
|
||||
UseHyperMain: true,
|
||||
},
|
||||
}
|
||||
ensureHyperliquidNativeStrategy("bot", "hyperliquid", cfg)
|
||||
|
||||
cs := cfg.CoinSource
|
||||
if cs.SourceType != "hyper_rank" {
|
||||
t.Errorf("SourceType = %q, want hyper_rank", cs.SourceType)
|
||||
}
|
||||
if cs.UseAI500 || cs.UseOITop || cs.UseOILow || cs.UseHyperAll || cs.UseHyperMain {
|
||||
t.Errorf("legacy source flags should all be cleared: %+v", cs)
|
||||
}
|
||||
if cs.HyperRankCategory != "stock" {
|
||||
t.Errorf("HyperRankCategory = %q, want stock", cs.HyperRankCategory)
|
||||
}
|
||||
if cs.HyperRankDirection != "gainers" {
|
||||
t.Errorf("HyperRankDirection = %q, want gainers", cs.HyperRankDirection)
|
||||
}
|
||||
if cs.HyperRankLimit != 5 {
|
||||
t.Errorf("HyperRankLimit = %d, want 5", cs.HyperRankLimit)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("existing hyper_rank settings are preserved when forcing", func(t *testing.T) {
|
||||
cfg := &store.StrategyConfig{
|
||||
CoinSource: store.CoinSourceConfig{
|
||||
SourceType: "oi_top",
|
||||
HyperRankCategory: "crypto",
|
||||
HyperRankDirection: "losers",
|
||||
HyperRankLimit: 8,
|
||||
},
|
||||
}
|
||||
ensureHyperliquidNativeStrategy("bot", "hyperliquid", cfg)
|
||||
|
||||
cs := cfg.CoinSource
|
||||
if cs.SourceType != "hyper_rank" {
|
||||
t.Errorf("SourceType = %q, want hyper_rank", cs.SourceType)
|
||||
}
|
||||
if cs.HyperRankCategory != "crypto" {
|
||||
t.Errorf("HyperRankCategory = %q, want crypto (preserved)", cs.HyperRankCategory)
|
||||
}
|
||||
if cs.HyperRankDirection != "losers" {
|
||||
t.Errorf("HyperRankDirection = %q, want losers (preserved)", cs.HyperRankDirection)
|
||||
}
|
||||
if cs.HyperRankLimit != 8 {
|
||||
t.Errorf("HyperRankLimit = %d, want 8 (preserved)", cs.HyperRankLimit)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exchange type is matched case-insensitively with whitespace", func(t *testing.T) {
|
||||
cfg := &store.StrategyConfig{
|
||||
CoinSource: store.CoinSourceConfig{SourceType: "ai500"},
|
||||
}
|
||||
ensureHyperliquidNativeStrategy("bot", " HyperLiquid ", cfg)
|
||||
|
||||
if cfg.CoinSource.SourceType != "hyper_rank" {
|
||||
t.Errorf("SourceType = %q, want hyper_rank for case-insensitive exchange match", cfg.CoinSource.SourceType)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetCompetitionDataEmptyAndCache(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
|
||||
first, err := tm.GetCompetitionData()
|
||||
if err != nil {
|
||||
t.Fatalf("GetCompetitionData unexpected error: %v", err)
|
||||
}
|
||||
if got := first["count"]; got != 0 {
|
||||
t.Errorf("count = %v, want 0", got)
|
||||
}
|
||||
if got := first["total_count"]; got != 0 {
|
||||
t.Errorf("total_count = %v, want 0", got)
|
||||
}
|
||||
|
||||
tm.competitionCache.mu.RLock()
|
||||
cachedTimestamp := tm.competitionCache.timestamp
|
||||
tm.competitionCache.mu.RUnlock()
|
||||
if cachedTimestamp.IsZero() {
|
||||
t.Error("competition cache timestamp should be set after first call")
|
||||
}
|
||||
|
||||
// Second call within 30s must be served from the cache.
|
||||
second, err := tm.GetCompetitionData()
|
||||
if err != nil {
|
||||
t.Fatalf("GetCompetitionData (cached) unexpected error: %v", err)
|
||||
}
|
||||
if got := second["count"]; got != 0 {
|
||||
t.Errorf("cached count = %v, want 0", got)
|
||||
}
|
||||
|
||||
tm.competitionCache.mu.RLock()
|
||||
timestampAfterSecond := tm.competitionCache.timestamp
|
||||
tm.competitionCache.mu.RUnlock()
|
||||
if !timestampAfterSecond.Equal(cachedTimestamp) {
|
||||
t.Error("cached call should not refresh the cache timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTopTradersDataEmpty(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
|
||||
result, err := tm.GetTopTradersData()
|
||||
if err != nil {
|
||||
t.Fatalf("GetTopTradersData unexpected error: %v", err)
|
||||
}
|
||||
if got := result["count"]; got != 0 {
|
||||
t.Errorf("count = %v, want 0", got)
|
||||
}
|
||||
traders, ok := result["traders"].([]map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("traders has type %T, want []map[string]interface{}", result["traders"])
|
||||
}
|
||||
if len(traders) != 0 {
|
||||
t.Errorf("traders length = %d, want 0", len(traders))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetComparisonDataEmpty(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
|
||||
result, err := tm.GetComparisonData()
|
||||
if err != nil {
|
||||
t.Fatalf("GetComparisonData unexpected error: %v", err)
|
||||
}
|
||||
if got := result["count"]; got != 0 {
|
||||
t.Errorf("count = %v, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentAccess exercises the RWMutex by hammering the read paths
|
||||
// while traders are removed concurrently. Run with -race.
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
tm := NewTraderManager()
|
||||
|
||||
const traderCount = 16
|
||||
ids := make([]string, 0, traderCount)
|
||||
for i := 0; i < traderCount; i++ {
|
||||
id := fmt.Sprintf("trader-%d", i)
|
||||
ids = append(ids, id)
|
||||
insertTrader(tm, id, newIdleTrader())
|
||||
}
|
||||
|
||||
const (
|
||||
goroutinesPerKind = 8
|
||||
iterations = 200
|
||||
)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Readers: GetTrader / GetLoadError
|
||||
for g := 0; g < goroutinesPerKind; g++ {
|
||||
wg.Add(1)
|
||||
go func(seed int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
id := ids[(seed+i)%traderCount]
|
||||
_, _ = tm.GetTrader(id)
|
||||
_ = tm.GetLoadError(id)
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
|
||||
// Readers: GetAllTraders / GetTraderIDs
|
||||
for g := 0; g < goroutinesPerKind; g++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
_ = tm.GetAllTraders()
|
||||
_ = tm.GetTraderIDs()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Writers: RemoveTrader (including repeated removal of the same ID)
|
||||
for g := 0; g < goroutinesPerKind; g++ {
|
||||
wg.Add(1)
|
||||
go func(seed int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
tm.RemoveTrader(ids[(seed+i)%traderCount])
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if got := len(tm.GetTraderIDs()); got != 0 {
|
||||
t.Errorf("all traders should be removed after concurrent removal, %d left", got)
|
||||
}
|
||||
}
|
||||
123
market/timeframe_test.go
Normal file
123
market/timeframe_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package market
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNormalizeTimeframe(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "valid lowercase minute", input: "1m", want: "1m"},
|
||||
{name: "valid lowercase hour", input: "4h", want: "4h"},
|
||||
{name: "valid lowercase day", input: "1d", want: "1d"},
|
||||
{name: "uppercase normalized", input: "1H", want: "1h"},
|
||||
{name: "mixed case normalized", input: "15M", want: "15m"},
|
||||
{name: "uppercase day", input: "1D", want: "1d"},
|
||||
{name: "leading and trailing whitespace", input: " 30m ", want: "30m"},
|
||||
{name: "whitespace and uppercase", input: " 12H ", want: "12h"},
|
||||
{name: "empty string", input: "", wantErr: true},
|
||||
{name: "whitespace only", input: " ", wantErr: true},
|
||||
{name: "unsupported value", input: "7m", wantErr: true},
|
||||
{name: "unsupported week", input: "1w", wantErr: true},
|
||||
{name: "garbage input", input: "abc", wantErr: true},
|
||||
{name: "internal whitespace not trimmed", input: "1 m", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NormalizeTimeframe(tt.input)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("NormalizeTimeframe(%q) = %q, want error", tt.input, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NormalizeTimeframe(%q) unexpected error: %v", tt.input, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("NormalizeTimeframe(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTFDuration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want time.Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "one minute", input: "1m", want: time.Minute},
|
||||
{name: "three minutes", input: "3m", want: 3 * time.Minute},
|
||||
{name: "five minutes", input: "5m", want: 5 * time.Minute},
|
||||
{name: "fifteen minutes", input: "15m", want: 15 * time.Minute},
|
||||
{name: "thirty minutes", input: "30m", want: 30 * time.Minute},
|
||||
{name: "one hour", input: "1h", want: time.Hour},
|
||||
{name: "two hours", input: "2h", want: 2 * time.Hour},
|
||||
{name: "four hours", input: "4h", want: 4 * time.Hour},
|
||||
{name: "six hours", input: "6h", want: 6 * time.Hour},
|
||||
{name: "twelve hours", input: "12h", want: 12 * time.Hour},
|
||||
{name: "one day", input: "1d", want: 24 * time.Hour},
|
||||
{name: "uppercase with whitespace", input: " 1D ", want: 24 * time.Hour},
|
||||
{name: "empty string", input: "", wantErr: true},
|
||||
{name: "unsupported value", input: "2d", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := TFDuration(tt.input)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("TFDuration(%q) = %v, want error", tt.input, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("TFDuration(%q) unexpected error: %v", tt.input, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("TFDuration(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedTimeframes(t *testing.T) {
|
||||
got := SupportedTimeframes()
|
||||
|
||||
if len(got) == 0 {
|
||||
t.Fatal("SupportedTimeframes() returned empty slice")
|
||||
}
|
||||
|
||||
if !slices.IsSorted(got) {
|
||||
t.Errorf("SupportedTimeframes() not sorted: %v", got)
|
||||
}
|
||||
|
||||
for _, required := range []string{"1m", "1d"} {
|
||||
if !slices.Contains(got, required) {
|
||||
t.Errorf("SupportedTimeframes() missing %q: %v", required, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Every advertised timeframe must round-trip through NormalizeTimeframe and TFDuration.
|
||||
for _, tf := range got {
|
||||
norm, err := NormalizeTimeframe(tf)
|
||||
if err != nil {
|
||||
t.Errorf("NormalizeTimeframe(%q) unexpected error: %v", tf, err)
|
||||
}
|
||||
if norm != tf {
|
||||
t.Errorf("NormalizeTimeframe(%q) = %q, want identity", tf, norm)
|
||||
}
|
||||
if d, err := TFDuration(tf); err != nil || d <= 0 {
|
||||
t.Errorf("TFDuration(%q) = %v, %v; want positive duration and nil error", tf, d, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
486
trader/position_rebuild_test.go
Normal file
486
trader/position_rebuild_test.go
Normal file
@@ -0,0 +1,486 @@
|
||||
package trader
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// testTime returns a deterministic timestamp offset by n minutes.
|
||||
func testTime(n int) time.Time {
|
||||
return time.Date(2026, 1, 2, 10, n, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func floatsClose(a, b float64) bool {
|
||||
return math.Abs(a-b) < 1e-9
|
||||
}
|
||||
|
||||
func TestRebuildPositionsFromTrades_EmptyInput(t *testing.T) {
|
||||
if got := RebuildPositionsFromTrades(nil); got != nil {
|
||||
t.Errorf("RebuildPositionsFromTrades(nil) = %v, want nil", got)
|
||||
}
|
||||
if got := RebuildPositionsFromTrades([]TradeRecord{}); got != nil {
|
||||
t.Errorf("RebuildPositionsFromTrades([]) = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRebuildPositionsFromTrades_SimpleLongOpenClose(t *testing.T) {
|
||||
trades := []TradeRecord{
|
||||
{
|
||||
TradeID: "t1",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "BUY",
|
||||
Price: 100.0,
|
||||
Quantity: 1.0,
|
||||
Fee: 0.1,
|
||||
Time: testTime(0),
|
||||
},
|
||||
{
|
||||
TradeID: "t2",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "SELL",
|
||||
Price: 110.0,
|
||||
Quantity: 1.0,
|
||||
RealizedPnL: 10.0,
|
||||
Fee: 0.2,
|
||||
Time: testTime(1),
|
||||
},
|
||||
}
|
||||
|
||||
records := RebuildPositionsFromTrades(trades)
|
||||
if len(records) != 1 {
|
||||
t.Fatalf("got %d records, want 1", len(records))
|
||||
}
|
||||
|
||||
r := records[0]
|
||||
if r.Symbol != "BTCUSDT" {
|
||||
t.Errorf("Symbol = %q, want BTCUSDT", r.Symbol)
|
||||
}
|
||||
if r.Side != "long" {
|
||||
t.Errorf("Side = %q, want long", r.Side)
|
||||
}
|
||||
if !floatsClose(r.EntryPrice, 100.0) {
|
||||
t.Errorf("EntryPrice = %v, want 100", r.EntryPrice)
|
||||
}
|
||||
if !floatsClose(r.ExitPrice, 110.0) {
|
||||
t.Errorf("ExitPrice = %v, want 110", r.ExitPrice)
|
||||
}
|
||||
if !floatsClose(r.Quantity, 1.0) {
|
||||
t.Errorf("Quantity = %v, want 1", r.Quantity)
|
||||
}
|
||||
if !floatsClose(r.RealizedPnL, 10.0) {
|
||||
t.Errorf("RealizedPnL = %v, want 10", r.RealizedPnL)
|
||||
}
|
||||
// Fee should be entry fee + exit fee.
|
||||
if !floatsClose(r.Fee, 0.3) {
|
||||
t.Errorf("Fee = %v, want 0.3 (entry 0.1 + exit 0.2)", r.Fee)
|
||||
}
|
||||
if !r.EntryTime.Equal(testTime(0)) {
|
||||
t.Errorf("EntryTime = %v, want %v", r.EntryTime, testTime(0))
|
||||
}
|
||||
if !r.ExitTime.Equal(testTime(1)) {
|
||||
t.Errorf("ExitTime = %v, want %v", r.ExitTime, testTime(1))
|
||||
}
|
||||
if r.OrderID != "t2" || r.ExchangeID != "t2" {
|
||||
t.Errorf("OrderID/ExchangeID = %q/%q, want t2/t2", r.OrderID, r.ExchangeID)
|
||||
}
|
||||
if r.CloseType != "unknown" {
|
||||
t.Errorf("CloseType = %q, want unknown", r.CloseType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRebuildPositionsFromTrades_PartialClose(t *testing.T) {
|
||||
trades := []TradeRecord{
|
||||
{
|
||||
TradeID: "open1",
|
||||
Symbol: "ETHUSDT",
|
||||
Side: "BUY",
|
||||
Price: 100.0,
|
||||
Quantity: 2.0,
|
||||
Fee: 0.4,
|
||||
Time: testTime(0),
|
||||
},
|
||||
{
|
||||
TradeID: "close1",
|
||||
Symbol: "ETHUSDT",
|
||||
Side: "SELL",
|
||||
Price: 110.0,
|
||||
Quantity: 1.0,
|
||||
RealizedPnL: 10.0,
|
||||
Fee: 0.1,
|
||||
Time: testTime(1),
|
||||
},
|
||||
{
|
||||
TradeID: "close2",
|
||||
Symbol: "ETHUSDT",
|
||||
Side: "SELL",
|
||||
Price: 120.0,
|
||||
Quantity: 1.0,
|
||||
RealizedPnL: 20.0,
|
||||
Fee: 0.1,
|
||||
Time: testTime(2),
|
||||
},
|
||||
}
|
||||
|
||||
records := RebuildPositionsFromTrades(trades)
|
||||
if len(records) != 2 {
|
||||
t.Fatalf("got %d records, want 2", len(records))
|
||||
}
|
||||
|
||||
for i, r := range records {
|
||||
if r.Side != "long" {
|
||||
t.Errorf("records[%d].Side = %q, want long", i, r.Side)
|
||||
}
|
||||
// FIFO: both partial closes consume the single open at 100.
|
||||
if !floatsClose(r.EntryPrice, 100.0) {
|
||||
t.Errorf("records[%d].EntryPrice = %v, want 100", i, r.EntryPrice)
|
||||
}
|
||||
if !floatsClose(r.Quantity, 1.0) {
|
||||
t.Errorf("records[%d].Quantity = %v, want 1", i, r.Quantity)
|
||||
}
|
||||
if !r.EntryTime.Equal(testTime(0)) {
|
||||
t.Errorf("records[%d].EntryTime = %v, want %v", i, r.EntryTime, testTime(0))
|
||||
}
|
||||
}
|
||||
|
||||
if !floatsClose(records[0].ExitPrice, 110.0) {
|
||||
t.Errorf("records[0].ExitPrice = %v, want 110", records[0].ExitPrice)
|
||||
}
|
||||
if !floatsClose(records[1].ExitPrice, 120.0) {
|
||||
t.Errorf("records[1].ExitPrice = %v, want 120", records[1].ExitPrice)
|
||||
}
|
||||
|
||||
// First partial close: exit fee 0.1 + proportional entry fee 0.4*(1/2) = 0.3.
|
||||
if !floatsClose(records[0].Fee, 0.3) {
|
||||
t.Errorf("records[0].Fee = %v, want 0.3", records[0].Fee)
|
||||
}
|
||||
// NOTE: documents current behavior. The open trade's Fee field is not
|
||||
// reduced when partially consumed, so the second close re-attributes the
|
||||
// full remaining ratio of the original fee: 0.1 + 0.4*(1/1) = 0.5.
|
||||
// Total attributed entry fee across both closes is 0.6 > 0.4 actually paid.
|
||||
if !floatsClose(records[1].Fee, 0.5) {
|
||||
t.Errorf("records[1].Fee = %v, want 0.5 (current over-attribution behavior)", records[1].Fee)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRebuildPositionsFromTrades_MultipleOpensWeightedEntry(t *testing.T) {
|
||||
trades := []TradeRecord{
|
||||
{
|
||||
TradeID: "open1",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "BUY",
|
||||
Price: 100.0,
|
||||
Quantity: 1.0,
|
||||
Fee: 0.1,
|
||||
Time: testTime(0),
|
||||
},
|
||||
{
|
||||
TradeID: "open2",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "BUY",
|
||||
Price: 110.0,
|
||||
Quantity: 1.0,
|
||||
Fee: 0.1,
|
||||
Time: testTime(1),
|
||||
},
|
||||
{
|
||||
TradeID: "close1",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "SELL",
|
||||
Price: 120.0,
|
||||
Quantity: 2.0,
|
||||
RealizedPnL: 30.0,
|
||||
Fee: 0.2,
|
||||
Time: testTime(2),
|
||||
},
|
||||
}
|
||||
|
||||
records := RebuildPositionsFromTrades(trades)
|
||||
if len(records) != 1 {
|
||||
t.Fatalf("got %d records, want 1", len(records))
|
||||
}
|
||||
|
||||
r := records[0]
|
||||
// Weighted average: (100*1 + 110*1) / 2 = 105.
|
||||
if !floatsClose(r.EntryPrice, 105.0) {
|
||||
t.Errorf("EntryPrice = %v, want 105", r.EntryPrice)
|
||||
}
|
||||
if !floatsClose(r.Quantity, 2.0) {
|
||||
t.Errorf("Quantity = %v, want 2", r.Quantity)
|
||||
}
|
||||
// Exit fee 0.2 + both entry fees 0.1 + 0.1.
|
||||
if !floatsClose(r.Fee, 0.4) {
|
||||
t.Errorf("Fee = %v, want 0.4", r.Fee)
|
||||
}
|
||||
// EntryTime is the first matched open trade's time.
|
||||
if !r.EntryTime.Equal(testTime(0)) {
|
||||
t.Errorf("EntryTime = %v, want %v", r.EntryTime, testTime(0))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRebuildPositionsFromTrades_HedgeMode(t *testing.T) {
|
||||
trades := []TradeRecord{
|
||||
{
|
||||
TradeID: "lo",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "BUY",
|
||||
PositionSide: "LONG",
|
||||
Price: 100.0,
|
||||
Quantity: 1.0,
|
||||
Time: testTime(0),
|
||||
},
|
||||
{
|
||||
TradeID: "so",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "SELL",
|
||||
PositionSide: "SHORT",
|
||||
Price: 100.0,
|
||||
Quantity: 1.0,
|
||||
Time: testTime(1),
|
||||
},
|
||||
{
|
||||
TradeID: "lc",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "SELL",
|
||||
PositionSide: "LONG",
|
||||
Price: 110.0,
|
||||
Quantity: 1.0,
|
||||
RealizedPnL: 10.0,
|
||||
Time: testTime(2),
|
||||
},
|
||||
{
|
||||
TradeID: "sc",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "BUY",
|
||||
PositionSide: "SHORT",
|
||||
Price: 90.0,
|
||||
Quantity: 1.0,
|
||||
RealizedPnL: 10.0,
|
||||
Time: testTime(3),
|
||||
},
|
||||
}
|
||||
|
||||
records := RebuildPositionsFromTrades(trades)
|
||||
if len(records) != 2 {
|
||||
t.Fatalf("got %d records, want 2", len(records))
|
||||
}
|
||||
|
||||
long := records[0]
|
||||
if long.Side != "long" {
|
||||
t.Fatalf("records[0].Side = %q, want long", long.Side)
|
||||
}
|
||||
if !floatsClose(long.EntryPrice, 100.0) || !floatsClose(long.ExitPrice, 110.0) {
|
||||
t.Errorf("long entry/exit = %v/%v, want 100/110", long.EntryPrice, long.ExitPrice)
|
||||
}
|
||||
|
||||
short := records[1]
|
||||
if short.Side != "short" {
|
||||
t.Fatalf("records[1].Side = %q, want short", short.Side)
|
||||
}
|
||||
if !floatsClose(short.EntryPrice, 100.0) || !floatsClose(short.ExitPrice, 90.0) {
|
||||
t.Errorf("short entry/exit = %v/%v, want 100/90", short.EntryPrice, short.ExitPrice)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRebuildPositionsFromTrades_OneWayModeShort(t *testing.T) {
|
||||
trades := []TradeRecord{
|
||||
{
|
||||
TradeID: "open1",
|
||||
Symbol: "SOLUSDT",
|
||||
Side: "SELL", // sell with zero PnL opens a short in one-way mode
|
||||
Price: 100.0,
|
||||
Quantity: 1.0,
|
||||
Fee: 0.05,
|
||||
Time: testTime(0),
|
||||
},
|
||||
{
|
||||
TradeID: "close1",
|
||||
Symbol: "SOLUSDT",
|
||||
Side: "BUY", // buy with non-zero PnL closes the short
|
||||
Price: 90.0,
|
||||
Quantity: 1.0,
|
||||
RealizedPnL: 10.0,
|
||||
Fee: 0.05,
|
||||
Time: testTime(1),
|
||||
},
|
||||
}
|
||||
|
||||
records := RebuildPositionsFromTrades(trades)
|
||||
if len(records) != 1 {
|
||||
t.Fatalf("got %d records, want 1", len(records))
|
||||
}
|
||||
|
||||
r := records[0]
|
||||
if r.Side != "short" {
|
||||
t.Errorf("Side = %q, want short", r.Side)
|
||||
}
|
||||
if !floatsClose(r.EntryPrice, 100.0) {
|
||||
t.Errorf("EntryPrice = %v, want 100", r.EntryPrice)
|
||||
}
|
||||
if !floatsClose(r.ExitPrice, 90.0) {
|
||||
t.Errorf("ExitPrice = %v, want 90", r.ExitPrice)
|
||||
}
|
||||
if !floatsClose(r.Fee, 0.1) {
|
||||
t.Errorf("Fee = %v, want 0.1", r.Fee)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRebuildPositionsFromTrades_PnLFallbackEntryPrice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
trade TradeRecord
|
||||
wantSide string
|
||||
wantEntry float64
|
||||
}{
|
||||
{
|
||||
name: "long fallback: entry = exit - pnl/qty",
|
||||
trade: TradeRecord{
|
||||
TradeID: "lone-long",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "SELL",
|
||||
Price: 110.0,
|
||||
Quantity: 2.0,
|
||||
RealizedPnL: 20.0,
|
||||
Time: testTime(0),
|
||||
},
|
||||
wantSide: "long",
|
||||
wantEntry: 100.0, // 110 - 20/2
|
||||
},
|
||||
{
|
||||
name: "short fallback: entry = exit + pnl/qty",
|
||||
trade: TradeRecord{
|
||||
TradeID: "lone-short",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "BUY",
|
||||
Price: 95.0,
|
||||
Quantity: 1.0,
|
||||
RealizedPnL: 5.0,
|
||||
Time: testTime(0),
|
||||
},
|
||||
wantSide: "short",
|
||||
wantEntry: 100.0, // 95 + 5/1
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
records := RebuildPositionsFromTrades([]TradeRecord{tt.trade})
|
||||
if len(records) != 1 {
|
||||
t.Fatalf("got %d records, want 1", len(records))
|
||||
}
|
||||
r := records[0]
|
||||
if r.Side != tt.wantSide {
|
||||
t.Errorf("Side = %q, want %q", r.Side, tt.wantSide)
|
||||
}
|
||||
if !floatsClose(r.EntryPrice, tt.wantEntry) {
|
||||
t.Errorf("EntryPrice = %v, want %v", r.EntryPrice, tt.wantEntry)
|
||||
}
|
||||
// Without a matching open trade, entry time falls back to exit time.
|
||||
if !r.EntryTime.Equal(r.ExitTime) {
|
||||
t.Errorf("EntryTime = %v, want exit time %v", r.EntryTime, r.ExitTime)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRebuildPositionsFromTrades_InvalidTrades(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
trades []TradeRecord
|
||||
}{
|
||||
{
|
||||
name: "closing trade with zero quantity",
|
||||
trades: []TradeRecord{
|
||||
{
|
||||
TradeID: "zq",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "SELL",
|
||||
Price: 110.0,
|
||||
Quantity: 0,
|
||||
RealizedPnL: 10.0,
|
||||
Time: testTime(0),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "closing trade with zero price",
|
||||
trades: []TradeRecord{
|
||||
{
|
||||
TradeID: "zp",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "SELL",
|
||||
Price: 0,
|
||||
Quantity: 1.0,
|
||||
RealizedPnL: 10.0,
|
||||
Time: testTime(0),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "trade with unrecognized side is skipped",
|
||||
trades: []TradeRecord{
|
||||
{
|
||||
TradeID: "bad-side",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "HOLD",
|
||||
Price: 100.0,
|
||||
Quantity: 1.0,
|
||||
Time: testTime(0),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
records := RebuildPositionsFromTrades(tt.trades)
|
||||
if len(records) != 0 {
|
||||
t.Errorf("got %d records, want 0: %+v", len(records), records)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRebuildPositionsFromTrades_UnsortedInputUsesChronologicalFIFO(t *testing.T) {
|
||||
// Deliberately out of chronological order: close first, opens reversed.
|
||||
trades := []TradeRecord{
|
||||
{
|
||||
TradeID: "close1",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "SELL",
|
||||
Price: 120.0,
|
||||
Quantity: 1.0,
|
||||
RealizedPnL: 20.0,
|
||||
Time: testTime(2),
|
||||
},
|
||||
{
|
||||
TradeID: "open2",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "BUY",
|
||||
Price: 110.0,
|
||||
Quantity: 1.0,
|
||||
Time: testTime(1),
|
||||
},
|
||||
{
|
||||
TradeID: "open1",
|
||||
Symbol: "BTCUSDT",
|
||||
Side: "BUY",
|
||||
Price: 100.0,
|
||||
Quantity: 1.0,
|
||||
Time: testTime(0),
|
||||
},
|
||||
}
|
||||
|
||||
records := RebuildPositionsFromTrades(trades)
|
||||
if len(records) != 1 {
|
||||
t.Fatalf("got %d records, want 1", len(records))
|
||||
}
|
||||
|
||||
// FIFO after time sort: the earliest open (price 100) is matched first.
|
||||
if !floatsClose(records[0].EntryPrice, 100.0) {
|
||||
t.Errorf("EntryPrice = %v, want 100 (earliest open via FIFO)", records[0].EntryPrice)
|
||||
}
|
||||
if !records[0].EntryTime.Equal(testTime(0)) {
|
||||
t.Errorf("EntryTime = %v, want %v", records[0].EntryTime, testTime(0))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user