diff --git a/examples/perft_test.rs b/examples/perft_test.rs index 2396527..f6206ef 100644 --- a/examples/perft_test.rs +++ b/examples/perft_test.rs @@ -1,5 +1,5 @@ -use chs::constants::*; use chs::chess::perft::run_perft; +use chs::constants::*; #[derive(Default)] struct TestRunner { @@ -18,13 +18,20 @@ impl TestRunner { } fn summarise(&self) { - println!("Ran {total} tests: {passed} passed, {failed} failed", total = self.passed + self.failed, passed = self.passed, failed = self.failed); + println!( + "Ran {total} tests: {passed} passed, {failed} failed", + total = self.passed + self.failed, + passed = self.passed, + failed = self.failed + ); } } fn main() { let mut runner = TestRunner::default(); - let mut perft = |fen: &str, depth: u64, expected_positions: u64| runner.run_test(|| run_perft(fen, depth, expected_positions)); + let mut perft = |fen: &str, depth: u64, expected_positions: u64| { + runner.run_test(|| run_perft(fen, depth, expected_positions)) + }; // Start position println!("Start position"); diff --git a/examples/single_perft_test.rs b/examples/single_perft_test.rs index a4d3d5b..ae3a71d 100644 --- a/examples/single_perft_test.rs +++ b/examples/single_perft_test.rs @@ -1,5 +1,9 @@ use chs::{chess::perft::run_perft, constants::POSITION_6}; fn main() { - run_perft("rnb2k1r/pp1Pbppp/2p5/q7/2B5/P7/1PP1NnPP/RNBQK2R w KQ - 1 9", 1, 9); + run_perft( + "rnb2k1r/pp1Pbppp/2p5/q7/2B5/P7/1PP1NnPP/RNBQK2R w KQ - 1 9", + 1, + 9, + ); } diff --git a/src/assets.rs b/src/assets.rs index 979489b..21f7f5c 100644 --- a/src/assets.rs +++ b/src/assets.rs @@ -1,4 +1,8 @@ -use axum::{response::{IntoResponse, Response}, body::{boxed, Full}, http::{header, StatusCode}}; +use axum::{ + body::{boxed, Full}, + http::{header, StatusCode}, + response::{IntoResponse, Response}, +}; use rust_embed::RustEmbed; #[derive(RustEmbed)] @@ -18,7 +22,10 @@ where Some(content) => { let body = boxed(Full::from(content.data)); let mime = mime_guess::from_path(path).first_or_octet_stream(); - Response::builder().header(header::CONTENT_TYPE, mime.as_ref()).body(body).unwrap() + Response::builder() + .header(header::CONTENT_TYPE, mime.as_ref()) + .body(body) + .unwrap() } None => StatusCode::NOT_FOUND.into_response(), } diff --git a/src/chess.rs b/src/chess.rs index 4d47b1f..33af4ba 100644 --- a/src/chess.rs +++ b/src/chess.rs @@ -181,7 +181,6 @@ pub struct Castling { pub queen: bool, } - /// FEN parsing implementation impl Board { pub fn from_fen(fen: &str) -> Result { @@ -296,26 +295,32 @@ impl Board { pub fn calc_check_state(&self) -> BySide { let white_moves = { let mut moves = vec![]; - mv::generate_pseudolegal_captures(&Board { - to_move: Side::White, - ..*self - }, &mut moves); + mv::generate_pseudolegal_captures( + &Board { + to_move: Side::White, + ..*self + }, + &mut moves, + ); moves }; let black_moves = { let mut moves = vec![]; - mv::generate_pseudolegal_captures(&Board { - to_move: Side::Black, - ..*self - }, &mut moves); + mv::generate_pseudolegal_captures( + &Board { + to_move: Side::Black, + ..*self + }, + &mut moves, + ); moves }; - let black_checked = white_moves.into_iter().any(|m| { - self.board[m.to.to_index()] == Some(Side::Black | PieceType::King) - }); - let white_checked = black_moves.into_iter().any(|m| { - self.board[m.to.to_index()] == Some(Side::White | PieceType::King) - }); + let black_checked = white_moves + .into_iter() + .any(|m| self.board[m.to.to_index()] == Some(Side::Black | PieceType::King)); + let white_checked = black_moves + .into_iter() + .any(|m| self.board[m.to.to_index()] == Some(Side::White | PieceType::King)); BySide { black: black_checked, white: white_checked, diff --git a/src/chess/mv.rs b/src/chess/mv.rs index 96e72c0..79bc500 100644 --- a/src/chess/mv.rs +++ b/src/chess/mv.rs @@ -1,7 +1,7 @@ -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; -use crate::{prelude::*, chess::Side}; -use super::{Board, Coordinate, Piece, PieceType, Castling}; +use super::{Board, Castling, Coordinate, Piece, PieceType}; +use crate::{chess::Side, prelude::*}; #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct Move { @@ -31,7 +31,9 @@ impl Move { fn make_inner(&self, board: &Board) -> Board { let mut board = board.clone(); - let piece = board.get(self.from).expect("cannot make a move with no piece"); + let piece = board + .get(self.from) + .expect("cannot make a move with no piece"); if piece.ty == King { *board.castling.get_mut(board.to_move) = Castling { king: false, @@ -53,7 +55,8 @@ impl Move { *board.get_mut(self.from) = None; { let captured_piece = board.get(self.to); - if matches!(captured_piece, Some(piece) if piece.ty == Rook && self.to.rank == board.to_move.other().back_rank()) { + if matches!(captured_piece, Some(piece) if piece.ty == Rook && self.to.rank == board.to_move.other().back_rank()) + { let castling = board.castling.get_mut(board.to_move.other()); match self.to.file { 0 => castling.queen = false, @@ -146,7 +149,12 @@ enum MoveResult { Capture, } -fn add_if bool>(board: &Board, moves: &mut Vec, mv: Move, predicate: F) -> MoveResult { +fn add_if bool>( + board: &Board, + moves: &mut Vec, + mv: Move, + predicate: F, +) -> MoveResult { if !mv.to.is_board_position() { return MoveResult::Invalid; } @@ -159,7 +167,9 @@ fn add_if bool>(board: &Board, moves: &mut Vec, mv: M } else { MoveResult::Capture } - } else if matches!(board.get(mv.from), Some(piece) if piece.ty == Pawn) && matches!(board.en_passant_target, Some(target) if target == mv.to) { + } else if matches!(board.get(mv.from), Some(piece) if piece.ty == Pawn) + && matches!(board.en_passant_target, Some(target) if target == mv.to) + { MoveResult::Capture } else { MoveResult::NoCapture @@ -185,7 +195,13 @@ fn add_if_capture(board: &Board, moves: &mut Vec, mv: Move) -> MoveResult } fn generate_castling_moves(board: &Board, moves: &mut Vec) { - fn test_move(board: &Board, king_from: Coordinate, king_to: Coordinate, rook_from: Coordinate, rook_to: Coordinate) -> Option { + fn test_move( + board: &Board, + king_from: Coordinate, + king_to: Coordinate, + rook_from: Coordinate, + rook_to: Coordinate, + ) -> Option { let king = board.get(king_from); let rook = board.get(rook_from); let (king, rook) = match (king, rook) { @@ -236,40 +252,19 @@ fn generate_castling_moves(board: &Board, moves: &mut Vec) { Black => 7, White => 0, }; - let king = Coordinate { - rank, - file: 4, - }; + let king = Coordinate { rank, file: 4 }; if castling.queen { - let rook = Coordinate { - rank, - file: 0, - }; - let king_target = Coordinate { - rank, - file: 2, - }; - let rook_target = Coordinate { - rank, - file: 3, - }; + let rook = Coordinate { rank, file: 0 }; + let king_target = Coordinate { rank, file: 2 }; + let rook_target = Coordinate { rank, file: 3 }; if let Some(mv) = test_move(board, king, king_target, rook, rook_target) { moves.push(mv); } } if castling.king { - let rook = Coordinate { - rank, - file: 7, - }; - let king_target = Coordinate { - rank, - file: 6, - }; - let rook_target = Coordinate { - rank, - file: 5, - }; + let rook = Coordinate { rank, file: 7 }; + let king_target = Coordinate { rank, file: 6 }; + let rook_target = Coordinate { rank, file: 5 }; if let Some(mv) = test_move(board, king, king_target, rook, rook_target) { moves.push(mv); } @@ -277,7 +272,12 @@ fn generate_castling_moves(board: &Board, moves: &mut Vec) { } fn generate_pawn_moves(board: &Board, moves: &mut Vec) { - fn pawn_move(side: Side, from: Coordinate, to: Coordinate, set_en_passant: Option) -> Move { + fn pawn_move( + side: Side, + from: Coordinate, + to: Coordinate, + set_en_passant: Option, + ) -> Move { let promotion_rank = match side { Black => 0, White => 7, @@ -315,13 +315,30 @@ fn generate_pawn_moves(board: &Board, moves: &mut Vec) { _ => {} } let from = Coordinate::from_index(index); - let forward_res = add_if_not_capture(board, moves, pawn_move(board.to_move, from, from + direction, None)); + let forward_res = add_if_not_capture( + board, + moves, + pawn_move(board.to_move, from, from + direction, None), + ); if forward_res == MoveResult::NoCapture && start_rank.contains(&index) { - add_if_not_capture(board, moves, pawn_move(board.to_move, from, from + (direction * 2), Some(from + direction))); + add_if_not_capture( + board, + moves, + pawn_move( + board.to_move, + from, + from + (direction * 2), + Some(from + direction), + ), + ); } for capture in PAWN_CAPTURES { - add_if_capture(board, moves, pawn_move(board.to_move, from, from + direction + capture, None)); + add_if_capture( + board, + moves, + pawn_move(board.to_move, from, from + direction + capture, None), + ); } } } @@ -374,16 +391,19 @@ pub fn generate_pseudolegal(board: &Board, moves: &mut Vec) { pub fn generate_legal(board: &Board) -> Vec { let mut moves = vec![]; generate_pseudolegal(board, &mut moves); - moves.into_iter().filter(|mv| { - if mv.other.is_some() { - // Cannot castle out of check - if *board.calc_check_state().get(board.to_move) { - return false; + moves + .into_iter() + .filter(|mv| { + if mv.other.is_some() { + // Cannot castle out of check + if *board.calc_check_state().get(board.to_move) { + return false; + } } - } - let test_board = mv.make(board); - !*test_board.calc_check_state().get(board.to_move) - }).collect::>() + let test_board = mv.make(board); + !*test_board.calc_check_state().get(board.to_move) + }) + .collect::>() } #[cfg(test)] diff --git a/src/chess/perft.rs b/src/chess/perft.rs index 9e35459..fdd50da 100644 --- a/src/chess/perft.rs +++ b/src/chess/perft.rs @@ -1,6 +1,6 @@ use crate::prelude::*; -use std::io::Write; use rayon::prelude::*; +use std::io::Write; fn perft_board(board: &Board, depth: u64, start_depth: u64) -> u64 { if depth == 0 { @@ -9,42 +9,45 @@ fn perft_board(board: &Board, depth: u64, start_depth: u64) -> u64 { let mut moves = vec![]; crate::chess::mv::generate_pseudolegal(board, &mut moves); - moves.into_par_iter().map(|mv| { - let mut count = 0; - if let Some(promotions) = mv.promotions { - for promotion in promotions { - let mv = Move { - promotions: Some(vec![promotion]), - other: mv.other.clone(), - ..mv - }; + moves + .into_par_iter() + .map(|mv| { + let mut count = 0; + if let Some(promotions) = mv.promotions { + for promotion in promotions { + let mv = Move { + promotions: Some(vec![promotion]), + other: mv.other.clone(), + ..mv + }; + let new_board = mv.make(board); + let check = new_board.calc_check_state(); + if *check.get(board.to_move) { + continue; + } + let this_count = perft_board(&new_board, depth - 1, start_depth); + count += this_count; + #[cfg(test)] + if depth == start_depth { + println!("{from}{to}: {this_count}", from = mv.from, to = mv.to); + } + } + } else { let new_board = mv.make(board); - let check = new_board.calc_check_state(); - if *check.get(board.to_move) { - continue; + if *new_board.calc_check_state().get(board.to_move) { + return 0; } let this_count = perft_board(&new_board, depth - 1, start_depth); count += this_count; #[cfg(test)] if depth == start_depth { println!("{from}{to}: {this_count}", from = mv.from, to = mv.to); + // println!("mv={mv:?} {piece:?}", piece = board.get(mv.from)); } } - } else { - let new_board = mv.make(board); - if *new_board.calc_check_state().get(board.to_move) { - return 0; - } - let this_count = perft_board(&new_board, depth - 1, start_depth); - count += this_count; - #[cfg(test)] - if depth == start_depth { - println!("{from}{to}: {this_count}", from = mv.from, to = mv.to); - // println!("mv={mv:?} {piece:?}", piece = board.get(mv.from)); - } - } - count - }).sum() + count + }) + .sum() } pub fn run_perft(fen: &str, depth: u64, expected_positions: u64) -> bool { diff --git a/src/constants.rs b/src/constants.rs index 6ac3192..6967b7a 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -3,4 +3,5 @@ pub const POSITION_2: &str = "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/ pub const POSITION_3: &str = "8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0"; pub const POSITION_4: &str = "r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1"; pub const POSITION_5: &str = "rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8"; -pub const POSITION_6: &str = "r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10"; +pub const POSITION_6: &str = + "r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10"; diff --git a/src/game.rs b/src/game.rs index 20b267e..37ad9f9 100644 --- a/src/game.rs +++ b/src/game.rs @@ -3,7 +3,12 @@ use std::collections::HashMap; use uuid::Uuid; use xtra::{prelude::*, WeakAddress}; -use crate::{prelude::Board, constants::START_FEN, chess::{Side, mv::generate_legal}, player::{Player, OutgoingPlayerEvent, IncomingPlayerEvent}}; +use crate::{ + chess::{mv::generate_legal, Side}, + constants::START_FEN, + player::{IncomingPlayerEvent, OutgoingPlayerEvent, Player}, + prelude::Board, +}; #[derive(Actor, Default)] pub struct GameManager { @@ -36,11 +41,13 @@ impl Handler for GameManager { async fn handle(&mut self, join_game: JoinGame, _ctx: &mut Context) -> Self::Return { if let Some(game) = self.games.get(&join_game.game_id) { - let res = game.send(join_game.clone()).await.ok() - .map(|r| match r { - Some(side) => JoinGameResponse::Success { side, game: game.clone() }, - None => JoinGameResponse::Full, - }); + let res = game.send(join_game.clone()).await.ok().map(|r| match r { + Some(side) => JoinGameResponse::Success { + side, + game: game.clone(), + }, + None => JoinGameResponse::Full, + }); if res.is_none() { self.games.remove(&join_game.game_id); } @@ -90,7 +97,11 @@ pub struct IncomingEvent { impl Handler for ChessGame { type Return = (); - async fn handle(&mut self, incoming_event: IncomingEvent, _ctx: &mut Context) -> Self::Return { + async fn handle( + &mut self, + incoming_event: IncomingEvent, + _ctx: &mut Context, + ) -> Self::Return { match incoming_event.data { IncomingPlayerEvent::MakeMove { mv } => { if incoming_event.side != self.board.to_move { @@ -103,7 +114,7 @@ impl Handler for ChessGame { self.broadcast_new_board(); self.send_possible_moves(); } - }, + } } } } diff --git a/src/player.rs b/src/player.rs index 6af8ccf..cbe6b42 100644 --- a/src/player.rs +++ b/src/player.rs @@ -2,7 +2,11 @@ use futures::{channel::mpsc::UnboundedSender, SinkExt}; use serde::{Deserialize, Serialize}; use xtra::prelude::*; -use crate::{prelude::*, chess::Side, game::{ChessGame, IncomingEvent}}; +use crate::{ + chess::Side, + game::{ChessGame, IncomingEvent}, + prelude::*, +}; #[derive(Actor)] pub struct Player { @@ -17,31 +21,31 @@ pub struct GameInfo { impl Player { pub fn new(sink: UnboundedSender) -> Self { - Self { - sink, - game: None, - } + Self { sink, game: None } } } #[derive(Debug, Serialize)] #[serde(tag = "event", content = "data")] pub enum OutgoingPlayerEvent { - BoardUpdate { - board: Vec>, - }, - PossibleMoves { - moves: Vec, - }, + BoardUpdate { board: Vec> }, + PossibleMoves { moves: Vec }, } #[async_trait] impl Handler for Player { type Return = (); - async fn handle(&mut self, outgoing_event: OutgoingPlayerEvent, _ctx: &mut Context) -> Self::Return { + async fn handle( + &mut self, + outgoing_event: OutgoingPlayerEvent, + _ctx: &mut Context, + ) -> Self::Return { // TODO: Better error handling - self.sink.send(outgoing_event).await.expect("failed to send outgoing event"); + self.sink + .send(outgoing_event) + .await + .expect("failed to send outgoing event"); } } @@ -51,19 +55,26 @@ pub enum IncomingPlayerEvent { MakeMove { #[serde(rename = "move")] mv: Move, - } + }, } #[async_trait] impl Handler for Player { type Return = (); - async fn handle(&mut self, incoming_event: IncomingPlayerEvent, _ctx: &mut Context) -> Self::Return { + async fn handle( + &mut self, + incoming_event: IncomingPlayerEvent, + _ctx: &mut Context, + ) -> Self::Return { if let Some(game) = &self.game { - game.game.send(IncomingEvent { - data: incoming_event, - side: game.side, - }).await.expect("game disconnected"); + game.game + .send(IncomingEvent { + data: incoming_event, + side: game.side, + }) + .await + .expect("game disconnected"); } } } diff --git a/src/routes.rs b/src/routes.rs index d27e2bc..04709c1 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,22 +1,34 @@ -use axum::{Router, extract::{WebSocketUpgrade, Path, ws::{Message, WebSocket}}, response::IntoResponse, routing::get, Extension}; -use futures::{StreamExt, SinkExt, stream::{SplitStream, SplitSink}, channel::mpsc::{self, UnboundedSender}}; +use axum::{ + extract::{ + ws::{Message, WebSocket}, + Path, WebSocketUpgrade, + }, + response::IntoResponse, + routing::get, + Extension, Router, +}; +use futures::{ + channel::mpsc::{self, UnboundedSender}, + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, +}; use tokio::task::JoinHandle; use uuid::Uuid; -use xtra::{Mailbox, Address}; +use xtra::{Address, Mailbox}; -use crate::{player::{Player, IncomingPlayerEvent, OutgoingPlayerEvent, GameInfo}, game::{GameManager, JoinGame, JoinGameResponse}}; #[cfg(not(debug_assertions))] use crate::assets::StaticFile; +use crate::{ + game::{GameManager, JoinGame, JoinGameResponse}, + player::{GameInfo, IncomingPlayerEvent, OutgoingPlayerEvent, Player}, +}; #[allow(clippy::let_and_return)] pub fn routes() -> Router { - let router = Router::new() - .route("/ws/:id", get(ws_handler)); + let router = Router::new().route("/ws/:id", get(ws_handler)); #[cfg(not(debug_assertions))] - let router = router - .route("/", get(index)) - .fallback(get(fallback)); + let router = router.route("/", get(index)).fallback(get(fallback)); router } @@ -47,11 +59,19 @@ async fn handle_socket(socket: WebSocket, id: Uuid, game_manager: Address) -> (UnboundedSender, JoinHandle<()>) { +fn socket_send( + mut tx: SplitSink, +) -> (UnboundedSender, JoinHandle<()>) { let (message_tx, mut rx) = mpsc::unbounded(); let task = tokio::spawn(async move {