diff --git a/config.json.example b/config.json.example index ac9d5ac6..87b01edd 100644 --- a/config.json.example +++ b/config.json.example @@ -5,6 +5,7 @@ "altcoin_leverage": 5 }, "use_default_coins": true, + "inside_coins": true, "default_coins": [ "BTCUSDT", "ETHUSDT", diff --git a/config/config.go b/config/config.go index 97fcc84d..3b736d0e 100644 --- a/config/config.go +++ b/config/config.go @@ -11,7 +11,7 @@ import ( type TraderConfig struct { ID string `json:"id"` Name string `json:"name"` - Enabled bool `json:"enabled"` // 是否启用该trader + Enabled bool `json:"enabled"` // 是否启用该trader AIModel string `json:"ai_model"` // "qwen" or "deepseek" // 交易平台选择(二选一) @@ -54,6 +54,7 @@ type LeverageConfig struct { type Config struct { Traders []TraderConfig `json:"traders"` UseDefaultCoins bool `json:"use_default_coins"` // 是否使用默认主流币种列表 + InsideCoins bool `json:"inside_coins"` // 是否使用内置AI评分币种列表 DefaultCoins []string `json:"default_coins"` // 默认主流币种池 APIServerPort int `json:"api_server_port"` MaxDailyLoss float64 `json:"max_daily_loss"` diff --git a/main.go b/main.go index 1d9631a9..7929072b 100644 --- a/main.go +++ b/main.go @@ -8,12 +8,14 @@ import ( "nofx/auth" "nofx/config" "nofx/manager" + "nofx/market" "nofx/pool" "os" "os/signal" "strconv" "strings" "syscall" + "time" ) // LeverageConfig 杠杆配置 @@ -28,6 +30,7 @@ type ConfigFile struct { APIServerPort int `json:"api_server_port"` UseDefaultCoins bool `json:"use_default_coins"` DefaultCoins []string `json:"default_coins"` + InsideCoins bool `json:"inside_coins"` CoinPoolAPIURL string `json:"coin_pool_api_url"` OITopAPIURL string `json:"oi_top_api_url"` MaxDailyLoss float64 `json:"max_daily_loss"` @@ -35,6 +38,7 @@ type ConfigFile struct { StopTradingMinutes int `json:"stop_trading_minutes"` Leverage LeverageConfig `json:"leverage"` JWTSecret string `json:"jwt_secret"` + DataKLineTime string `json:"data_k_line_time"` } // syncConfigToDatabase 从config.json读取配置并同步到数据库 @@ -61,14 +65,15 @@ func syncConfigToDatabase(database *config.Database) error { // 同步各配置项到数据库 configs := map[string]string{ - "admin_mode": fmt.Sprintf("%t", configFile.AdminMode), - "api_server_port": strconv.Itoa(configFile.APIServerPort), - "use_default_coins": fmt.Sprintf("%t", configFile.UseDefaultCoins), - "coin_pool_api_url": configFile.CoinPoolAPIURL, - "oi_top_api_url": configFile.OITopAPIURL, - "max_daily_loss": fmt.Sprintf("%.1f", configFile.MaxDailyLoss), - "max_drawdown": fmt.Sprintf("%.1f", configFile.MaxDrawdown), - "stop_trading_minutes": strconv.Itoa(configFile.StopTradingMinutes), + "admin_mode": fmt.Sprintf("%t", configFile.AdminMode), + "api_server_port": strconv.Itoa(configFile.APIServerPort), + "use_default_coins": fmt.Sprintf("%t", configFile.UseDefaultCoins), + "inside_coins": fmt.Sprintf("%t", configFile.InsideCoins), + "coin_pool_api_url": configFile.CoinPoolAPIURL, + "oi_top_api_url": configFile.OITopAPIURL, + "max_daily_loss": fmt.Sprintf("%.1f", configFile.MaxDailyLoss), + "max_drawdown": fmt.Sprintf("%.1f", configFile.MaxDrawdown), + "stop_trading_minutes": strconv.Itoa(configFile.StopTradingMinutes), } // 同步default_coins(转换为JSON字符串存储) @@ -132,12 +137,14 @@ func main() { // 获取系统配置 useDefaultCoinsStr, _ := database.GetSystemConfig("use_default_coins") useDefaultCoins := useDefaultCoinsStr == "true" + InsideCoinsStr, _ := database.GetSystemConfig("inside_coins") + insideCoins := InsideCoinsStr == "true" apiPortStr, _ := database.GetSystemConfig("api_server_port") - + // 获取管理员模式配置 adminModeStr, _ := database.GetSystemConfig("admin_mode") adminMode := adminModeStr != "false" // 默认为true - + // 设置JWT密钥 jwtSecret, _ := database.GetSystemConfig("jwt_secret") if jwtSecret == "" { @@ -145,7 +152,7 @@ func main() { log.Printf("⚠️ 使用默认JWT密钥,建议在生产环境中配置") } auth.SetJWTSecret(jwtSecret) - + // 在管理员模式下,确保admin用户存在 if adminMode { err := database.EnsureAdminUser() @@ -156,7 +163,7 @@ func main() { } auth.SetAdminMode(true) } - + log.Printf("✓ 配置数据库初始化成功") fmt.Println() @@ -180,6 +187,25 @@ func main() { pool.SetDefaultCoins(defaultCoins) + //内置AI评分 + if insideCoins { + log.Printf("✓ 启用内置AI评分币种列表") + monitor := market.NewWSMonitor(150) + go func() { + monitor.Start() + // 定时器设置默认的币种列表 - 覆蓋defaultCoins设置 + for { + if len(monitor.FilterSymbol) > 0 { + for _, coin := range defaultCoins { + monitor.FilterSymbol = append(monitor.FilterSymbol, coin) + } + pool.SetDefaultCoins(monitor.FilterSymbol) + monitor.FilterSymbol = nil + } + time.Sleep(1 * time.Minute) + } + }() + } // 设置是否使用默认主流币种 pool.SetUseDefaultCoins(useDefaultCoins) if useDefaultCoins { @@ -192,7 +218,7 @@ func main() { pool.SetCoinPoolAPI(coinPoolAPIURL) log.Printf("✓ 已配置AI500币种池API") } - + oiTopAPIURL, _ := database.GetSystemConfig("oi_top_api_url") if oiTopAPIURL != "" { pool.SetOITopAPI(oiTopAPIURL) @@ -208,37 +234,26 @@ func main() { log.Fatalf("❌ 加载交易员失败: %v", err) } - // 获取所有用户的交易员配置(用于显示) - userIDs, err := database.GetAllUsers() + // 获取数据库中的所有交易员配置(用于显示,使用default用户) + traders, err := database.GetTraders("default") if err != nil { - log.Printf("⚠️ 获取用户列表失败: %v", err) - userIDs = []string{"default"} // 回退到default用户 - } - - var allTraders []*config.TraderRecord - for _, userID := range userIDs { - traders, err := database.GetTraders(userID) - if err != nil { - log.Printf("⚠️ 获取用户 %s 的交易员失败: %v", userID, err) - continue - } - allTraders = append(allTraders, traders...) + log.Fatalf("❌ 获取交易员列表失败: %v", err) } // 显示加载的交易员信息 fmt.Println() fmt.Println("🤖 数据库中的AI交易员配置:") - if len(allTraders) == 0 { + if len(traders) == 0 { fmt.Println(" • 暂无配置的交易员,请通过Web界面创建") } else { - for _, trader := range allTraders { + for _, trader := range traders { status := "停止" if trader.IsRunning { status = "运行中" } - fmt.Printf(" • %s (%s + %s) - 用户: %s - 初始资金: %.0f USDT [%s]\n", - trader.Name, strings.ToUpper(trader.AIModelID), strings.ToUpper(trader.ExchangeID), - trader.UserID, trader.InitialBalance, status) + fmt.Printf(" • %s (%s + %s) - 初始资金: %.0f USDT [%s]\n", + trader.Name, strings.ToUpper(trader.AIModelID), strings.ToUpper(trader.ExchangeID), + trader.InitialBalance, status) } } @@ -256,7 +271,7 @@ func main() { fmt.Println() // 获取API服务器端口 - apiPort := 8080 // 默认端口 + apiPort := 8080 // 默认端口 if apiPortStr != "" { if port, err := strconv.Atoi(apiPortStr); err == nil { apiPort = port diff --git a/market/api_client.go b/market/api_client.go new file mode 100644 index 00000000..70bb1150 --- /dev/null +++ b/market/api_client.go @@ -0,0 +1,150 @@ +package market + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strconv" + "time" +) + +const ( + baseURL = "https://fapi.binance.com" +) + +type APIClient struct { + client *http.Client +} + +func NewAPIClient() *APIClient { + return &APIClient{ + client: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +func (c *APIClient) GetExchangeInfo() (*ExchangeInfo, error) { + url := fmt.Sprintf("%s/fapi/v1/exchangeInfo", baseURL) + resp, err := c.client.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var exchangeInfo ExchangeInfo + err = json.Unmarshal(body, &exchangeInfo) + if err != nil { + return nil, err + } + + return &exchangeInfo, nil +} + +func (c *APIClient) GetKlines(symbol, interval string, limit int) ([]Kline, error) { + url := fmt.Sprintf("%s/fapi/v1/klines", baseURL) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + q := req.URL.Query() + q.Add("symbol", symbol) + q.Add("interval", interval) + q.Add("limit", strconv.Itoa(limit)) + req.URL.RawQuery = q.Encode() + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var klineResponses []KlineResponse + err = json.Unmarshal(body, &klineResponses) + if err != nil { + return nil, err + } + + var klines []Kline + for _, kr := range klineResponses { + kline, err := parseKline(kr) + if err != nil { + log.Printf("解析K线数据失败: %v", err) + continue + } + klines = append(klines, kline) + } + + return klines, nil +} + +func parseKline(kr KlineResponse) (Kline, error) { + var kline Kline + + if len(kr) < 11 { + return kline, fmt.Errorf("invalid kline data") + } + + // 解析各个字段 + kline.OpenTime = int64(kr[0].(float64)) + kline.Open, _ = strconv.ParseFloat(kr[1].(string), 64) + kline.High, _ = strconv.ParseFloat(kr[2].(string), 64) + kline.Low, _ = strconv.ParseFloat(kr[3].(string), 64) + kline.Close, _ = strconv.ParseFloat(kr[4].(string), 64) + kline.Volume, _ = strconv.ParseFloat(kr[5].(string), 64) + kline.CloseTime = int64(kr[6].(float64)) + kline.QuoteVolume, _ = strconv.ParseFloat(kr[7].(string), 64) + kline.Trades = int(kr[8].(float64)) + kline.TakerBuyBaseVolume, _ = strconv.ParseFloat(kr[9].(string), 64) + kline.TakerBuyQuoteVolume, _ = strconv.ParseFloat(kr[10].(string), 64) + + return kline, nil +} + +func (c *APIClient) GetCurrentPrice(symbol string) (float64, error) { + url := fmt.Sprintf("%s/fapi/v1/ticker/price", baseURL) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return 0, err + } + + q := req.URL.Query() + q.Add("symbol", symbol) + req.URL.RawQuery = q.Encode() + + resp, err := c.client.Do(req) + if err != nil { + return 0, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return 0, err + } + + var ticker PriceTicker + err = json.Unmarshal(body, &ticker) + if err != nil { + return 0, err + } + + price, err := strconv.ParseFloat(ticker.Price, 64) + if err != nil { + return 0, err + } + + return price, nil +} diff --git a/market/combined_streams.go b/market/combined_streams.go new file mode 100644 index 00000000..801d423e --- /dev/null +++ b/market/combined_streams.go @@ -0,0 +1,202 @@ +package market + +import ( + "encoding/json" + "fmt" + "log" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type CombinedStreamsClient struct { + conn *websocket.Conn + mu sync.RWMutex + subscribers map[string]chan []byte + reconnect bool + done chan struct{} + batchSize int // 每批订阅的流数量 +} + +func NewCombinedStreamsClient(batchSize int) *CombinedStreamsClient { + return &CombinedStreamsClient{ + subscribers: make(map[string]chan []byte), + reconnect: true, + done: make(chan struct{}), + batchSize: batchSize, + } +} + +func (c *CombinedStreamsClient) Connect() error { + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + + // 组合流使用不同的端点 + conn, _, err := dialer.Dial("wss://fstream.binance.com/stream", nil) + if err != nil { + return fmt.Errorf("组合流WebSocket连接失败: %v", err) + } + + c.mu.Lock() + c.conn = conn + c.mu.Unlock() + + log.Println("组合流WebSocket连接成功") + go c.readMessages() + + return nil +} + +// BatchSubscribeKlines 批量订阅K线 +func (c *CombinedStreamsClient) BatchSubscribeKlines(symbols []string, interval string) error { + // 将symbols分批处理 + batches := c.splitIntoBatches(symbols, c.batchSize) + + for i, batch := range batches { + log.Printf("订阅第 %d 批, 数量: %d", i+1, len(batch)) + + streams := make([]string, len(batch)) + for j, symbol := range batch { + streams[j] = fmt.Sprintf("%s@kline_%s", strings.ToLower(symbol), interval) + } + + if err := c.subscribeStreams(streams); err != nil { + return fmt.Errorf("第 %d 批订阅失败: %v", i+1, err) + } + + // 批次间延迟,避免被限制 + if i < len(batches)-1 { + time.Sleep(100 * time.Millisecond) + } + } + + return nil +} + +// splitIntoBatches 将切片分成指定大小的批次 +func (c *CombinedStreamsClient) splitIntoBatches(symbols []string, batchSize int) [][]string { + var batches [][]string + + for i := 0; i < len(symbols); i += batchSize { + end := i + batchSize + if end > len(symbols) { + end = len(symbols) + } + batches = append(batches, symbols[i:end]) + } + + return batches +} + +// subscribeStreams 订阅多个流 +func (c *CombinedStreamsClient) subscribeStreams(streams []string) error { + subscribeMsg := map[string]interface{}{ + "method": "SUBSCRIBE", + "params": streams, + "id": time.Now().UnixNano(), + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.conn == nil { + return fmt.Errorf("WebSocket未连接") + } + + log.Printf("订阅流: %v", streams) + return c.conn.WriteJSON(subscribeMsg) +} + +func (c *CombinedStreamsClient) readMessages() { + for { + select { + case <-c.done: + return + default: + c.mu.RLock() + conn := c.conn + c.mu.RUnlock() + + if conn == nil { + time.Sleep(1 * time.Second) + continue + } + + _, message, err := conn.ReadMessage() + if err != nil { + log.Printf("读取组合流消息失败: %v", err) + c.handleReconnect() + return + } + + c.handleCombinedMessage(message) + } + } +} + +func (c *CombinedStreamsClient) handleCombinedMessage(message []byte) { + var combinedMsg struct { + Stream string `json:"stream"` + Data json.RawMessage `json:"data"` + } + + if err := json.Unmarshal(message, &combinedMsg); err != nil { + log.Printf("解析组合消息失败: %v", err) + return + } + + c.mu.RLock() + ch, exists := c.subscribers[combinedMsg.Stream] + c.mu.RUnlock() + + if exists { + select { + case ch <- combinedMsg.Data: + default: + log.Printf("订阅者通道已满: %s", combinedMsg.Stream) + } + } +} + +func (c *CombinedStreamsClient) AddSubscriber(stream string, bufferSize int) <-chan []byte { + ch := make(chan []byte, bufferSize) + c.mu.Lock() + c.subscribers[stream] = ch + c.mu.Unlock() + return ch +} + +func (c *CombinedStreamsClient) handleReconnect() { + if !c.reconnect { + return + } + + log.Println("组合流尝试重新连接...") + time.Sleep(3 * time.Second) + + if err := c.Connect(); err != nil { + log.Printf("组合流重新连接失败: %v", err) + go c.handleReconnect() + } +} + +func (c *CombinedStreamsClient) Close() { + c.reconnect = false + close(c.done) + + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn != nil { + c.conn.Close() + c.conn = nil + } + + for stream, ch := range c.subscribers { + close(ch) + delete(c.subscribers, stream) + } +} diff --git a/market/data.go b/market/data.go index 97812e64..cd40be75 100644 --- a/market/data.go +++ b/market/data.go @@ -10,72 +10,20 @@ import ( "strings" ) -// Data 市场数据结构 -type Data struct { - Symbol string - CurrentPrice float64 - PriceChange1h float64 // 1小时价格变化百分比 - PriceChange4h float64 // 4小时价格变化百分比 - CurrentEMA20 float64 - CurrentMACD float64 - CurrentRSI7 float64 - OpenInterest *OIData - FundingRate float64 - IntradaySeries *IntradayData - LongerTermContext *LongerTermData -} - -// OIData Open Interest数据 -type OIData struct { - Latest float64 - Average float64 -} - -// IntradayData 日内数据(3分钟间隔) -type IntradayData struct { - MidPrices []float64 - EMA20Values []float64 - MACDValues []float64 - RSI7Values []float64 - RSI14Values []float64 -} - -// LongerTermData 长期数据(4小时时间框架) -type LongerTermData struct { - EMA20 float64 - EMA50 float64 - ATR3 float64 - ATR14 float64 - CurrentVolume float64 - AverageVolume float64 - MACDValues []float64 - RSI14Values []float64 -} - -// Kline K线数据 -type Kline struct { - OpenTime int64 - Open float64 - High float64 - Low float64 - Close float64 - Volume float64 - CloseTime int64 -} - // Get 获取指定代币的市场数据 func Get(symbol string) (*Data, error) { + var klines3m, klines4h []Kline + var err error // 标准化symbol symbol = Normalize(symbol) - // 获取3分钟K线数据 (最近10个) - klines3m, err := getKlines(symbol, "3m", 40) // 多获取一些用于计算 + klines3m, err = WSMonitorCli.GetCurrentKlines(symbol, "3m") // 多获取一些用于计算 if err != nil { return nil, fmt.Errorf("获取3分钟K线失败: %v", err) } // 获取4小时K线数据 (最近10个) - klines4h, err := getKlines(symbol, "4h", 60) // 多获取用于计算指标 + klines4h, err = WSMonitorCli.GetCurrentKlines(symbol, "4h") // 多获取用于计算指标 if err != nil { return nil, fmt.Errorf("获取4小时K线失败: %v", err) } @@ -136,51 +84,6 @@ func Get(symbol string) (*Data, error) { }, nil } -// getKlines 从Binance获取K线数据 -func getKlines(symbol, interval string, limit int) ([]Kline, error) { - url := fmt.Sprintf("https://fapi.binance.com/fapi/v1/klines?symbol=%s&interval=%s&limit=%d", - symbol, interval, limit) - - resp, err := http.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - var rawData [][]interface{} - if err := json.Unmarshal(body, &rawData); err != nil { - return nil, err - } - - klines := make([]Kline, len(rawData)) - for i, item := range rawData { - openTime := int64(item[0].(float64)) - open, _ := parseFloat(item[1]) - high, _ := parseFloat(item[2]) - low, _ := parseFloat(item[3]) - close, _ := parseFloat(item[4]) - volume, _ := parseFloat(item[5]) - closeTime := int64(item[6].(float64)) - - klines[i] = Kline{ - OpenTime: openTime, - Open: open, - High: high, - Low: low, - Close: close, - Volume: volume, - CloseTime: closeTime, - } - } - - return klines, nil -} - // calculateEMA 计算EMA func calculateEMA(klines []Kline, period int) float64 { if len(klines) < period { diff --git a/market/feature_engine.go b/market/feature_engine.go new file mode 100644 index 00000000..91540a29 --- /dev/null +++ b/market/feature_engine.go @@ -0,0 +1,229 @@ +package market + +import ( + "fmt" + "math" + "time" +) + +type FeatureEngine struct { + alertThresholds AlertThresholds +} + +func NewFeatureEngine(thresholds AlertThresholds) *FeatureEngine { + return &FeatureEngine{ + alertThresholds: thresholds, + } +} + +func (e *FeatureEngine) CalculateFeatures(symbol string, klines []Kline) *SymbolFeatures { + if len(klines) < 20 { + return nil + } + + features := &SymbolFeatures{ + Symbol: symbol, + Timestamp: time.Now(), + } + + // 提取价格和交易量数据 + closes := make([]float64, len(klines)) + volumes := make([]float64, len(klines)) + highs := make([]float64, len(klines)) + lows := make([]float64, len(klines)) + + for i, k := range klines { + closes[i] = k.Close + volumes[i] = k.Volume + highs[i] = k.High + lows[i] = k.Low + } + + // 价格特征 + features.Price = closes[len(closes)-1] + features.PriceChange15Min = (closes[len(closes)-1] - closes[len(closes)-2]) / closes[len(closes)-2] + + if len(closes) >= 5 { + features.PriceChange1H = (closes[len(closes)-1] - closes[len(closes)-5]) / closes[len(closes)-5] + } + if len(closes) >= 17 { + features.PriceChange4H = (closes[len(closes)-1] - closes[len(closes)-17]) / closes[len(closes)-17] + } + + // 交易量特征 + currentVolume := volumes[len(volumes)-1] + features.Volume = currentVolume + + // 5周期平均交易量 + if len(volumes) >= 6 { + avgVolume5 := e.calculateAverage(volumes[len(volumes)-6 : len(volumes)-1]) + features.VolumeRatio5 = currentVolume / avgVolume5 + } + + // 20周期平均交易量 + if len(volumes) >= 21 { + avgVolume20 := e.calculateAverage(volumes[len(volumes)-21 : len(volumes)-1]) + features.VolumeRatio20 = currentVolume / avgVolume20 + } + + // 交易量趋势 + if features.VolumeRatio20 > 0 { + features.VolumeTrend = features.VolumeRatio5 / features.VolumeRatio20 + } + + // 技术指标 + features.RSI14 = e.calculateRSI(closes, 14) + features.SMA5 = e.calculateSMA(closes, 5) + features.SMA10 = e.calculateSMA(closes, 10) + features.SMA20 = e.calculateSMA(closes, 20) + + // 波动特征 + currentHigh := highs[len(highs)-1] + currentLow := lows[len(lows)-1] + features.HighLowRatio = (currentHigh - currentLow) / features.Price + features.Volatility20 = e.calculateVolatility(closes, 20) + + // 价格在区间中的位置 + if currentHigh != currentLow { + features.PositionInRange = (features.Price - currentLow) / (currentHigh - currentLow) + } else { + features.PositionInRange = 0.5 + } + + return features +} + +func (e *FeatureEngine) calculateAverage(values []float64) float64 { + sum := 0.0 + for _, v := range values { + sum += v + } + return sum / float64(len(values)) +} + +func (e *FeatureEngine) calculateSMA(prices []float64, period int) float64 { + if len(prices) < period { + return 0 + } + return e.calculateAverage(prices[len(prices)-period:]) +} + +func (e *FeatureEngine) calculateRSI(prices []float64, period int) float64 { + if len(prices) <= period { + return 50 + } + + gains := make([]float64, 0) + losses := make([]float64, 0) + + for i := 1; i < len(prices); i++ { + change := prices[i] - prices[i-1] + if change > 0 { + gains = append(gains, change) + losses = append(losses, 0) + } else { + gains = append(gains, 0) + losses = append(losses, -change) + } + } + + // 只取最近period个数据点 + if len(gains) > period { + gains = gains[len(gains)-period:] + losses = losses[len(losses)-period:] + } + + avgGain := e.calculateAverage(gains) + avgLoss := e.calculateAverage(losses) + + if avgLoss == 0 { + return 100 + } + + rs := avgGain / avgLoss + return 100 - (100 / (1 + rs)) +} + +func (e *FeatureEngine) calculateVolatility(prices []float64, period int) float64 { + if len(prices) < period { + return 0 + } + + periodPrices := prices[len(prices)-period:] + mean := e.calculateAverage(periodPrices) + + variance := 0.0 + for _, price := range periodPrices { + variance += math.Pow(price-mean, 2) + } + variance /= float64(len(periodPrices)) + + return math.Sqrt(variance) / mean +} + +func (e *FeatureEngine) DetectAlerts(features *SymbolFeatures) []Alert { + var alerts []Alert + + // 交易量放大检测 + if features.VolumeRatio5 > e.alertThresholds.VolumeSpike { + alerts = append(alerts, Alert{ + Type: "VOLUME_SPIKE", + Symbol: features.Symbol, + Value: features.VolumeRatio5, + Threshold: e.alertThresholds.VolumeSpike, + Message: fmt.Sprintf("%s 交易量放大 %.2f 倍", features.Symbol, features.VolumeRatio5), + Timestamp: time.Now(), + }) + } + + // 15分钟价格异动 + if math.Abs(features.PriceChange15Min) > e.alertThresholds.PriceChange15Min { + direction := "上涨" + if features.PriceChange15Min < 0 { + direction = "下跌" + } + alerts = append(alerts, Alert{ + Type: "PRICE_CHANGE_15MIN", + Symbol: features.Symbol, + Value: features.PriceChange15Min, + Threshold: e.alertThresholds.PriceChange15Min, + Message: fmt.Sprintf("%s 15分钟%s %.2f%%", features.Symbol, direction, features.PriceChange15Min*100), + Timestamp: time.Now(), + }) + } + + // 交易量趋势 + if features.VolumeTrend > e.alertThresholds.VolumeTrend { + alerts = append(alerts, Alert{ + Type: "VOLUME_TREND", + Symbol: features.Symbol, + Value: features.VolumeTrend, + Threshold: e.alertThresholds.VolumeTrend, + Message: fmt.Sprintf("%s 交易量趋势增强 %.2f 倍", features.Symbol, features.VolumeTrend), + Timestamp: time.Now(), + }) + } + + // RSI超买超卖 + if features.RSI14 > e.alertThresholds.RSIOverbought { + alerts = append(alerts, Alert{ + Type: "RSI_OVERBOUGHT", + Symbol: features.Symbol, + Value: features.RSI14, + Threshold: e.alertThresholds.RSIOverbought, + Message: fmt.Sprintf("%s RSI超买: %.2f", features.Symbol, features.RSI14), + Timestamp: time.Now(), + }) + } else if features.RSI14 < e.alertThresholds.RSIOversold { + alerts = append(alerts, Alert{ + Type: "RSI_OVERSOLD", + Symbol: features.Symbol, + Value: features.RSI14, + Threshold: e.alertThresholds.RSIOversold, + Message: fmt.Sprintf("%s RSI超卖: %.2f", features.Symbol, features.RSI14), + Timestamp: time.Now(), + }) + } + + return alerts +} diff --git a/market/monitor.go b/market/monitor.go new file mode 100644 index 00000000..9837623e --- /dev/null +++ b/market/monitor.go @@ -0,0 +1,526 @@ +package market + +import ( + "encoding/json" + "fmt" + "log" + "math" + "sort" + "strings" + "sync" + "time" +) + +type WSMonitor struct { + wsClient *WSClient + combinedClient *CombinedStreamsClient + featureEngine *FeatureEngine + symbols []string + featuresMap sync.Map + alertsChan chan Alert + klineDataMap3m sync.Map // 存储每个交易对的K线历史数据 + klineDataMap4h sync.Map // 存储每个交易对的K线历史数据 + tickerDataMap sync.Map // 存储每个交易对的ticker数据 + batchSize int + filterSymbols sync.Map // 使用sync.Map来存储需要监控的币种和其状态 + symbolStats sync.Map // 存储币种统计信息 + FilterSymbol []string //经过筛选的币种 +} +type SymbolStats struct { + LastActiveTime time.Time + AlertCount int + VolumeSpikeCount int + LastAlertTime time.Time + Score float64 // 综合评分 +} + +var WSMonitorCli *WSMonitor + +func NewWSMonitor(batchSize int) *WSMonitor { + WSMonitorCli = &WSMonitor{ + wsClient: NewWSClient(), + combinedClient: NewCombinedStreamsClient(batchSize), + featureEngine: NewFeatureEngine(config.AlertThresholds), + alertsChan: make(chan Alert, 1000), + batchSize: batchSize, + } + return WSMonitorCli +} + +func (m *WSMonitor) Initialize() error { + log.Println("初始化WebSocket监控器...") + + // 获取交易对信息 + apiClient := NewAPIClient() + exchangeInfo, err := apiClient.GetExchangeInfo() + if err != nil { + return err + } + + // 筛选永续合约交易对 --仅测试时使用 + //exchangeInfo.Symbols = exchangeInfo.Symbols[0:2] + for _, symbol := range exchangeInfo.Symbols { + if symbol.Status == "TRADING" && symbol.ContractType == "PERPETUAL" { + m.symbols = append(m.symbols, Normalize(symbol.Symbol)) + } + } + log.Printf("找到 %d 个交易对", len(m.symbols)) + // 初始化历史数据 + if err := m.initializeHistoricalData(); err != nil { + log.Printf("初始化历史数据失败: %v", err) + } + + return nil +} + +func (m *WSMonitor) initializeHistoricalData() error { + apiClient := NewAPIClient() + + var wg sync.WaitGroup + semaphore := make(chan struct{}, 5) // 限制并发数 + + for _, symbol := range m.symbols { + wg.Add(1) + semaphore <- struct{}{} + + go func(s string) { + defer wg.Done() + defer func() { <-semaphore }() + + // 获取历史K线数据 + klines, err := apiClient.GetKlines(s, "3m", 100) + if err != nil { + log.Printf("获取 %s 历史数据失败: %v", s, err) + return + } + if len(klines) > 0 { + m.klineDataMap3m.Store(s, klines) + log.Printf("已加载 %s 的历史K线数据-3m: %d 条", s, len(klines)) + } + // 获取历史K线数据 + klines4h, err := apiClient.GetKlines(s, "4h", 100) + if err != nil { + log.Printf("获取 %s 历史数据失败: %v", s, err) + return + } + if len(klines4h) > 0 { + m.klineDataMap4h.Store(s, klines) + log.Printf("已加载 %s 的历史K线数据-4h: %d 条", s, len(klines)) + } + }(symbol) + } + + wg.Wait() + return nil +} + +func (m *WSMonitor) Start() { + log.Printf("启动WebSocket实时监控...") + // 初始化交易对 + err := m.Initialize() + if err != nil { + log.Fatalf("❌ 初始化币种: %v", err) + return + } + + err = m.combinedClient.Connect() + if err != nil { + log.Fatalf("❌ 批量订阅流: %v", err) + return + } + // 启动警报处理器 + go m.handleAlerts() + // 启动定期清理任务 + go m.cleanupInactiveSymbols() + // 输出监控统计 - 评分前十名 + go m.printFilterStats(50) + // 订阅所有交易对 + err = m.subscribeAll() + + if err != nil { + log.Fatalf("❌ 订阅币种交易对: %v", err) + return + } +} + +func (m *WSMonitor) subscribeAll() error { + // 执行批量订阅 + log.Println("开始订阅所有交易对...") + for _, symbol := range m.symbols { + stream3m := fmt.Sprintf("%s@kline_3m", strings.ToLower(symbol)) + ch3m := m.combinedClient.AddSubscriber(stream3m, 100) + go m.handleKlineData(symbol, ch3m, "3m") + + stream4h := fmt.Sprintf("%s@kline_4h", strings.ToLower(symbol)) + ch4h := m.combinedClient.AddSubscriber(stream4h, 100) + go m.handleKlineData(symbol, ch4h, "4h") + } + + err := m.combinedClient.BatchSubscribeKlines(m.symbols, "3m") + if err != nil { + log.Fatalf("❌ 订阅3m K线: %v", err) + return err + } + err = m.combinedClient.BatchSubscribeKlines(m.symbols, "4h") + if err != nil { + log.Fatalf("❌ 订阅4h K线: %v", err) + return err + } + log.Println("所有交易对订阅完成") + return nil +} + +func (m *WSMonitor) handleKlineData(symbol string, ch <-chan []byte, _time string) { + for data := range ch { + var klineData KlineWSData + if err := json.Unmarshal(data, &klineData); err != nil { + log.Printf("解析Kline数据失败: %v", err) + continue + } + m.processKlineUpdate(symbol, klineData, _time) + } +} + +func (m *WSMonitor) handleTickerData(symbol string, ch <-chan []byte) { + for data := range ch { + var tickerData TickerWSData + if err := json.Unmarshal(data, &tickerData); err != nil { + log.Printf("解析Ticker数据失败: %v", err) + continue + } + + m.processTickerUpdate(symbol, tickerData) + } +} +func (m *WSMonitor) handleTickerDatas(ch <-chan []byte) { + for data := range ch { + var tickerData []TickerWSData + if err := json.Unmarshal(data, &tickerData); err != nil { + log.Printf("解析Ticker数据失败: %v", err) + continue + } + log.Fatalln(tickerData) + //m.processTickerUpdate(symbol, tickerData) + } +} +func (m *WSMonitor) getKlineDataMap(_time string) *sync.Map { + var klineDataMap *sync.Map + if _time == "3m" { + klineDataMap = &m.klineDataMap3m + } else { + klineDataMap = &m.klineDataMap4h + } + return klineDataMap +} +func (m *WSMonitor) processKlineUpdate(symbol string, wsData KlineWSData, _time string) { + // 转换WebSocket数据为Kline结构 + kline := Kline{ + OpenTime: wsData.Kline.StartTime, + CloseTime: wsData.Kline.CloseTime, + Trades: wsData.Kline.NumberOfTrades, + } + kline.Open, _ = parseFloat(wsData.Kline.OpenPrice) + kline.High, _ = parseFloat(wsData.Kline.HighPrice) + kline.Low, _ = parseFloat(wsData.Kline.LowPrice) + kline.Close, _ = parseFloat(wsData.Kline.ClosePrice) + kline.Volume, _ = parseFloat(wsData.Kline.Volume) + kline.High, _ = parseFloat(wsData.Kline.HighPrice) + kline.QuoteVolume, _ = parseFloat(wsData.Kline.QuoteVolume) + kline.TakerBuyBaseVolume, _ = parseFloat(wsData.Kline.TakerBuyBaseVolume) + kline.TakerBuyQuoteVolume, _ = parseFloat(wsData.Kline.TakerBuyQuoteVolume) + // 更新K线数据 + var klineDataMap = m.getKlineDataMap(_time) + value, exists := klineDataMap.Load(symbol) + var klines []Kline + if exists { + klines = value.([]Kline) + + // 检查是否是新的K线 + if len(klines) > 0 && klines[len(klines)-1].OpenTime == kline.OpenTime { + // 更新当前K线 + klines[len(klines)-1] = kline + } else { + // 添加新K线 + klines = append(klines, kline) + + // 保持数据长度 + if len(klines) > 100 { + klines = klines[1:] + } + } + } else { + klines = []Kline{kline} + } + + klineDataMap.Store(symbol, klines) + // 计算特征并检测警报 + if len(klines) >= 20 { + features := m.featureEngine.CalculateFeatures(symbol, klines) + if features != nil { + m.featuresMap.Store(symbol, features) + + alerts := m.featureEngine.DetectAlerts(features) + hasAlert := len(alerts) > 0 + + // 更新统计信息 + m.updateSymbolStats(symbol, features, hasAlert) + + for _, alert := range alerts { + m.alertsChan <- alert + } + + // 实时日志输出重要特征 + if len(alerts) > 0 || features.VolumeRatio5 > 2.0 || math.Abs(features.PriceChange15Min) > 0.02 { + //log.Printf("📊 %s - 价格: %.4f, 15分钟变动: %.2f%%, 交易量倍数: %.2f, RSI: %.1f", + // symbol, features.Price, features.PriceChange15Min*100, + // features.VolumeRatio5, features.RSI14) + } + } + } +} + +func (m *WSMonitor) processTickerUpdate(symbol string, tickerData TickerWSData) { + // 存储ticker数据 + m.tickerDataMap.Store(symbol, tickerData) +} + +func (m *WSMonitor) handleAlerts() { + alertCounts := make(map[string]int) + lastReset := time.Now() + + for alert := range m.alertsChan { + // 重置计数器(每小时) + if time.Since(lastReset) > time.Hour { + alertCounts = make(map[string]int) + lastReset = time.Now() + } + + // 警报去重和频率控制 + alertKey := fmt.Sprintf("%s_%s", alert.Symbol, alert.Type) + alertCounts[alertKey]++ + m.filterSymbols.Store(alert.Symbol, true) + + //log.Printf("✅ 自动添加监控: %s (因警报: %s)", alert.Symbol, alert.Message) + if alertCounts[alertKey] <= 3 { // 每小时最多3次相同警报 + //log.Printf("🚨 实时警报: %s", alert.Message) + + // 这里可以添加其他警报处理逻辑 + } + } +} + +func (m *WSMonitor) GetCurrentKlines(symbol string, _time string) ([]Kline, error) { + value, exists := m.getKlineDataMap(_time).Load(symbol) + if !exists { + // 如果Ws数据未初始化完成时,单独使用api获取 - 兼容性代码 (防止在未初始化完成是,已经有交易员运行) + apiClient := NewAPIClient() + klines, err := apiClient.GetKlines(symbol, _time, 40) + if err != nil { + return nil, fmt.Errorf("获取%v分钟K线失败: %v", _time, err) + } + return klines, fmt.Errorf("symbol不存在") + } + return value.([]Kline), nil +} + +func (m *WSMonitor) GetCurrentFeatures(symbol string) (*SymbolFeatures, bool) { + value, exists := m.featuresMap.Load(symbol) + if !exists { + return nil, false + } + return value.(*SymbolFeatures), true +} + +func (m *WSMonitor) GetAllFeatures() map[string]*SymbolFeatures { + features := make(map[string]*SymbolFeatures) + m.featuresMap.Range(func(key, value interface{}) bool { + features[key.(string)] = value.(*SymbolFeatures) + return true + }) + return features +} + +func (m *WSMonitor) Close() { + m.wsClient.Close() + close(m.alertsChan) +} +func (m *WSMonitor) printFilterStats(nember int) { + ticker := time.NewTicker(2 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + var monitoredSymbols []string + m.filterSymbols.Range(func(key, value interface{}) bool { + monitoredSymbols = append(monitoredSymbols, key.(string)) + return true + }) + + log.Printf("🎯 监控统计 - 总数: %d, 币种: %v", + len(monitoredSymbols), monitoredSymbols) + + // 打印前5个评分最高的币种 + type symbolScore struct { + symbol string + score float64 + } + var topScores []symbolScore + + m.symbolStats.Range(func(key, value interface{}) bool { + symbol := key.(string) + stats := value.(*SymbolStats) + topScores = append(topScores, symbolScore{symbol, stats.Score}) + return true + }) + + // 按评分排序 + sort.Slice(topScores, func(i, j int) bool { + return topScores[i].score > topScores[j].score + }) + m.FilterSymbol = nil + if len(topScores) > 0 { + log.Printf("🏆 评分TOP%v:", nember) + for i := 0; i < len(topScores) && i < nember; i++ { + m.FilterSymbol = append(m.FilterSymbol, topScores[i].symbol) + log.Printf(" %d. %s: %.1f分", i+1, topScores[i].symbol, topScores[i].score) + } + } + } +} + +// evaluateSymbolScore 评估币种得分,决定是否保留 +func (m *WSMonitor) evaluateSymbolScore(symbol string, features *SymbolFeatures) float64 { + score := 0.0 + + // 交易量活跃度评分 (权重: 40%) + if features.VolumeRatio5 > 1.5 { + score += 40 * math.Min(features.VolumeRatio5/5.0, 1.0) + } + + // 价格波动评分 (权重: 30%) + volatilityScore := math.Abs(features.PriceChange15Min) * 1000 // 放大系数 + score += 30 * math.Min(volatilityScore/10.0, 1.0) // 最大10%波动得满分 + + // RSI活跃度评分 (权重: 20%) + if features.RSI14 < 30 || features.RSI14 > 70 { + score += 20 // RSI在极端区域 + } else if features.RSI14 < 40 || features.RSI14 > 60 { + score += 10 // RSI在活跃区域 + } + + // 交易量趋势评分 (权重: 10%) + if features.VolumeTrend > 1.2 { + score += 10 * math.Min(features.VolumeTrend/3.0, 1.0) + } + + return score +} + +// shouldRemoveFromFilter 判断是否应该从FilterSymbols中移除 +func (m *WSMonitor) shouldRemoveFromFilter(symbol string) bool { + value, exists := m.symbolStats.Load(symbol) + if !exists { + return true // 没有统计信息,移除 + } + + stats := value.(*SymbolStats) + + // 规则1: 超过30分钟没有活跃迹象 + if time.Since(stats.LastActiveTime) > 30*time.Minute { + log.Printf("🔻 %s 因长时间不活跃被移除", symbol) + return true + } + + // 规则2: 评分持续低于阈值 (最近5次评分平均) + if stats.Score < 15 { // 调整这个阈值 + log.Printf("🔻 %s 因评分过低(%.1f)被移除", symbol, stats.Score) + return true + } + + // 规则3: 超过2小时没有产生警报 + if time.Since(stats.LastAlertTime) > 2*time.Hour && stats.AlertCount > 0 { + log.Printf("🔻 %s 因长时间无新警报被移除", symbol) + return true + } + + return false +} + +// updateSymbolStats 更新币种统计信息 +func (m *WSMonitor) updateSymbolStats(symbol string, features *SymbolFeatures, hasAlert bool) { + now := time.Now() + + value, exists := m.symbolStats.Load(symbol) + var stats *SymbolStats + + if !exists { + stats = &SymbolStats{ + LastActiveTime: now, + Score: m.evaluateSymbolScore(symbol, features), + } + } else { + stats = value.(*SymbolStats) + stats.LastActiveTime = now + + // 平滑更新评分 (指数移动平均) + newScore := m.evaluateSymbolScore(symbol, features) + stats.Score = 0.7*stats.Score + 0.3*newScore + } + + if hasAlert { + stats.AlertCount++ + stats.LastAlertTime = now + } + + if features.VolumeRatio5 > 2.0 { + stats.VolumeSpikeCount++ + } + + m.symbolStats.Store(symbol, stats) +} + +// removeFromFilter 从FilterSymbols中移除币种 +func (m *WSMonitor) removeFromFilter(symbol string) { + + // 从filterSymbols中移除 + m.filterSymbols.Delete(symbol) + m.symbolStats.Delete(symbol) + + log.Printf("🗑️ 已移除币种监控: %s", symbol) +} + +// cleanupInactiveSymbols 定期清理不活跃的币种 +func (m *WSMonitor) cleanupInactiveSymbols() { + ticker := time.NewTicker(5 * time.Minute) // 每5分钟检查一次 + defer ticker.Stop() + + for range ticker.C { + var symbolsToRemove []string + + // 收集需要移除的币种 + m.filterSymbols.Range(func(key, value interface{}) bool { + symbol := key.(string) + if m.shouldRemoveFromFilter(symbol) { + symbolsToRemove = append(symbolsToRemove, symbol) + } + return true + }) + + // 执行移除操作 + for _, symbol := range symbolsToRemove { + m.removeFromFilter(symbol) + } + + if len(symbolsToRemove) > 0 { + log.Printf("🧹 清理完成,移除了 %d 个不活跃币种", len(symbolsToRemove)) + } + } +} + +// getSymbolScore 获取币种当前评分 +func (m *WSMonitor) getSymbolScore(symbol string) float64 { + value, exists := m.symbolStats.Load(symbol) + if !exists { + return 0 + } + return value.(*SymbolStats).Score +} diff --git a/market/types.go b/market/types.go new file mode 100644 index 00000000..82f44415 --- /dev/null +++ b/market/types.go @@ -0,0 +1,157 @@ +package market + +import "time" + +// Data 市场数据结构 +type Data struct { + Symbol string + CurrentPrice float64 + PriceChange1h float64 // 1小时价格变化百分比 + PriceChange4h float64 // 4小时价格变化百分比 + CurrentEMA20 float64 + CurrentMACD float64 + CurrentRSI7 float64 + OpenInterest *OIData + FundingRate float64 + IntradaySeries *IntradayData + LongerTermContext *LongerTermData +} + +// OIData Open Interest数据 +type OIData struct { + Latest float64 + Average float64 +} + +// IntradayData 日内数据(3分钟间隔) +type IntradayData struct { + MidPrices []float64 + EMA20Values []float64 + MACDValues []float64 + RSI7Values []float64 + RSI14Values []float64 +} + +// LongerTermData 长期数据(4小时时间框架) +type LongerTermData struct { + EMA20 float64 + EMA50 float64 + ATR3 float64 + ATR14 float64 + CurrentVolume float64 + AverageVolume float64 + MACDValues []float64 + RSI14Values []float64 +} + +// Binance API 响应结构 +type ExchangeInfo struct { + Symbols []SymbolInfo `json:"symbols"` +} + +type SymbolInfo struct { + Symbol string `json:"symbol"` + Status string `json:"status"` + BaseAsset string `json:"baseAsset"` + QuoteAsset string `json:"quoteAsset"` + ContractType string `json:"contractType"` + PricePrecision int `json:"pricePrecision"` + QuantityPrecision int `json:"quantityPrecision"` +} + +type Kline struct { + OpenTime int64 `json:"openTime"` + Open float64 `json:"open"` + High float64 `json:"high"` + Low float64 `json:"low"` + Close float64 `json:"close"` + Volume float64 `json:"volume"` + CloseTime int64 `json:"closeTime"` + QuoteVolume float64 `json:"quoteVolume"` + Trades int `json:"trades"` + TakerBuyBaseVolume float64 `json:"takerBuyBaseVolume"` + TakerBuyQuoteVolume float64 `json:"takerBuyQuoteVolume"` +} + +type KlineResponse []interface{} + +type PriceTicker struct { + Symbol string `json:"symbol"` + Price string `json:"price"` +} + +type Ticker24hr struct { + Symbol string `json:"symbol"` + PriceChange string `json:"priceChange"` + PriceChangePercent string `json:"priceChangePercent"` + Volume string `json:"volume"` + QuoteVolume string `json:"quoteVolume"` +} + +// 特征数据结构 +type SymbolFeatures struct { + Symbol string `json:"symbol"` + Timestamp time.Time `json:"timestamp"` + Price float64 `json:"price"` + PriceChange15Min float64 `json:"price_change_15min"` + PriceChange1H float64 `json:"price_change_1h"` + PriceChange4H float64 `json:"price_change_4h"` + Volume float64 `json:"volume"` + VolumeRatio5 float64 `json:"volume_ratio_5"` + VolumeRatio20 float64 `json:"volume_ratio_20"` + VolumeTrend float64 `json:"volume_trend"` + RSI14 float64 `json:"rsi_14"` + SMA5 float64 `json:"sma_5"` + SMA10 float64 `json:"sma_10"` + SMA20 float64 `json:"sma_20"` + HighLowRatio float64 `json:"high_low_ratio"` + Volatility20 float64 `json:"volatility_20"` + PositionInRange float64 `json:"position_in_range"` +} + +// 警报数据结构 +type Alert struct { + Type string `json:"type"` + Symbol string `json:"symbol"` + Value float64 `json:"value"` + Threshold float64 `json:"threshold"` + Message string `json:"message"` + Timestamp time.Time `json:"timestamp"` +} + +type Config struct { + AlertThresholds AlertThresholds `json:"alert_thresholds"` + UpdateInterval int `json:"update_interval"` // seconds + CleanupConfig CleanupConfig `json:"cleanup_config"` +} + +type AlertThresholds struct { + VolumeSpike float64 `json:"volume_spike"` + PriceChange15Min float64 `json:"price_change_15min"` + VolumeTrend float64 `json:"volume_trend"` + RSIOverbought float64 `json:"rsi_overbought"` + RSIOversold float64 `json:"rsi_oversold"` +} +type CleanupConfig struct { + InactiveTimeout time.Duration `json:"inactive_timeout"` // 不活跃超时时间 + MinScoreThreshold float64 `json:"min_score_threshold"` // 最低评分阈值 + NoAlertTimeout time.Duration `json:"no_alert_timeout"` // 无警报超时时间 + CheckInterval time.Duration `json:"check_interval"` // 检查间隔 +} + +var config = Config{ + AlertThresholds: AlertThresholds{ + VolumeSpike: 3.0, + PriceChange15Min: 0.05, + VolumeTrend: 2.0, + RSIOverbought: 70, + RSIOversold: 30, + }, + CleanupConfig: CleanupConfig{ + InactiveTimeout: 30 * time.Minute, + MinScoreThreshold: 15.0, + NoAlertTimeout: 20 * time.Minute, + CheckInterval: 5 * time.Minute, + }, + UpdateInterval: 60, // 1 minute +} diff --git a/market/websocket_client.go b/market/websocket_client.go new file mode 100644 index 00000000..ce151691 --- /dev/null +++ b/market/websocket_client.go @@ -0,0 +1,231 @@ +package market + +import ( + "encoding/json" + "fmt" + "log" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type WSClient struct { + conn *websocket.Conn + mu sync.RWMutex + subscribers map[string]chan []byte + reconnect bool + done chan struct{} +} + +type WSMessage struct { + Stream string `json:"stream"` + Data json.RawMessage `json:"data"` +} + +type KlineWSData struct { + EventType string `json:"e"` + EventTime int64 `json:"E"` + Symbol string `json:"s"` + Kline struct { + StartTime int64 `json:"t"` + CloseTime int64 `json:"T"` + Symbol string `json:"s"` + Interval string `json:"i"` + FirstTradeID int64 `json:"f"` + LastTradeID int64 `json:"L"` + OpenPrice string `json:"o"` + ClosePrice string `json:"c"` + HighPrice string `json:"h"` + LowPrice string `json:"l"` + Volume string `json:"v"` + NumberOfTrades int `json:"n"` + IsFinal bool `json:"x"` + QuoteVolume string `json:"q"` + TakerBuyBaseVolume string `json:"V"` + TakerBuyQuoteVolume string `json:"Q"` + } `json:"k"` +} + +type TickerWSData struct { + EventType string `json:"e"` + EventTime int64 `json:"E"` + Symbol string `json:"s"` + PriceChange string `json:"p"` + PriceChangePercent string `json:"P"` + WeightedAvgPrice string `json:"w"` + LastPrice string `json:"c"` + LastQty string `json:"Q"` + OpenPrice string `json:"o"` + HighPrice string `json:"h"` + LowPrice string `json:"l"` + Volume string `json:"v"` + QuoteVolume string `json:"q"` + OpenTime int64 `json:"O"` + CloseTime int64 `json:"C"` + FirstID int64 `json:"F"` + LastID int64 `json:"L"` + Count int `json:"n"` +} + +func NewWSClient() *WSClient { + return &WSClient{ + subscribers: make(map[string]chan []byte), + reconnect: true, + done: make(chan struct{}), + } +} + +func (w *WSClient) Connect() error { + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + + conn, _, err := dialer.Dial("wss://ws-fapi.binance.com/ws-fapi/v1", nil) + if err != nil { + return fmt.Errorf("WebSocket连接失败: %v", err) + } + + w.mu.Lock() + w.conn = conn + w.mu.Unlock() + + log.Println("WebSocket连接成功") + + // 启动消息读取循环 + go w.readMessages() + + return nil +} + +func (w *WSClient) SubscribeKline(symbol, interval string) error { + stream := fmt.Sprintf("%s@kline_%s", symbol, interval) + return w.subscribe(stream) +} + +func (w *WSClient) SubscribeTicker(symbol string) error { + stream := fmt.Sprintf("%s@ticker", symbol) + return w.subscribe(stream) +} + +func (w *WSClient) SubscribeMiniTicker(symbol string) error { + stream := fmt.Sprintf("%s@miniTicker", symbol) + return w.subscribe(stream) +} + +func (w *WSClient) subscribe(stream string) error { + subscribeMsg := map[string]interface{}{ + "method": "SUBSCRIBE", + "params": []string{stream}, + "id": time.Now().Unix(), + } + + w.mu.RLock() + defer w.mu.RUnlock() + + if w.conn == nil { + return fmt.Errorf("WebSocket未连接") + } + + err := w.conn.WriteJSON(subscribeMsg) + if err != nil { + return err + } + + log.Printf("订阅流: %s", stream) + return nil +} + +func (w *WSClient) readMessages() { + for { + select { + case <-w.done: + return + default: + w.mu.RLock() + conn := w.conn + w.mu.RUnlock() + + if conn == nil { + time.Sleep(1 * time.Second) + continue + } + + _, message, err := conn.ReadMessage() + if err != nil { + log.Printf("读取WebSocket消息失败: %v", err) + w.handleReconnect() + return + } + + w.handleMessage(message) + } + } +} + +func (w *WSClient) handleMessage(message []byte) { + var wsMsg WSMessage + if err := json.Unmarshal(message, &wsMsg); err != nil { + // 可能是其他格式的消息 + return + } + + w.mu.RLock() + ch, exists := w.subscribers[wsMsg.Stream] + w.mu.RUnlock() + + if exists { + select { + case ch <- wsMsg.Data: + default: + log.Printf("订阅者通道已满: %s", wsMsg.Stream) + } + } +} + +func (w *WSClient) handleReconnect() { + if !w.reconnect { + return + } + + log.Println("尝试重新连接...") + time.Sleep(3 * time.Second) + + if err := w.Connect(); err != nil { + log.Printf("重新连接失败: %v", err) + go w.handleReconnect() + } +} + +func (w *WSClient) AddSubscriber(stream string, bufferSize int) <-chan []byte { + ch := make(chan []byte, bufferSize) + w.mu.Lock() + w.subscribers[stream] = ch + w.mu.Unlock() + return ch +} + +func (w *WSClient) RemoveSubscriber(stream string) { + w.mu.Lock() + delete(w.subscribers, stream) + w.mu.Unlock() +} + +func (w *WSClient) Close() { + w.reconnect = false + close(w.done) + + w.mu.Lock() + defer w.mu.Unlock() + + if w.conn != nil { + w.conn.Close() + w.conn = nil + } + + // 关闭所有订阅者通道 + for stream, ch := range w.subscribers { + close(ch) + delete(w.subscribers, stream) + } +}