Files
nofx/store/equity.go
2025-12-08 01:43:22 +08:00

258 lines
7.6 KiB
Go

package store
import (
"database/sql"
"fmt"
"time"
)
// EquityStore account equity storage (for plotting return curves)
type EquityStore struct {
db *sql.DB
}
// EquitySnapshot equity snapshot
type EquitySnapshot struct {
ID int64 `json:"id"`
TraderID string `json:"trader_id"`
Timestamp time.Time `json:"timestamp"`
TotalEquity float64 `json:"total_equity"` // Account equity (balance + unrealized PnL)
Balance float64 `json:"balance"` // Account balance
UnrealizedPnL float64 `json:"unrealized_pnl"` // Unrealized profit and loss
PositionCount int `json:"position_count"` // Position count
MarginUsedPct float64 `json:"margin_used_pct"` // Margin usage percentage
}
// initTables initializes equity tables
func (s *EquityStore) initTables() error {
queries := []string{
// Equity snapshot table - specifically for return curves
`CREATE TABLE IF NOT EXISTS trader_equity_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trader_id TEXT NOT NULL,
timestamp DATETIME NOT NULL,
total_equity REAL NOT NULL DEFAULT 0,
balance REAL NOT NULL DEFAULT 0,
unrealized_pnl REAL NOT NULL DEFAULT 0,
position_count INTEGER DEFAULT 0,
margin_used_pct REAL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
// Indexes
`CREATE INDEX IF NOT EXISTS idx_equity_trader_time ON trader_equity_snapshots(trader_id, timestamp DESC)`,
`CREATE INDEX IF NOT EXISTS idx_equity_timestamp ON trader_equity_snapshots(timestamp DESC)`,
}
for _, query := range queries {
if _, err := s.db.Exec(query); err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
}
}
return nil
}
// Save saves equity snapshot
func (s *EquityStore) Save(snapshot *EquitySnapshot) error {
if snapshot.Timestamp.IsZero() {
snapshot.Timestamp = time.Now().UTC()
} else {
snapshot.Timestamp = snapshot.Timestamp.UTC()
}
result, err := s.db.Exec(`
INSERT INTO trader_equity_snapshots (
trader_id, timestamp, total_equity, balance,
unrealized_pnl, position_count, margin_used_pct
) VALUES (?, ?, ?, ?, ?, ?, ?)
`,
snapshot.TraderID,
snapshot.Timestamp.Format(time.RFC3339),
snapshot.TotalEquity,
snapshot.Balance,
snapshot.UnrealizedPnL,
snapshot.PositionCount,
snapshot.MarginUsedPct,
)
if err != nil {
return fmt.Errorf("failed to save equity snapshot: %w", err)
}
id, _ := result.LastInsertId()
snapshot.ID = id
return nil
}
// GetLatest gets the latest N equity records for specified trader (sorted in ascending chronological order: old to new)
func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, timestamp, total_equity, balance,
unrealized_pnl, position_count, margin_used_pct
FROM trader_equity_snapshots
WHERE trader_id = ?
ORDER BY timestamp DESC
LIMIT ?
`, traderID, limit)
if err != nil {
return nil, fmt.Errorf("failed to query equity records: %w", err)
}
defer rows.Close()
var snapshots []*EquitySnapshot
for rows.Next() {
snap := &EquitySnapshot{}
var timestampStr string
err := rows.Scan(
&snap.ID, &snap.TraderID, &timestampStr, &snap.TotalEquity,
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
)
if err != nil {
continue
}
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
snapshots = append(snapshots, snap)
}
// Reverse the array to sort time from old to new (suitable for plotting curves)
for i, j := 0, len(snapshots)-1; i < j; i, j = i+1, j-1 {
snapshots[i], snapshots[j] = snapshots[j], snapshots[i]
}
return snapshots, nil
}
// GetByTimeRange gets equity records within specified time range
func (s *EquityStore) GetByTimeRange(traderID string, start, end time.Time) ([]*EquitySnapshot, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, timestamp, total_equity, balance,
unrealized_pnl, position_count, margin_used_pct
FROM trader_equity_snapshots
WHERE trader_id = ? AND timestamp >= ? AND timestamp <= ?
ORDER BY timestamp ASC
`, traderID, start.Format(time.RFC3339), end.Format(time.RFC3339))
if err != nil {
return nil, fmt.Errorf("failed to query equity records: %w", err)
}
defer rows.Close()
var snapshots []*EquitySnapshot
for rows.Next() {
snap := &EquitySnapshot{}
var timestampStr string
err := rows.Scan(
&snap.ID, &snap.TraderID, &timestampStr, &snap.TotalEquity,
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
)
if err != nil {
continue
}
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
snapshots = append(snapshots, snap)
}
return snapshots, nil
}
// GetAllTradersLatest gets latest equity for all traders (for leaderboards)
func (s *EquityStore) GetAllTradersLatest() (map[string]*EquitySnapshot, error) {
rows, err := s.db.Query(`
SELECT e.id, e.trader_id, e.timestamp, e.total_equity, e.balance,
e.unrealized_pnl, e.position_count, e.margin_used_pct
FROM trader_equity_snapshots e
INNER JOIN (
SELECT trader_id, MAX(timestamp) as max_ts
FROM trader_equity_snapshots
GROUP BY trader_id
) latest ON e.trader_id = latest.trader_id AND e.timestamp = latest.max_ts
`)
if err != nil {
return nil, fmt.Errorf("failed to query latest equity: %w", err)
}
defer rows.Close()
result := make(map[string]*EquitySnapshot)
for rows.Next() {
snap := &EquitySnapshot{}
var timestampStr string
err := rows.Scan(
&snap.ID, &snap.TraderID, &timestampStr, &snap.TotalEquity,
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
)
if err != nil {
continue
}
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
result[snap.TraderID] = snap
}
return result, nil
}
// CleanOldRecords cleans old records from N days ago
func (s *EquityStore) CleanOldRecords(traderID string, days int) (int64, error) {
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339)
result, err := s.db.Exec(`
DELETE FROM trader_equity_snapshots
WHERE trader_id = ? AND timestamp < ?
`, traderID, cutoffTime)
if err != nil {
return 0, fmt.Errorf("failed to clean old records: %w", err)
}
return result.RowsAffected()
}
// GetCount gets record count for specified trader
func (s *EquityStore) GetCount(traderID string) (int, error) {
var count int
err := s.db.QueryRow(`
SELECT COUNT(*) FROM trader_equity_snapshots WHERE trader_id = ?
`, traderID).Scan(&count)
return count, err
}
// MigrateFromDecision migrates data from old decision_account_snapshots table
func (s *EquityStore) MigrateFromDecision() (int64, error) {
// Check if migration is needed (whether new table is empty)
var count int
s.db.QueryRow(`SELECT COUNT(*) FROM trader_equity_snapshots`).Scan(&count)
if count > 0 {
return 0, nil // Already has data, skip migration
}
// Check if old table exists
var tableName string
err := s.db.QueryRow(`
SELECT name FROM sqlite_master
WHERE type='table' AND name='decision_account_snapshots'
`).Scan(&tableName)
if err != nil {
return 0, nil // Old table doesn't exist, skip
}
// Migrate data: join query from decision_records + decision_account_snapshots
result, err := s.db.Exec(`
INSERT INTO trader_equity_snapshots (
trader_id, timestamp, total_equity, balance,
unrealized_pnl, position_count, margin_used_pct
)
SELECT
dr.trader_id,
dr.timestamp,
das.total_balance,
das.available_balance,
das.total_unrealized_profit,
das.position_count,
das.margin_used_pct
FROM decision_records dr
JOIN decision_account_snapshots das ON dr.id = das.decision_id
ORDER BY dr.timestamp ASC
`)
if err != nil {
return 0, fmt.Errorf("failed to migrate data: %w", err)
}
return result.RowsAffected()
}