improve die parser: allow chained expressions

This commit is contained in:
Denis-Cosmin Nutiu 2024-01-24 12:14:14 +02:00
parent 19f0909907
commit 739316f903
6 changed files with 87 additions and 47 deletions

View file

@ -18,7 +18,7 @@ class YamlConfigSettingsSource(PydanticBaseSettingsSource):
at the project's root. at the project's root.
Here we happen to choose to use the `env_file_encoding` from Config Here we happen to choose to use the `env_file_encoding` from Config
when reading `config.json` when reading `config.yaml`
""" """
@functools.lru_cache @functools.lru_cache
@ -59,11 +59,19 @@ class YamlConfigSettingsSource(PydanticBaseSettingsSource):
class DiscordSettings(BaseModel): class DiscordSettings(BaseModel):
"""
Holds all the settings needed to configure the bot for Discord usage.
"""
token: str = Field() token: str = Field()
command_prefix: str = Field(default=".") command_prefix: str = Field(default=".")
class Settings(BaseSettings): class Settings(BaseSettings):
"""
Settings class for the bot
"""
discord: DiscordSettings discord: DiscordSettings
@classmethod @classmethod

View file

@ -16,11 +16,13 @@ class DiceCog(commands.Cog):
- 2d20 will roll a two d20 dies and multiply the result by two. - 2d20 will roll a two d20 dies and multiply the result by two.
- 2d20+5 will roll a two d20 dies and multiply the result by two and ads 5. - 2d20+5 will roll a two d20 dies and multiply the result by two and ads 5.
""" """
if dice_expression == "":
return
if dice_expression == "0/0": # easter eggs if dice_expression == "0/0": # easter eggs
return await ctx.send("What do you expect me to do, destroy the universe?") return await ctx.send("What do you expect me to do, destroy the universe?")
try: try:
roll_result = DiceRoller.roll(dice_expression) roll_result = DiceRoller.roll_simple(dice_expression)
await ctx.send(f"You rolled: {roll_result}") await ctx.send(f"You rolled: {roll_result}")
except ValueError as e: except ValueError as e:
await ctx.send(f"Roll failed: {e}") await ctx.send(f"Roll failed: {e}")

View file

@ -1,8 +1,31 @@
import dataclasses
import typing import typing
from src.dice.parser import DieParser from src.dice.parser import DieParser
@dataclasses.dataclass
class DieRollResult:
"""
DieRoll is the result of a die roll.
"""
result: int
modifier: int
rolls: typing.List[int]
type: str
@dataclasses.dataclass
class DieExpressionResult:
"""
DiceResult is the result of a dice roll expression.
"""
total: int
dies: typing.List[DieRollResult]
class DiceRoller: class DiceRoller:
""" """
DiceRoller is a simple class that allows you to roll dices. DiceRoller is a simple class that allows you to roll dices.
@ -10,45 +33,28 @@ class DiceRoller:
A die can be rolled using the following expression: A die can be rolled using the following expression:
- 1d20 will roll a 20-faceted die and output the result a random number between 1 and 20. - 1d20 will roll a 20-faceted die and output the result a random number between 1 and 20.
- 1d100 will roll a 100 faceted die. - 1d100 will roll a 100 faceted die.
- 2d20 will roll a two d20 dies and multiply the result by two. - 2d20 will roll two d20 dies and multiply the result by two.
- 2d20+5 will roll a two d20 dies and multiply the result by two and ads 5. - 2d20+5 will roll two d20 dies add them together then add 5 to the result.
""" """
_parser = DieParser.create() _parser = DieParser.create()
@staticmethod @staticmethod
def roll(expression: str, *, advantage: typing.Optional[bool] = None) -> int: def roll_simple(expression: str) -> int:
""" """
Roll die and return the result. Roll die and return the result.
:param expression: The die expression. :param expression: The die expression.
:param advantage: Optionally, rolls a die with advantage or disadvantage.
:return: The die result. :return: The die result.
""" """
if advantage is None: result = DiceRoller._parser.parse(expression)
return DiceRoller._parser.parse(expression) return result.get("total")
elif advantage is True:
return DiceRoller.roll_with_advantage(expression)
elif advantage is False:
return DiceRoller.roll_with_disadvantage(expression)
@staticmethod @staticmethod
def roll_with_advantage(expression: str) -> int: def roll(expression: str) -> DieExpressionResult:
""" """
Roll two dies and return the highest result. Roll die and return the DiceResult.
:param expression: The die expression. :param expression: The die expression.
:return: The die result. :return: The die result.
""" """
one = DiceRoller._parser.parse(expression) result = DiceRoller._parser.parse(expression)
two = DiceRoller._parser.parse(expression) return DieExpressionResult(**result)
return max(one, two)
@staticmethod
def roll_with_disadvantage(expression: str) -> int:
"""
Roll two dies and return the lowest result.
:param expression: The die expression.
:return: The die result.
"""
one = DiceRoller._parser.parse(expression)
two = DiceRoller._parser.parse(expression)
return min(one, two)

View file

@ -8,12 +8,12 @@ DIE_GRAMMAR = """
@@grammar::Die @@grammar::Die
@@whitespace :: None @@whitespace :: None
start = die:die $; start = die:die ~ {op:operator die:die} $;
die = [number_of_dies:number] die_type:die_type die_number:number [modifier:die_modifier]; die = [number_of_dies:number] die_type:die_type die_number:number [modifier:die_modifier];
die_modifier = op:operator modifier:number; die_modifier = op:operator modifier:number;
operator = '+' | '-'; operator = '+' | '-' | 'adv' | 'dis';
die_type = 'd' | 'zd'; die_type = 'd' | 'zd';
@ -35,7 +35,7 @@ class DieParser:
def create() -> "DieParser": def create() -> "DieParser":
return DieParser() return DieParser()
def parse(self, expression: str) -> int: def parse(self, expression: str) -> dict:
""" """
Parses the die expression and returns the result. Parses the die expression and returns the result.
""" """

View file

@ -1,4 +1,6 @@
import copy
import random import random
from collections import deque
from tatsu.ast import AST from tatsu.ast import AST
@ -8,7 +10,32 @@ class DieSemantics:
return int(ast) return int(ast)
def start(self, ast): def start(self, ast):
return ast.get("die").get("result") die = ast.get("die")
if isinstance(die, dict):
return {"total": die.get("result"), "dies": [die]}
elif isinstance(die, list):
return_value = {"total": 0, "dies": copy.deepcopy(die)}
operators = deque(ast.get("op", []))
die_results = deque(map(lambda x: x.get("result"), die))
# Note: we may need to use a dequeue, the ops are quite inefficient.
while len(die_results) != 1:
left = die_results.popleft()
right = die_results.popleft()
operator = operators.popleft()
total = 0
if operator == "+":
total = left + right
if operator == "-":
total = left - right
if operator == "adv":
total = max(left, right)
if operator == "dis":
total = min(left, right)
die_results.appendleft(total)
return_value["total"] = die_results.pop()
return return_value
def die(self, ast): def die(self, ast):
if not isinstance(ast, AST): if not isinstance(ast, AST):
@ -34,8 +61,8 @@ class DieSemantics:
return { return {
"result": max(sum(rolls) + die_modifier, minimum_value_for_die), "result": max(sum(rolls) + die_modifier, minimum_value_for_die),
"die_type": die_type, "type": die_type,
"roll_history": rolls, "rolls": rolls,
"modifier": die_modifier, "modifier": die_modifier,
} }

View file

@ -48,10 +48,10 @@ def dice_roller():
("1d 4 +0", 1, 4), ("1d 4 +0", 1, 4),
], ],
) )
def test_die_roller_die_roll(expression, range_min, range_max, dice_roller): def test_die_roller_die_roll_simple(expression, range_min, range_max, dice_roller):
# let the dies roll... # let the dies roll...
for i in range(100): for i in range(100):
result = dice_roller.roll(expression) result = dice_roller.roll_simple(expression)
assert range_min <= result <= range_max assert range_min <= result <= range_max
@ -95,10 +95,10 @@ def test_die_roller_die_roll(expression, range_min, range_max, dice_roller):
("1zd 4 +0", 0, 4), ("1zd 4 +0", 0, 4),
], ],
) )
def test_die_roller_zero_die_roll(expression, range_min, range_max, dice_roller): def test_die_roller_zero_die_roll_simple(expression, range_min, range_max, dice_roller):
# let the dies roll... # let the dies roll...
for i in range(100): for i in range(100):
result = dice_roller.roll(expression) result = dice_roller.roll_simple(expression)
assert range_min <= result <= range_max assert range_min <= result <= range_max
@ -119,14 +119,11 @@ def test_die_roller_zero_die_roll(expression, range_min, range_max, dice_roller)
) )
def test_die_roller_die_parsing_fail(expression, dice_roller): def test_die_roller_die_parsing_fail(expression, dice_roller):
with pytest.raises(ValueError): with pytest.raises(ValueError):
dice_roller.roll(expression) dice_roller.roll_simple(expression)
def test_die_roller_roll_with_advantage(dice_roller): def test_die_roller_roll(dice_roller):
assert 1 <= dice_roller.roll_with_advantage("d20") <= 20 for i in range(100):
assert 1 <= dice_roller.roll("d20", advantage=True) <= 20 result = dice_roller.roll("d20 + d20 adv d20+5 dis d12+3")
assert 1 <= result.total <= 15
assert len(result.dies) == 4
def test_die_roller_roll_with_disadvantage(dice_roller):
assert 1 <= dice_roller.roll_with_advantage("d20") <= 20
assert 1 <= dice_roller.roll("d20", advantage=False) <= 20