feat: initial commit

This commit is contained in:
Ashhhleyyy 2022-10-31 13:25:32 +00:00
commit 2fd10446b6
Signed by: ash
GPG key ID: 83B789081A0878FB
38 changed files with 4653 additions and 0 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
/target
/node_modules
/dist

1374
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

22
Cargo.toml Normal file
View file

@ -0,0 +1,22 @@
[package]
name = "chs"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = { version = "0.5.17", features = ["ws"] }
futures = "0.3.25"
mime_guess = "2.0.4"
rayon = "1.5.3"
rust-embed = "6.4.2"
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0.87"
thiserror = "1.0.37"
tokio = { version = "1.21.2", features = ["full"] }
tower-http = { version = "0.3.4", features = ["trace", "fs"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
uuid = { version = "1.2.1", features = ["v4", "serde"] }
xtra = { git = "https://github.com/Restioson/xtra.git", version = "0.6.0", features = ["tokio", "instrumentation", "macros", "sink"] }

8
Makefile.toml Normal file
View file

@ -0,0 +1,8 @@
[tasks.build-ui]
command = "pnpm"
args = ["build"]
[tasks.build-release]
command = "cargo"
args = ["build", "--release"]
dependencies = ["build-ui"]

78
examples/perft_test.rs Normal file
View file

@ -0,0 +1,78 @@
use chs::constants::*;
use chs::chess::perft::run_perft;
#[derive(Default)]
struct TestRunner {
passed: usize,
failed: usize,
}
impl TestRunner {
fn run_test<F: Fn() -> bool>(&mut self, f: F) {
let result = f();
if result {
self.passed += 1;
} else {
self.failed += 1;
}
}
fn summarise(&self) {
println!("Ran {total} tests: {passed} passed, {failed} failed", total = self.passed + self.failed, passed = self.passed, failed = self.failed);
}
}
fn main() {
let mut runner = TestRunner::default();
let mut perft = |fen: &str, depth: u64, expected_positions: u64| runner.run_test(|| run_perft(fen, depth, expected_positions));
// Start position
println!("Start position");
perft(START_FEN, 0, 1);
perft(START_FEN, 1, 20);
perft(START_FEN, 2, 400);
perft(START_FEN, 3, 8_902);
perft(START_FEN, 4, 197_281);
// Other test positions (https://www.chessprogramming.org/Perft_Results)
// Position 2
println!("Position 2");
perft(POSITION_2, 1, 48);
perft(POSITION_2, 2, 2_039);
perft(POSITION_2, 3, 97_862);
perft(POSITION_2, 4, 4_085_603);
// Position 3
println!("Position 3");
perft(POSITION_3, 1, 14);
perft(POSITION_3, 2, 191);
perft(POSITION_3, 3, 2_812);
perft(POSITION_3, 4, 43_238);
perft(POSITION_3, 5, 674_624);
perft(POSITION_3, 6, 11_030_083);
// Position 4
println!("Position 4");
perft(POSITION_4, 1, 6);
perft(POSITION_4, 2, 264);
perft(POSITION_4, 3, 9_467);
perft(POSITION_4, 4, 422_333);
perft(POSITION_4, 5, 15_833_292);
// Position 5
println!("Position 5");
perft(POSITION_5, 1, 44);
perft(POSITION_5, 2, 1_486);
perft(POSITION_5, 3, 62_379);
perft(POSITION_5, 4, 2_103_487);
// Position 6
println!("Position 6");
perft(POSITION_6, 1, 46);
perft(POSITION_6, 2, 2_079);
perft(POSITION_6, 3, 89_890);
perft(POSITION_6, 4, 3_894_594);
runner.summarise();
}

View file

@ -0,0 +1,5 @@
use chs::{chess::perft::run_perft, constants::POSITION_6};
fn main() {
run_perft("rnb2k1r/pp1Pbppp/2p5/q7/2B5/P7/1PP1NnPP/RNBQK2R w KQ - 1 9", 1, 9);
}

15
index.html Normal file
View file

@ -0,0 +1,15 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="theme-color" content="#000000" />
<title>Chs</title>
</head>
<body>
<noscript>You need to enable JavaScript to run this app.</noscript>
<div id="root"></div>
<script src="/src-web/index.tsx" type="module"></script>
</body>
</html>

27
package.json Normal file
View file

@ -0,0 +1,27 @@
{
"name": "chs",
"version": "1.0.0",
"description": "",
"main": "index.js",
"scripts": {
"dev": "vite",
"build": "vite build",
"serve": "vite preview"
},
"keywords": [],
"author": "",
"license": "ISC",
"devDependencies": {
"@types/node": "^18.11.7",
"typescript": "^4.8.4",
"vite": "^3.2.1",
"vite-plugin-solid": "^2.3.10"
},
"dependencies": {
"@fontsource/noto-sans-symbols-2": "^4.5.10",
"@solid-primitives/websocket": "^0.3.3",
"solid-devtools": "^0.20.1",
"solid-js": "^1.6.0",
"zod": "^3.19.1"
}
}

1076
pnpm-lock.yaml Normal file

File diff suppressed because it is too large Load diff

61
src-web/App.tsx Normal file
View file

@ -0,0 +1,61 @@
import { Component, createEffect, createSignal, Match, Switch } from 'solid-js';
import createWebsocket from '@solid-primitives/websocket';
import { wsUrl } from './constants';
import Spinner from './components/Spinner';
import Welcome from './components/Welcome';
import { Board as BoardData, Move, ServerChessEvent } from './events';
import Board from './components/Board';
interface Props {
gameId: string;
}
const App: Component<Props> = (props) => {
const [board, setBoard] = createSignal<BoardData>(Array(64).fill(null));
const [possibleMoves, setPossibleMoves] = createSignal<Move[]>([]);
function handleEvent(e: MessageEvent<string>) {
const data = JSON.parse(e.data);
const event = ServerChessEvent.parse(data);
if (event.event === 'BoardUpdate') {
console.log(event.data.board);
setBoard(event.data.board);
} else if (event.event === 'PossibleMoves') {
console.log(event.data.moves);
setPossibleMoves(event.data.moves);
}
}
createEffect(() => {
console.log('board is now', board());
});
createEffect(() => {
console.log('moves is now', possibleMoves());
});
const [connect, disconnect, send, state, socket] = createWebsocket(wsUrl(props.gameId), handleEvent, console.error);
function makeMove(move: Move) {
send(JSON.stringify({
event: 'MakeMove',
data: {
move,
},
}));
}
return <Switch fallback={<><h1>Hello, World!</h1><Spinner /></>}>
<Match when={state() === WebSocket.CLOSED}>
<Welcome gameId={props.gameId} joinGame={() => connect()} />
</Match>
<Match when={state() === WebSocket.CONNECTING}>
<Spinner />
</Match>
<Match when={state() === WebSocket.OPEN}>
<Board board={board} moves={possibleMoves} makeMove={makeMove} />
</Match>
</Switch>
};
export default App;

View file

@ -0,0 +1,62 @@
.board {
font-family: 'Noto Sans Symbols 2', sans-serif;
display: grid;
grid-template-columns: repeat(8, auto);
}
.square {
width: 64px;
height: 64px;
font-size: 48px;
display: flex;
flex-direction: row;
align-items: baseline;
justify-content: center;
user-select: none;
cursor: default;
padding: auto;
position: relative;
border-radius: 8px;
}
.square.selectable {
cursor: pointer;
}
.light {
background-color: #c8a2c8;
/* color: #5d375d; */
}
.dark {
background-color: #5d375d;
/* color: #C8A2C8; */
}
.white {
color: white;
}
.black {
color: black;
}
.borderSelected {
position: absolute;
width: 64px;
height: 64px;
top: 0;
left: 0;
border: 4px solid yellowgreen;
border-radius: 8px;
}
.borderTarget {
position: absolute;
width: 32px;
height: 32px;
top: 16px;
left: 16px;
border: 4px solid yellowgreen;
border-radius: 999px;
}

View file

@ -0,0 +1,73 @@
import { Accessor, Component, createSignal, For, Show } from "solid-js";
import { Board as BoardData, Coordinate, Move, PieceType } from "../events";
import './Board.css';
interface Props {
board: Accessor<BoardData>;
moves: Accessor<Move[]>;
makeMove: (move: Move) => void;
}
const PIECE_CHARS: Record<PieceType, string> = {
'king': '♔',
'queen': '♕',
'rook': '♜',
'knight': '♞',
'bishop': '♝',
'pawn': '♙',
};
const Board: Component<Props> = (props) => {
const [selectedSquare, setSelectedSquare] = createSignal<number | null>(null);
const validMoves = () => {
const selected = selectedSquare();
if (selected !== null) {
return props.moves().filter(move => (move.from.rank * 8 + move.from.file) === selected);
} else {
return [];
}
};
return <div class="board">
{props.board().map((piece, i) => {
const coord: Coordinate = {
rank: Math.floor(i / 8),
file: i % 8,
};
const isLight = () => ((i % 8) + Math.floor(i / 8)) % 2 == 0;
const hasMoves = () => props.moves().find((move) => move.from.rank === coord.rank && move.from.file === coord.file) !== undefined;
const targetMove = () => validMoves().find(move => move.to.rank === coord.rank && move.to.file === coord.file);
const isTarget = () => targetMove() !== undefined;
return <div classList={{
square: true,
light: isLight(),
dark: !isLight(),
black: piece?.side === 'black',
white: piece?.side === 'white',
selectable: hasMoves() || isTarget(),
}} onClick={() => {
if (hasMoves()) {
if (selectedSquare() === i) {
setSelectedSquare(null);
} else {
setSelectedSquare(i);
}
} else {
const target = targetMove();
if (target && !target.promotions) {
props.makeMove(target);
setSelectedSquare(null);
}
}
}}>
<Show when={selectedSquare() === i}><div class="borderSelected" /></Show>
<Show when={isTarget()}><div class="borderTarget" /></Show>
<Show when={piece !== null}>{PIECE_CHARS[piece!.ty]}</Show>
</div>
})}
</div>
}
export default Board;

View file

@ -0,0 +1,16 @@
.button {
width: 100%;
font-size: 1.35rem;
padding: 8px;
border-radius: 8px;
border: 1px solid black;
background-color: #ddd;
}
.button:hover {
filter: brightness(.9);
}
.button:active {
filter: brightness(.8);
}

View file

@ -0,0 +1,16 @@
import { children, Component, JSX } from "solid-js";
import './Button.css';
interface Props {
onClick?: () => void;
children: JSX.Element;
}
const Button: Component<Props> = (props) => {
const c = children(() => props.children)
return <button class="button" onClick={props.onClick}>
{c()}
</button>
}
export default Button;

View file

View file

@ -0,0 +1,11 @@
import { Component } from 'solid-js';
import './Connecting.css';
import Spinner from './Spinner';
const Connecting: Component = () => {
return <div>
<Spinner />
</div>
}
export default Connecting;

View file

@ -0,0 +1,19 @@
.spinner {
display: inline-block;
width: var(--spinner-size);
height: var(--spinner-size);
border: 4px solid;
border-color: transparent transparent var(--spinner-colour) var(--spinner-colour);
border-radius: 99999px;
transform-origin: center;
animation: spinner-spin 500ms linear infinite;
}
@keyframes spinner-spin {
from {
transform: rotateZ(0deg);
}
to {
transform: rotateZ(360deg);
}
}

View file

@ -0,0 +1,16 @@
import { Component } from "solid-js";
import './Spinner.css';
interface Props {
colour?: string;
size?: string;
}
const Spinner: Component<Props> = (props) => {
return <div class="spinner" style={{
'--spinner-size': props.size ?? '32px',
'--spinner-colour': props.colour ?? '#f9027a',
}} />
}
export default Spinner;

View file

@ -0,0 +1,22 @@
import { Component } from "solid-js";
import Button from "./Button";
interface Props {
gameId: string;
joinGame: () => void;
}
const Welcome: Component<Props> = (props) => {
console.log(props.gameId);
return <main>
<h1>Welcome</h1>
<div>
Game ID: <pre>{props.gameId}</pre>
</div>
<Button onClick={() => props.joinGame()}>
Join
</Button>
</main>;
}
export default Welcome;

12
src-web/constants.ts Normal file
View file

@ -0,0 +1,12 @@
// const WS_BASE = 'ws://localhost:3000/ws/';
export function wsUrl(id: string) {
if (import.meta.env.WS_BASE) {
return import.meta.env.WS_BASE + id;
} else {
const loc = window.location;
let newUri = loc.protocol === "https:" ? "wss://" : "ws://";
newUri += loc.host + '/ws/' + id;
return newUri;
}
}

59
src-web/events.ts Normal file
View file

@ -0,0 +1,59 @@
import { z } from 'zod';
export const Coordinate = z.object({
rank: z.number(),
file: z.number(),
});
export type Coordinate = z.infer<typeof Coordinate>;
export const Side = z.enum(['white', 'black']);
export type Side = z.infer<typeof Side>;
export const PieceType = z.enum(['king', 'queen', 'rook', 'bishop', 'knight', 'pawn']);
export type PieceType = z.infer<typeof PieceType>;
export const Piece = z.object({
side: Side,
ty: PieceType,
});
export type Piece = z.infer<typeof Piece>;
export interface Move {
from: Coordinate;
to: Coordinate;
set_en_passant: Coordinate | null;
other: Move | null;
promotions: Piece[] | null;
}
export const Move: z.ZodType<Move> = z.lazy(() => z.object({
from: Coordinate,
to: Coordinate,
set_en_passant: z.nullable(Coordinate),
other: z.nullable(Move),
promotions: z.nullable(z.array(Piece)),
}));
export const Board = z.array(z.nullable(Piece));
export type Board = z.infer<typeof Board>;
export const BoardUpdateEvent = z.object({
board: Board,
});
export type BoardUpdateEvent = z.infer<typeof BoardUpdateEvent>;
export const PossibleMovesEvent = z.object({
moves: z.array(Move),
});
export type PossibleMovesEvent = z.infer<typeof PossibleMovesEvent>;
export const ServerChessEvent = z.discriminatedUnion("event", [
z.object({ event: z.literal('BoardUpdate'), data: BoardUpdateEvent }),
z.object({ event: z.literal('PossibleMoves'), data: PossibleMovesEvent }),
]);
export type ServerChessEvent = z.infer<typeof ServerChessEvent>;

14
src-web/index.tsx Normal file
View file

@ -0,0 +1,14 @@
/* @refresh reload */
import { render } from 'solid-js/web';
import 'solid-devtools';
import App from './App';
import './main.css';
import "@fontsource/noto-sans-symbols-2";
const search = new URLSearchParams(window.location.search);
if (search.has('game_id')) {
render(() => <App gameId={search.get('game_id')!} />, document.getElementById('root') as HTMLElement);
} else {
console.error('no game id');
}

19
src-web/main.css Normal file
View file

@ -0,0 +1,19 @@
html {
box-sizing: border-box;
}
html, body {
padding: 0;
margin: 0;
}
* {
box-sizing: inherit;
}
body {
display: grid;
place-items: center;
width: 100%;
min-height: 100vh;
}

26
src/assets.rs Normal file
View file

@ -0,0 +1,26 @@
use axum::{response::{IntoResponse, Response}, body::{boxed, Full}, http::{header, StatusCode}};
use rust_embed::RustEmbed;
#[derive(RustEmbed)]
#[folder = "dist/"]
struct Asset;
pub struct StaticFile<T>(pub T);
impl<T> IntoResponse for StaticFile<T>
where
T: Into<String>,
{
fn into_response(self) -> Response {
let path = self.0.into();
match Asset::get(path.as_str()) {
Some(content) => {
let body = boxed(Full::from(content.data));
let mime = mime_guess::from_path(path).first_or_octet_stream();
Response::builder().header(header::CONTENT_TYPE, mime.as_ref()).body(body).unwrap()
}
None => StatusCode::NOT_FOUND.into_response(),
}
}
}

636
src/chess.rs Normal file
View file

@ -0,0 +1,636 @@
use std::fmt::Display;
use serde::{Deserialize, Serialize};
use crate::error::{FENParseError, NotationError};
pub mod mv;
pub mod perft;
#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Coordinate {
rank: isize,
file: isize,
}
impl Coordinate {
pub fn is_board_position(self) -> bool {
self.rank >= 0 && self.rank < 8 && self.file >= 0 && self.file < 8
}
pub fn to_index(self) -> usize {
if !self.is_board_position() {
panic!("called to_index on a coordinate with out-of-bounds values: {self:?}");
}
(self.rank as usize) * 8 + (self.file as usize)
}
pub fn from_index(index: usize) -> Self {
let rank = (index / 8) as isize;
let file = (index % 8) as isize;
return Self { rank, file };
}
pub fn parse_algebraic(s: &str) -> Result<Self, NotationError> {
if s.len() != 2 {
return Err(NotationError::InvalidLength {
length: s.len(),
expected: 2,
});
}
let file = s
.chars()
.next()
.unwrap()
.to_digit(18)
.map(|v| v - 10)
.ok_or_else(|| NotationError::Other(s.to_owned()))? as isize;
let rank = s
.chars()
.skip(1)
.next()
.unwrap()
.to_digit(9)
.map(|v| v - 1)
.ok_or_else(|| NotationError::Other(s.to_owned()))? as isize;
if file < 0 || rank < 0 {
return Err(NotationError::Other(s.to_owned()));
}
Ok(Self { file, rank })
}
pub fn between(from: Coordinate, to: Coordinate) -> BetweenCoordsIter {
if from.rank != to.rank && from.file != to.file {
panic!("Coordinate::between must be passed two values in the same rank or file");
}
if from.rank != to.rank {
let min_rank = from.rank.min(to.rank);
let max_rank = from.rank.max(to.rank);
BetweenCoordsIter {
start_file: from.file,
start_rank: min_rank,
rank: true,
offset: 1,
max_offset: (max_rank - min_rank) - 1,
}
} else {
let min_file = from.file.min(to.file);
let max_file = from.file.max(to.file);
BetweenCoordsIter {
start_rank: from.rank,
start_file: min_file,
rank: false,
offset: 1,
max_offset: (max_file - min_file) - 1,
}
}
}
}
const FILES: [char; 8] = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'];
impl Display for Coordinate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if !self.is_board_position() {
panic!("cannot format non-board Coordinate in algebraic notation");
}
write!(f, "{}", FILES[self.file as usize])?;
write!(f, "{}", self.rank + 1)?;
Ok(())
}
}
#[derive(Debug)]
pub struct BetweenCoordsIter {
start_rank: isize,
start_file: isize,
rank: bool,
offset: isize,
max_offset: isize,
}
impl Iterator for BetweenCoordsIter {
type Item = Coordinate;
fn next(&mut self) -> Option<Self::Item> {
if self.offset > self.max_offset {
return None;
}
let rank = if self.rank {
self.start_rank + self.offset
} else {
self.start_rank
};
let file = if self.rank {
self.start_file
} else {
self.start_file + self.offset
};
self.offset += 1;
Some(Coordinate { rank, file })
}
}
impl std::ops::Add<Coordinate> for Coordinate {
type Output = Coordinate;
fn add(self, rhs: Coordinate) -> Self::Output {
Self {
file: self.file + rhs.file,
rank: self.rank + rhs.rank,
}
}
}
impl std::ops::Mul<isize> for Coordinate {
type Output = Coordinate;
fn mul(self, rhs: isize) -> Self::Output {
Self {
rank: self.rank * rhs,
file: self.file * rhs,
}
}
}
#[derive(Clone, Debug)]
pub struct Board {
pub board: [Option<Piece>; 64],
pub to_move: Side,
pub castling: BySide<Castling>,
pub en_passant_target: Option<Coordinate>,
pub halfmove_clock: u32,
pub fullmove_number: u32,
}
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub struct Castling {
pub king: bool,
pub queen: bool,
}
/// FEN parsing implementation
impl Board {
pub fn from_fen(fen: &str) -> Result<Self, FENParseError> {
let mut board = [None; 64];
let sections = fen.split(' ').collect::<Vec<_>>();
if sections.len() < 6 {
return Err(FENParseError::NotEnoughSections);
}
Self::parse_fen_board(&mut board, sections[0])?;
let to_move = match sections[1] {
"w" => Side::White,
"b" => Side::Black,
c => return Err(NotationError::InvalidSide(c.to_owned()).into()),
};
let castling = Self::parse_fen_castling(sections[2])?;
let en_passant_target = match sections[3] {
"-" => None,
s => Some(Coordinate::parse_algebraic(s)?),
};
let halfmove_clock = sections[4].parse::<u32>()?;
let fullmove_number = sections[5].parse::<u32>()?;
Ok(Self {
board,
to_move,
castling,
en_passant_target,
halfmove_clock,
fullmove_number,
})
}
fn parse_fen_castling(fen: &str) -> Result<BySide<Castling>, NotationError> {
let mut castling = BySide::<Castling>::default();
if fen != "-" {
for c in fen.chars() {
let side = Side::from(c);
let castling = castling.get_mut(side);
match c {
'q' | 'Q' => {
castling.queen = true;
}
'k' | 'K' => {
castling.king = true;
}
_ => return Err(NotationError::InvalidPiece(c)),
}
}
}
Ok(castling)
}
fn parse_fen_board(target: &mut [Option<Piece>; 64], board: &str) -> Result<(), FENParseError> {
enum State {
ParsingRank { file: usize },
WaitingForSlash,
EndOfString,
}
let mut state = State::ParsingRank { file: 0 };
let mut rank = 7;
for c in board.chars() {
match &mut state {
State::ParsingRank { file } => {
match c {
'0'..='8' => {
*file += c.to_digit(9).unwrap() as usize;
}
c => {
target[rank * 8 + *file] = Some(Piece::try_from(c)?);
*file += 1;
}
}
if *file == 8 {
state = if rank == 0 {
State::EndOfString
} else {
State::WaitingForSlash
};
}
}
State::WaitingForSlash => {
if c != '/' {
return Err(FENParseError::ExpectedSlash);
} else {
state = State::ParsingRank { file: 0 };
rank -= 1;
}
}
State::EndOfString => {
// This case should never be reached on a valid FEN string
return Err(FENParseError::BoardTooLarge);
}
}
}
Ok(())
}
}
/// Accessors
impl Board {
pub fn calc_check_state(&self) -> BySide<bool> {
let white_moves = {
let mut moves = vec![];
mv::generate_pseudolegal_captures(&Board {
to_move: Side::White,
..*self
}, &mut moves);
moves
};
let black_moves = {
let mut moves = vec![];
mv::generate_pseudolegal_captures(&Board {
to_move: Side::Black,
..*self
}, &mut moves);
moves
};
let black_checked = white_moves.into_iter().any(|m| {
self.board[m.to.to_index()] == Some(Side::Black | PieceType::King)
});
let white_checked = black_moves.into_iter().any(|m| {
self.board[m.to.to_index()] == Some(Side::White | PieceType::King)
});
return BySide {
black: black_checked,
white: white_checked,
}
}
pub fn get(&self, c: Coordinate) -> &Option<Piece> {
&self.board[c.to_index()]
}
pub fn get_mut(&mut self, c: Coordinate) -> &mut Option<Piece> {
&mut self.board[c.to_index()]
}
pub fn print_display(&self) {
for rank in (0..8).rev() {
for file in 0..8 {
if let Some(piece) = self.get(Coordinate { rank, file }) {
print!("{}", piece.to_char());
} else {
print!(" ");
}
}
println!();
}
}
}
#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Black,
White,
}
impl Side {
pub fn other(&self) -> Self {
match self {
Self::White => Self::Black,
Self::Black => Self::White,
}
}
pub fn back_rank(&self) -> isize {
match self {
Self::Black => 7,
Self::White => 0,
}
}
pub fn pawn_rank(&self) -> isize {
match self {
Self::Black => 6,
Self::White => 1,
}
}
}
impl From<char> for Side {
fn from(c: char) -> Self {
if c.is_uppercase() {
Self::White
} else {
Self::Black
}
}
}
#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum PieceType {
King,
Queen,
Rook,
Bishop,
Knight,
Pawn,
}
impl PieceType {
pub fn to_char(self) -> char {
match self {
Self::King => 'k',
Self::Queen => 'q',
Self::Rook => 'r',
Self::Bishop => 'b',
Self::Knight => 'n',
Self::Pawn => 'p',
}
}
}
impl TryFrom<char> for PieceType {
type Error = NotationError;
fn try_from(value: char) -> Result<Self, Self::Error> {
match value.to_ascii_lowercase() {
'k' => Ok(Self::King),
'q' => Ok(Self::Queen),
'r' => Ok(Self::Rook),
'b' => Ok(Self::Bishop),
'n' => Ok(Self::Knight),
'p' => Ok(Self::Pawn),
_ => Err(NotationError::InvalidPiece(value)),
}
}
}
#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Piece {
pub side: Side,
pub ty: PieceType,
}
impl Piece {
pub fn to_char(&self) -> char {
match self.side {
Side::Black => self.ty.to_char(),
Side::White => self.ty.to_char().to_ascii_uppercase(),
}
}
}
impl TryFrom<char> for Piece {
type Error = FENParseError;
fn try_from(value: char) -> Result<Self, Self::Error> {
let side = Side::from(value);
let ty = PieceType::try_from(value)?;
Ok(side | ty)
}
}
impl Display for Piece {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_char())
}
}
impl std::ops::BitOr<PieceType> for Side {
type Output = Piece;
fn bitor(self, rhs: PieceType) -> Self::Output {
Piece {
side: self,
ty: rhs,
}
}
}
#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
pub struct BySide<T> {
pub white: T,
pub black: T,
}
impl<T> BySide<T> {
pub fn get(&self, side: Side) -> &T {
match side {
Side::Black => &self.black,
Side::White => &self.white,
}
}
pub fn get_mut(&mut self, side: Side) -> &mut T {
match side {
Side::Black => &mut self.black,
Side::White => &mut self.white,
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::prelude::*;
// === FEN LAYOUT PARSER ===
#[test]
fn start_position_board() {
let mut target = [None; 64];
let res =
Board::parse_fen_board(&mut target, "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR");
assert!(res.is_ok());
}
#[test]
fn start_position_board_extra() {
let mut target = [None; 64];
let res = Board::parse_fen_board(
&mut target,
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNRaaaaaaaaa",
);
assert!(matches!(res, Err(FENParseError::BoardTooLarge)));
}
#[test]
fn start_position_board_invalid_piece() {
let mut target = [None; 64];
let res =
Board::parse_fen_board(&mut target, "rnbqkbnr/vvvvvvvv/8/8/8/8/PPPPPPPP/RNBQKBNR");
assert!(matches!(
res,
Err(FENParseError::InvalidNotation(NotationError::InvalidPiece(
_
)))
));
}
#[test]
fn start_position_board_rank_too_short() {
let mut target = [None; 64];
let res = Board::parse_fen_board(&mut target, "rnbqkbnr/pp/8/8/8/8/PPPPPPPP/RNBQKBNR");
assert!(matches!(
res,
Err(FENParseError::InvalidNotation(NotationError::InvalidPiece(
'/'
)))
));
}
// === FULL FEN PARSER ===
#[test]
fn start_position() {
let board =
Board::from_fen(crate::constants::START_FEN).expect("FEN string should be parsed");
assert_eq!(
board.board,
[
Some(White | Rook),
Some(White | Knight),
Some(White | Bishop),
Some(White | Queen),
Some(White | King),
Some(White | Bishop),
Some(White | Knight),
Some(White | Rook),
Some(White | Pawn),
Some(White | Pawn),
Some(White | Pawn),
Some(White | Pawn),
Some(White | Pawn),
Some(White | Pawn),
Some(White | Pawn),
Some(White | Pawn),
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
Some(Black | Pawn),
Some(Black | Pawn),
Some(Black | Pawn),
Some(Black | Pawn),
Some(Black | Pawn),
Some(Black | Pawn),
Some(Black | Pawn),
Some(Black | Pawn),
Some(Black | Rook),
Some(Black | Knight),
Some(Black | Bishop),
Some(Black | Queen),
Some(Black | King),
Some(Black | Bishop),
Some(Black | Knight),
Some(Black | Rook),
]
);
assert_eq!(board.to_move, White);
assert_eq!(
board.castling,
BySide {
white: Castling {
king: true,
queen: true,
},
black: Castling {
king: true,
queen: true,
},
}
);
assert_eq!(board.en_passant_target, None);
assert_eq!(board.halfmove_clock, 0);
assert_eq!(board.fullmove_number, 1);
}
}

404
src/chess/mv.rs Normal file
View file

@ -0,0 +1,404 @@
use serde::{Serialize, Deserialize};
use crate::{prelude::*, chess::Side};
use super::{Board, Coordinate, Piece, PieceType, Castling};
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct Move {
pub from: Coordinate,
pub to: Coordinate,
pub set_en_passant: Option<Coordinate>,
pub other: Option<Box<Move>>,
pub promotions: Option<Vec<Piece>>,
}
impl Move {
pub fn new(from: Coordinate, to: Coordinate) -> Self {
Self {
from,
to,
set_en_passant: None,
other: None,
promotions: None,
}
}
pub fn make(&self, board: &Board) -> Board {
let mut board = self.make_inner(board);
board.to_move = board.to_move.other();
board
}
fn make_inner(&self, board: &Board) -> Board {
let mut board = board.clone();
let piece = board.get(self.from).expect("cannot make a move with no piece");
if piece.ty == King {
*board.castling.get_mut(board.to_move) = Castling {
king: false,
queen: false,
};
}
let start_rank = board.to_move.back_rank();
if piece.ty == Rook && self.from.rank == start_rank {
if self.from.file == 0 {
board.castling.get_mut(board.to_move).queen = false;
}
if self.from.file == 7 {
board.castling.get_mut(board.to_move).king = false;
}
}
*board.get_mut(self.from) = None;
{
let captured_piece = board.get(self.to);
if matches!(captured_piece, Some(piece) if piece.ty == Rook && self.to.rank == board.to_move.other().back_rank()) {
let castling = board.castling.get_mut(board.to_move.other());
match self.to.file {
0 => castling.queen = false,
7 => castling.king = false,
_ => {}
};
}
}
let new_piece = if let Some(promotions) = &self.promotions {
match promotions.len() {
0 => piece,
1 => promotions[0],
_ => panic!("tried to make a move with more than one promotion"),
}
} else {
piece
};
*board.get_mut(self.to) = Some(new_piece);
if piece.ty == Pawn {
if let Some(en_passant_target) = board.en_passant_target {
if en_passant_target == self.to {
*board.get_mut(Coordinate {
rank: self.from.rank,
file: self.to.file,
}) = None;
}
}
}
board.en_passant_target = self.set_en_passant;
if let Some(other) = &self.other {
other.make_inner(&board)
} else {
board
}
}
}
const CARDINALS: [Coordinate; 4] = [
Coordinate { rank: 1, file: 0 },
Coordinate { rank: -1, file: 0 },
Coordinate { rank: 0, file: 1 },
Coordinate { rank: 0, file: -1 },
];
const DIAGONALS: [Coordinate; 4] = [
Coordinate { rank: 1, file: 1 },
Coordinate { rank: -1, file: -1 },
Coordinate { rank: -1, file: 1 },
Coordinate { rank: 1, file: -1 },
];
// Wish there was a way to statically concatenate CARDIANALS and DIAGONALS so I didn't have to copy-paste these.
const ALL_DIRECTIONS: [Coordinate; 8] = [
Coordinate { rank: 1, file: 0 },
Coordinate { rank: -1, file: 0 },
Coordinate { rank: 0, file: 1 },
Coordinate { rank: 0, file: -1 },
Coordinate { rank: 1, file: 1 },
Coordinate { rank: -1, file: -1 },
Coordinate { rank: -1, file: 1 },
Coordinate { rank: 1, file: -1 },
];
// And no I can't be bothered to write a macro lol
const KNIGHT: [Coordinate; 8] = [
Coordinate { rank: 1, file: 2 },
Coordinate { rank: 2, file: 1 },
Coordinate { rank: -1, file: -2 },
Coordinate { rank: -2, file: -1 },
Coordinate { rank: -1, file: 2 },
Coordinate { rank: -2, file: 1 },
Coordinate { rank: 1, file: -2 },
Coordinate { rank: 2, file: -1 },
];
const PAWN_CAPTURES: [Coordinate; 2] = [
Coordinate { rank: 0, file: 1 },
Coordinate { rank: 0, file: -1 },
];
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum MoveResult {
NoCapture,
Invalid,
Capture,
}
fn add_if<F: Fn(MoveResult) -> bool>(board: &Board, moves: &mut Vec<Move>, mv: Move, predicate: F) -> MoveResult {
if !mv.to.is_board_position() {
return MoveResult::Invalid;
}
let target_piece = board.board[mv.to.to_index()];
let result = if let Some(target_piece) = target_piece {
if target_piece.side == board.to_move {
MoveResult::Invalid
} else {
MoveResult::Capture
}
} else if matches!(board.get(mv.from), Some(piece) if piece.ty == Pawn) && matches!(board.en_passant_target, Some(target) if target == mv.to) {
MoveResult::Capture
} else {
MoveResult::NoCapture
};
if predicate(result) {
moves.push(mv);
}
result
}
fn add_if_valid(board: &Board, moves: &mut Vec<Move>, mv: Move) -> MoveResult {
add_if(board, moves, mv, |r| r != MoveResult::Invalid)
}
fn add_if_not_capture(board: &Board, moves: &mut Vec<Move>, mv: Move) -> MoveResult {
add_if(board, moves, mv, |r| r == MoveResult::NoCapture)
}
fn add_if_capture(board: &Board, moves: &mut Vec<Move>, mv: Move) -> MoveResult {
add_if(board, moves, mv, |r| r == MoveResult::Capture)
}
fn generate_castling_moves(board: &Board, moves: &mut Vec<Move>) {
fn test_move(board: &Board, king_from: Coordinate, king_to: Coordinate, rook_from: Coordinate, rook_to: Coordinate) -> Option<Move> {
let king = board.get(king_from);
let rook = board.get(rook_from);
let (king, rook) = match (king, rook) {
(Some(king), Some(rook)) => (*king, *rook),
_ => return None,
};
if king.ty != King || rook.ty != Rook {
return None;
}
if board.get(king_to).is_some() || board.get(rook_to).is_some() {
return None;
}
for mid in Coordinate::between(king_from, king_to) {
if board.get(mid).is_some() {
return None;
}
let test_state = Move::new(king_from, mid).make(&board);
if *test_state.calc_check_state().get(board.to_move) {
return None;
}
}
for mid in Coordinate::between(rook_from, rook_to) {
if board.get(mid).is_some() {
return None;
}
}
Some(Move {
from: king_from,
to: king_to,
other: Some(Box::new(Move::new(rook_from, rook_to))),
promotions: None,
set_en_passant: None,
})
}
let castling = board.castling.get(board.to_move);
// Cannot castle out of check
if (castling.queen || castling.king) && *board.calc_check_state().get(board.to_move) {
return;
}
let rank = match board.to_move {
Black => 7,
White => 0,
};
let king = Coordinate {
rank,
file: 4,
};
if castling.queen {
let rook = Coordinate {
rank,
file: 0,
};
let king_target = Coordinate {
rank,
file: 2,
};
let rook_target = Coordinate {
rank,
file: 3,
};
if let Some(mv) = test_move(board, king, king_target, rook, rook_target) {
moves.push(mv);
}
}
if castling.king {
let rook = Coordinate {
rank,
file: 7,
};
let king_target = Coordinate {
rank,
file: 6,
};
let rook_target = Coordinate {
rank,
file: 5,
};
if let Some(mv) = test_move(board, king, king_target, rook, rook_target) {
moves.push(mv);
}
}
}
fn generate_pawn_moves(board: &Board, moves: &mut Vec<Move>) {
fn pawn_move(side: Side, from: Coordinate, to: Coordinate, set_en_passant: Option<Coordinate>) -> Move {
let promotion_rank = match side {
Black => 0,
White => 7,
};
let promotions = if to.rank == promotion_rank {
Some(vec![
side | Queen,
side | Rook,
side | Bishop,
side | Knight,
])
} else {
None
};
let mv = Move {
from,
to,
other: None,
promotions,
set_en_passant,
};
mv
}
let (start_rank, direction) = match board.to_move {
Black => (48..56, Coordinate { rank: -1, file: 0 }),
White => (8..16, Coordinate { rank: 1, file: 0 }),
};
let pawn = board.to_move | Pawn;
for (index, piece) in board.board.iter().enumerate() {
match piece {
None => continue,
Some(p) if *p != pawn => continue,
_ => {}
}
let from = Coordinate::from_index(index);
let forward_res = add_if_not_capture(board, moves, pawn_move(board.to_move, from, from + direction, None));
if forward_res == MoveResult::NoCapture && start_rank.contains(&index) {
add_if_not_capture(board, moves, pawn_move(board.to_move, from, from + (direction * 2), Some(from + direction)));
}
for capture in PAWN_CAPTURES {
add_if_capture(board, moves, pawn_move(board.to_move, from, from + direction + capture, None));
}
}
}
fn generate_line_moves(
board: &Board,
moves: &mut Vec<Move>,
ty: PieceType,
directions: &[Coordinate],
distance_limit: isize,
) {
let target_piece = board.to_move | ty;
for (index, piece) in board.board.iter().enumerate() {
if let Some(piece) = piece {
if piece != &target_piece {
continue;
}
let from = Coordinate::from_index(index);
for direction in directions {
let mut multiplier = 1;
while add_if_valid(
board,
moves,
Move::new(from, from + (*direction * multiplier)),
) == MoveResult::NoCapture
&& multiplier < distance_limit
{
multiplier += 1;
}
}
}
}
}
/// Same as `generate_pseudolegal`, but excludes castling moves
pub fn generate_pseudolegal_captures(board: &Board, moves: &mut Vec<Move>) {
generate_line_moves(board, moves, King, &ALL_DIRECTIONS, 1);
generate_line_moves(board, moves, Queen, &ALL_DIRECTIONS, 8);
generate_line_moves(board, moves, Rook, &CARDINALS, 8);
generate_line_moves(board, moves, Knight, &KNIGHT, 1);
generate_line_moves(board, moves, Bishop, &DIAGONALS, 8);
generate_pawn_moves(board, moves);
}
pub fn generate_pseudolegal(board: &Board, moves: &mut Vec<Move>) {
generate_pseudolegal_captures(board, moves);
generate_castling_moves(board, moves);
}
pub fn generate_legal(board: &Board) -> Vec<Move> {
let mut moves = vec![];
generate_pseudolegal(board, &mut moves);
moves.into_iter().filter(|mv| {
if mv.other.is_some() {
// Cannot castle out of check
if *board.calc_check_state().get(board.to_move) {
return false;
}
}
let test_board = mv.make(board);
!*test_board.calc_check_state().get(board.to_move)
}).collect::<Vec<_>>()
}
#[cfg(test)]
mod test {
use super::*;
use crate::constants::START_FEN;
#[test]
fn start_position_pseudolegal() {
let board = Board::from_fen(START_FEN).expect("valid board");
let mut moves = vec![];
generate_pseudolegal(&board, &mut moves);
assert_eq!(moves.len(), 20);
}
}

66
src/chess/perft.rs Normal file
View file

@ -0,0 +1,66 @@
use crate::prelude::*;
use std::io::Write;
use rayon::prelude::*;
fn perft_board(board: &Board, depth: u64, start_depth: u64) -> u64 {
if depth == 0 {
return 1;
}
let mut moves = vec![];
crate::chess::mv::generate_pseudolegal(board, &mut moves);
let count = moves.into_par_iter().map(|mv| {
let mut count = 0;
if let Some(promotions) = mv.promotions {
for promotion in promotions {
let mv = Move {
promotions: Some(vec![promotion]),
other: mv.other.clone(),
..mv
};
let new_board = mv.make(board);
let check = new_board.calc_check_state();
if *check.get(board.to_move) {
continue;
}
let this_count = perft_board(&new_board, depth - 1, start_depth);
count += this_count;
#[cfg(test)]
if depth == start_depth {
println!("{from}{to}: {this_count}", from = mv.from, to = mv.to);
}
}
} else {
let new_board = mv.make(board);
if *new_board.calc_check_state().get(board.to_move) {
return 0;
}
let this_count = perft_board(&new_board, depth - 1, start_depth);
count += this_count;
#[cfg(test)]
if depth == start_depth {
println!("{from}{to}: {this_count}", from = mv.from, to = mv.to);
// println!("mv={mv:?} {piece:?}", piece = board.get(mv.from));
}
}
count
}).sum();
count
}
pub fn run_perft(fen: &str, depth: u64, expected_positions: u64) -> bool {
let board = Board::from_fen(fen).expect("failed to parse position");
println!("Running perft on position {fen:?} with depth {depth}, expecting {expected_positions} positions...");
std::io::stdout().flush().unwrap();
let positions = perft_board(&board, depth, depth);
if positions == expected_positions {
println!("Passed perft on position {fen:?} with depth {depth}, expecting {expected_positions} positions! ");
} else {
println!("Failed perft on position {fen:?} with depth {depth}, expecting {expected_positions} positions, found {positions}!");
}
positions == expected_positions
}

6
src/constants.rs Normal file
View file

@ -0,0 +1,6 @@
pub const START_FEN: &str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
pub const POSITION_2: &str = "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0";
pub const POSITION_3: &str = "8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0";
pub const POSITION_4: &str = "r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1";
pub const POSITION_5: &str = "rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8";
pub const POSITION_6: &str = "r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10";

25
src/error.rs Normal file
View file

@ -0,0 +1,25 @@
#[derive(Debug, thiserror::Error)]
pub enum FENParseError {
#[error("not enough sections")]
NotEnoughSections,
#[error("expected slash")]
ExpectedSlash,
#[error("board too large")]
BoardTooLarge,
#[error("invalid notation: {0}")]
InvalidNotation(#[from] NotationError),
#[error("{0}")]
InvalidCount(#[from] std::num::ParseIntError),
}
#[derive(Debug, thiserror::Error)]
pub enum NotationError {
#[error("expected length {length} for notation, expected: {expected}")]
InvalidLength { length: usize, expected: usize },
#[error("invalid piece character: {0}")]
InvalidPiece(char),
#[error("invalid side: {0}")]
InvalidSide(String),
#[error("{0:?} is invalid notation")]
Other(String),
}

144
src/game.rs Normal file
View file

@ -0,0 +1,144 @@
use std::collections::HashMap;
use uuid::Uuid;
use xtra::{prelude::*, WeakAddress};
use crate::{prelude::{Board, Move}, constants::START_FEN, chess::{Side, mv::generate_legal}, player::{Player, OutgoingPlayerEvent, IncomingPlayerEvent}};
#[derive(Actor)]
pub struct GameManager {
games: HashMap<Uuid, Address<ChessGame>>,
}
impl GameManager {
pub fn new() -> Self {
Self {
games: HashMap::new(),
}
}
}
#[derive(Clone)]
pub struct JoinGame {
pub game_id: Uuid,
pub player: WeakAddress<Player>,
}
pub enum JoinGameResponse {
Success {
side: Side,
game: Address<ChessGame>,
},
Full,
}
#[async_trait]
impl Handler<JoinGame> for GameManager {
type Return = Option<JoinGameResponse>;
async fn handle(&mut self, join_game: JoinGame, _ctx: &mut Context<Self>) -> Self::Return {
if let Some(game) = self.games.get(&join_game.game_id) {
let res = game.send(join_game.clone()).await.ok()
.map(|r| match r {
Some(side) => JoinGameResponse::Success { side, game: game.clone() },
None => JoinGameResponse::Full,
});
if res.is_none() {
self.games.remove(&join_game.game_id);
}
res
} else {
let game = ChessGame::new(join_game.player);
let game = xtra::spawn_tokio(game, Mailbox::unbounded());
self.games.insert(join_game.game_id, game.clone());
Some(JoinGameResponse::Success {
side: Side::White,
game,
})
}
}
}
#[derive(Actor)]
pub struct ChessGame {
board: Board,
white: WeakAddress<Player>,
black: Option<WeakAddress<Player>>,
}
#[async_trait]
impl Handler<JoinGame> for ChessGame {
type Return = Option<Side>;
async fn handle(&mut self, join_game: JoinGame, _ctx: &mut Context<Self>) -> Self::Return {
if self.black.is_some() {
None
} else {
self.black = Some(join_game.player);
self.broadcast_new_board();
self.send_possible_moves();
Some(Side::Black)
}
}
}
#[derive(Debug, serde::Deserialize)]
pub struct IncomingEvent {
pub data: IncomingPlayerEvent,
pub side: Side,
}
#[async_trait]
impl Handler<IncomingEvent> for ChessGame {
type Return = ();
async fn handle(&mut self, incoming_event: IncomingEvent, _ctx: &mut Context<Self>) -> Self::Return {
match incoming_event.data {
IncomingPlayerEvent::MakeMove { mv } => {
if incoming_event.side != self.board.to_move {
tracing::warn!(?incoming_event.side, ?mv, "other player tried to make move");
return;
}
let legal_moves = generate_legal(&self.board);
if legal_moves.contains(&mv) {
self.board = mv.make(&self.board);
self.broadcast_new_board();
self.send_possible_moves();
}
},
}
}
}
impl ChessGame {
pub fn new(white: WeakAddress<Player>) -> Self {
Self {
board: Board::from_fen(START_FEN).expect("start fen is invalid"),
white,
black: None,
}
}
fn broadcast_new_board(&self) {
// TODO: Handle players disconnecting
tokio::spawn(self.white.send(OutgoingPlayerEvent::BoardUpdate {
board: self.board.board.to_vec(),
}));
if let Some(black) = &self.black {
tokio::spawn(black.send(OutgoingPlayerEvent::BoardUpdate {
board: self.board.board.to_vec(),
}));
}
}
fn send_possible_moves(&self) {
let (current, other) = match &self.board.to_move {
Side::Black => (self.black.clone().unwrap(), self.white.clone()),
Side::White => (self.white.clone(), self.black.clone().unwrap()),
};
let moves = generate_legal(&self.board);
// TODO: handle player disconnects
tokio::spawn(current.send(OutgoingPlayerEvent::PossibleMoves { moves }));
tokio::spawn(other.send(OutgoingPlayerEvent::PossibleMoves { moves: Vec::new() }));
}
}

18
src/lib.rs Normal file
View file

@ -0,0 +1,18 @@
#[cfg(not(debug_assertions))]
pub mod assets;
pub mod chess;
pub mod constants;
pub mod error;
pub mod game;
pub mod player;
pub mod routes;
pub mod prelude {
pub use crate::chess::Board;
pub use crate::chess::Piece;
pub use crate::chess::PieceType::*;
pub use crate::chess::Side::*;
pub use crate::chess::mv::Move;
}

29
src/main.rs Normal file
View file

@ -0,0 +1,29 @@
use axum::Extension;
use chs::{prelude::*, game::GameManager};
use tower_http::trace::TraceLayer;
use tracing_subscriber::prelude::*;
use xtra::Mailbox;
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(tracing_subscriber::filter::EnvFilter::new(
std::env::var("RUST_LOG").unwrap_or_else(|_| "debug,hyper=info".into()),
))
.with(tracing_subscriber::fmt::layer())
.init();
tracing::info!("Hello, world!");
let game_manager = GameManager::new();
let game_manager = xtra::spawn_tokio(game_manager, Mailbox::unbounded());
let app = chs::routes::routes()
.layer(Extension(game_manager))
.layer(TraceLayer::new_for_http());
axum::Server::bind(&format!("0.0.0.0:3000").parse().unwrap())
.serve(app.into_make_service())
.await
.unwrap();
}

78
src/player.rs Normal file
View file

@ -0,0 +1,78 @@
use futures::{channel::mpsc::UnboundedSender, SinkExt};
use serde::{Deserialize, Serialize};
use xtra::prelude::*;
use crate::{prelude::*, chess::Side, game::{ChessGame, IncomingEvent}};
#[derive(Actor)]
pub struct Player {
sink: UnboundedSender<OutgoingPlayerEvent>,
game: Option<GameInfo>,
}
pub struct GameInfo {
pub side: Side,
pub game: Address<ChessGame>,
}
impl Player {
pub fn new(sink: UnboundedSender<OutgoingPlayerEvent>) -> Self {
Self {
sink,
game: None,
}
}
}
#[derive(Debug, Serialize)]
#[serde(tag = "event", content = "data")]
pub enum OutgoingPlayerEvent {
BoardUpdate {
board: Vec<Option<Piece>>,
},
PossibleMoves {
moves: Vec<Move>,
},
}
#[async_trait]
impl Handler<OutgoingPlayerEvent> for Player {
type Return = ();
async fn handle(&mut self, outgoing_event: OutgoingPlayerEvent, _ctx: &mut Context<Self>) -> Self::Return {
// TODO: Better error handling
self.sink.send(outgoing_event).await.expect("failed to send outgoing event");
}
}
#[derive(Debug, Deserialize)]
#[serde(tag = "event", content = "data")]
pub enum IncomingPlayerEvent {
MakeMove {
#[serde(rename = "move")]
mv: Move,
}
}
#[async_trait]
impl Handler<IncomingPlayerEvent> for Player {
type Return = ();
async fn handle(&mut self, incoming_event: IncomingPlayerEvent, _ctx: &mut Context<Self>) -> Self::Return {
if let Some(game) = &self.game {
game.game.send(IncomingEvent {
data: incoming_event,
side: game.side,
}).await.expect("game disconnected");
}
}
}
#[async_trait]
impl Handler<GameInfo> for Player {
type Return = ();
async fn handle(&mut self, game_info: GameInfo, _ctx: &mut Context<Self>) -> Self::Return {
self.game = Some(game_info);
}
}

121
src/routes.rs Normal file
View file

@ -0,0 +1,121 @@
use axum::{Router, extract::{WebSocketUpgrade, Path, ws::{Message, WebSocket}}, response::IntoResponse, routing::{get, get_service}, http::{StatusCode, Uri}, Extension};
use futures::{StreamExt, SinkExt, stream::{SplitStream, SplitSink}, channel::mpsc::{self, UnboundedSender}};
use tokio::task::JoinHandle;
use tower_http::services::ServeDir;
use uuid::Uuid;
use xtra::{Mailbox, Address};
use crate::{player::{Player, IncomingPlayerEvent, OutgoingPlayerEvent, GameInfo}, game::{GameManager, JoinGame, JoinGameResponse}};
#[cfg(not(debug_assertions))]
use crate::assets::StaticFile;
pub fn routes() -> Router {
let router = Router::new()
.route("/ws/:id", get(ws_handler));
#[cfg(not(debug_assertions))]
let router = router
.route("/", get(index))
.fallback(get(fallback));
router
}
#[cfg(not(debug_assertions))]
async fn index() -> impl IntoResponse {
StaticFile("index.html")
}
#[cfg(not(debug_assertions))]
async fn fallback(uri: Uri) -> impl IntoResponse {
StaticFile(uri.path().trim_start_matches('/').to_owned())
}
async fn ws_handler(
ws: WebSocketUpgrade,
Path(id): Path<Uuid>,
Extension(game_manager): Extension<Address<GameManager>>,
) -> impl IntoResponse {
ws.on_upgrade(move |ws| handle_socket(ws, id, game_manager))
}
async fn handle_socket(socket: WebSocket, id: Uuid, game_manager: Address<GameManager>) {
let (tx, rx) = socket.split();
let (tx, tx_task) = socket_send(tx);
let player = Player::new(tx);
let player = xtra::spawn_tokio(player, Mailbox::unbounded());
let res = game_manager.send(JoinGame { game_id: id, player: player.downgrade() }).await
.expect("game manager disconnected");
if let Some(res) = res {
if let JoinGameResponse::Success { side, game } = res {
player.send(GameInfo { side, game }).await.expect("player disconnected");
let rx_task = socket_recv(rx, player);
tokio::select! {
_ = rx_task => {},
_ = tx_task => {},
}
}
}
}
fn socket_send(mut tx: SplitSink<WebSocket, Message>) -> (UnboundedSender<OutgoingPlayerEvent>, JoinHandle<()>) {
let (message_tx, mut rx) = mpsc::unbounded();
let task = tokio::spawn(async move {
while let Some(message) = rx.next().await {
match serde_json::to_string(&message) {
Ok(json) => {
if tx.send(Message::Text(json)).await.is_err() {
return;
}
}
Err(e) => {
tracing::error!(?e, ?message, "failed to encode outgoing message");
}
}
}
});
(message_tx, task)
}
fn socket_recv(mut rx: SplitStream<WebSocket>, player: Address<Player>) -> JoinHandle<()> {
let task = tokio::spawn(async move {
while let Some(msg) = rx.next().await {
if let Ok(msg) = msg {
match msg {
Message::Text(t) => {
let message = serde_json::from_str::<IncomingPlayerEvent>(&t);
match message {
Ok(message) => {
if player.send(message).await.is_err() {
return;
}
}
Err(e) => {
tracing::error!(?e, "client send invalid data");
}
}
}
Message::Close(_) => {
tracing::info!("client disconnected");
return;
}
frame => {
tracing::warn!(?frame, "client sent invalid frame type");
}
}
} else {
tracing::info!("client disconnected");
return;
}
}
});
task
}

48
tests/perft.rs Normal file
View file

@ -0,0 +1,48 @@
macro_rules! perft_test {
($name:ident, $fen:expr, $depth:expr, $expected:expr) => {
#[test]
fn $name() {
let result = chs::chess::perft::run_perft($fen, $depth, $expected);
assert!(result);
}
};
}
mod perft {
use chs::constants::*;
perft_test!(start_fen_depth_0, START_FEN, 0, 1);
perft_test!(start_fen_depth_1, START_FEN, 1, 20);
perft_test!(start_fen_depth_2, START_FEN, 2, 400);
perft_test!(start_fen_depth_3, START_FEN, 3, 8_902);
perft_test!(start_fen_depth_4, START_FEN, 4, 197_281);
perft_test!(start_fen_depth_5, START_FEN, 5, 4_865_609);
perft_test!(position_2_depth_1, POSITION_2, 1, 48);
perft_test!(position_2_depth_2, POSITION_2, 2, 2_039);
perft_test!(position_2_depth_3, POSITION_2, 3, 97_862);
perft_test!(position_2_depth_4, POSITION_2, 4, 4_085_603);
perft_test!(position_3_depth_1, POSITION_3, 1, 14);
perft_test!(position_3_depth_2, POSITION_3, 2, 191);
perft_test!(position_3_depth_3, POSITION_3, 3, 2_812);
perft_test!(position_3_depth_4, POSITION_3, 4, 43_238);
perft_test!(position_3_depth_5, POSITION_3, 5, 674_624);
perft_test!(position_3_depth_6, POSITION_3, 6, 11_030_083);
perft_test!(position_4_depth_1, POSITION_4, 1, 6);
perft_test!(position_4_depth_2, POSITION_4, 2, 264);
perft_test!(position_4_depth_3, POSITION_4, 3, 9_467);
perft_test!(position_4_depth_4, POSITION_4, 4, 422_333);
perft_test!(position_4_depth_5, POSITION_4, 5, 15_833_292);
perft_test!(position_5_depth_1, POSITION_5, 1, 44);
perft_test!(position_5_depth_2, POSITION_5, 2, 1_486);
perft_test!(position_5_depth_3, POSITION_5, 3, 62_379);
perft_test!(position_5_depth_4, POSITION_5, 4, 2_103_487);
perft_test!(position_6_depth_1, POSITION_6, 1, 46);
perft_test!(position_6_depth_2, POSITION_6, 2, 2_079);
perft_test!(position_6_depth_3, POSITION_6, 3, 89_890);
perft_test!(position_6_depth_4, POSITION_6, 4, 3_894_594);
}

15
tsconfig.json Normal file
View file

@ -0,0 +1,15 @@
{
"compilerOptions": {
"target": "ESNext",
"module": "ESNext",
"moduleResolution": "node",
"allowSyntheticDefaultImports": true,
"esModuleInterop": true,
"jsx": "preserve",
"jsxImportSource": "solid-js",
"types": ["vite/client"],
"noEmit": true,
"isolatedModules": true,
"strict": true,
}
}

8
tsconfig.node.json Normal file
View file

@ -0,0 +1,8 @@
{
"compilerOptions": {
"composite": true,
"module": "esnext",
"moduleResolution": "node"
},
"include": ["vite.config.ts"]
}

21
vite.config.ts Normal file
View file

@ -0,0 +1,21 @@
import { defineConfig } from 'vite';
import solidPlugin from 'vite-plugin-solid';
import devtools from 'solid-devtools/vite';
export default defineConfig({
plugins: [
solidPlugin(),
devtools({
// Will automatically add names when creating signals, memos, stores, or mutables
name: true,
}),
],
server: {
hmr: {
clientPort: parseInt(process.env.CLIENT_PORT || '5173'),
},
},
build: {
target: 'esnext',
},
});