improve die parser: allow chained expressions
This commit is contained in:
parent
19f0909907
commit
739316f903
6 changed files with 87 additions and 47 deletions
|
@ -18,7 +18,7 @@ class YamlConfigSettingsSource(PydanticBaseSettingsSource):
|
|||
at the project's root.
|
||||
|
||||
Here we happen to choose to use the `env_file_encoding` from Config
|
||||
when reading `config.json`
|
||||
when reading `config.yaml`
|
||||
"""
|
||||
|
||||
@functools.lru_cache
|
||||
|
@ -59,11 +59,19 @@ class YamlConfigSettingsSource(PydanticBaseSettingsSource):
|
|||
|
||||
|
||||
class DiscordSettings(BaseModel):
|
||||
"""
|
||||
Holds all the settings needed to configure the bot for Discord usage.
|
||||
"""
|
||||
|
||||
token: str = Field()
|
||||
command_prefix: str = Field(default=".")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""
|
||||
Settings class for the bot
|
||||
"""
|
||||
|
||||
discord: DiscordSettings
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -16,11 +16,13 @@ class DiceCog(commands.Cog):
|
|||
- 2d20 will roll a two d20 dies and multiply the result by two.
|
||||
- 2d20+5 will roll a two d20 dies and multiply the result by two and ads 5.
|
||||
"""
|
||||
if dice_expression == "":
|
||||
return
|
||||
if dice_expression == "0/0": # easter eggs
|
||||
return await ctx.send("What do you expect me to do, destroy the universe?")
|
||||
|
||||
try:
|
||||
roll_result = DiceRoller.roll(dice_expression)
|
||||
roll_result = DiceRoller.roll_simple(dice_expression)
|
||||
await ctx.send(f"You rolled: {roll_result}")
|
||||
except ValueError as e:
|
||||
await ctx.send(f"Roll failed: {e}")
|
||||
|
|
|
@ -1,8 +1,31 @@
|
|||
import dataclasses
|
||||
import typing
|
||||
|
||||
from src.dice.parser import DieParser
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DieRollResult:
|
||||
"""
|
||||
DieRoll is the result of a die roll.
|
||||
"""
|
||||
|
||||
result: int
|
||||
modifier: int
|
||||
rolls: typing.List[int]
|
||||
type: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DieExpressionResult:
|
||||
"""
|
||||
DiceResult is the result of a dice roll expression.
|
||||
"""
|
||||
|
||||
total: int
|
||||
dies: typing.List[DieRollResult]
|
||||
|
||||
|
||||
class DiceRoller:
|
||||
"""
|
||||
DiceRoller is a simple class that allows you to roll dices.
|
||||
|
@ -10,45 +33,28 @@ class DiceRoller:
|
|||
A die can be rolled using the following expression:
|
||||
- 1d20 will roll a 20-faceted die and output the result a random number between 1 and 20.
|
||||
- 1d100 will roll a 100 faceted die.
|
||||
- 2d20 will roll a two d20 dies and multiply the result by two.
|
||||
- 2d20+5 will roll a two d20 dies and multiply the result by two and ads 5.
|
||||
- 2d20 will roll two d20 dies and multiply the result by two.
|
||||
- 2d20+5 will roll two d20 dies add them together then add 5 to the result.
|
||||
"""
|
||||
|
||||
_parser = DieParser.create()
|
||||
|
||||
@staticmethod
|
||||
def roll(expression: str, *, advantage: typing.Optional[bool] = None) -> int:
|
||||
def roll_simple(expression: str) -> int:
|
||||
"""
|
||||
Roll die and return the result.
|
||||
:param expression: The die expression.
|
||||
:param advantage: Optionally, rolls a die with advantage or disadvantage.
|
||||
:return: The die result.
|
||||
"""
|
||||
if advantage is None:
|
||||
return DiceRoller._parser.parse(expression)
|
||||
elif advantage is True:
|
||||
return DiceRoller.roll_with_advantage(expression)
|
||||
elif advantage is False:
|
||||
return DiceRoller.roll_with_disadvantage(expression)
|
||||
result = DiceRoller._parser.parse(expression)
|
||||
return result.get("total")
|
||||
|
||||
@staticmethod
|
||||
def roll_with_advantage(expression: str) -> int:
|
||||
def roll(expression: str) -> DieExpressionResult:
|
||||
"""
|
||||
Roll two dies and return the highest result.
|
||||
Roll die and return the DiceResult.
|
||||
:param expression: The die expression.
|
||||
:return: The die result.
|
||||
"""
|
||||
one = DiceRoller._parser.parse(expression)
|
||||
two = DiceRoller._parser.parse(expression)
|
||||
return max(one, two)
|
||||
|
||||
@staticmethod
|
||||
def roll_with_disadvantage(expression: str) -> int:
|
||||
"""
|
||||
Roll two dies and return the lowest result.
|
||||
:param expression: The die expression.
|
||||
:return: The die result.
|
||||
"""
|
||||
one = DiceRoller._parser.parse(expression)
|
||||
two = DiceRoller._parser.parse(expression)
|
||||
return min(one, two)
|
||||
result = DiceRoller._parser.parse(expression)
|
||||
return DieExpressionResult(**result)
|
||||
|
|
|
@ -8,12 +8,12 @@ DIE_GRAMMAR = """
|
|||
@@grammar::Die
|
||||
@@whitespace :: None
|
||||
|
||||
start = die:die $;
|
||||
start = die:die ~ {op:operator die:die} $;
|
||||
|
||||
die = [number_of_dies:number] die_type:die_type die_number:number [modifier:die_modifier];
|
||||
die_modifier = op:operator modifier:number;
|
||||
|
||||
operator = '+' | '-';
|
||||
operator = '+' | '-' | 'adv' | 'dis';
|
||||
|
||||
die_type = 'd' | 'zd';
|
||||
|
||||
|
@ -35,7 +35,7 @@ class DieParser:
|
|||
def create() -> "DieParser":
|
||||
return DieParser()
|
||||
|
||||
def parse(self, expression: str) -> int:
|
||||
def parse(self, expression: str) -> dict:
|
||||
"""
|
||||
Parses the die expression and returns the result.
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import copy
|
||||
import random
|
||||
from collections import deque
|
||||
from tatsu.ast import AST
|
||||
|
||||
|
||||
|
@ -8,7 +10,32 @@ class DieSemantics:
|
|||
return int(ast)
|
||||
|
||||
def start(self, ast):
|
||||
return ast.get("die").get("result")
|
||||
die = ast.get("die")
|
||||
if isinstance(die, dict):
|
||||
return {"total": die.get("result"), "dies": [die]}
|
||||
elif isinstance(die, list):
|
||||
return_value = {"total": 0, "dies": copy.deepcopy(die)}
|
||||
operators = deque(ast.get("op", []))
|
||||
|
||||
die_results = deque(map(lambda x: x.get("result"), die))
|
||||
# Note: we may need to use a dequeue, the ops are quite inefficient.
|
||||
while len(die_results) != 1:
|
||||
left = die_results.popleft()
|
||||
right = die_results.popleft()
|
||||
operator = operators.popleft()
|
||||
total = 0
|
||||
if operator == "+":
|
||||
total = left + right
|
||||
if operator == "-":
|
||||
total = left - right
|
||||
if operator == "adv":
|
||||
total = max(left, right)
|
||||
if operator == "dis":
|
||||
total = min(left, right)
|
||||
die_results.appendleft(total)
|
||||
|
||||
return_value["total"] = die_results.pop()
|
||||
return return_value
|
||||
|
||||
def die(self, ast):
|
||||
if not isinstance(ast, AST):
|
||||
|
@ -34,8 +61,8 @@ class DieSemantics:
|
|||
|
||||
return {
|
||||
"result": max(sum(rolls) + die_modifier, minimum_value_for_die),
|
||||
"die_type": die_type,
|
||||
"roll_history": rolls,
|
||||
"type": die_type,
|
||||
"rolls": rolls,
|
||||
"modifier": die_modifier,
|
||||
}
|
||||
|
||||
|
|
|
@ -48,10 +48,10 @@ def dice_roller():
|
|||
("1d 4 +0", 1, 4),
|
||||
],
|
||||
)
|
||||
def test_die_roller_die_roll(expression, range_min, range_max, dice_roller):
|
||||
def test_die_roller_die_roll_simple(expression, range_min, range_max, dice_roller):
|
||||
# let the dies roll...
|
||||
for i in range(100):
|
||||
result = dice_roller.roll(expression)
|
||||
result = dice_roller.roll_simple(expression)
|
||||
assert range_min <= result <= range_max
|
||||
|
||||
|
||||
|
@ -95,10 +95,10 @@ def test_die_roller_die_roll(expression, range_min, range_max, dice_roller):
|
|||
("1zd 4 +0", 0, 4),
|
||||
],
|
||||
)
|
||||
def test_die_roller_zero_die_roll(expression, range_min, range_max, dice_roller):
|
||||
def test_die_roller_zero_die_roll_simple(expression, range_min, range_max, dice_roller):
|
||||
# let the dies roll...
|
||||
for i in range(100):
|
||||
result = dice_roller.roll(expression)
|
||||
result = dice_roller.roll_simple(expression)
|
||||
assert range_min <= result <= range_max
|
||||
|
||||
|
||||
|
@ -119,14 +119,11 @@ def test_die_roller_zero_die_roll(expression, range_min, range_max, dice_roller)
|
|||
)
|
||||
def test_die_roller_die_parsing_fail(expression, dice_roller):
|
||||
with pytest.raises(ValueError):
|
||||
dice_roller.roll(expression)
|
||||
dice_roller.roll_simple(expression)
|
||||
|
||||
|
||||
def test_die_roller_roll_with_advantage(dice_roller):
|
||||
assert 1 <= dice_roller.roll_with_advantage("d20") <= 20
|
||||
assert 1 <= dice_roller.roll("d20", advantage=True) <= 20
|
||||
|
||||
|
||||
def test_die_roller_roll_with_disadvantage(dice_roller):
|
||||
assert 1 <= dice_roller.roll_with_advantage("d20") <= 20
|
||||
assert 1 <= dice_roller.roll("d20", advantage=False) <= 20
|
||||
def test_die_roller_roll(dice_roller):
|
||||
for i in range(100):
|
||||
result = dice_roller.roll("d20 + d20 adv d20+5 dis d12+3")
|
||||
assert 1 <= result.total <= 15
|
||||
assert len(result.dies) == 4
|
||||
|
|
Loading…
Reference in a new issue