improve die parser: allow chained expressions

This commit is contained in:
Denis-Cosmin Nutiu 2024-01-24 12:14:14 +02:00
parent 19f0909907
commit 739316f903
6 changed files with 87 additions and 47 deletions

View file

@ -18,7 +18,7 @@ class YamlConfigSettingsSource(PydanticBaseSettingsSource):
at the project's root.
Here we happen to choose to use the `env_file_encoding` from Config
when reading `config.json`
when reading `config.yaml`
"""
@functools.lru_cache
@ -59,11 +59,19 @@ class YamlConfigSettingsSource(PydanticBaseSettingsSource):
class DiscordSettings(BaseModel):
"""
Holds all the settings needed to configure the bot for Discord usage.
"""
token: str = Field()
command_prefix: str = Field(default=".")
class Settings(BaseSettings):
"""
Settings class for the bot
"""
discord: DiscordSettings
@classmethod

View file

@ -16,11 +16,13 @@ class DiceCog(commands.Cog):
- 2d20 will roll a two d20 dies and multiply the result by two.
- 2d20+5 will roll a two d20 dies and multiply the result by two and ads 5.
"""
if dice_expression == "":
return
if dice_expression == "0/0": # easter eggs
return await ctx.send("What do you expect me to do, destroy the universe?")
try:
roll_result = DiceRoller.roll(dice_expression)
roll_result = DiceRoller.roll_simple(dice_expression)
await ctx.send(f"You rolled: {roll_result}")
except ValueError as e:
await ctx.send(f"Roll failed: {e}")

View file

@ -1,8 +1,31 @@
import dataclasses
import typing
from src.dice.parser import DieParser
@dataclasses.dataclass
class DieRollResult:
"""
DieRoll is the result of a die roll.
"""
result: int
modifier: int
rolls: typing.List[int]
type: str
@dataclasses.dataclass
class DieExpressionResult:
"""
DiceResult is the result of a dice roll expression.
"""
total: int
dies: typing.List[DieRollResult]
class DiceRoller:
"""
DiceRoller is a simple class that allows you to roll dices.
@ -10,45 +33,28 @@ class DiceRoller:
A die can be rolled using the following expression:
- 1d20 will roll a 20-faceted die and output the result a random number between 1 and 20.
- 1d100 will roll a 100 faceted die.
- 2d20 will roll a two d20 dies and multiply the result by two.
- 2d20+5 will roll a two d20 dies and multiply the result by two and ads 5.
- 2d20 will roll two d20 dies and multiply the result by two.
- 2d20+5 will roll two d20 dies add them together then add 5 to the result.
"""
_parser = DieParser.create()
@staticmethod
def roll(expression: str, *, advantage: typing.Optional[bool] = None) -> int:
def roll_simple(expression: str) -> int:
"""
Roll die and return the result.
:param expression: The die expression.
:param advantage: Optionally, rolls a die with advantage or disadvantage.
:return: The die result.
"""
if advantage is None:
return DiceRoller._parser.parse(expression)
elif advantage is True:
return DiceRoller.roll_with_advantage(expression)
elif advantage is False:
return DiceRoller.roll_with_disadvantage(expression)
result = DiceRoller._parser.parse(expression)
return result.get("total")
@staticmethod
def roll_with_advantage(expression: str) -> int:
def roll(expression: str) -> DieExpressionResult:
"""
Roll two dies and return the highest result.
Roll die and return the DiceResult.
:param expression: The die expression.
:return: The die result.
"""
one = DiceRoller._parser.parse(expression)
two = DiceRoller._parser.parse(expression)
return max(one, two)
@staticmethod
def roll_with_disadvantage(expression: str) -> int:
"""
Roll two dies and return the lowest result.
:param expression: The die expression.
:return: The die result.
"""
one = DiceRoller._parser.parse(expression)
two = DiceRoller._parser.parse(expression)
return min(one, two)
result = DiceRoller._parser.parse(expression)
return DieExpressionResult(**result)

View file

@ -8,12 +8,12 @@ DIE_GRAMMAR = """
@@grammar::Die
@@whitespace :: None
start = die:die $;
start = die:die ~ {op:operator die:die} $;
die = [number_of_dies:number] die_type:die_type die_number:number [modifier:die_modifier];
die_modifier = op:operator modifier:number;
operator = '+' | '-';
operator = '+' | '-' | 'adv' | 'dis';
die_type = 'd' | 'zd';
@ -35,7 +35,7 @@ class DieParser:
def create() -> "DieParser":
return DieParser()
def parse(self, expression: str) -> int:
def parse(self, expression: str) -> dict:
"""
Parses the die expression and returns the result.
"""

View file

@ -1,4 +1,6 @@
import copy
import random
from collections import deque
from tatsu.ast import AST
@ -8,7 +10,32 @@ class DieSemantics:
return int(ast)
def start(self, ast):
return ast.get("die").get("result")
die = ast.get("die")
if isinstance(die, dict):
return {"total": die.get("result"), "dies": [die]}
elif isinstance(die, list):
return_value = {"total": 0, "dies": copy.deepcopy(die)}
operators = deque(ast.get("op", []))
die_results = deque(map(lambda x: x.get("result"), die))
# Note: we may need to use a dequeue, the ops are quite inefficient.
while len(die_results) != 1:
left = die_results.popleft()
right = die_results.popleft()
operator = operators.popleft()
total = 0
if operator == "+":
total = left + right
if operator == "-":
total = left - right
if operator == "adv":
total = max(left, right)
if operator == "dis":
total = min(left, right)
die_results.appendleft(total)
return_value["total"] = die_results.pop()
return return_value
def die(self, ast):
if not isinstance(ast, AST):
@ -34,8 +61,8 @@ class DieSemantics:
return {
"result": max(sum(rolls) + die_modifier, minimum_value_for_die),
"die_type": die_type,
"roll_history": rolls,
"type": die_type,
"rolls": rolls,
"modifier": die_modifier,
}

View file

@ -48,10 +48,10 @@ def dice_roller():
("1d 4 +0", 1, 4),
],
)
def test_die_roller_die_roll(expression, range_min, range_max, dice_roller):
def test_die_roller_die_roll_simple(expression, range_min, range_max, dice_roller):
# let the dies roll...
for i in range(100):
result = dice_roller.roll(expression)
result = dice_roller.roll_simple(expression)
assert range_min <= result <= range_max
@ -95,10 +95,10 @@ def test_die_roller_die_roll(expression, range_min, range_max, dice_roller):
("1zd 4 +0", 0, 4),
],
)
def test_die_roller_zero_die_roll(expression, range_min, range_max, dice_roller):
def test_die_roller_zero_die_roll_simple(expression, range_min, range_max, dice_roller):
# let the dies roll...
for i in range(100):
result = dice_roller.roll(expression)
result = dice_roller.roll_simple(expression)
assert range_min <= result <= range_max
@ -119,14 +119,11 @@ def test_die_roller_zero_die_roll(expression, range_min, range_max, dice_roller)
)
def test_die_roller_die_parsing_fail(expression, dice_roller):
with pytest.raises(ValueError):
dice_roller.roll(expression)
dice_roller.roll_simple(expression)
def test_die_roller_roll_with_advantage(dice_roller):
assert 1 <= dice_roller.roll_with_advantage("d20") <= 20
assert 1 <= dice_roller.roll("d20", advantage=True) <= 20
def test_die_roller_roll_with_disadvantage(dice_roller):
assert 1 <= dice_roller.roll_with_advantage("d20") <= 20
assert 1 <= dice_roller.roll("d20", advantage=False) <= 20
def test_die_roller_roll(dice_roller):
for i in range(100):
result = dice_roller.roll("d20 + d20 adv d20+5 dis d12+3")
assert 1 <= result.total <= 15
assert len(result.dies) == 4